Skip to main content

zlayer_agent/
cdi.rs

1//! Container Device Interface (CDI) support.
2//!
3//! CDI is a vendor-neutral mechanism for declaring and injecting devices
4//! into OCI containers. `ZLayer` discovers CDI specs from standard locations
5//! (`/etc/cdi/`, `/var/run/cdi/`) and applies them to container specs
6//! as an alternative to manual device passthrough.
7//!
8//! See: <https://github.com/cncf-tags/container-device-interface>
9
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info, warn};
15
16/// Standard CDI spec discovery directories
17const CDI_SPEC_DIRS: &[&str] = &["/etc/cdi", "/var/run/cdi"];
18
19/// Environment variable that overrides the default CDI spec search path.
20///
21/// When set, its value is interpreted as a list of directories separated by
22/// the platform path separator (`:` on Unix, `;` on Windows). Each directory
23/// is scanned in addition to the standard locations.
24pub const CDI_SPEC_DIRS_ENV: &str = "CDI_SPEC_DIRS";
25
26/// Map a `GpuSpec.vendor` short name to a CDI kind.
27///
28/// CDI kinds are fully-qualified (`vendor.tld/class`) while `GpuSpec.vendor`
29/// is a short alias (`"nvidia"`, `"amd"`, `"intel"`). This is the canonical
30/// mapping used when resolving GPU devices from a service spec.
31#[must_use]
32pub fn vendor_to_cdi_kind(vendor: &str) -> Option<&'static str> {
33    match vendor {
34        "nvidia" => Some("nvidia.com/gpu"),
35        "amd" => Some("amd.com/gpu"),
36        "intel" => Some("intel.com/gpu"),
37        _ => None,
38    }
39}
40
41/// A parsed CDI specification file
42#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct CdiSpec {
45    /// CDI spec version (e.g. "0.6.0")
46    pub cdi_version: String,
47    /// Device vendor and class (e.g. "nvidia.com/gpu")
48    pub kind: String,
49    /// Devices declared by this spec
50    #[serde(default)]
51    pub devices: Vec<CdiDevice>,
52    /// Container edits applied to all devices of this kind
53    #[serde(default)]
54    pub container_edits: Option<CdiContainerEdits>,
55}
56
57/// A device within a CDI spec
58#[derive(Debug, Clone, Serialize, Deserialize)]
59#[serde(rename_all = "camelCase")]
60pub struct CdiDevice {
61    /// Device name (e.g. "0" for GPU 0)
62    pub name: String,
63    /// Container edits specific to this device
64    #[serde(default)]
65    pub container_edits: Option<CdiContainerEdits>,
66}
67
68/// Modifications to apply to the OCI container spec
69#[derive(Debug, Clone, Serialize, Deserialize, Default)]
70#[serde(rename_all = "camelCase")]
71pub struct CdiContainerEdits {
72    /// Environment variables to add
73    #[serde(default)]
74    pub env: Vec<String>,
75    /// Device nodes to create in the container
76    #[serde(default)]
77    pub device_nodes: Vec<CdiDeviceNode>,
78    /// Mounts to add
79    #[serde(default)]
80    pub mounts: Vec<CdiMount>,
81    /// Hooks to run
82    #[serde(default)]
83    pub hooks: Option<CdiHooks>,
84}
85
86/// A device node to create in the container
87#[derive(Debug, Clone, Serialize, Deserialize)]
88#[serde(rename_all = "camelCase")]
89pub struct CdiDeviceNode {
90    /// Path inside the container
91    pub path: String,
92    /// Host path (defaults to container path)
93    pub host_path: Option<String>,
94    /// Device type: "b" (block) or "c" (char)
95    #[serde(rename = "type")]
96    pub device_type: Option<String>,
97    /// Major device number
98    pub major: Option<i64>,
99    /// Minor device number
100    pub minor: Option<i64>,
101    /// File mode (e.g. 0o666)
102    #[serde(default)]
103    pub file_mode: Option<u32>,
104    /// Owner UID
105    pub uid: Option<u32>,
106    /// Owner GID
107    pub gid: Option<u32>,
108    /// Device permissions ("rwm")
109    pub permissions: Option<String>,
110}
111
112/// A mount to add to the container
113#[derive(Debug, Clone, Serialize, Deserialize)]
114#[serde(rename_all = "camelCase")]
115pub struct CdiMount {
116    /// Container path
117    pub container_path: String,
118    /// Host path
119    pub host_path: String,
120    /// Mount options
121    #[serde(default)]
122    pub options: Vec<String>,
123}
124
125/// OCI lifecycle hooks
126#[derive(Debug, Clone, Serialize, Deserialize, Default)]
127#[serde(rename_all = "camelCase")]
128pub struct CdiHooks {
129    #[serde(default)]
130    pub prestart: Vec<CdiHook>,
131    #[serde(default)]
132    pub create_runtime: Vec<CdiHook>,
133    #[serde(default)]
134    pub create_container: Vec<CdiHook>,
135    #[serde(default)]
136    pub start_container: Vec<CdiHook>,
137    #[serde(default)]
138    pub poststart: Vec<CdiHook>,
139    #[serde(default)]
140    pub poststop: Vec<CdiHook>,
141}
142
143/// A single OCI hook
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct CdiHook {
146    pub path: String,
147    #[serde(default)]
148    pub args: Vec<String>,
149    #[serde(default)]
150    pub env: Vec<String>,
151}
152
153/// Registry of all discovered CDI specs, indexed by fully-qualified device name.
154///
155/// A fully-qualified CDI device name has the format `vendor.com/class=device`,
156/// e.g. `nvidia.com/gpu=0`.
157#[derive(Debug, Default)]
158pub struct CdiRegistry {
159    /// All discovered specs, keyed by kind (e.g. "nvidia.com/gpu")
160    specs: HashMap<String, CdiSpec>,
161}
162
163impl CdiRegistry {
164    /// Discover and load CDI specs from the standard directories.
165    ///
166    /// Scans `/etc/cdi/` and `/var/run/cdi/` for `*.json` and `*.yaml` files,
167    /// parses them, and indexes them by kind. Honors the `CDI_SPEC_DIRS`
168    /// environment variable for an additional override search path.
169    pub fn discover() -> Self {
170        let mut dirs: Vec<PathBuf> = CDI_SPEC_DIRS.iter().map(PathBuf::from).collect();
171        if let Ok(env_dirs) = std::env::var(CDI_SPEC_DIRS_ENV) {
172            for entry in std::env::split_paths(&env_dirs) {
173                if !entry.as_os_str().is_empty() {
174                    dirs.push(entry);
175                }
176            }
177        }
178        Self::discover_from(&dirs)
179    }
180
181    /// Discover and load CDI specs from an explicit list of directories.
182    ///
183    /// This is primarily useful for tests where the standard system paths
184    /// are not appropriate. Missing directories are silently skipped.
185    pub fn discover_from<P: AsRef<Path>>(dirs: &[P]) -> Self {
186        let mut registry = Self::default();
187
188        for dir in dirs {
189            let dir_path = dir.as_ref();
190            if !dir_path.is_dir() {
191                debug!(dir = %dir_path.display(), "CDI spec directory does not exist, skipping");
192                continue;
193            }
194
195            let entries = match std::fs::read_dir(dir_path) {
196                Ok(e) => e,
197                Err(e) => {
198                    warn!(dir = %dir_path.display(), error = %e, "Failed to read CDI spec directory");
199                    continue;
200                }
201            };
202
203            for entry in entries.flatten() {
204                let path = entry.path();
205                let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
206                if ext != "json" && ext != "yaml" && ext != "yml" {
207                    continue;
208                }
209
210                match Self::load_spec(&path) {
211                    Ok(spec) => {
212                        info!(
213                            kind = %spec.kind,
214                            devices = spec.devices.len(),
215                            path = %path.display(),
216                            "Loaded CDI spec"
217                        );
218                        registry.specs.insert(spec.kind.clone(), spec);
219                    }
220                    Err(e) => {
221                        warn!(path = %path.display(), error = %e, "Failed to parse CDI spec");
222                    }
223                }
224            }
225        }
226
227        registry
228    }
229
230    /// Load a single CDI spec file.
231    fn load_spec(path: &Path) -> Result<CdiSpec, CdiError> {
232        let content = std::fs::read_to_string(path)
233            .map_err(|e| CdiError::Io(format!("{}: {e}", path.display())))?;
234
235        let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
236        if ext == "json" {
237            serde_json::from_str(&content)
238                .map_err(|e| CdiError::Parse(format!("{}: {e}", path.display())))
239        } else {
240            serde_yaml::from_str(&content)
241                .map_err(|e| CdiError::Parse(format!("{}: {e}", path.display())))
242        }
243    }
244
245    /// Look up a CDI spec by kind (e.g. "nvidia.com/gpu").
246    #[must_use]
247    pub fn get_spec(&self, kind: &str) -> Option<&CdiSpec> {
248        self.specs.get(kind)
249    }
250
251    /// Resolve a fully-qualified CDI device name to its container edits.
252    ///
253    /// Format: `vendor.com/class=device` (e.g. `nvidia.com/gpu=0`)
254    ///
255    /// Returns the merged container edits (global + device-specific) or None
256    /// if the device is not found.
257    #[must_use]
258    pub fn resolve_device(&self, qualified_name: &str) -> Option<CdiContainerEdits> {
259        let (kind, device_name) = qualified_name.split_once('=')?;
260        let spec = self.specs.get(kind)?;
261        let device = spec.devices.iter().find(|d| d.name == device_name)?;
262
263        // Merge global and device-specific edits
264        let mut merged = spec.container_edits.clone().unwrap_or_default();
265
266        if let Some(ref dev_edits) = device.container_edits {
267            merged.env.extend(dev_edits.env.clone());
268            merged.device_nodes.extend(dev_edits.device_nodes.clone());
269            merged.mounts.extend(dev_edits.mounts.clone());
270            if let Some(ref dev_hooks) = dev_edits.hooks {
271                let merged_hooks = merged.hooks.get_or_insert_with(CdiHooks::default);
272                merged_hooks.prestart.extend(dev_hooks.prestart.clone());
273                merged_hooks
274                    .create_runtime
275                    .extend(dev_hooks.create_runtime.clone());
276                merged_hooks
277                    .create_container
278                    .extend(dev_hooks.create_container.clone());
279                merged_hooks
280                    .start_container
281                    .extend(dev_hooks.start_container.clone());
282                merged_hooks.poststart.extend(dev_hooks.poststart.clone());
283                merged_hooks.poststop.extend(dev_hooks.poststop.clone());
284            }
285        }
286
287        Some(merged)
288    }
289
290    /// Resolve one or more device names for a given vendor kind into a
291    /// flat list of per-device container edits.
292    ///
293    /// The special device name `"all"` expands to every device declared in
294    /// the spec for the requested vendor — this mirrors the semantics of
295    /// `NVIDIA_VISIBLE_DEVICES=all` and matches the behavior of `nvidia-ctk`'s
296    /// CDI implementation.
297    ///
298    /// # Errors
299    ///
300    /// Returns [`CdiError::SpecMissing`] if no spec is loaded for the
301    /// requested kind. Returns [`CdiError::DeviceMissing`] if any of the
302    /// requested device names is not declared in the spec. Returns
303    /// [`CdiError::NoDevices`] if the request asks for `"all"` but the
304    /// spec contains no devices.
305    pub fn resolve_for_kind(
306        &self,
307        kind: &str,
308        device_names: &[String],
309    ) -> std::result::Result<Vec<CdiContainerEdits>, CdiError> {
310        let spec = self
311            .specs
312            .get(kind)
313            .ok_or_else(|| CdiError::SpecMissing(kind.to_string()))?;
314
315        // Expand the "all" alias to every declared device name (excluding
316        // any device that itself is literally named "all" — that's a
317        // sentinel device used by some vendors to express "use any GPU").
318        let expanded: Vec<String> = if device_names.iter().any(|n| n == "all") {
319            let names: Vec<String> = spec
320                .devices
321                .iter()
322                .filter(|d| d.name != "all")
323                .map(|d| d.name.clone())
324                .collect();
325            if names.is_empty() {
326                return Err(CdiError::NoDevices(kind.to_string()));
327            }
328            names
329        } else {
330            device_names.to_vec()
331        };
332
333        let mut out = Vec::with_capacity(expanded.len());
334        for name in &expanded {
335            let qualified = format!("{kind}={name}");
336            let edits = self
337                .resolve_device(&qualified)
338                .ok_or_else(|| CdiError::DeviceMissing {
339                    kind: kind.to_string(),
340                    device: name.clone(),
341                })?;
342            out.push(edits);
343        }
344        Ok(out)
345    }
346
347    /// Check if any CDI specs are available.
348    #[must_use]
349    pub fn is_empty(&self) -> bool {
350        self.specs.is_empty()
351    }
352
353    /// Get all available kinds.
354    pub fn kinds(&self) -> impl Iterator<Item = &str> {
355        self.specs.keys().map(String::as_str)
356    }
357
358    /// Generate a CDI spec for NVIDIA GPUs using nvidia-ctk.
359    ///
360    /// Runs `nvidia-ctk cdi generate` and returns the resulting spec,
361    /// or None if nvidia-ctk is not available.
362    pub async fn generate_nvidia_spec() -> Option<CdiSpec> {
363        let output = tokio::process::Command::new("nvidia-ctk")
364            .args(["cdi", "generate"])
365            .output()
366            .await
367            .ok()?;
368
369        if !output.status.success() {
370            let stderr = String::from_utf8_lossy(&output.stderr);
371            warn!("nvidia-ctk cdi generate failed: {stderr}");
372            return None;
373        }
374
375        let stdout = String::from_utf8_lossy(&output.stdout);
376        match serde_yaml::from_str(&stdout) {
377            Ok(spec) => {
378                info!("Generated NVIDIA CDI spec via nvidia-ctk");
379                Some(spec)
380            }
381            Err(e) => {
382                warn!("Failed to parse nvidia-ctk output: {e}");
383                None
384            }
385        }
386    }
387}
388
389/// Errors from CDI operations
390#[derive(Debug, thiserror::Error)]
391pub enum CdiError {
392    /// I/O error reading a CDI spec file
393    #[error("CDI I/O error: {0}")]
394    Io(String),
395    /// Failed to parse a CDI spec file
396    #[error("CDI parse error: {0}")]
397    Parse(String),
398    /// No CDI spec installed for the requested vendor/kind.
399    ///
400    /// Typically means the vendor's CDI generator has not been run on this
401    /// host (e.g. `nvidia-ctk cdi generate --output=/etc/cdi/nvidia.json`).
402    #[error("no CDI spec installed for kind '{0}' (run the vendor's CDI generator)")]
403    SpecMissing(String),
404    /// A requested device name is not declared in the CDI spec for the kind.
405    #[error("CDI device '{device}' not declared in spec for kind '{kind}'")]
406    DeviceMissing {
407        /// CDI kind (e.g. `nvidia.com/gpu`).
408        kind: String,
409        /// Device name that was requested but not found.
410        device: String,
411    },
412    /// A request for `"all"` devices resolved to an empty list because the
413    /// installed CDI spec declares no devices for the kind.
414    #[error("CDI spec for kind '{0}' declares no devices (host has no compatible hardware)")]
415    NoDevices(String),
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    fn sample_spec_json() -> &'static str {
423        r#"{
424            "cdiVersion": "0.6.0",
425            "kind": "nvidia.com/gpu",
426            "devices": [
427                {
428                    "name": "0",
429                    "containerEdits": {
430                        "env": ["NVIDIA_VISIBLE_DEVICES=0"],
431                        "deviceNodes": [
432                            {
433                                "path": "/dev/nvidia0",
434                                "hostPath": "/dev/nvidia0",
435                                "type": "c",
436                                "major": 195,
437                                "minor": 0
438                            }
439                        ]
440                    }
441                },
442                {
443                    "name": "all",
444                    "containerEdits": {
445                        "env": ["NVIDIA_VISIBLE_DEVICES=all"]
446                    }
447                }
448            ],
449            "containerEdits": {
450                "env": ["NVIDIA_DRIVER_CAPABILITIES=all"],
451                "deviceNodes": [
452                    {
453                        "path": "/dev/nvidiactl",
454                        "hostPath": "/dev/nvidiactl",
455                        "type": "c"
456                    }
457                ],
458                "mounts": [
459                    {
460                        "containerPath": "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1",
461                        "hostPath": "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1",
462                        "options": ["ro", "nosuid", "nodev", "bind"]
463                    }
464                ]
465            }
466        }"#
467    }
468
469    #[test]
470    fn parse_cdi_spec_json() {
471        let spec: CdiSpec = serde_json::from_str(sample_spec_json()).unwrap();
472        assert_eq!(spec.cdi_version, "0.6.0");
473        assert_eq!(spec.kind, "nvidia.com/gpu");
474        assert_eq!(spec.devices.len(), 2);
475        assert_eq!(spec.devices[0].name, "0");
476
477        let global_edits = spec.container_edits.as_ref().unwrap();
478        assert_eq!(global_edits.env, vec!["NVIDIA_DRIVER_CAPABILITIES=all"]);
479        assert_eq!(global_edits.device_nodes.len(), 1);
480        assert_eq!(global_edits.mounts.len(), 1);
481    }
482
483    #[test]
484    fn resolve_device_merges_edits() {
485        let spec: CdiSpec = serde_json::from_str(sample_spec_json()).unwrap();
486        let mut registry = CdiRegistry::default();
487        registry.specs.insert(spec.kind.clone(), spec);
488
489        let edits = registry
490            .resolve_device("nvidia.com/gpu=0")
491            .expect("should resolve gpu 0");
492
493        // Global env + device env
494        assert!(edits
495            .env
496            .contains(&"NVIDIA_DRIVER_CAPABILITIES=all".to_string()));
497        assert!(edits.env.contains(&"NVIDIA_VISIBLE_DEVICES=0".to_string()));
498
499        // Global device node + device-specific device node
500        assert_eq!(edits.device_nodes.len(), 2);
501
502        // Global mount preserved
503        assert_eq!(edits.mounts.len(), 1);
504    }
505
506    #[test]
507    fn resolve_unknown_device_returns_none() {
508        let registry = CdiRegistry::default();
509        assert!(registry.resolve_device("nvidia.com/gpu=99").is_none());
510    }
511
512    #[test]
513    fn resolve_malformed_name_returns_none() {
514        let registry = CdiRegistry::default();
515        assert!(registry.resolve_device("no-equals-sign").is_none());
516    }
517
518    #[test]
519    fn empty_registry() {
520        let registry = CdiRegistry::default();
521        assert!(registry.is_empty());
522        assert_eq!(registry.kinds().count(), 0);
523    }
524
525    #[test]
526    fn parse_cdi_spec_yaml() {
527        let yaml = r#"
528cdiVersion: "0.6.0"
529kind: "vendor.com/net"
530devices:
531  - name: "eth0"
532    containerEdits:
533      env:
534        - "NET_DEVICE=eth0"
535"#;
536        let spec: CdiSpec = serde_yaml::from_str(yaml).unwrap();
537        assert_eq!(spec.kind, "vendor.com/net");
538        assert_eq!(spec.devices.len(), 1);
539        assert_eq!(spec.devices[0].name, "eth0");
540    }
541
542    fn fixture_spec_with_hooks() -> &'static str {
543        r#"{
544            "cdiVersion": "0.6.0",
545            "kind": "nvidia.com/gpu",
546            "devices": [
547                {
548                    "name": "0",
549                    "containerEdits": {
550                        "env": ["NVIDIA_VISIBLE_DEVICES=0"],
551                        "deviceNodes": [
552                            {"path": "/dev/nvidia0", "type": "c", "major": 195, "minor": 0}
553                        ],
554                        "hooks": {
555                            "createContainer": [{
556                                "path": "/usr/bin/nvidia-container-runtime-hook",
557                                "args": ["nvidia-container-runtime-hook", "prestart"]
558                            }]
559                        }
560                    }
561                },
562                {
563                    "name": "1",
564                    "containerEdits": {
565                        "env": ["NVIDIA_VISIBLE_DEVICES=1"],
566                        "deviceNodes": [
567                            {"path": "/dev/nvidia1", "type": "c", "major": 195, "minor": 1}
568                        ]
569                    }
570                }
571            ]
572        }"#
573    }
574
575    fn registry_with_fixture_dir() -> (tempfile::TempDir, CdiRegistry) {
576        let dir = tempfile::tempdir().unwrap();
577        let path = dir.path().join("nvidia.json");
578        std::fs::write(&path, fixture_spec_with_hooks()).unwrap();
579        let registry = CdiRegistry::discover_from(&[dir.path()]);
580        (dir, registry)
581    }
582
583    #[test]
584    fn discover_from_loads_specs() {
585        let (_keep, registry) = registry_with_fixture_dir();
586        assert_eq!(registry.kinds().count(), 1);
587        assert!(registry.get_spec("nvidia.com/gpu").is_some());
588    }
589
590    #[test]
591    fn discover_from_empty_dir_is_empty() {
592        let dir = tempfile::tempdir().unwrap();
593        let registry = CdiRegistry::discover_from(&[dir.path()]);
594        assert!(registry.is_empty());
595    }
596
597    #[test]
598    fn resolve_for_kind_returns_edits_per_device() {
599        let (_keep, registry) = registry_with_fixture_dir();
600        let edits = registry
601            .resolve_for_kind("nvidia.com/gpu", &["0".to_string()])
602            .expect("resolve gpu 0");
603        assert_eq!(edits.len(), 1);
604        assert!(edits[0].env.iter().any(|e| e == "NVIDIA_VISIBLE_DEVICES=0"));
605        assert!(edits[0]
606            .device_nodes
607            .iter()
608            .any(|d| d.path == "/dev/nvidia0"));
609        let hooks = edits[0].hooks.as_ref().expect("hooks merged");
610        assert_eq!(hooks.create_container.len(), 1);
611    }
612
613    #[test]
614    fn resolve_for_kind_all_expands_to_every_device() {
615        let (_keep, registry) = registry_with_fixture_dir();
616        let edits = registry
617            .resolve_for_kind("nvidia.com/gpu", &["all".to_string()])
618            .expect("resolve all");
619        assert_eq!(edits.len(), 2, "should expand to both '0' and '1'");
620        let names: Vec<&str> = edits
621            .iter()
622            .flat_map(|e| e.env.iter())
623            .filter(|s| s.starts_with("NVIDIA_VISIBLE_DEVICES="))
624            .map(String::as_str)
625            .collect();
626        assert!(names.contains(&"NVIDIA_VISIBLE_DEVICES=0"));
627        assert!(names.contains(&"NVIDIA_VISIBLE_DEVICES=1"));
628    }
629
630    #[test]
631    fn resolve_for_kind_missing_spec_errors() {
632        let registry = CdiRegistry::default();
633        let err = registry
634            .resolve_for_kind("nvidia.com/gpu", &["0".to_string()])
635            .unwrap_err();
636        assert!(matches!(err, CdiError::SpecMissing(ref k) if k == "nvidia.com/gpu"));
637    }
638
639    #[test]
640    fn resolve_for_kind_unknown_device_errors() {
641        let (_keep, registry) = registry_with_fixture_dir();
642        let err = registry
643            .resolve_for_kind("nvidia.com/gpu", &["99".to_string()])
644            .unwrap_err();
645        assert!(matches!(
646            err,
647            CdiError::DeviceMissing { ref device, .. } if device == "99"
648        ));
649    }
650
651    #[test]
652    fn vendor_to_cdi_kind_maps_known_vendors() {
653        assert_eq!(vendor_to_cdi_kind("nvidia"), Some("nvidia.com/gpu"));
654        assert_eq!(vendor_to_cdi_kind("amd"), Some("amd.com/gpu"));
655        assert_eq!(vendor_to_cdi_kind("intel"), Some("intel.com/gpu"));
656        assert_eq!(vendor_to_cdi_kind("apple"), None);
657    }
658}