1use std::fmt;
2use std::path::{Component, PathBuf};
3
4use crate::error::SemaError;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct Caps(u64);
8
9impl Caps {
10 pub const NONE: Caps = Caps(0);
11 pub const FS_READ: Caps = Caps(1 << 0);
12 pub const FS_WRITE: Caps = Caps(1 << 1);
13 pub const SHELL: Caps = Caps(1 << 2);
14 pub const NETWORK: Caps = Caps(1 << 3);
15 pub const ENV_READ: Caps = Caps(1 << 4);
16 pub const ENV_WRITE: Caps = Caps(1 << 5);
17 pub const PROCESS: Caps = Caps(1 << 6);
18 pub const LLM: Caps = Caps(1 << 7);
19
20 pub const ALL: Caps = Caps(
21 Self::FS_READ.0
22 | Self::FS_WRITE.0
23 | Self::SHELL.0
24 | Self::NETWORK.0
25 | Self::ENV_READ.0
26 | Self::ENV_WRITE.0
27 | Self::PROCESS.0
28 | Self::LLM.0,
29 );
30
31 pub const STRICT: Caps = Caps(
32 Self::SHELL.0
33 | Self::FS_WRITE.0
34 | Self::NETWORK.0
35 | Self::ENV_WRITE.0
36 | Self::PROCESS.0
37 | Self::LLM.0,
38 );
39
40 pub fn contains(self, other: Caps) -> bool {
41 self.0 & other.0 == other.0
42 }
43
44 pub fn union(self, other: Caps) -> Caps {
45 Caps(self.0 | other.0)
46 }
47
48 pub fn name(self) -> &'static str {
49 match self {
50 Caps::NONE => "none",
51 Caps::FS_READ => "fs-read",
52 Caps::FS_WRITE => "fs-write",
53 Caps::SHELL => "shell",
54 Caps::NETWORK => "network",
55 Caps::ENV_READ => "env-read",
56 Caps::ENV_WRITE => "env-write",
57 Caps::PROCESS => "process",
58 Caps::LLM => "llm",
59 Caps::ALL => "all",
60 Caps::STRICT => "strict",
61 _ => "unknown",
62 }
63 }
64
65 pub fn from_name(s: &str) -> Option<Self> {
66 match s {
67 "none" => Some(Caps::NONE),
68 "fs-read" => Some(Caps::FS_READ),
69 "fs-write" => Some(Caps::FS_WRITE),
70 "shell" => Some(Caps::SHELL),
71 "network" => Some(Caps::NETWORK),
72 "env-read" => Some(Caps::ENV_READ),
73 "env-write" => Some(Caps::ENV_WRITE),
74 "process" => Some(Caps::PROCESS),
75 "llm" => Some(Caps::LLM),
76 "all" => Some(Caps::ALL),
77 "strict" => Some(Caps::STRICT),
78 _ => None,
79 }
80 }
81}
82
83impl fmt::Display for Caps {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 f.write_str(self.name())
86 }
87}
88
89#[derive(Debug, Clone)]
90pub struct Sandbox {
91 pub denied: Caps,
92 allowed_paths: Option<Vec<PathBuf>>,
93}
94
95fn normalize_lexical(path: &std::path::Path) -> PathBuf {
96 let mut result = PathBuf::new();
97 for component in path.components() {
98 match component {
99 Component::ParentDir => {
100 result.pop();
101 }
102 Component::CurDir => {}
103 other => result.push(other),
104 }
105 }
106 result
107}
108
109impl Sandbox {
110 pub fn allow_all() -> Self {
111 Sandbox {
112 denied: Caps::NONE,
113 allowed_paths: None,
114 }
115 }
116
117 pub fn deny(caps: Caps) -> Self {
118 Sandbox {
119 denied: caps,
120 allowed_paths: None,
121 }
122 }
123
124 pub fn with_allowed_paths(mut self, paths: Vec<PathBuf>) -> Self {
125 self.allowed_paths = Some(
126 paths
127 .into_iter()
128 .map(|p| std::fs::canonicalize(&p).unwrap_or(p))
129 .collect(),
130 );
131 self
132 }
133
134 pub fn is_unrestricted(&self) -> bool {
135 self.denied == Caps::NONE && self.allowed_paths.is_none()
136 }
137
138 pub fn check(&self, required: Caps, fn_name: &str) -> Result<(), SemaError> {
139 if self.denied.contains(required) {
140 Err(SemaError::PermissionDenied {
141 function: fn_name.to_string(),
142 capability: required.name().to_string(),
143 })
144 } else {
145 Ok(())
146 }
147 }
148
149 pub fn check_path(&self, path: &str, fn_name: &str) -> Result<(), SemaError> {
156 let allowed = match &self.allowed_paths {
157 Some(paths) => paths,
158 None => return Ok(()),
159 };
160 let p = std::path::Path::new(path);
161 let canonical = std::fs::canonicalize(p).unwrap_or_else(|_| {
162 if let Some(parent) = p.parent() {
163 if let Ok(canon_parent) = std::fs::canonicalize(parent) {
164 return canon_parent.join(p.file_name().unwrap_or_default());
165 }
166 }
167 let abs = if p.is_absolute() {
168 p.to_path_buf()
169 } else {
170 std::env::current_dir()
171 .unwrap_or_else(|_| PathBuf::from("."))
172 .join(p)
173 };
174 normalize_lexical(&abs)
175 });
176 for allowed_path in allowed {
177 if canonical.starts_with(allowed_path) {
178 return Ok(());
179 }
180 }
181 Err(SemaError::PathDenied {
182 function: fn_name.to_string(),
183 path: canonical.display().to_string(),
184 })
185 }
186
187 pub fn parse_allowed_paths(value: &str) -> Vec<PathBuf> {
188 value
189 .split(',')
190 .map(|s| s.trim())
191 .filter(|s| !s.is_empty())
192 .map(|s| {
193 let p = PathBuf::from(s);
194 std::fs::canonicalize(&p).unwrap_or(p)
195 })
196 .collect()
197 }
198
199 pub fn parse_cli(value: &str) -> Result<Self, String> {
200 match value {
201 "strict" => Ok(Sandbox::deny(Caps::STRICT)),
202 "all" => Ok(Sandbox::deny(Caps::ALL)),
203 other => {
204 let mut denied = Caps::NONE;
205 for part in other.split(',') {
206 let part = part.trim();
207 if part.is_empty() {
208 continue;
209 }
210 let name = part.strip_prefix("no-").unwrap_or(part);
211 match Caps::from_name(name) {
212 Some(cap) => denied = denied.union(cap),
213 None => return Err(format!("unknown capability: {name}")),
214 }
215 }
216 Ok(Sandbox::deny(denied))
217 }
218 }
219 }
220}
221
222impl Default for Sandbox {
223 fn default() -> Self {
224 Sandbox::allow_all()
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_caps_contains() {
234 let all = Caps::FS_READ.union(Caps::FS_WRITE).union(Caps::SHELL);
235 assert!(all.contains(Caps::FS_READ));
236 assert!(all.contains(Caps::FS_WRITE));
237 assert!(all.contains(Caps::SHELL));
238 assert!(!all.contains(Caps::NETWORK));
239 assert!(!all.contains(Caps::LLM));
240 }
241
242 #[test]
243 fn test_caps_contains_none_is_always_true() {
244 assert!(Caps::NONE.contains(Caps::NONE));
247 assert!(Caps::ALL.contains(Caps::NONE));
248 }
249
250 #[test]
251 fn test_caps_union() {
252 let combined = Caps::SHELL.union(Caps::NETWORK);
253 assert!(combined.contains(Caps::SHELL));
254 assert!(combined.contains(Caps::NETWORK));
255 assert!(!combined.contains(Caps::FS_READ));
256 }
257
258 #[test]
259 fn test_caps_all_contains_every_cap() {
260 assert!(Caps::ALL.contains(Caps::FS_READ));
261 assert!(Caps::ALL.contains(Caps::FS_WRITE));
262 assert!(Caps::ALL.contains(Caps::SHELL));
263 assert!(Caps::ALL.contains(Caps::NETWORK));
264 assert!(Caps::ALL.contains(Caps::ENV_READ));
265 assert!(Caps::ALL.contains(Caps::ENV_WRITE));
266 assert!(Caps::ALL.contains(Caps::PROCESS));
267 assert!(Caps::ALL.contains(Caps::LLM));
268 }
269
270 #[test]
271 fn test_caps_strict_preset() {
272 assert!(Caps::STRICT.contains(Caps::SHELL));
273 assert!(Caps::STRICT.contains(Caps::FS_WRITE));
274 assert!(Caps::STRICT.contains(Caps::NETWORK));
275 assert!(Caps::STRICT.contains(Caps::ENV_WRITE));
276 assert!(Caps::STRICT.contains(Caps::PROCESS));
277 assert!(Caps::STRICT.contains(Caps::LLM));
278 assert!(!Caps::STRICT.contains(Caps::FS_READ));
280 assert!(!Caps::STRICT.contains(Caps::ENV_READ));
281 }
282
283 #[test]
284 fn test_caps_name_roundtrip() {
285 let caps = [
286 Caps::FS_READ,
287 Caps::FS_WRITE,
288 Caps::SHELL,
289 Caps::NETWORK,
290 Caps::ENV_READ,
291 Caps::ENV_WRITE,
292 Caps::PROCESS,
293 Caps::LLM,
294 ];
295 for cap in caps {
296 let name = cap.name();
297 assert_eq!(
298 Caps::from_name(name),
299 Some(cap),
300 "roundtrip failed for {name}"
301 );
302 }
303 }
304
305 #[test]
306 fn test_caps_from_name_unknown() {
307 assert_eq!(Caps::from_name("garbage"), None);
308 assert_eq!(Caps::from_name(""), None);
309 }
310
311 #[test]
312 fn test_caps_display() {
313 assert_eq!(format!("{}", Caps::SHELL), "shell");
314 assert_eq!(format!("{}", Caps::NETWORK), "network");
315 assert_eq!(format!("{}", Caps::FS_READ), "fs-read");
316 }
317
318 #[test]
319 fn test_sandbox_allow_all_is_unrestricted() {
320 let sb = Sandbox::allow_all();
321 assert!(sb.is_unrestricted());
322 }
323
324 #[test]
325 fn test_sandbox_deny_is_restricted() {
326 let sb = Sandbox::deny(Caps::SHELL);
327 assert!(!sb.is_unrestricted());
328 }
329
330 #[test]
331 fn test_sandbox_default_is_unrestricted() {
332 let sb = Sandbox::default();
333 assert!(sb.is_unrestricted());
334 }
335
336 #[test]
337 fn test_sandbox_check_allowed() {
338 let sb = Sandbox::deny(Caps::SHELL);
339 assert!(sb.check(Caps::NETWORK, "http/get").is_ok());
340 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
341 }
342
343 #[test]
344 fn test_sandbox_check_denied() {
345 let sb = Sandbox::deny(Caps::SHELL);
346 let err = sb.check(Caps::SHELL, "shell").unwrap_err();
347 assert!(err.to_string().contains("Permission denied"));
348 assert!(err.to_string().contains("shell"));
349 }
350
351 #[test]
352 fn test_sandbox_check_denied_error_format() {
353 let sb = Sandbox::deny(Caps::NETWORK);
354 let err = sb.check(Caps::NETWORK, "http/get").unwrap_err();
355 let msg = err.to_string();
356 assert!(
357 msg.contains("http/get"),
358 "should contain function name: {msg}"
359 );
360 assert!(
361 msg.contains("network"),
362 "should contain capability name: {msg}"
363 );
364 }
365
366 #[test]
367 fn test_sandbox_check_multiple_denied() {
368 let sb = Sandbox::deny(Caps::SHELL.union(Caps::NETWORK));
369 assert!(sb.check(Caps::SHELL, "shell").is_err());
370 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
371 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
372 }
373
374 #[test]
375 fn test_sandbox_parse_cli_strict() {
376 let sb = Sandbox::parse_cli("strict").unwrap();
377 assert!(sb.check(Caps::SHELL, "shell").is_err());
378 assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
379 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
380 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
382 assert!(sb.check(Caps::ENV_READ, "env").is_ok());
383 }
384
385 #[test]
386 fn test_sandbox_parse_cli_all() {
387 let sb = Sandbox::parse_cli("all").unwrap();
388 assert!(sb.check(Caps::SHELL, "shell").is_err());
389 assert!(sb.check(Caps::FS_READ, "file/read").is_err());
390 assert!(sb.check(Caps::ENV_READ, "env").is_err());
391 }
392
393 #[test]
394 fn test_sandbox_parse_cli_no_prefix() {
395 let sb = Sandbox::parse_cli("no-shell,no-network").unwrap();
396 assert!(sb.check(Caps::SHELL, "shell").is_err());
397 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
398 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
399 }
400
401 #[test]
402 fn test_sandbox_parse_cli_without_no_prefix() {
403 let sb = Sandbox::parse_cli("shell,network").unwrap();
404 assert!(sb.check(Caps::SHELL, "shell").is_err());
405 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
406 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
407 }
408
409 #[test]
410 fn test_sandbox_parse_cli_single() {
411 let sb = Sandbox::parse_cli("no-fs-write").unwrap();
412 assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
413 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
414 }
415
416 #[test]
417 fn test_sandbox_parse_cli_with_spaces() {
418 let sb = Sandbox::parse_cli("no-shell, no-network").unwrap();
419 assert!(sb.check(Caps::SHELL, "shell").is_err());
420 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
421 }
422
423 #[test]
424 fn test_sandbox_parse_cli_empty_parts() {
425 let sb = Sandbox::parse_cli("no-shell,,no-network").unwrap();
426 assert!(sb.check(Caps::SHELL, "shell").is_err());
427 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
428 }
429
430 #[test]
431 fn test_sandbox_parse_cli_invalid() {
432 assert!(Sandbox::parse_cli("no-bogus").is_err());
433 assert!(Sandbox::parse_cli("no-shell,no-bogus").is_err());
434 }
435
436 #[test]
437 fn test_check_path_none_allows_everything() {
438 let sb = Sandbox::allow_all();
439 assert!(sb.check_path("/etc/passwd", "file/read").is_ok());
440 assert!(sb.check_path("relative.txt", "file/read").is_ok());
441 }
442
443 #[test]
444 fn test_check_path_inside_allowed_dir() {
445 let tmp = std::env::temp_dir();
446 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
447 let test_path = tmp.join("sema-test-file.txt");
448 std::fs::write(&test_path, "test").ok();
449 assert!(sb
450 .check_path(test_path.to_str().unwrap(), "file/read")
451 .is_ok());
452 let _ = std::fs::remove_file(&test_path);
453 }
454
455 #[test]
456 fn test_check_path_outside_allowed_dir() {
457 let tmp = std::env::temp_dir().join("sema-sandbox-test-dir");
458 std::fs::create_dir_all(&tmp).ok();
459 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
460 let result = sb.check_path("/etc/hosts", "file/read");
461 assert!(result.is_err());
462 let err = result.unwrap_err();
463 assert!(err.to_string().contains("Permission denied"), "{err}");
464 assert!(
465 err.to_string().contains("outside allowed directories"),
466 "{err}"
467 );
468 let _ = std::fs::remove_dir_all(&tmp);
469 }
470
471 #[test]
472 fn test_check_path_traversal_attempt() {
473 let tmp = std::env::temp_dir().join("sema-sandbox-traverse");
474 std::fs::create_dir_all(&tmp).ok();
475 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
476 let evil = format!("{}/../../../etc/passwd", tmp.display());
477 let result = sb.check_path(&evil, "file/read");
478 assert!(result.is_err(), "path traversal should be denied");
479 let _ = std::fs::remove_dir_all(&tmp);
480 }
481
482 #[test]
483 fn test_check_path_multiple_allowed() {
484 let dir_a = std::env::temp_dir().join("sema-sandbox-a");
485 let dir_b = std::env::temp_dir().join("sema-sandbox-b");
486 std::fs::create_dir_all(&dir_a).ok();
487 std::fs::create_dir_all(&dir_b).ok();
488 let sb = Sandbox::allow_all().with_allowed_paths(vec![dir_a.clone(), dir_b.clone()]);
489 let file_a = dir_a.join("ok.txt");
490 std::fs::write(&file_a, "a").ok();
491 let file_b = dir_b.join("ok.txt");
492 std::fs::write(&file_b, "b").ok();
493 assert!(sb.check_path(file_a.to_str().unwrap(), "file/read").is_ok());
494 assert!(sb.check_path(file_b.to_str().unwrap(), "file/read").is_ok());
495 assert!(sb.check_path("/etc/hosts", "file/read").is_err());
496 let _ = std::fs::remove_dir_all(&dir_a);
497 let _ = std::fs::remove_dir_all(&dir_b);
498 }
499
500 #[test]
501 fn test_parse_allowed_paths() {
502 let paths = Sandbox::parse_allowed_paths("/tmp, /var");
503 assert_eq!(paths.len(), 2);
504 }
505
506 #[test]
507 fn test_parse_allowed_paths_empty_parts() {
508 let paths = Sandbox::parse_allowed_paths("/tmp,,/var,");
509 assert_eq!(paths.len(), 2);
510 }
511
512 #[test]
513 fn test_with_allowed_paths_makes_restricted() {
514 let sb = Sandbox::allow_all().with_allowed_paths(vec![std::path::PathBuf::from("/tmp")]);
515 assert!(!sb.is_unrestricted());
516 }
517
518 #[test]
519 fn test_check_path_nonexistent_component_escape() {
520 let tmp = std::env::temp_dir().join("sema-sandbox-escape");
521 std::fs::create_dir_all(&tmp).ok();
522 let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
523 let evil = format!("{}/nonexistent/../../etc/passwd", tmp.display());
524 let result = sb.check_path(&evil, "file/write");
525 assert!(
526 result.is_err(),
527 "nonexistent component escape should be denied"
528 );
529 let _ = std::fs::remove_dir_all(&tmp);
530 }
531
532 #[test]
533 fn test_check_path_relative_nonexistent_escape() {
534 let tmp = std::env::temp_dir().join("sema-sandbox-rel-escape");
535 let allowed = tmp.join("allowed");
536 std::fs::create_dir_all(&allowed).ok();
537 let sb = Sandbox::allow_all().with_allowed_paths(vec![allowed.clone()]);
538 let evil = format!("{}/fake/../../../etc/hosts", allowed.display());
539 let result = sb.check_path(&evil, "file/write");
540 assert!(
541 result.is_err(),
542 "relative nonexistent escape should be denied"
543 );
544 let _ = std::fs::remove_dir_all(&tmp);
545 }
546}