1use crate::errors::{Result, SandboxError};
4use nix::unistd::Pid;
5use std::fs;
6use std::io::Write;
7use std::path::{Path, PathBuf};
8
9const CGROUP_V2_ROOT: &str = "/sys/fs/cgroup";
10
11#[derive(Debug, Clone, Default)]
13pub struct CgroupConfig {
14 pub memory_limit: Option<u64>,
16 pub cpu_weight: Option<u32>,
18 pub cpu_quota: Option<u64>,
20 pub cpu_period: Option<u64>,
22 pub max_pids: Option<u32>,
24}
25
26impl CgroupConfig {
27 pub fn with_memory(limit: u64) -> Self {
29 Self {
30 memory_limit: Some(limit),
31 ..Default::default()
32 }
33 }
34
35 pub fn with_cpu_quota(quota: u64, period: u64) -> Self {
37 Self {
38 cpu_quota: Some(quota),
39 cpu_period: Some(period),
40 ..Default::default()
41 }
42 }
43
44 pub fn validate(&self) -> Result<()> {
46 if let Some(limit) = self.memory_limit
47 && limit == 0
48 {
49 return Err(SandboxError::InvalidConfig(
50 "Memory limit must be greater than 0".to_string(),
51 ));
52 }
53
54 if let Some(weight) = self.cpu_weight
55 && (!(100..=10000).contains(&weight))
56 {
57 return Err(SandboxError::InvalidConfig(
58 "CPU weight must be between 100-10000".to_string(),
59 ));
60 }
61
62 Ok(())
63 }
64}
65
66pub struct Cgroup {
68 path: PathBuf,
69 pid: Pid,
70}
71
72fn cgroup_root_path() -> PathBuf {
73 std::env::var("SANDBOX_CGROUP_ROOT")
74 .map(PathBuf::from)
75 .unwrap_or_else(|_| PathBuf::from(CGROUP_V2_ROOT))
76}
77
78impl Cgroup {
79 pub fn new(name: &str, pid: Pid) -> Result<Self> {
81 let cgroup_path = cgroup_root_path().join(name);
82
83 fs::create_dir_all(&cgroup_path).map_err(|e| {
85 SandboxError::Cgroup(format!(
86 "Failed to create cgroup directory {}: {}",
87 cgroup_path.display(),
88 e
89 ))
90 })?;
91
92 ensure_controller_files(&cgroup_path)?;
93
94 Ok(Self {
95 path: cgroup_path,
96 pid,
97 })
98 }
99
100 pub fn apply_config(&self, config: &CgroupConfig) -> Result<()> {
102 config.validate()?;
103
104 if let Some(memory) = config.memory_limit {
105 self.set_memory_limit(memory)?;
106 }
107
108 if let Some(weight) = config.cpu_weight {
109 self.set_cpu_weight(weight)?;
110 }
111
112 if let Some(quota) = config.cpu_quota {
113 let period = config.cpu_period.unwrap_or(100000);
114 self.set_cpu_quota(quota, period)?;
115 }
116
117 if let Some(max_pids) = config.max_pids {
118 self.set_max_pids(max_pids)?;
119 }
120
121 Ok(())
122 }
123
124 pub fn add_process(&self, pid: Pid) -> Result<()> {
126 let procs_file = self.path.join("cgroup.procs");
127 self.write_file(&procs_file, &pid.as_raw().to_string())
128 }
129
130 fn set_memory_limit(&self, limit: u64) -> Result<()> {
132 let mem_file = self.path.join("memory.max");
133 self.write_file(&mem_file, &limit.to_string())
134 }
135
136 fn set_cpu_weight(&self, weight: u32) -> Result<()> {
138 let cpu_file = self.path.join("cpu.weight");
139 self.write_file(&cpu_file, &weight.to_string())
140 }
141
142 fn set_cpu_quota(&self, quota: u64, period: u64) -> Result<()> {
144 let quota_file = self.path.join("cpu.max");
145 let quota_str = if quota == u64::MAX {
146 "max".to_string()
147 } else {
148 format!("{} {}", quota, period)
149 };
150 self.write_file("a_file, "a_str)
151 }
152
153 fn set_max_pids(&self, max_pids: u32) -> Result<()> {
155 let pids_file = self.path.join("pids.max");
156 self.write_file(&pids_file, &max_pids.to_string())
157 }
158
159 pub fn get_memory_usage(&self) -> Result<u64> {
161 let mem_file = self.path.join("memory.current");
162 self.read_file_u64(&mem_file)
163 }
164
165 pub fn get_memory_limit(&self) -> Result<u64> {
167 let mem_file = self.path.join("memory.max");
168 self.read_file_u64(&mem_file)
169 }
170
171 pub fn get_cpu_usage(&self) -> Result<u64> {
173 let cpu_file = self.path.join("cpu.stat");
174 let content = fs::read_to_string(&cpu_file).map_err(|e| {
175 SandboxError::Cgroup(format!("Failed to read {}: {}", cpu_file.display(), e))
176 })?;
177
178 for line in content.lines() {
180 if line.starts_with("usage_usec") {
181 let parts: Vec<&str> = line.split_whitespace().collect();
182 if parts.len() >= 2 {
183 return parts[1].parse::<u64>().map_err(|e| {
184 SandboxError::Cgroup(format!("Failed to parse CPU usage: {}", e))
185 });
186 }
187 }
188 }
189
190 Ok(0)
191 }
192
193 pub fn exists(&self) -> bool {
195 self.path.exists()
196 }
197
198 pub fn pid(&self) -> Pid {
200 self.pid
201 }
202
203 pub fn delete(&self) -> Result<()> {
205 match fs::remove_dir(&self.path) {
207 Ok(()) => Ok(()),
208 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
209 Ok(())
211 }
212 Err(e) => Err(SandboxError::Cgroup(format!(
213 "Failed to delete cgroup {}: {}",
214 self.path.display(),
215 e
216 ))),
217 }
218 }
219
220 fn write_file(&self, path: &Path, content: &str) -> Result<()> {
221 let mut file = fs::OpenOptions::new().write(true).open(path).map_err(|e| {
222 SandboxError::Cgroup(format!("Failed to open {}: {}", path.display(), e))
223 })?;
224
225 write!(file, "{}", content).map_err(|e| {
226 SandboxError::Cgroup(format!("Failed to write to {}: {}", path.display(), e))
227 })?;
228
229 Ok(())
230 }
231
232 fn read_file_u64(&self, path: &Path) -> Result<u64> {
233 let content = fs::read_to_string(path).map_err(|e| {
234 SandboxError::Cgroup(format!("Failed to read {}: {}", path.display(), e))
235 })?;
236
237 content
238 .trim()
239 .parse::<u64>()
240 .map_err(|e| SandboxError::Cgroup(format!("Failed to parse value: {}", e)))
241 }
242
243 #[cfg(test)]
244 pub(crate) fn for_testing(path: PathBuf) -> Self {
245 Self {
246 path,
247 pid: Pid::from_raw(0),
248 }
249 }
250}
251
252fn ensure_controller_files(path: &Path) -> Result<()> {
253 let files = [
254 ("memory.max", "max"),
255 ("memory.current", "0"),
256 ("cpu.weight", "100"),
257 ("cpu.max", "max 100000"),
258 ("cpu.stat", "usage_usec 0\n"),
259 ("pids.max", "max"),
260 ("cgroup.procs", ""),
261 ];
262
263 for (name, default_content) in files {
264 let file_path = path.join(name);
265 if !file_path.exists() {
266 fs::write(&file_path, default_content).map_err(|e| {
267 SandboxError::Cgroup(format!("Failed to create {}: {}", file_path.display(), e))
268 })?;
269 }
270 }
271
272 Ok(())
273}
274
275impl Drop for Cgroup {
276 fn drop(&mut self) {
277 let _ = self.delete();
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use std::env;
286 use std::fs;
287 use tempfile::tempdir;
288
289 fn prepare_cgroup_dir() -> (tempfile::TempDir, std::path::PathBuf) {
290 let tmp = tempdir().unwrap();
291 let path = tmp.path().join("cgroup-test");
292 fs::create_dir_all(&path).unwrap();
293 for file in &[
294 "memory.max",
295 "memory.current",
296 "cpu.weight",
297 "cpu.max",
298 "cpu.stat",
299 "pids.max",
300 "cgroup.procs",
301 ] {
302 let file_path = path.join(file);
303 if let Some(parent) = file_path.parent() {
304 fs::create_dir_all(parent).unwrap();
305 }
306 fs::write(&file_path, "0").unwrap();
307 }
308 fs::write(path.join("cpu.stat"), "usage_usec 0\n").unwrap();
309 fs::write(path.join("memory.current"), "0\n").unwrap();
310 (tmp, path)
311 }
312
313 #[test]
314 fn test_cgroup_config_default() {
315 let config = CgroupConfig::default();
316 assert!(config.memory_limit.is_none());
317 assert!(config.cpu_weight.is_none());
318 }
319
320 #[test]
321 fn test_cgroup_config_with_memory() {
322 let config = CgroupConfig::with_memory(100 * 1024 * 1024);
323 assert_eq!(config.memory_limit, Some(100 * 1024 * 1024));
324 }
325
326 #[test]
327 fn test_cgroup_config_with_cpu_quota() {
328 let config = CgroupConfig::with_cpu_quota(50000, 100000);
329 assert_eq!(config.cpu_quota, Some(50000));
330 assert_eq!(config.cpu_period, Some(100000));
331 }
332
333 #[test]
334 fn test_cgroup_config_validate() {
335 let config = CgroupConfig::default();
336 assert!(config.validate().is_ok());
337
338 let bad_config = CgroupConfig {
339 memory_limit: Some(0),
340 ..Default::default()
341 };
342 assert!(bad_config.validate().is_err());
343
344 let bad_cpu_config = CgroupConfig {
345 cpu_weight: Some(50),
346 ..Default::default()
347 };
348 assert!(bad_cpu_config.validate().is_err());
349
350 let good_cpu_config = CgroupConfig {
351 cpu_weight: Some(100),
352 ..Default::default()
353 };
354 assert!(good_cpu_config.validate().is_ok());
355 }
356
357 #[test]
358 fn test_cgroup_path_creation() {
359 let test_path = Path::new(CGROUP_V2_ROOT);
362 if test_path.exists() {
363 let result = Cgroup::new(
365 "sandbox-test-delete-me",
366 Pid::from_raw(std::process::id() as i32),
367 );
368 let _ = result;
370 }
371 }
372
373 #[test]
374 fn test_cgroup_apply_config_writes_files() {
375 let (_tmp, path) = prepare_cgroup_dir();
376 let cgroup = Cgroup::for_testing(path.clone());
377
378 let config = CgroupConfig {
379 memory_limit: Some(2048),
380 cpu_weight: Some(500),
381 cpu_quota: Some(50_000),
382 cpu_period: Some(100_000),
383 max_pids: Some(32),
384 };
385
386 cgroup.apply_config(&config).unwrap();
387
388 assert_eq!(
389 fs::read_to_string(path.join("memory.max")).unwrap().trim(),
390 "2048"
391 );
392 assert_eq!(
393 fs::read_to_string(path.join("cpu.weight")).unwrap().trim(),
394 "500"
395 );
396 assert_eq!(
397 fs::read_to_string(path.join("cpu.max")).unwrap().trim(),
398 "50000 100000"
399 );
400 assert_eq!(
401 fs::read_to_string(path.join("pids.max")).unwrap().trim(),
402 "32"
403 );
404 }
405
406 #[test]
407 fn test_cgroup_add_process_writes_pid() {
408 let (_tmp, path) = prepare_cgroup_dir();
409 let cgroup = Cgroup::for_testing(path.clone());
410
411 cgroup.add_process(Pid::from_raw(1234)).unwrap();
412 assert_eq!(
413 fs::read_to_string(path.join("cgroup.procs")).unwrap(),
414 "1234"
415 );
416 }
417
418 #[test]
419 fn test_cgroup_resource_readers() {
420 let (_tmp, path) = prepare_cgroup_dir();
421 fs::write(path.join("memory.current"), "4096").unwrap();
422 fs::write(path.join("cpu.stat"), "usage_usec 900\n").unwrap();
423 let cgroup = Cgroup::for_testing(path.clone());
424
425 assert_eq!(cgroup.get_memory_usage().unwrap(), 4096);
426 assert_eq!(cgroup.get_cpu_usage().unwrap(), 900);
427 }
428
429 #[test]
430 fn test_cgroup_delete_removes_directory() {
431 let (tmp, path) = prepare_cgroup_dir();
432 let cgroup = Cgroup::for_testing(path.clone());
433 assert!(path.exists());
434 for entry in fs::read_dir(&path).unwrap() {
435 let entry = entry.unwrap();
436 if entry.path().is_file() {
437 fs::remove_file(entry.path()).unwrap();
438 }
439 }
440 cgroup.delete().unwrap();
441 assert!(!path.exists());
442 drop(tmp);
443 }
444
445 #[test]
446 fn test_cgroup_new_uses_env_override() {
447 let tmp = tempdir().unwrap();
448 let prev = env::var("SANDBOX_CGROUP_ROOT").ok();
449 unsafe {
450 env::set_var("SANDBOX_CGROUP_ROOT", tmp.path());
451 }
452
453 let cg = Cgroup::new("env-test", Pid::from_raw(0)).unwrap();
454 assert!(cg.exists());
455 assert!(tmp.path().join("env-test").exists());
456
457 if let Some(value) = prev {
458 unsafe {
459 env::set_var("SANDBOX_CGROUP_ROOT", value);
460 }
461 } else {
462 unsafe {
463 env::remove_var("SANDBOX_CGROUP_ROOT");
464 }
465 }
466 }
467}