1use std::fmt;
2use std::path::{Component, PathBuf};
3
4use crate::error::SemaError;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct Caps(u64);
8
9impl Caps {
10 pub const NONE: Caps = Caps(0);
11 pub const FS_READ: Caps = Caps(1 << 0);
12 pub const FS_WRITE: Caps = Caps(1 << 1);
13 pub const SHELL: Caps = Caps(1 << 2);
14 pub const NETWORK: Caps = Caps(1 << 3);
15 pub const ENV_READ: Caps = Caps(1 << 4);
16 pub const ENV_WRITE: Caps = Caps(1 << 5);
17 pub const PROCESS: Caps = Caps(1 << 6);
18 pub const LLM: Caps = Caps(1 << 7);
19 pub const SERIAL: Caps = Caps(1 << 8);
20
21 pub const ALL: Caps = Caps(
22 Self::FS_READ.0
23 | Self::FS_WRITE.0
24 | Self::SHELL.0
25 | Self::NETWORK.0
26 | Self::ENV_READ.0
27 | Self::ENV_WRITE.0
28 | Self::PROCESS.0
29 | Self::LLM.0
30 | Self::SERIAL.0,
31 );
32
33 pub const STRICT: Caps = Caps(
34 Self::SHELL.0
35 | Self::FS_WRITE.0
36 | Self::NETWORK.0
37 | Self::ENV_WRITE.0
38 | Self::PROCESS.0
39 | Self::LLM.0
40 | Self::SERIAL.0,
41 );
42
43 pub fn contains(self, other: Caps) -> bool {
44 self.0 & other.0 == other.0
45 }
46
47 pub fn union(self, other: Caps) -> Caps {
48 Caps(self.0 | other.0)
49 }
50
51 pub fn name(self) -> &'static str {
52 match self {
53 Caps::NONE => "none",
54 Caps::FS_READ => "fs-read",
55 Caps::FS_WRITE => "fs-write",
56 Caps::SHELL => "shell",
57 Caps::NETWORK => "network",
58 Caps::ENV_READ => "env-read",
59 Caps::ENV_WRITE => "env-write",
60 Caps::PROCESS => "process",
61 Caps::LLM => "llm",
62 Caps::SERIAL => "serial",
63 Caps::ALL => "all",
64 Caps::STRICT => "strict",
65 _ => "unknown",
66 }
67 }
68
69 pub fn from_name(s: &str) -> Option<Self> {
70 match s {
71 "none" => Some(Caps::NONE),
72 "fs-read" => Some(Caps::FS_READ),
73 "fs-write" => Some(Caps::FS_WRITE),
74 "shell" => Some(Caps::SHELL),
75 "network" => Some(Caps::NETWORK),
76 "env-read" => Some(Caps::ENV_READ),
77 "env-write" => Some(Caps::ENV_WRITE),
78 "process" => Some(Caps::PROCESS),
79 "llm" => Some(Caps::LLM),
80 "serial" => Some(Caps::SERIAL),
81 "all" => Some(Caps::ALL),
82 "strict" => Some(Caps::STRICT),
83 _ => None,
84 }
85 }
86}
87
88impl fmt::Display for Caps {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 f.write_str(self.name())
91 }
92}
93
94#[derive(Debug, Clone)]
95pub struct Sandbox {
96 pub denied: Caps,
97 allowed_paths: Option<Vec<PathBuf>>,
98}
99
100fn normalize_lexical(path: &std::path::Path) -> PathBuf {
101 let mut result = PathBuf::new();
102 for component in path.components() {
103 match component {
104 Component::ParentDir => {
105 result.pop();
106 }
107 Component::CurDir => {}
108 other => result.push(other),
109 }
110 }
111 result
112}
113
114impl Sandbox {
115 pub fn allow_all() -> Self {
116 Sandbox {
117 denied: Caps::NONE,
118 allowed_paths: None,
119 }
120 }
121
122 pub fn deny(caps: Caps) -> Self {
123 Sandbox {
124 denied: caps,
125 allowed_paths: None,
126 }
127 }
128
129 pub fn with_allowed_paths(mut self, paths: Vec<PathBuf>) -> Self {
130 self.allowed_paths = Some(
131 paths
132 .into_iter()
133 .map(|p| std::fs::canonicalize(&p).unwrap_or(p))
134 .collect(),
135 );
136 self
137 }
138
139 pub fn is_unrestricted(&self) -> bool {
140 self.denied == Caps::NONE && self.allowed_paths.is_none()
141 }
142
143 pub fn check(&self, required: Caps, fn_name: &str) -> Result<(), SemaError> {
144 if required == Caps::NONE {
148 return Ok(());
149 }
150 if self.denied.contains(required) {
151 Err(SemaError::PermissionDenied {
152 function: fn_name.to_string(),
153 capability: required.name().to_string(),
154 })
155 } else {
156 Ok(())
157 }
158 }
159
160 pub fn check_path(&self, path: &str, fn_name: &str) -> Result<(), SemaError> {
167 let allowed = match &self.allowed_paths {
168 Some(paths) => paths,
169 None => return Ok(()),
170 };
171 let p = std::path::Path::new(path);
172 let canonical = std::fs::canonicalize(p).unwrap_or_else(|_| {
173 if let Some(parent) = p.parent() {
174 if let Ok(canon_parent) = std::fs::canonicalize(parent) {
175 return canon_parent.join(p.file_name().unwrap_or_default());
176 }
177 }
178 let abs = if p.is_absolute() {
179 p.to_path_buf()
180 } else {
181 std::env::current_dir()
182 .unwrap_or_else(|_| PathBuf::from("."))
183 .join(p)
184 };
185 normalize_lexical(&abs)
186 });
187 for allowed_path in allowed {
188 if canonical.starts_with(allowed_path) {
189 return Ok(());
190 }
191 }
192 Err(SemaError::PathDenied {
193 function: fn_name.to_string(),
194 path: canonical.display().to_string(),
195 })
196 }
197
198 pub fn parse_allowed_paths(value: &str) -> Vec<PathBuf> {
199 value
200 .split(',')
201 .map(|s| s.trim())
202 .filter(|s| !s.is_empty())
203 .map(|s| {
204 let p = PathBuf::from(s);
205 std::fs::canonicalize(&p).unwrap_or(p)
206 })
207 .collect()
208 }
209
210 pub fn parse_cli(value: &str) -> Result<Self, String> {
211 match value {
212 "strict" => Ok(Sandbox::deny(Caps::STRICT)),
213 "all" => Ok(Sandbox::deny(Caps::ALL)),
214 other => {
215 let mut denied = Caps::NONE;
216 for part in other.split(',') {
217 let part = part.trim();
218 if part.is_empty() {
219 continue;
220 }
221 let name = part.strip_prefix("no-").unwrap_or(part);
222 match Caps::from_name(name) {
223 Some(cap) => denied = denied.union(cap),
224 None => return Err(format!("unknown capability: {name}")),
225 }
226 }
227 Ok(Sandbox::deny(denied))
228 }
229 }
230 }
231}
232
233impl Default for Sandbox {
234 fn default() -> Self {
235 Sandbox::allow_all()
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_caps_contains() {
245 let all = Caps::FS_READ.union(Caps::FS_WRITE).union(Caps::SHELL);
246 assert!(all.contains(Caps::FS_READ));
247 assert!(all.contains(Caps::FS_WRITE));
248 assert!(all.contains(Caps::SHELL));
249 assert!(!all.contains(Caps::NETWORK));
250 assert!(!all.contains(Caps::LLM));
251 }
252
253 #[test]
254 fn test_caps_contains_none_is_always_true() {
255 assert!(Caps::NONE.contains(Caps::NONE));
258 assert!(Caps::ALL.contains(Caps::NONE));
259 }
260
261 #[test]
262 fn test_caps_union() {
263 let combined = Caps::SHELL.union(Caps::NETWORK);
264 assert!(combined.contains(Caps::SHELL));
265 assert!(combined.contains(Caps::NETWORK));
266 assert!(!combined.contains(Caps::FS_READ));
267 }
268
269 #[test]
270 fn test_caps_all_contains_every_cap() {
271 assert!(Caps::ALL.contains(Caps::FS_READ));
272 assert!(Caps::ALL.contains(Caps::FS_WRITE));
273 assert!(Caps::ALL.contains(Caps::SHELL));
274 assert!(Caps::ALL.contains(Caps::NETWORK));
275 assert!(Caps::ALL.contains(Caps::ENV_READ));
276 assert!(Caps::ALL.contains(Caps::ENV_WRITE));
277 assert!(Caps::ALL.contains(Caps::PROCESS));
278 assert!(Caps::ALL.contains(Caps::LLM));
279 assert!(Caps::ALL.contains(Caps::SERIAL));
280 }
281
282 #[test]
283 fn test_caps_strict_preset() {
284 assert!(Caps::STRICT.contains(Caps::SHELL));
285 assert!(Caps::STRICT.contains(Caps::FS_WRITE));
286 assert!(Caps::STRICT.contains(Caps::NETWORK));
287 assert!(Caps::STRICT.contains(Caps::ENV_WRITE));
288 assert!(Caps::STRICT.contains(Caps::PROCESS));
289 assert!(Caps::STRICT.contains(Caps::LLM));
290 assert!(Caps::STRICT.contains(Caps::SERIAL));
291 assert!(!Caps::STRICT.contains(Caps::FS_READ));
293 assert!(!Caps::STRICT.contains(Caps::ENV_READ));
294 }
295
296 #[test]
297 fn test_caps_name_roundtrip() {
298 let caps = [
299 Caps::FS_READ,
300 Caps::FS_WRITE,
301 Caps::SHELL,
302 Caps::NETWORK,
303 Caps::ENV_READ,
304 Caps::ENV_WRITE,
305 Caps::PROCESS,
306 Caps::LLM,
307 Caps::SERIAL,
308 ];
309 for cap in caps {
310 let name = cap.name();
311 assert_eq!(
312 Caps::from_name(name),
313 Some(cap),
314 "roundtrip failed for {name}"
315 );
316 }
317 }
318
319 #[test]
320 fn test_caps_from_name_unknown() {
321 assert_eq!(Caps::from_name("garbage"), None);
322 assert_eq!(Caps::from_name(""), None);
323 }
324
325 #[test]
326 fn test_caps_display() {
327 assert_eq!(format!("{}", Caps::SHELL), "shell");
328 assert_eq!(format!("{}", Caps::NETWORK), "network");
329 assert_eq!(format!("{}", Caps::FS_READ), "fs-read");
330 assert_eq!(format!("{}", Caps::SERIAL), "serial");
331 }
332
333 #[test]
334 fn test_sandbox_allow_all_is_unrestricted() {
335 let sb = Sandbox::allow_all();
336 assert!(sb.is_unrestricted());
337 }
338
339 #[test]
340 fn test_sandbox_deny_is_restricted() {
341 let sb = Sandbox::deny(Caps::SHELL);
342 assert!(!sb.is_unrestricted());
343 }
344
345 #[test]
346 fn test_sandbox_default_is_unrestricted() {
347 let sb = Sandbox::default();
348 assert!(sb.is_unrestricted());
349 }
350
351 #[test]
352 fn test_sandbox_check_allowed() {
353 let sb = Sandbox::deny(Caps::SHELL);
354 assert!(sb.check(Caps::NETWORK, "http/get").is_ok());
355 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
356 }
357
358 #[test]
359 fn test_sandbox_check_denied() {
360 let sb = Sandbox::deny(Caps::SHELL);
361 let err = sb.check(Caps::SHELL, "shell").unwrap_err();
362 assert!(err.to_string().contains("Permission denied"));
363 assert!(err.to_string().contains("shell"));
364 }
365
366 #[test]
367 fn test_sandbox_check_denied_error_format() {
368 let sb = Sandbox::deny(Caps::NETWORK);
369 let err = sb.check(Caps::NETWORK, "http/get").unwrap_err();
370 let msg = err.to_string();
371 assert!(
372 msg.contains("http/get"),
373 "should contain function name: {msg}"
374 );
375 assert!(
376 msg.contains("network"),
377 "should contain capability name: {msg}"
378 );
379 }
380
381 #[test]
382 fn test_sandbox_check_multiple_denied() {
383 let sb = Sandbox::deny(Caps::SHELL.union(Caps::NETWORK));
384 assert!(sb.check(Caps::SHELL, "shell").is_err());
385 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
386 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
387 }
388
389 #[test]
390 fn test_sandbox_parse_cli_strict() {
391 let sb = Sandbox::parse_cli("strict").unwrap();
392 assert!(sb.check(Caps::SHELL, "shell").is_err());
393 assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
394 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
395 assert!(sb.check(Caps::SERIAL, "serial/list").is_err());
396 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
398 assert!(sb.check(Caps::ENV_READ, "env").is_ok());
399 }
400
401 #[test]
402 fn test_sandbox_parse_cli_all() {
403 let sb = Sandbox::parse_cli("all").unwrap();
404 assert!(sb.check(Caps::SHELL, "shell").is_err());
405 assert!(sb.check(Caps::FS_READ, "file/read").is_err());
406 assert!(sb.check(Caps::ENV_READ, "env").is_err());
407 assert!(sb.check(Caps::SERIAL, "serial/list").is_err());
408 }
409
410 #[test]
411 fn test_sandbox_parse_cli_no_prefix() {
412 let sb = Sandbox::parse_cli("no-shell,no-network").unwrap();
413 assert!(sb.check(Caps::SHELL, "shell").is_err());
414 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
415 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
416 }
417
418 #[test]
419 fn test_sandbox_parse_cli_without_no_prefix() {
420 let sb = Sandbox::parse_cli("shell,network").unwrap();
421 assert!(sb.check(Caps::SHELL, "shell").is_err());
422 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
423 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
424 }
425
426 #[test]
427 fn test_sandbox_parse_cli_single() {
428 let sb = Sandbox::parse_cli("no-fs-write").unwrap();
429 assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
430 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
431 }
432
433 #[test]
434 fn test_sandbox_parse_cli_with_spaces() {
435 let sb = Sandbox::parse_cli("no-shell, no-network").unwrap();
436 assert!(sb.check(Caps::SHELL, "shell").is_err());
437 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
438 }
439
440 #[test]
441 fn test_sandbox_parse_cli_empty_parts() {
442 let sb = Sandbox::parse_cli("no-shell,,no-network").unwrap();
443 assert!(sb.check(Caps::SHELL, "shell").is_err());
444 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
445 }
446
447 #[test]
448 fn test_sandbox_parse_cli_invalid() {
449 assert!(Sandbox::parse_cli("no-bogus").is_err());
450 assert!(Sandbox::parse_cli("no-shell,no-bogus").is_err());
451 }
452
453 #[test]
454 fn test_check_path_none_allows_everything() {
455 let sb = Sandbox::allow_all();
456 assert!(sb.check_path("/etc/passwd", "file/read").is_ok());
457 assert!(sb.check_path("relative.txt", "file/read").is_ok());
458 }
459
460 #[test]
461 fn test_check_path_inside_allowed_dir() {
462 let tmp = std::env::temp_dir();
463 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
464 let test_path = tmp.join("sema-test-file.txt");
465 std::fs::write(&test_path, "test").ok();
466 assert!(sb
467 .check_path(test_path.to_str().unwrap(), "file/read")
468 .is_ok());
469 let _ = std::fs::remove_file(&test_path);
470 }
471
472 #[test]
473 fn test_check_path_outside_allowed_dir() {
474 let tmp = std::env::temp_dir().join("sema-sandbox-test-dir");
475 std::fs::create_dir_all(&tmp).ok();
476 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
477 let result = sb.check_path("/etc/hosts", "file/read");
478 assert!(result.is_err());
479 let err = result.unwrap_err();
480 assert!(err.to_string().contains("Permission denied"), "{err}");
481 assert!(
482 err.to_string().contains("outside allowed directories"),
483 "{err}"
484 );
485 let _ = std::fs::remove_dir_all(&tmp);
486 }
487
488 #[test]
489 fn test_check_path_traversal_attempt() {
490 let tmp = std::env::temp_dir().join("sema-sandbox-traverse");
491 std::fs::create_dir_all(&tmp).ok();
492 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
493 let evil = format!("{}/../../../etc/passwd", tmp.display());
494 let result = sb.check_path(&evil, "file/read");
495 assert!(result.is_err(), "path traversal should be denied");
496 let _ = std::fs::remove_dir_all(&tmp);
497 }
498
499 #[test]
500 fn test_check_path_multiple_allowed() {
501 let dir_a = std::env::temp_dir().join("sema-sandbox-a");
502 let dir_b = std::env::temp_dir().join("sema-sandbox-b");
503 std::fs::create_dir_all(&dir_a).ok();
504 std::fs::create_dir_all(&dir_b).ok();
505 let sb = Sandbox::allow_all().with_allowed_paths(vec![dir_a.clone(), dir_b.clone()]);
506 let file_a = dir_a.join("ok.txt");
507 std::fs::write(&file_a, "a").ok();
508 let file_b = dir_b.join("ok.txt");
509 std::fs::write(&file_b, "b").ok();
510 assert!(sb.check_path(file_a.to_str().unwrap(), "file/read").is_ok());
511 assert!(sb.check_path(file_b.to_str().unwrap(), "file/read").is_ok());
512 assert!(sb.check_path("/etc/hosts", "file/read").is_err());
513 let _ = std::fs::remove_dir_all(&dir_a);
514 let _ = std::fs::remove_dir_all(&dir_b);
515 }
516
517 #[test]
518 fn test_parse_allowed_paths() {
519 let paths = Sandbox::parse_allowed_paths("/tmp, /var");
520 assert_eq!(paths.len(), 2);
521 }
522
523 #[test]
524 fn test_parse_allowed_paths_empty_parts() {
525 let paths = Sandbox::parse_allowed_paths("/tmp,,/var,");
526 assert_eq!(paths.len(), 2);
527 }
528
529 #[test]
530 fn test_with_allowed_paths_makes_restricted() {
531 let sb = Sandbox::allow_all().with_allowed_paths(vec![std::path::PathBuf::from("/tmp")]);
532 assert!(!sb.is_unrestricted());
533 }
534
535 #[test]
536 fn test_check_path_nonexistent_component_escape() {
537 let tmp = std::env::temp_dir().join("sema-sandbox-escape");
538 std::fs::create_dir_all(&tmp).ok();
539 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
540 let evil = format!("{}/nonexistent/../../etc/passwd", tmp.display());
541 let result = sb.check_path(&evil, "file/write");
542 assert!(
543 result.is_err(),
544 "nonexistent component escape should be denied"
545 );
546 let _ = std::fs::remove_dir_all(&tmp);
547 }
548
549 #[test]
550 fn test_check_path_relative_nonexistent_escape() {
551 let tmp = std::env::temp_dir().join("sema-sandbox-rel-escape");
552 let allowed = tmp.join("allowed");
553 std::fs::create_dir_all(&allowed).ok();
554 let sb = Sandbox::allow_all().with_allowed_paths(vec![allowed.clone()]);
555 let evil = format!("{}/fake/../../../etc/hosts", allowed.display());
556 let result = sb.check_path(&evil, "file/write");
557 assert!(
558 result.is_err(),
559 "relative nonexistent escape should be denied"
560 );
561 let _ = std::fs::remove_dir_all(&tmp);
562 }
563}