zlayer_agent/
gpu_sharing.rs1use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use tokio::process::Command;
12use tokio::sync::Mutex;
13use tracing::{debug, error, info, warn};
14
15#[derive(Debug)]
21pub struct MpsDaemonManager {
22 daemons: Arc<Mutex<HashMap<u32, MpsDaemonState>>>,
24 base_dir: PathBuf,
26}
27
28#[derive(Debug)]
29struct MpsDaemonState {
30 ref_count: u32,
32 pid: Option<u32>,
34}
35
36impl MpsDaemonManager {
37 #[must_use]
42 pub fn new(base_dir: impl Into<PathBuf>) -> Self {
43 Self {
44 daemons: Arc::new(Mutex::new(HashMap::new())),
45 base_dir: base_dir.into(),
46 }
47 }
48
49 fn pipe_dir(&self, gpu_index: u32) -> PathBuf {
51 self.base_dir.join(format!("gpu{gpu_index}/pipe"))
52 }
53
54 fn log_dir(&self, gpu_index: u32) -> PathBuf {
56 self.base_dir.join(format!("gpu{gpu_index}/log"))
57 }
58
59 pub async fn acquire(&self, gpu_index: u32) -> Result<PathBuf, MpsError> {
71 let mut daemons = self.daemons.lock().await;
72
73 let state = daemons.entry(gpu_index).or_insert(MpsDaemonState {
74 ref_count: 0,
75 pid: None,
76 });
77
78 if state.ref_count > 0 {
79 state.ref_count += 1;
81 debug!(
82 gpu_index,
83 ref_count = state.ref_count,
84 "MPS daemon ref count incremented"
85 );
86 return Ok(self.pipe_dir(gpu_index));
87 }
88
89 let pipe_dir = self.pipe_dir(gpu_index);
91 let log_dir = self.log_dir(gpu_index);
92
93 tokio::fs::create_dir_all(&pipe_dir)
95 .await
96 .map_err(|e| MpsError::Setup(format!("Failed to create MPS pipe directory: {e}")))?;
97 tokio::fs::create_dir_all(&log_dir)
98 .await
99 .map_err(|e| MpsError::Setup(format!("Failed to create MPS log directory: {e}")))?;
100
101 info!(gpu_index, pipe_dir = %pipe_dir.display(), "Starting MPS control daemon");
102
103 let child = Command::new("nvidia-cuda-mps-control")
104 .arg("-d") .env("CUDA_VISIBLE_DEVICES", gpu_index.to_string())
106 .env("CUDA_MPS_PIPE_DIRECTORY", &pipe_dir)
107 .env("CUDA_MPS_LOG_DIRECTORY", &log_dir)
108 .spawn()
109 .map_err(|e| MpsError::Start(format!("Failed to spawn MPS daemon: {e}")))?;
110
111 let pid = child.id();
112 state.ref_count = 1;
113 state.pid = pid;
114
115 info!(gpu_index, ?pid, "MPS control daemon started");
116
117 Ok(pipe_dir)
118 }
119
120 pub async fn release(&self, gpu_index: u32) {
125 let mut daemons = self.daemons.lock().await;
126
127 if let Some(state) = daemons.get_mut(&gpu_index) {
128 state.ref_count = state.ref_count.saturating_sub(1);
129 debug!(
130 gpu_index,
131 ref_count = state.ref_count,
132 "MPS daemon ref count decremented"
133 );
134
135 if state.ref_count == 0 {
136 info!(
137 gpu_index,
138 "Stopping MPS control daemon (last container released)"
139 );
140 if let Err(e) = self.stop_daemon(gpu_index).await {
141 error!(gpu_index, error = %e, "Failed to stop MPS daemon cleanly");
142 }
143 daemons.remove(&gpu_index);
144 }
145 }
146 }
147
148 async fn stop_daemon(&self, gpu_index: u32) -> Result<(), MpsError> {
150 let pipe_dir = self.pipe_dir(gpu_index);
151
152 let output = Command::new("sh")
153 .arg("-c")
154 .arg(format!(
155 "echo quit | CUDA_MPS_PIPE_DIRECTORY={} nvidia-cuda-mps-control",
156 pipe_dir.display()
157 ))
158 .output()
159 .await
160 .map_err(|e| MpsError::Stop(format!("Failed to send quit to MPS daemon: {e}")))?;
161
162 if !output.status.success() {
163 let stderr = String::from_utf8_lossy(&output.stderr);
164 warn!(gpu_index, stderr = %stderr, "MPS daemon quit command returned non-zero");
165 }
166
167 Ok(())
168 }
169
170 #[must_use]
172 pub fn env_vars(&self, gpu_index: u32) -> Vec<(String, String)> {
173 vec![
174 (
175 "CUDA_MPS_PIPE_DIRECTORY".to_string(),
176 self.pipe_dir(gpu_index).to_string_lossy().to_string(),
177 ),
178 (
179 "CUDA_MPS_LOG_DIRECTORY".to_string(),
180 self.log_dir(gpu_index).to_string_lossy().to_string(),
181 ),
182 ]
183 }
184
185 pub async fn enable_time_slicing(gpu_index: u32) -> Result<(), MpsError> {
194 let output = Command::new("nvidia-smi")
195 .args(["-i", &gpu_index.to_string()])
196 .args(["-c", "DEFAULT"]) .output()
198 .await
199 .map_err(|e| MpsError::Setup(format!("Failed to set compute mode: {e}")))?;
200
201 if !output.status.success() {
202 let stderr = String::from_utf8_lossy(&output.stderr);
203 return Err(MpsError::Setup(format!(
204 "nvidia-smi set compute mode failed: {stderr}"
205 )));
206 }
207
208 info!(gpu_index, "Time-slicing enabled (compute mode = DEFAULT)");
209 Ok(())
210 }
211}
212
213#[derive(Debug, thiserror::Error)]
215pub enum MpsError {
216 #[error("MPS setup failed: {0}")]
218 Setup(String),
219 #[error("MPS daemon start failed: {0}")]
221 Start(String),
222 #[error("MPS daemon stop failed: {0}")]
224 Stop(String),
225}