Skip to content

Commit

Permalink
Merge branch 'upstream-add-redundant-device-info' into 'master'
Browse files Browse the repository at this point in the history
Optionally return list of device nodes in Allocate() call

See merge request nvidia/kubernetes/device-plugin!22
  • Loading branch information
klueska committed Apr 6, 2020
2 parents 412dcbd + ea604b2 commit 5a9f3c3
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 26 deletions.
3 changes: 3 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package main

import (
"flag"
"log"
"os"
"syscall"
Expand All @@ -37,6 +38,8 @@ func getAllPlugins() []*NvidiaDevicePlugin {
}

func main() {
flag.Parse()

log.Println("Loading NVML")
if err := nvml.Init(); err != nil {
log.Printf("Failed to initialize NVML: %s.", err)
Expand Down
29 changes: 17 additions & 12 deletions nvidia.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@ const (
allHealthChecks = "xids"
)

type Device struct {
pluginapi.Device
Path string
}

type ResourceManager interface {
Devices() []*pluginapi.Device
CheckHealth(stop <-chan interface{}, devices []*pluginapi.Device, unhealthy chan<- *pluginapi.Device)
Devices() []*Device
CheckHealth(stop <-chan interface{}, devices []*Device, unhealthy chan<- *Device)
}

type GpuDeviceManager struct {}
Expand All @@ -48,29 +53,29 @@ func NewGpuDeviceManager() *GpuDeviceManager {
return &GpuDeviceManager{}
}

func (g *GpuDeviceManager) Devices() []*pluginapi.Device {
func (g *GpuDeviceManager) Devices() []*Device {
n, err := nvml.GetDeviceCount()
check(err)

var devs []*pluginapi.Device
var devs []*Device
for i := uint(0); i < n; i++ {
d, err := nvml.NewDeviceLite(i)
check(err)
devs = append(devs, buildPluginDevice(d))
devs = append(devs, buildDevice(d))
}

return devs
}

func (g *GpuDeviceManager) CheckHealth(stop <-chan interface{}, devices []*pluginapi.Device, unhealthy chan<- *pluginapi.Device) {
func (g *GpuDeviceManager) CheckHealth(stop <-chan interface{}, devices []*Device, unhealthy chan<- *Device) {
checkHealth(stop, devices, unhealthy)
}

func buildPluginDevice(d *nvml.Device) *pluginapi.Device {
dev := pluginapi.Device{
ID: d.UUID,
Health: pluginapi.Healthy,
}
func buildDevice(d *nvml.Device) *Device {
dev := Device{}
dev.ID = d.UUID
dev.Health = pluginapi.Healthy
dev.Path = d.Path
if d.CPUAffinity != nil {
dev.Topology = &pluginapi.TopologyInfo{
Nodes: []*pluginapi.NUMANode{
Expand All @@ -83,7 +88,7 @@ func buildPluginDevice(d *nvml.Device) *pluginapi.Device {
return &dev
}

func checkHealth(stop <-chan interface{}, devices []*pluginapi.Device, unhealthy chan<- *pluginapi.Device) {
func checkHealth(stop <-chan interface{}, devices []*Device, unhealthy chan<- *Device) {
disableHealthChecks := strings.ToLower(os.Getenv(envDisableHealthChecks))
if disableHealthChecks == "all" {
disableHealthChecks = allHealthChecks
Expand Down
77 changes: 63 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package main

import (
"flag"
"fmt"
"log"
"net"
Expand All @@ -30,6 +31,8 @@ import (
pluginapi "k8s.io/kubernetes/pkg/kubelet/apis/deviceplugin/v1beta1"
)

var passDeviceSpecs = flag.Bool("pass-device-specs", false, "pass the list of DeviceSpecs to the kubelet on Allocate()")

// NvidiaDevicePlugin implements the Kubernetes device plugin API
type NvidiaDevicePlugin struct {
ResourceManager
Expand All @@ -38,8 +41,8 @@ type NvidiaDevicePlugin struct {
socket string

server *grpc.Server
cachedDevices []*pluginapi.Device
health chan *pluginapi.Device
cachedDevices []*Device
health chan *Device
stop chan interface{}
}

Expand All @@ -57,14 +60,13 @@ func NewNvidiaDevicePlugin(resourceName string, resourceManager ResourceManager,
server: nil,
health: nil,
stop: nil,

}
}

func (m *NvidiaDevicePlugin) initialize() {
m.cachedDevices = m.Devices()
m.server = grpc.NewServer([]grpc.ServerOption{}...)
m.health = make(chan *pluginapi.Device)
m.health = make(chan *Device)
m.stop = make(chan interface{})
}

Expand Down Expand Up @@ -193,7 +195,7 @@ func (m *NvidiaDevicePlugin) GetDevicePluginOptions(context.Context, *pluginapi.

// ListAndWatch lists devices and update that list according to the health status
func (m *NvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error {
s.Send(&pluginapi.ListAndWatchResponse{Devices: m.cachedDevices})
s.Send(&pluginapi.ListAndWatchResponse{Devices: m.apiDevices()})

for {
select {
Expand All @@ -203,7 +205,7 @@ func (m *NvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.Device
// FIXME: there is no way to recover from the Unhealthy state.
d.Health = pluginapi.Unhealthy
log.Printf("'%s' device marked unhealthy: %s", m.resourceName, d.ID)
s.Send(&pluginapi.ListAndWatchResponse{Devices: m.cachedDevices})
s.Send(&pluginapi.ListAndWatchResponse{Devices: m.apiDevices()})
}
}
}
Expand All @@ -212,16 +214,19 @@ func (m *NvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.Device
func (m *NvidiaDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) {
responses := pluginapi.AllocateResponse{}
for _, req := range reqs.ContainerRequests {
for _, id := range req.DevicesIDs {
if !m.deviceExists(id) {
return nil, fmt.Errorf("invalid allocation request for '%s': unknown device: %s", m.resourceName, id)
}
}

response := pluginapi.ContainerAllocateResponse{
Envs: map[string]string{
m.allocateEnvvar: strings.Join(req.DevicesIDs, ","),
},
}

for _, id := range req.DevicesIDs {
if !m.deviceExists(m.cachedDevices, id) {
return nil, fmt.Errorf("invalid allocation request for '%s': unknown device: %s", m.resourceName, id)
}
if *passDeviceSpecs {
response.Devices = m.apiDeviceSpecs(req.DevicesIDs)
}

responses.ContainerResponses = append(responses.ContainerResponses, &response)
Expand Down Expand Up @@ -250,12 +255,56 @@ func (m *NvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Duration)
return c, nil
}


func (m *NvidiaDevicePlugin) deviceExists(devs []*pluginapi.Device, id string) bool {
for _, d := range devs {
func (m *NvidiaDevicePlugin) deviceExists(id string) bool {
for _, d := range m.cachedDevices {
if d.ID == id {
return true
}
}
return false
}

func (m *NvidiaDevicePlugin) apiDevices() []*pluginapi.Device {
var pdevs []*pluginapi.Device
for _, d := range m.cachedDevices {
pdevs = append(pdevs, &d.Device)
}
return pdevs
}

func (m *NvidiaDevicePlugin) apiDeviceSpecs(filter []string) []*pluginapi.DeviceSpec {
var specs []*pluginapi.DeviceSpec

paths := []string{
"/dev/nvidiactl",
"/dev/nvidia-uvm",
"/dev/nvidia-uvm-tools",
"/dev/nvidia-modeset",
}

for _, p := range paths {
if _, err := os.Stat(p); err == nil {
spec := &pluginapi.DeviceSpec{
ContainerPath: p,
HostPath: p,
Permissions: "rw",
}
specs = append(specs, spec)
}
}

for _, d := range m.cachedDevices {
for _, id := range filter {
if d.ID == id {
spec := &pluginapi.DeviceSpec{
ContainerPath: d.Path,
HostPath: d.Path,
Permissions: "rw",
}
specs = append(specs, spec)
}
}
}

return specs
}

0 comments on commit 5a9f3c3

Please sign in to comment.