Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ this plugin deployed in your Kubernetes cluster, you will be able to run jobs
* This plugin targets Kubernetes v1.18+.

## Deployment
The device plugin needs to be run on all the nodes that are equipped with Confidential Computing devices (e.g. TPM). The simplest way of doing so is to create a Kubernetes [DaemonSet][dp], which run a copy of a pod on all (or some) Nodes in the cluster. We have a pre-built Docker image on [Goolge Artifact Registry][release] that you can use for with your DaemonSet. This repository also have a pre-defined yaml file named `cc-device-plugin.yaml`. You can create a DaemonSet in your Kubernetes cluster by running this command:
The device plugin needs to be run on all the nodes that are equipped with Confidential Computing devices (e.g. TPM). The simplest way of doing so is to create a Kubernetes [DaemonSet][dp], which run a copy of a pod on all (or some) Nodes in the cluster. We have a pre-built Docker image on [Google Artifact Registry][release] that you can use for with your DaemonSet. This repository also have a pre-defined yaml file named `cc-device-plugin.yaml`. You can create a DaemonSet in your Kubernetes cluster by running this command:

```
kubectl create -f manifests/cc-device-plugin.yaml
Expand Down
246 changes: 152 additions & 94 deletions deviceplugin/ccdevice.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"

Expand All @@ -32,9 +33,18 @@ import (
)

const (
deviceCheckInterval = 5 * time.Second
// By default, GKE allows up to 110 Pods per node on Standard clusters. Standard clusters can be configured to allow up to 256 Pods per node.
workloadSharedLimit = 256
deviceCheckInterval = 5 * time.Second
copiedEventLogDirectory = "/run/cc-device-plugin"
copiedEventLogLocation = "/run/cc-device-plugin/binary_bios_measurements"
containerEventLogDirectory = "/run/cc-device-plugin"
)

// AttestationType defines if the attestation is based on software emulation or hardware.
type AttestationType string

const (
SoftwareAttestation AttestationType = "software" // e.g., vTPM
HardwareAttestation AttestationType = "hardware" // e.g., Intel TDX, AMD SEV-SNP
)

var (
Expand All @@ -47,25 +57,25 @@ type CcDeviceSpec struct {
Resource string
DevicePaths []string
MeasurementPaths []string
DeviceLimit int // Number of allocatable instances of this resource
Type AttestationType // New flag to explicitly define the device type
}

// CcDevice wraps the v1.beta1.Device type, which has hostPath, containerPath and permission
type CcDevice struct {
v1beta1.Device
DeviceSpecs []*v1beta1.DeviceSpec
Mounts []*v1beta1.Mount
// Limit specifies the cap number of workloads sharing a worker node
Limit int
}

// CcDevicePlugin is a device plugin for cc devices
type CcDevicePlugin struct {
cds *CcDeviceSpec
ccDevices map[string]CcDevice
copiedEventLogDirectory string
copiedEventLogLocation string
containerEventLogDirectory string
logger log.Logger
cds *CcDeviceSpec
ccDevices map[string]CcDevice
logger log.Logger
copiedEventLogDirectory string
copiedEventLogLocation string
containerEventLogDirectory string
// this lock prevents data race when kubelet sends multiple requests at the same time
mu sync.Mutex

Expand All @@ -79,14 +89,17 @@ func NewCcDevicePlugin(cds *CcDeviceSpec, devicePluginPath string, socket string
if logger == nil {
logger = log.NewNopLogger()
}
if cds.DeviceLimit <= 0 {
cds.DeviceLimit = 1 // Default to 1 if not specified
}

cdp := &CcDevicePlugin{
cds: cds,
ccDevices: make(map[string]CcDevice),
logger: logger,
copiedEventLogDirectory: "/run/cc-device-plugin",
copiedEventLogLocation: "/run/cc-device-plugin/binary_bios_measurements",
containerEventLogDirectory: "/run/cc-device-plugin",
cds: cds,
ccDevices: make(map[string]CcDevice),
logger: logger,
copiedEventLogDirectory: copiedEventLogDirectory,
copiedEventLogLocation: copiedEventLogLocation, // Note: This path is static, used only by vTPM plugin instance.
containerEventLogDirectory: containerEventLogDirectory,
deviceGauge: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "cc_device_plugin_devices",
Help: "The number of cc devices managed by this device plugin.",
Expand All @@ -97,16 +110,19 @@ func NewCcDevicePlugin(cds *CcDeviceSpec, devicePluginPath string, socket string
}),
}

// Check if the copiedEventLogDirectory directory exists
if _, err := os.Stat(cdp.copiedEventLogDirectory); os.IsNotExist(err) {
// Create the directory
err = os.Mkdir(cdp.copiedEventLogDirectory, 0755)
if err != nil {
return nil, err
// Only create the directory if the device type is software-based (e.g., vTPM),
// as hardware-based devices (TDX/SNP) do not require copying measurement files to /run.
if cdp.cds.Type == SoftwareAttestation {
if _, err := os.Stat(cdp.copiedEventLogDirectory); os.IsNotExist(err) {
// Create the directory
err = os.MkdirAll(cdp.copiedEventLogDirectory, 0755)
if err != nil {
return nil, err
}
level.Info(cdp.logger).Log("msg", "Directory created:" + cdp.copiedEventLogDirectory)
} else {
level.Info(cdp.logger).Log("msg", "Directory already exists:" + cdp.copiedEventLogDirectory)
}
level.Info(cdp.logger).Log("msg", "Directory created:"+cdp.copiedEventLogDirectory)
} else {
level.Info(cdp.logger).Log("msg", "Directory already exists:"+cdp.copiedEventLogDirectory)
}

if reg != nil {
Expand All @@ -118,75 +134,109 @@ func NewCcDevicePlugin(cds *CcDeviceSpec, devicePluginPath string, socket string

func (cdp *CcDevicePlugin) discoverCcDevices() ([]CcDevice, error) {
var ccDevices []CcDevice
cd := CcDevice{
Device: v1beta1.Device{
Health: v1beta1.Healthy,
},
// set cap
Limit: workloadSharedLimit,
}
h := sha1.New()
var foundDevicePaths []string

// We use foundDevicePaths as an accumulator because a single resource (like TDX)
// might be represented by multiple device path patterns.
for _, path := range cdp.cds.DevicePaths {
matches, err := filepath.Glob(path)
if err != nil {
return nil, err
}
for _, matchPath := range matches {
level.Info(cdp.logger).Log("msg", "device path found:"+matchPath)
cd.DeviceSpecs = append(cd.DeviceSpecs, &v1beta1.DeviceSpec{
HostPath: matchPath,
ContainerPath: matchPath,
Permissions: "mrw",
})
if len(matches) > 0 {
level.Info(cdp.logger).Log("msg", "found matching device path(s)", "pattern", path, "matches", strings.Join(matches, ","))
foundDevicePaths = append(foundDevicePaths, matches...)
}
}

for _, path := range cdp.cds.MeasurementPaths {
matches, err := filepath.Glob(path)
if err != nil {
return nil, err
// If no device paths were found for this resource type, simply return an empty list.
// This is not an error; the node just doesn't have this specific hardware.
if len(foundDevicePaths) == 0 {
return nil, nil
}

baseDevice := CcDevice{
Device: v1beta1.Device{
Health: v1beta1.Healthy,
},
}

for _, matchPath := range foundDevicePaths {
baseDevice.DeviceSpecs = append(baseDevice.DeviceSpecs, &v1beta1.DeviceSpec{
HostPath: matchPath,
ContainerPath: matchPath,
Permissions: "mrw",
})
}

// Measurement files are currently only expected for software-emulated devices (vTPM).
if cdp.cds.Type == SoftwareAttestation && len(cdp.cds.MeasurementPaths) > 0 {
var foundMeasurementPath string
for _, path := range cdp.cds.MeasurementPaths {
matches, err := filepath.Glob(path)
if err != nil {
return nil, err
}
if len(matches) > 0 {
// We only expect one measurement file
foundMeasurementPath = matches[0]
level.Info(cdp.logger).Log("msg", "measurement path found", "path", foundMeasurementPath)
break
}
}
for _, matchPath := range matches {
level.Info(cdp.logger).Log("msg", "measurement path found:"+matchPath)
cd.Mounts = append(cd.Mounts, &v1beta1.Mount{
if foundMeasurementPath != "" {
baseDevice.Mounts = append(baseDevice.Mounts, &v1beta1.Mount{
HostPath: cdp.copiedEventLogDirectory,
ContainerPath: cdp.containerEventLogDirectory,
ReadOnly: true,
})

// copy when no measurement file at copiedEventLogLocation
fileInfo, err := os.Stat(cdp.copiedEventLogLocation)
if errors.Is(err, os.ErrNotExist) {
err := copyMeasurementFile(matchPath, cdp.copiedEventLogLocation)
if err != nil {
if err := copyMeasurementFile(foundMeasurementPath, cdp.copiedEventLogLocation); err != nil {
level.Error(cdp.logger).Log("msg", "failed to copy measurement file", "error", err)
return nil, err
}
} else {
// copy when measurement file at /run was updated, but not by the current instance.
// measurementFileLastUpdate is init to 0.
// when file exists during first run, this instance deletes and creates a new file
if fileInfo.ModTime().After(measurementFileLastUpdate) {
err := copyMeasurementFile(matchPath, cdp.copiedEventLogLocation)
if err != nil {
return nil, err
}
} else if err == nil && fileInfo.ModTime().After(measurementFileLastUpdate) {
// Refresh the copy if the source file has been updated by the kernel since the last copy.
if err := copyMeasurementFile(foundMeasurementPath, cdp.copiedEventLogLocation); err != nil {
level.Error(cdp.logger).Log("msg", "failed to re-copy measurement file", "error", err)
return nil, err
}
} else if err != nil {
level.Error(cdp.logger).Log("msg", "failed to stat copied measurement file", "error", err)
return nil, err
}
} else {
level.Warn(cdp.logger).Log("msg", "MeasurementPaths specified but no measurement file found", "paths", strings.Join(cdp.cds.MeasurementPaths, ","))
}
}
if cd.DeviceSpecs != nil {
for i := 0; i < cd.Limit; i++ {
b := make([]byte, 1)
b[0] = byte(i)
cd.ID = fmt.Sprintf("%x", h.Sum(b))
ccDevices = append(ccDevices, cd)

// Create DeviceLimit instances of the device
h := sha1.New()
h.Write([]byte(cdp.cds.Resource))
baseID := fmt.Sprintf("%x", h.Sum(nil))

for i := 0; i < cdp.cds.DeviceLimit; i++ {
cd := baseDevice // Copy the base structure
// For single-limit devices, ID is baseID. For multi-limit, append index.
if cdp.cds.DeviceLimit > 1 {
cd.ID = fmt.Sprintf("%s-%d", baseID, i)
} else {
cd.ID = baseID
}
ccDevices = append(ccDevices, cd)
}

return ccDevices, nil
}

func copyMeasurementFile(src string, dest string) error {
// get time for src
sourceInfo, err := os.Stat(src)
if err != nil {
return err
}
// copy out measurement
eventlogFile, err := os.ReadFile(src)
if err != nil {
Expand All @@ -201,11 +251,7 @@ func copyMeasurementFile(src string, dest string) error {
if err != nil {
return err
}
fileInfo, err := os.Stat(dest)
if err != nil {
return err
}
measurementFileLastUpdate = fileInfo.ModTime()
measurementFileLastUpdate = sourceInfo.ModTime()
return nil
}

Expand Down Expand Up @@ -235,18 +281,28 @@ func (cdp *CcDevicePlugin) refreshDevices() (bool, error) {
devicesUnchange = false
}
}
if !devicesUnchange {
return false, nil
if len(ccDevices) != len(old) {
devicesUnchange = false
}

// Check if devices were removed.
if devicesUnchange {
return true, nil
}

// Log if devices were removed
for k := range old {
if _, ok := cdp.ccDevices[k]; !ok {
level.Warn(cdp.logger).Log("msg", "devices removed")
return false, nil
level.Info(cdp.logger).Log("msg", "device removed", "id", k)
}
}
return true, nil
// Log if devices were added
for k := range cdp.ccDevices {
if _, ok := old[k]; !ok {
level.Info(cdp.logger).Log("msg", "device added", "id", k)
}
}

return false, nil
}

// Allocate assigns cc devices to a Pod.
Expand All @@ -267,19 +323,18 @@ func (cdp *CcDevicePlugin) Allocate(_ context.Context, req *v1beta1.AllocateRequ
if ccDevice.Health != v1beta1.Healthy {
return nil, fmt.Errorf("requested cc device is not healthy %q", id)
}
level.Info(cdp.logger).Log("msg", "adding device and measurement to Pod, device id is:"+id)
level.Info(cdp.logger).Log("msg", "adding device and measurement to Pod", "device id", id)

for _, ds := range ccDevice.DeviceSpecs {
level.Info(cdp.logger).Log("msg", "added ccDevice.deviceSpecs is:"+ds.String())
level.Debug(cdp.logger).Log("msg", "added ccDevice.deviceSpecs", "spec", ds.String())
}

for _, dm := range ccDevice.Mounts {
level.Info(cdp.logger).Log("msg", "added ccDevice.mounts is:"+dm.String())
level.Debug(cdp.logger).Log("msg", "added ccDevice.mounts", "mount", dm.String())
}

resp.Devices = append(resp.Devices, ccDevice.DeviceSpecs...)
resp.Mounts = append(resp.Mounts, ccDevice.Mounts...)

}
res.ContainerResponses = append(res.ContainerResponses, resp)
}
Expand All @@ -298,23 +353,26 @@ func (cdp *CcDevicePlugin) ListAndWatch(_ *v1beta1.Empty, stream v1beta1.DeviceP
if _, err := cdp.refreshDevices(); err != nil {
return err
}
refreshComplete := false
var err error

for {
if !refreshComplete {
res := new(v1beta1.ListAndWatchResponse)
for _, dev := range cdp.ccDevices {
res.Devices = append(res.Devices, &v1beta1.Device{ID: dev.ID, Health: dev.Health})
}
if err := stream.Send(res); err != nil {
return err
}
res := new(v1beta1.ListAndWatchResponse)
cdp.mu.Lock()
for _, dev := range cdp.ccDevices {
res.Devices = append(res.Devices, &v1beta1.Device{ID: dev.ID, Health: dev.Health})
}
<-time.After(deviceCheckInterval)
refreshComplete, err = cdp.refreshDevices()
if err != nil {
cdp.mu.Unlock()

if err := stream.Send(res); err != nil {
level.Error(cdp.logger).Log("msg", "failed to send ListAndWatchResponse", "error", err)
return err
}

<-time.After(deviceCheckInterval)

if _, err := cdp.refreshDevices(); err != nil {
level.Error(cdp.logger).Log("msg", "error during device refresh", "error", err)
// Don't return error immediately, try to continue
}
}
}

Expand Down
Loading