Skip to main content

zlayer_agent/
gpu_sharing.rs

1//! GPU sharing support via NVIDIA MPS (Multi-Process Service) and time-slicing.
2//!
3//! MPS allows multiple containers to share a single GPU with hardware-level
4//! isolation of compute resources. Time-slicing provides simpler round-robin
5//! GPU sharing without concurrent kernel execution.
6
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use tokio::process::Command;
12use tokio::sync::Mutex;
13use tracing::{debug, error, info, warn};
14
15/// Manages NVIDIA MPS daemon instances, one per physical GPU.
16///
17/// The MPS control daemon must run on the host for each GPU that uses MPS sharing.
18/// This manager starts the daemon on first MPS container creation and stops it
19/// when the last MPS container on that GPU exits.
20#[derive(Debug)]
21pub struct MpsDaemonManager {
22    /// Per-GPU state: GPU index -> daemon state
23    daemons: Arc<Mutex<HashMap<u32, MpsDaemonState>>>,
24    /// Base directory for MPS pipe and log files
25    base_dir: PathBuf,
26}
27
28#[derive(Debug)]
29struct MpsDaemonState {
30    /// Number of containers currently using MPS on this GPU
31    ref_count: u32,
32    /// PID of the MPS control daemon process
33    pid: Option<u32>,
34}
35
36impl MpsDaemonManager {
37    /// Create a new MPS daemon manager.
38    ///
39    /// `base_dir` is the root directory for per-GPU MPS pipe and log directories
40    /// (e.g. `/var/run/zlayer/mps`).
41    #[must_use]
42    pub fn new(base_dir: impl Into<PathBuf>) -> Self {
43        Self {
44            daemons: Arc::new(Mutex::new(HashMap::new())),
45            base_dir: base_dir.into(),
46        }
47    }
48
49    /// Pipe directory for a specific GPU index.
50    fn pipe_dir(&self, gpu_index: u32) -> PathBuf {
51        self.base_dir.join(format!("gpu{gpu_index}/pipe"))
52    }
53
54    /// Log directory for a specific GPU index.
55    fn log_dir(&self, gpu_index: u32) -> PathBuf {
56        self.base_dir.join(format!("gpu{gpu_index}/log"))
57    }
58
59    /// Start or attach to the MPS daemon for a GPU.
60    ///
61    /// If the daemon is already running for this GPU, increments the reference count.
62    /// Otherwise, starts a new `nvidia-cuda-mps-control` daemon process.
63    ///
64    /// Returns the pipe directory path that should be injected as
65    /// `CUDA_MPS_PIPE_DIRECTORY` into the container's environment.
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if the MPS control daemon fails to start.
70    pub async fn acquire(&self, gpu_index: u32) -> Result<PathBuf, MpsError> {
71        let mut daemons = self.daemons.lock().await;
72
73        let state = daemons.entry(gpu_index).or_insert(MpsDaemonState {
74            ref_count: 0,
75            pid: None,
76        });
77
78        if state.ref_count > 0 {
79            // Daemon already running, just bump the reference count
80            state.ref_count += 1;
81            debug!(
82                gpu_index,
83                ref_count = state.ref_count,
84                "MPS daemon ref count incremented"
85            );
86            return Ok(self.pipe_dir(gpu_index));
87        }
88
89        // Start a new MPS daemon for this GPU
90        let pipe_dir = self.pipe_dir(gpu_index);
91        let log_dir = self.log_dir(gpu_index);
92
93        // Create directories
94        tokio::fs::create_dir_all(&pipe_dir)
95            .await
96            .map_err(|e| MpsError::Setup(format!("Failed to create MPS pipe directory: {e}")))?;
97        tokio::fs::create_dir_all(&log_dir)
98            .await
99            .map_err(|e| MpsError::Setup(format!("Failed to create MPS log directory: {e}")))?;
100
101        info!(gpu_index, pipe_dir = %pipe_dir.display(), "Starting MPS control daemon");
102
103        let child = Command::new("nvidia-cuda-mps-control")
104            .arg("-d") // daemon mode
105            .env("CUDA_VISIBLE_DEVICES", gpu_index.to_string())
106            .env("CUDA_MPS_PIPE_DIRECTORY", &pipe_dir)
107            .env("CUDA_MPS_LOG_DIRECTORY", &log_dir)
108            .spawn()
109            .map_err(|e| MpsError::Start(format!("Failed to spawn MPS daemon: {e}")))?;
110
111        let pid = child.id();
112        state.ref_count = 1;
113        state.pid = pid;
114
115        info!(gpu_index, ?pid, "MPS control daemon started");
116
117        Ok(pipe_dir)
118    }
119
120    /// Release a reference to the MPS daemon for a GPU.
121    ///
122    /// Decrements the reference count. When it reaches zero, the MPS daemon
123    /// is shut down by sending "quit" to the control pipe.
124    pub async fn release(&self, gpu_index: u32) {
125        let mut daemons = self.daemons.lock().await;
126
127        if let Some(state) = daemons.get_mut(&gpu_index) {
128            state.ref_count = state.ref_count.saturating_sub(1);
129            debug!(
130                gpu_index,
131                ref_count = state.ref_count,
132                "MPS daemon ref count decremented"
133            );
134
135            if state.ref_count == 0 {
136                info!(
137                    gpu_index,
138                    "Stopping MPS control daemon (last container released)"
139                );
140                if let Err(e) = self.stop_daemon(gpu_index).await {
141                    error!(gpu_index, error = %e, "Failed to stop MPS daemon cleanly");
142                }
143                daemons.remove(&gpu_index);
144            }
145        }
146    }
147
148    /// Stop the MPS daemon by sending "quit" to the control pipe.
149    async fn stop_daemon(&self, gpu_index: u32) -> Result<(), MpsError> {
150        let pipe_dir = self.pipe_dir(gpu_index);
151
152        let output = Command::new("sh")
153            .arg("-c")
154            .arg(format!(
155                "echo quit | CUDA_MPS_PIPE_DIRECTORY={} nvidia-cuda-mps-control",
156                pipe_dir.display()
157            ))
158            .output()
159            .await
160            .map_err(|e| MpsError::Stop(format!("Failed to send quit to MPS daemon: {e}")))?;
161
162        if !output.status.success() {
163            let stderr = String::from_utf8_lossy(&output.stderr);
164            warn!(gpu_index, stderr = %stderr, "MPS daemon quit command returned non-zero");
165        }
166
167        Ok(())
168    }
169
170    /// Get the environment variables needed for a container using MPS on a specific GPU.
171    #[must_use]
172    pub fn env_vars(&self, gpu_index: u32) -> Vec<(String, String)> {
173        vec![
174            (
175                "CUDA_MPS_PIPE_DIRECTORY".to_string(),
176                self.pipe_dir(gpu_index).to_string_lossy().to_string(),
177            ),
178            (
179                "CUDA_MPS_LOG_DIRECTORY".to_string(),
180                self.log_dir(gpu_index).to_string_lossy().to_string(),
181            ),
182        ]
183    }
184
185    /// Configure time-slicing for a GPU via nvidia-smi.
186    ///
187    /// This sets the GPU to shared compute mode, allowing multiple processes
188    /// to use it with round-robin scheduling.
189    ///
190    /// # Errors
191    ///
192    /// Returns an error if `nvidia-smi` fails.
193    pub async fn enable_time_slicing(gpu_index: u32) -> Result<(), MpsError> {
194        let output = Command::new("nvidia-smi")
195            .args(["-i", &gpu_index.to_string()])
196            .args(["-c", "DEFAULT"]) // Set compute mode to default (shared)
197            .output()
198            .await
199            .map_err(|e| MpsError::Setup(format!("Failed to set compute mode: {e}")))?;
200
201        if !output.status.success() {
202            let stderr = String::from_utf8_lossy(&output.stderr);
203            return Err(MpsError::Setup(format!(
204                "nvidia-smi set compute mode failed: {stderr}"
205            )));
206        }
207
208        info!(gpu_index, "Time-slicing enabled (compute mode = DEFAULT)");
209        Ok(())
210    }
211}
212
213/// Errors from MPS daemon operations
214#[derive(Debug, thiserror::Error)]
215pub enum MpsError {
216    /// Failed to set up MPS directories or configuration
217    #[error("MPS setup failed: {0}")]
218    Setup(String),
219    /// Failed to start the MPS daemon
220    #[error("MPS daemon start failed: {0}")]
221    Start(String),
222    /// Failed to stop the MPS daemon
223    #[error("MPS daemon stop failed: {0}")]
224    Stop(String),
225}