1use std::fmt;
2use std::path::{Component, PathBuf};
3
4use crate::error::SemaError;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct Caps(u64);
8
9impl Caps {
10 pub const NONE: Caps = Caps(0);
11 pub const FS_READ: Caps = Caps(1 << 0);
12 pub const FS_WRITE: Caps = Caps(1 << 1);
13 pub const SHELL: Caps = Caps(1 << 2);
14 pub const NETWORK: Caps = Caps(1 << 3);
15 pub const ENV_READ: Caps = Caps(1 << 4);
16 pub const ENV_WRITE: Caps = Caps(1 << 5);
17 pub const PROCESS: Caps = Caps(1 << 6);
18 pub const LLM: Caps = Caps(1 << 7);
19 pub const SERIAL: Caps = Caps(1 << 8);
20
21 pub const ALL: Caps = Caps(
22 Self::FS_READ.0
23 | Self::FS_WRITE.0
24 | Self::SHELL.0
25 | Self::NETWORK.0
26 | Self::ENV_READ.0
27 | Self::ENV_WRITE.0
28 | Self::PROCESS.0
29 | Self::LLM.0
30 | Self::SERIAL.0,
31 );
32
33 pub const STRICT: Caps = Caps(
34 Self::SHELL.0
35 | Self::FS_WRITE.0
36 | Self::NETWORK.0
37 | Self::ENV_WRITE.0
38 | Self::PROCESS.0
39 | Self::LLM.0
40 | Self::SERIAL.0,
41 );
42
43 pub fn contains(self, other: Caps) -> bool {
44 self.0 & other.0 == other.0
45 }
46
47 pub fn union(self, other: Caps) -> Caps {
48 Caps(self.0 | other.0)
49 }
50
51 pub fn name(self) -> &'static str {
52 match self {
53 Caps::NONE => "none",
54 Caps::FS_READ => "fs-read",
55 Caps::FS_WRITE => "fs-write",
56 Caps::SHELL => "shell",
57 Caps::NETWORK => "network",
58 Caps::ENV_READ => "env-read",
59 Caps::ENV_WRITE => "env-write",
60 Caps::PROCESS => "process",
61 Caps::LLM => "llm",
62 Caps::SERIAL => "serial",
63 Caps::ALL => "all",
64 Caps::STRICT => "strict",
65 _ => "unknown",
66 }
67 }
68
69 pub fn from_name(s: &str) -> Option<Self> {
70 match s {
71 "none" => Some(Caps::NONE),
72 "fs-read" => Some(Caps::FS_READ),
73 "fs-write" => Some(Caps::FS_WRITE),
74 "shell" => Some(Caps::SHELL),
75 "network" => Some(Caps::NETWORK),
76 "env-read" => Some(Caps::ENV_READ),
77 "env-write" => Some(Caps::ENV_WRITE),
78 "process" => Some(Caps::PROCESS),
79 "llm" => Some(Caps::LLM),
80 "serial" => Some(Caps::SERIAL),
81 "all" => Some(Caps::ALL),
82 "strict" => Some(Caps::STRICT),
83 _ => None,
84 }
85 }
86}
87
88impl fmt::Display for Caps {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 f.write_str(self.name())
91 }
92}
93
94#[derive(Debug, Clone)]
95pub struct Sandbox {
96 pub denied: Caps,
97 allowed_paths: Option<Vec<PathBuf>>,
98}
99
100fn normalize_lexical(path: &std::path::Path) -> PathBuf {
101 let mut result = PathBuf::new();
102 for component in path.components() {
103 match component {
104 Component::ParentDir => {
105 result.pop();
106 }
107 Component::CurDir => {}
108 other => result.push(other),
109 }
110 }
111 result
112}
113
114impl Sandbox {
115 pub fn allow_all() -> Self {
116 Sandbox {
117 denied: Caps::NONE,
118 allowed_paths: None,
119 }
120 }
121
122 pub fn deny(caps: Caps) -> Self {
123 Sandbox {
124 denied: caps,
125 allowed_paths: None,
126 }
127 }
128
129 pub fn with_allowed_paths(mut self, paths: Vec<PathBuf>) -> Self {
130 self.allowed_paths = Some(
131 paths
132 .into_iter()
133 .map(|p| std::fs::canonicalize(&p).unwrap_or(p))
134 .collect(),
135 );
136 self
137 }
138
139 pub fn is_unrestricted(&self) -> bool {
140 self.denied == Caps::NONE && self.allowed_paths.is_none()
141 }
142
143 pub fn check(&self, required: Caps, fn_name: &str) -> Result<(), SemaError> {
144 if self.denied.contains(required) {
145 Err(SemaError::PermissionDenied {
146 function: fn_name.to_string(),
147 capability: required.name().to_string(),
148 })
149 } else {
150 Ok(())
151 }
152 }
153
154 pub fn check_path(&self, path: &str, fn_name: &str) -> Result<(), SemaError> {
161 let allowed = match &self.allowed_paths {
162 Some(paths) => paths,
163 None => return Ok(()),
164 };
165 let p = std::path::Path::new(path);
166 let canonical = std::fs::canonicalize(p).unwrap_or_else(|_| {
167 if let Some(parent) = p.parent() {
168 if let Ok(canon_parent) = std::fs::canonicalize(parent) {
169 return canon_parent.join(p.file_name().unwrap_or_default());
170 }
171 }
172 let abs = if p.is_absolute() {
173 p.to_path_buf()
174 } else {
175 std::env::current_dir()
176 .unwrap_or_else(|_| PathBuf::from("."))
177 .join(p)
178 };
179 normalize_lexical(&abs)
180 });
181 for allowed_path in allowed {
182 if canonical.starts_with(allowed_path) {
183 return Ok(());
184 }
185 }
186 Err(SemaError::PathDenied {
187 function: fn_name.to_string(),
188 path: canonical.display().to_string(),
189 })
190 }
191
192 pub fn parse_allowed_paths(value: &str) -> Vec<PathBuf> {
193 value
194 .split(',')
195 .map(|s| s.trim())
196 .filter(|s| !s.is_empty())
197 .map(|s| {
198 let p = PathBuf::from(s);
199 std::fs::canonicalize(&p).unwrap_or(p)
200 })
201 .collect()
202 }
203
204 pub fn parse_cli(value: &str) -> Result<Self, String> {
205 match value {
206 "strict" => Ok(Sandbox::deny(Caps::STRICT)),
207 "all" => Ok(Sandbox::deny(Caps::ALL)),
208 other => {
209 let mut denied = Caps::NONE;
210 for part in other.split(',') {
211 let part = part.trim();
212 if part.is_empty() {
213 continue;
214 }
215 let name = part.strip_prefix("no-").unwrap_or(part);
216 match Caps::from_name(name) {
217 Some(cap) => denied = denied.union(cap),
218 None => return Err(format!("unknown capability: {name}")),
219 }
220 }
221 Ok(Sandbox::deny(denied))
222 }
223 }
224 }
225}
226
227impl Default for Sandbox {
228 fn default() -> Self {
229 Sandbox::allow_all()
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_caps_contains() {
239 let all = Caps::FS_READ.union(Caps::FS_WRITE).union(Caps::SHELL);
240 assert!(all.contains(Caps::FS_READ));
241 assert!(all.contains(Caps::FS_WRITE));
242 assert!(all.contains(Caps::SHELL));
243 assert!(!all.contains(Caps::NETWORK));
244 assert!(!all.contains(Caps::LLM));
245 }
246
247 #[test]
248 fn test_caps_contains_none_is_always_true() {
249 assert!(Caps::NONE.contains(Caps::NONE));
252 assert!(Caps::ALL.contains(Caps::NONE));
253 }
254
255 #[test]
256 fn test_caps_union() {
257 let combined = Caps::SHELL.union(Caps::NETWORK);
258 assert!(combined.contains(Caps::SHELL));
259 assert!(combined.contains(Caps::NETWORK));
260 assert!(!combined.contains(Caps::FS_READ));
261 }
262
263 #[test]
264 fn test_caps_all_contains_every_cap() {
265 assert!(Caps::ALL.contains(Caps::FS_READ));
266 assert!(Caps::ALL.contains(Caps::FS_WRITE));
267 assert!(Caps::ALL.contains(Caps::SHELL));
268 assert!(Caps::ALL.contains(Caps::NETWORK));
269 assert!(Caps::ALL.contains(Caps::ENV_READ));
270 assert!(Caps::ALL.contains(Caps::ENV_WRITE));
271 assert!(Caps::ALL.contains(Caps::PROCESS));
272 assert!(Caps::ALL.contains(Caps::LLM));
273 assert!(Caps::ALL.contains(Caps::SERIAL));
274 }
275
276 #[test]
277 fn test_caps_strict_preset() {
278 assert!(Caps::STRICT.contains(Caps::SHELL));
279 assert!(Caps::STRICT.contains(Caps::FS_WRITE));
280 assert!(Caps::STRICT.contains(Caps::NETWORK));
281 assert!(Caps::STRICT.contains(Caps::ENV_WRITE));
282 assert!(Caps::STRICT.contains(Caps::PROCESS));
283 assert!(Caps::STRICT.contains(Caps::LLM));
284 assert!(Caps::STRICT.contains(Caps::SERIAL));
285 assert!(!Caps::STRICT.contains(Caps::FS_READ));
287 assert!(!Caps::STRICT.contains(Caps::ENV_READ));
288 }
289
290 #[test]
291 fn test_caps_name_roundtrip() {
292 let caps = [
293 Caps::FS_READ,
294 Caps::FS_WRITE,
295 Caps::SHELL,
296 Caps::NETWORK,
297 Caps::ENV_READ,
298 Caps::ENV_WRITE,
299 Caps::PROCESS,
300 Caps::LLM,
301 Caps::SERIAL,
302 ];
303 for cap in caps {
304 let name = cap.name();
305 assert_eq!(
306 Caps::from_name(name),
307 Some(cap),
308 "roundtrip failed for {name}"
309 );
310 }
311 }
312
313 #[test]
314 fn test_caps_from_name_unknown() {
315 assert_eq!(Caps::from_name("garbage"), None);
316 assert_eq!(Caps::from_name(""), None);
317 }
318
319 #[test]
320 fn test_caps_display() {
321 assert_eq!(format!("{}", Caps::SHELL), "shell");
322 assert_eq!(format!("{}", Caps::NETWORK), "network");
323 assert_eq!(format!("{}", Caps::FS_READ), "fs-read");
324 assert_eq!(format!("{}", Caps::SERIAL), "serial");
325 }
326
327 #[test]
328 fn test_sandbox_allow_all_is_unrestricted() {
329 let sb = Sandbox::allow_all();
330 assert!(sb.is_unrestricted());
331 }
332
333 #[test]
334 fn test_sandbox_deny_is_restricted() {
335 let sb = Sandbox::deny(Caps::SHELL);
336 assert!(!sb.is_unrestricted());
337 }
338
339 #[test]
340 fn test_sandbox_default_is_unrestricted() {
341 let sb = Sandbox::default();
342 assert!(sb.is_unrestricted());
343 }
344
345 #[test]
346 fn test_sandbox_check_allowed() {
347 let sb = Sandbox::deny(Caps::SHELL);
348 assert!(sb.check(Caps::NETWORK, "http/get").is_ok());
349 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
350 }
351
352 #[test]
353 fn test_sandbox_check_denied() {
354 let sb = Sandbox::deny(Caps::SHELL);
355 let err = sb.check(Caps::SHELL, "shell").unwrap_err();
356 assert!(err.to_string().contains("Permission denied"));
357 assert!(err.to_string().contains("shell"));
358 }
359
360 #[test]
361 fn test_sandbox_check_denied_error_format() {
362 let sb = Sandbox::deny(Caps::NETWORK);
363 let err = sb.check(Caps::NETWORK, "http/get").unwrap_err();
364 let msg = err.to_string();
365 assert!(
366 msg.contains("http/get"),
367 "should contain function name: {msg}"
368 );
369 assert!(
370 msg.contains("network"),
371 "should contain capability name: {msg}"
372 );
373 }
374
375 #[test]
376 fn test_sandbox_check_multiple_denied() {
377 let sb = Sandbox::deny(Caps::SHELL.union(Caps::NETWORK));
378 assert!(sb.check(Caps::SHELL, "shell").is_err());
379 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
380 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
381 }
382
383 #[test]
384 fn test_sandbox_parse_cli_strict() {
385 let sb = Sandbox::parse_cli("strict").unwrap();
386 assert!(sb.check(Caps::SHELL, "shell").is_err());
387 assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
388 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
389 assert!(sb.check(Caps::SERIAL, "serial/list").is_err());
390 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
392 assert!(sb.check(Caps::ENV_READ, "env").is_ok());
393 }
394
395 #[test]
396 fn test_sandbox_parse_cli_all() {
397 let sb = Sandbox::parse_cli("all").unwrap();
398 assert!(sb.check(Caps::SHELL, "shell").is_err());
399 assert!(sb.check(Caps::FS_READ, "file/read").is_err());
400 assert!(sb.check(Caps::ENV_READ, "env").is_err());
401 assert!(sb.check(Caps::SERIAL, "serial/list").is_err());
402 }
403
404 #[test]
405 fn test_sandbox_parse_cli_no_prefix() {
406 let sb = Sandbox::parse_cli("no-shell,no-network").unwrap();
407 assert!(sb.check(Caps::SHELL, "shell").is_err());
408 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
409 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
410 }
411
412 #[test]
413 fn test_sandbox_parse_cli_without_no_prefix() {
414 let sb = Sandbox::parse_cli("shell,network").unwrap();
415 assert!(sb.check(Caps::SHELL, "shell").is_err());
416 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
417 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
418 }
419
420 #[test]
421 fn test_sandbox_parse_cli_single() {
422 let sb = Sandbox::parse_cli("no-fs-write").unwrap();
423 assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
424 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
425 }
426
427 #[test]
428 fn test_sandbox_parse_cli_with_spaces() {
429 let sb = Sandbox::parse_cli("no-shell, no-network").unwrap();
430 assert!(sb.check(Caps::SHELL, "shell").is_err());
431 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
432 }
433
434 #[test]
435 fn test_sandbox_parse_cli_empty_parts() {
436 let sb = Sandbox::parse_cli("no-shell,,no-network").unwrap();
437 assert!(sb.check(Caps::SHELL, "shell").is_err());
438 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
439 }
440
441 #[test]
442 fn test_sandbox_parse_cli_invalid() {
443 assert!(Sandbox::parse_cli("no-bogus").is_err());
444 assert!(Sandbox::parse_cli("no-shell,no-bogus").is_err());
445 }
446
447 #[test]
448 fn test_check_path_none_allows_everything() {
449 let sb = Sandbox::allow_all();
450 assert!(sb.check_path("/etc/passwd", "file/read").is_ok());
451 assert!(sb.check_path("relative.txt", "file/read").is_ok());
452 }
453
454 #[test]
455 fn test_check_path_inside_allowed_dir() {
456 let tmp = std::env::temp_dir();
457 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
458 let test_path = tmp.join("sema-test-file.txt");
459 std::fs::write(&test_path, "test").ok();
460 assert!(sb
461 .check_path(test_path.to_str().unwrap(), "file/read")
462 .is_ok());
463 let _ = std::fs::remove_file(&test_path);
464 }
465
466 #[test]
467 fn test_check_path_outside_allowed_dir() {
468 let tmp = std::env::temp_dir().join("sema-sandbox-test-dir");
469 std::fs::create_dir_all(&tmp).ok();
470 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
471 let result = sb.check_path("/etc/hosts", "file/read");
472 assert!(result.is_err());
473 let err = result.unwrap_err();
474 assert!(err.to_string().contains("Permission denied"), "{err}");
475 assert!(
476 err.to_string().contains("outside allowed directories"),
477 "{err}"
478 );
479 let _ = std::fs::remove_dir_all(&tmp);
480 }
481
482 #[test]
483 fn test_check_path_traversal_attempt() {
484 let tmp = std::env::temp_dir().join("sema-sandbox-traverse");
485 std::fs::create_dir_all(&tmp).ok();
486 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
487 let evil = format!("{}/../../../etc/passwd", tmp.display());
488 let result = sb.check_path(&evil, "file/read");
489 assert!(result.is_err(), "path traversal should be denied");
490 let _ = std::fs::remove_dir_all(&tmp);
491 }
492
493 #[test]
494 fn test_check_path_multiple_allowed() {
495 let dir_a = std::env::temp_dir().join("sema-sandbox-a");
496 let dir_b = std::env::temp_dir().join("sema-sandbox-b");
497 std::fs::create_dir_all(&dir_a).ok();
498 std::fs::create_dir_all(&dir_b).ok();
499 let sb = Sandbox::allow_all().with_allowed_paths(vec![dir_a.clone(), dir_b.clone()]);
500 let file_a = dir_a.join("ok.txt");
501 std::fs::write(&file_a, "a").ok();
502 let file_b = dir_b.join("ok.txt");
503 std::fs::write(&file_b, "b").ok();
504 assert!(sb.check_path(file_a.to_str().unwrap(), "file/read").is_ok());
505 assert!(sb.check_path(file_b.to_str().unwrap(), "file/read").is_ok());
506 assert!(sb.check_path("/etc/hosts", "file/read").is_err());
507 let _ = std::fs::remove_dir_all(&dir_a);
508 let _ = std::fs::remove_dir_all(&dir_b);
509 }
510
511 #[test]
512 fn test_parse_allowed_paths() {
513 let paths = Sandbox::parse_allowed_paths("/tmp, /var");
514 assert_eq!(paths.len(), 2);
515 }
516
517 #[test]
518 fn test_parse_allowed_paths_empty_parts() {
519 let paths = Sandbox::parse_allowed_paths("/tmp,,/var,");
520 assert_eq!(paths.len(), 2);
521 }
522
523 #[test]
524 fn test_with_allowed_paths_makes_restricted() {
525 let sb = Sandbox::allow_all().with_allowed_paths(vec![std::path::PathBuf::from("/tmp")]);
526 assert!(!sb.is_unrestricted());
527 }
528
529 #[test]
530 fn test_check_path_nonexistent_component_escape() {
531 let tmp = std::env::temp_dir().join("sema-sandbox-escape");
532 std::fs::create_dir_all(&tmp).ok();
533 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
534 let evil = format!("{}/nonexistent/../../etc/passwd", tmp.display());
535 let result = sb.check_path(&evil, "file/write");
536 assert!(
537 result.is_err(),
538 "nonexistent component escape should be denied"
539 );
540 let _ = std::fs::remove_dir_all(&tmp);
541 }
542
543 #[test]
544 fn test_check_path_relative_nonexistent_escape() {
545 let tmp = std::env::temp_dir().join("sema-sandbox-rel-escape");
546 let allowed = tmp.join("allowed");
547 std::fs::create_dir_all(&allowed).ok();
548 let sb = Sandbox::allow_all().with_allowed_paths(vec![allowed.clone()]);
549 let evil = format!("{}/fake/../../../etc/hosts", allowed.display());
550 let result = sb.check_path(&evil, "file/write");
551 assert!(
552 result.is_err(),
553 "relative nonexistent escape should be denied"
554 );
555 let _ = std::fs::remove_dir_all(&tmp);
556 }
557}