Skip to main content

sema_core/
sandbox.rs

1use std::fmt;
2use std::path::{Component, PathBuf};
3
4use crate::error::SemaError;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct Caps(u64);
8
9impl Caps {
10    pub const NONE: Caps = Caps(0);
11    pub const FS_READ: Caps = Caps(1 << 0);
12    pub const FS_WRITE: Caps = Caps(1 << 1);
13    pub const SHELL: Caps = Caps(1 << 2);
14    pub const NETWORK: Caps = Caps(1 << 3);
15    pub const ENV_READ: Caps = Caps(1 << 4);
16    pub const ENV_WRITE: Caps = Caps(1 << 5);
17    pub const PROCESS: Caps = Caps(1 << 6);
18    pub const LLM: Caps = Caps(1 << 7);
19
20    pub const ALL: Caps = Caps(
21        Self::FS_READ.0
22            | Self::FS_WRITE.0
23            | Self::SHELL.0
24            | Self::NETWORK.0
25            | Self::ENV_READ.0
26            | Self::ENV_WRITE.0
27            | Self::PROCESS.0
28            | Self::LLM.0,
29    );
30
31    pub const STRICT: Caps = Caps(
32        Self::SHELL.0
33            | Self::FS_WRITE.0
34            | Self::NETWORK.0
35            | Self::ENV_WRITE.0
36            | Self::PROCESS.0
37            | Self::LLM.0,
38    );
39
40    pub fn contains(self, other: Caps) -> bool {
41        self.0 & other.0 == other.0
42    }
43
44    pub fn union(self, other: Caps) -> Caps {
45        Caps(self.0 | other.0)
46    }
47
48    pub fn name(self) -> &'static str {
49        match self {
50            Caps::NONE => "none",
51            Caps::FS_READ => "fs-read",
52            Caps::FS_WRITE => "fs-write",
53            Caps::SHELL => "shell",
54            Caps::NETWORK => "network",
55            Caps::ENV_READ => "env-read",
56            Caps::ENV_WRITE => "env-write",
57            Caps::PROCESS => "process",
58            Caps::LLM => "llm",
59            Caps::ALL => "all",
60            Caps::STRICT => "strict",
61            _ => "unknown",
62        }
63    }
64
65    pub fn from_name(s: &str) -> Option<Self> {
66        match s {
67            "none" => Some(Caps::NONE),
68            "fs-read" => Some(Caps::FS_READ),
69            "fs-write" => Some(Caps::FS_WRITE),
70            "shell" => Some(Caps::SHELL),
71            "network" => Some(Caps::NETWORK),
72            "env-read" => Some(Caps::ENV_READ),
73            "env-write" => Some(Caps::ENV_WRITE),
74            "process" => Some(Caps::PROCESS),
75            "llm" => Some(Caps::LLM),
76            "all" => Some(Caps::ALL),
77            "strict" => Some(Caps::STRICT),
78            _ => None,
79        }
80    }
81}
82
83impl fmt::Display for Caps {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.write_str(self.name())
86    }
87}
88
89#[derive(Debug, Clone)]
90pub struct Sandbox {
91    pub denied: Caps,
92    allowed_paths: Option<Vec<PathBuf>>,
93}
94
95fn normalize_lexical(path: &std::path::Path) -> PathBuf {
96    let mut result = PathBuf::new();
97    for component in path.components() {
98        match component {
99            Component::ParentDir => {
100                result.pop();
101            }
102            Component::CurDir => {}
103            other => result.push(other),
104        }
105    }
106    result
107}
108
109impl Sandbox {
110    pub fn allow_all() -> Self {
111        Sandbox {
112            denied: Caps::NONE,
113            allowed_paths: None,
114        }
115    }
116
117    pub fn deny(caps: Caps) -> Self {
118        Sandbox {
119            denied: caps,
120            allowed_paths: None,
121        }
122    }
123
124    pub fn with_allowed_paths(mut self, paths: Vec<PathBuf>) -> Self {
125        self.allowed_paths = Some(
126            paths
127                .into_iter()
128                .map(|p| std::fs::canonicalize(&p).unwrap_or(p))
129                .collect(),
130        );
131        self
132    }
133
134    pub fn is_unrestricted(&self) -> bool {
135        self.denied == Caps::NONE && self.allowed_paths.is_none()
136    }
137
138    pub fn check(&self, required: Caps, fn_name: &str) -> Result<(), SemaError> {
139        if self.denied.contains(required) {
140            Err(SemaError::PermissionDenied {
141                function: fn_name.to_string(),
142                capability: required.name().to_string(),
143            })
144        } else {
145            Ok(())
146        }
147    }
148
149    // NOTE: check_path validates the path before the file operation, not during it.
150    // This means a TOCTOU (time-of-check-to-time-of-use) window exists where an external
151    // process could swap a symlink between the check and the actual fs operation. Mitigating
152    // this properly requires OS-specific secure open patterns (openat with O_NOFOLLOW, dirfds)
153    // which is a significantly larger change. The current approach is on par with most
154    // scripting language sandboxes.
155    pub fn check_path(&self, path: &str, fn_name: &str) -> Result<(), SemaError> {
156        let allowed = match &self.allowed_paths {
157            Some(paths) => paths,
158            None => return Ok(()),
159        };
160        let p = std::path::Path::new(path);
161        let canonical = std::fs::canonicalize(p).unwrap_or_else(|_| {
162            if let Some(parent) = p.parent() {
163                if let Ok(canon_parent) = std::fs::canonicalize(parent) {
164                    return canon_parent.join(p.file_name().unwrap_or_default());
165                }
166            }
167            let abs = if p.is_absolute() {
168                p.to_path_buf()
169            } else {
170                std::env::current_dir()
171                    .unwrap_or_else(|_| PathBuf::from("."))
172                    .join(p)
173            };
174            normalize_lexical(&abs)
175        });
176        for allowed_path in allowed {
177            if canonical.starts_with(allowed_path) {
178                return Ok(());
179            }
180        }
181        Err(SemaError::PathDenied {
182            function: fn_name.to_string(),
183            path: canonical.display().to_string(),
184        })
185    }
186
187    pub fn parse_allowed_paths(value: &str) -> Vec<PathBuf> {
188        value
189            .split(',')
190            .map(|s| s.trim())
191            .filter(|s| !s.is_empty())
192            .map(|s| {
193                let p = PathBuf::from(s);
194                std::fs::canonicalize(&p).unwrap_or(p)
195            })
196            .collect()
197    }
198
199    pub fn parse_cli(value: &str) -> Result<Self, String> {
200        match value {
201            "strict" => Ok(Sandbox::deny(Caps::STRICT)),
202            "all" => Ok(Sandbox::deny(Caps::ALL)),
203            other => {
204                let mut denied = Caps::NONE;
205                for part in other.split(',') {
206                    let part = part.trim();
207                    if part.is_empty() {
208                        continue;
209                    }
210                    let name = part.strip_prefix("no-").unwrap_or(part);
211                    match Caps::from_name(name) {
212                        Some(cap) => denied = denied.union(cap),
213                        None => return Err(format!("unknown capability: {name}")),
214                    }
215                }
216                Ok(Sandbox::deny(denied))
217            }
218        }
219    }
220}
221
222impl Default for Sandbox {
223    fn default() -> Self {
224        Sandbox::allow_all()
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_caps_contains() {
234        let all = Caps::FS_READ.union(Caps::FS_WRITE).union(Caps::SHELL);
235        assert!(all.contains(Caps::FS_READ));
236        assert!(all.contains(Caps::FS_WRITE));
237        assert!(all.contains(Caps::SHELL));
238        assert!(!all.contains(Caps::NETWORK));
239        assert!(!all.contains(Caps::LLM));
240    }
241
242    #[test]
243    fn test_caps_contains_none_is_always_true() {
244        // NONE.contains(NONE) = true since (0 & 0 == 0), but this should not
245        // cause false positives in Sandbox::check because check guards against NONE.
246        assert!(Caps::NONE.contains(Caps::NONE));
247        assert!(Caps::ALL.contains(Caps::NONE));
248    }
249
250    #[test]
251    fn test_caps_union() {
252        let combined = Caps::SHELL.union(Caps::NETWORK);
253        assert!(combined.contains(Caps::SHELL));
254        assert!(combined.contains(Caps::NETWORK));
255        assert!(!combined.contains(Caps::FS_READ));
256    }
257
258    #[test]
259    fn test_caps_all_contains_every_cap() {
260        assert!(Caps::ALL.contains(Caps::FS_READ));
261        assert!(Caps::ALL.contains(Caps::FS_WRITE));
262        assert!(Caps::ALL.contains(Caps::SHELL));
263        assert!(Caps::ALL.contains(Caps::NETWORK));
264        assert!(Caps::ALL.contains(Caps::ENV_READ));
265        assert!(Caps::ALL.contains(Caps::ENV_WRITE));
266        assert!(Caps::ALL.contains(Caps::PROCESS));
267        assert!(Caps::ALL.contains(Caps::LLM));
268    }
269
270    #[test]
271    fn test_caps_strict_preset() {
272        assert!(Caps::STRICT.contains(Caps::SHELL));
273        assert!(Caps::STRICT.contains(Caps::FS_WRITE));
274        assert!(Caps::STRICT.contains(Caps::NETWORK));
275        assert!(Caps::STRICT.contains(Caps::ENV_WRITE));
276        assert!(Caps::STRICT.contains(Caps::PROCESS));
277        assert!(Caps::STRICT.contains(Caps::LLM));
278        // strict does NOT deny read-only operations
279        assert!(!Caps::STRICT.contains(Caps::FS_READ));
280        assert!(!Caps::STRICT.contains(Caps::ENV_READ));
281    }
282
283    #[test]
284    fn test_caps_name_roundtrip() {
285        let caps = [
286            Caps::FS_READ,
287            Caps::FS_WRITE,
288            Caps::SHELL,
289            Caps::NETWORK,
290            Caps::ENV_READ,
291            Caps::ENV_WRITE,
292            Caps::PROCESS,
293            Caps::LLM,
294        ];
295        for cap in caps {
296            let name = cap.name();
297            assert_eq!(
298                Caps::from_name(name),
299                Some(cap),
300                "roundtrip failed for {name}"
301            );
302        }
303    }
304
305    #[test]
306    fn test_caps_from_name_unknown() {
307        assert_eq!(Caps::from_name("garbage"), None);
308        assert_eq!(Caps::from_name(""), None);
309    }
310
311    #[test]
312    fn test_caps_display() {
313        assert_eq!(format!("{}", Caps::SHELL), "shell");
314        assert_eq!(format!("{}", Caps::NETWORK), "network");
315        assert_eq!(format!("{}", Caps::FS_READ), "fs-read");
316    }
317
318    #[test]
319    fn test_sandbox_allow_all_is_unrestricted() {
320        let sb = Sandbox::allow_all();
321        assert!(sb.is_unrestricted());
322    }
323
324    #[test]
325    fn test_sandbox_deny_is_restricted() {
326        let sb = Sandbox::deny(Caps::SHELL);
327        assert!(!sb.is_unrestricted());
328    }
329
330    #[test]
331    fn test_sandbox_default_is_unrestricted() {
332        let sb = Sandbox::default();
333        assert!(sb.is_unrestricted());
334    }
335
336    #[test]
337    fn test_sandbox_check_allowed() {
338        let sb = Sandbox::deny(Caps::SHELL);
339        assert!(sb.check(Caps::NETWORK, "http/get").is_ok());
340        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
341    }
342
343    #[test]
344    fn test_sandbox_check_denied() {
345        let sb = Sandbox::deny(Caps::SHELL);
346        let err = sb.check(Caps::SHELL, "shell").unwrap_err();
347        assert!(err.to_string().contains("Permission denied"));
348        assert!(err.to_string().contains("shell"));
349    }
350
351    #[test]
352    fn test_sandbox_check_denied_error_format() {
353        let sb = Sandbox::deny(Caps::NETWORK);
354        let err = sb.check(Caps::NETWORK, "http/get").unwrap_err();
355        let msg = err.to_string();
356        assert!(
357            msg.contains("http/get"),
358            "should contain function name: {msg}"
359        );
360        assert!(
361            msg.contains("network"),
362            "should contain capability name: {msg}"
363        );
364    }
365
366    #[test]
367    fn test_sandbox_check_multiple_denied() {
368        let sb = Sandbox::deny(Caps::SHELL.union(Caps::NETWORK));
369        assert!(sb.check(Caps::SHELL, "shell").is_err());
370        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
371        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
372    }
373
374    #[test]
375    fn test_sandbox_parse_cli_strict() {
376        let sb = Sandbox::parse_cli("strict").unwrap();
377        assert!(sb.check(Caps::SHELL, "shell").is_err());
378        assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
379        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
380        // strict allows reads
381        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
382        assert!(sb.check(Caps::ENV_READ, "env").is_ok());
383    }
384
385    #[test]
386    fn test_sandbox_parse_cli_all() {
387        let sb = Sandbox::parse_cli("all").unwrap();
388        assert!(sb.check(Caps::SHELL, "shell").is_err());
389        assert!(sb.check(Caps::FS_READ, "file/read").is_err());
390        assert!(sb.check(Caps::ENV_READ, "env").is_err());
391    }
392
393    #[test]
394    fn test_sandbox_parse_cli_no_prefix() {
395        let sb = Sandbox::parse_cli("no-shell,no-network").unwrap();
396        assert!(sb.check(Caps::SHELL, "shell").is_err());
397        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
398        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
399    }
400
401    #[test]
402    fn test_sandbox_parse_cli_without_no_prefix() {
403        let sb = Sandbox::parse_cli("shell,network").unwrap();
404        assert!(sb.check(Caps::SHELL, "shell").is_err());
405        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
406        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
407    }
408
409    #[test]
410    fn test_sandbox_parse_cli_single() {
411        let sb = Sandbox::parse_cli("no-fs-write").unwrap();
412        assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
413        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
414    }
415
416    #[test]
417    fn test_sandbox_parse_cli_with_spaces() {
418        let sb = Sandbox::parse_cli("no-shell, no-network").unwrap();
419        assert!(sb.check(Caps::SHELL, "shell").is_err());
420        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
421    }
422
423    #[test]
424    fn test_sandbox_parse_cli_empty_parts() {
425        let sb = Sandbox::parse_cli("no-shell,,no-network").unwrap();
426        assert!(sb.check(Caps::SHELL, "shell").is_err());
427        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
428    }
429
430    #[test]
431    fn test_sandbox_parse_cli_invalid() {
432        assert!(Sandbox::parse_cli("no-bogus").is_err());
433        assert!(Sandbox::parse_cli("no-shell,no-bogus").is_err());
434    }
435
436    #[test]
437    fn test_check_path_none_allows_everything() {
438        let sb = Sandbox::allow_all();
439        assert!(sb.check_path("/etc/passwd", "file/read").is_ok());
440        assert!(sb.check_path("relative.txt", "file/read").is_ok());
441    }
442
443    #[test]
444    fn test_check_path_inside_allowed_dir() {
445        let tmp = std::env::temp_dir();
446        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
447        let test_path = tmp.join("sema-test-file.txt");
448        std::fs::write(&test_path, "test").ok();
449        assert!(sb
450            .check_path(test_path.to_str().unwrap(), "file/read")
451            .is_ok());
452        let _ = std::fs::remove_file(&test_path);
453    }
454
455    #[test]
456    fn test_check_path_outside_allowed_dir() {
457        let tmp = std::env::temp_dir().join("sema-sandbox-test-dir");
458        std::fs::create_dir_all(&tmp).ok();
459        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
460        let result = sb.check_path("/etc/hosts", "file/read");
461        assert!(result.is_err());
462        let err = result.unwrap_err();
463        assert!(err.to_string().contains("Permission denied"), "{err}");
464        assert!(
465            err.to_string().contains("outside allowed directories"),
466            "{err}"
467        );
468        let _ = std::fs::remove_dir_all(&tmp);
469    }
470
471    #[test]
472    fn test_check_path_traversal_attempt() {
473        let tmp = std::env::temp_dir().join("sema-sandbox-traverse");
474        std::fs::create_dir_all(&tmp).ok();
475        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
476        let evil = format!("{}/../../../etc/passwd", tmp.display());
477        let result = sb.check_path(&evil, "file/read");
478        assert!(result.is_err(), "path traversal should be denied");
479        let _ = std::fs::remove_dir_all(&tmp);
480    }
481
482    #[test]
483    fn test_check_path_multiple_allowed() {
484        let dir_a = std::env::temp_dir().join("sema-sandbox-a");
485        let dir_b = std::env::temp_dir().join("sema-sandbox-b");
486        std::fs::create_dir_all(&dir_a).ok();
487        std::fs::create_dir_all(&dir_b).ok();
488        let sb = Sandbox::allow_all().with_allowed_paths(vec![dir_a.clone(), dir_b.clone()]);
489        let file_a = dir_a.join("ok.txt");
490        std::fs::write(&file_a, "a").ok();
491        let file_b = dir_b.join("ok.txt");
492        std::fs::write(&file_b, "b").ok();
493        assert!(sb.check_path(file_a.to_str().unwrap(), "file/read").is_ok());
494        assert!(sb.check_path(file_b.to_str().unwrap(), "file/read").is_ok());
495        assert!(sb.check_path("/etc/hosts", "file/read").is_err());
496        let _ = std::fs::remove_dir_all(&dir_a);
497        let _ = std::fs::remove_dir_all(&dir_b);
498    }
499
500    #[test]
501    fn test_parse_allowed_paths() {
502        let paths = Sandbox::parse_allowed_paths("/tmp, /var");
503        assert_eq!(paths.len(), 2);
504    }
505
506    #[test]
507    fn test_parse_allowed_paths_empty_parts() {
508        let paths = Sandbox::parse_allowed_paths("/tmp,,/var,");
509        assert_eq!(paths.len(), 2);
510    }
511
512    #[test]
513    fn test_with_allowed_paths_makes_restricted() {
514        let sb = Sandbox::allow_all().with_allowed_paths(vec![std::path::PathBuf::from("/tmp")]);
515        assert!(!sb.is_unrestricted());
516    }
517
518    #[test]
519    fn test_check_path_nonexistent_component_escape() {
520        let tmp = std::env::temp_dir().join("sema-sandbox-escape");
521        std::fs::create_dir_all(&tmp).ok();
522        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
523        let evil = format!("{}/nonexistent/../../etc/passwd", tmp.display());
524        let result = sb.check_path(&evil, "file/write");
525        assert!(
526            result.is_err(),
527            "nonexistent component escape should be denied"
528        );
529        let _ = std::fs::remove_dir_all(&tmp);
530    }
531
532    #[test]
533    fn test_check_path_relative_nonexistent_escape() {
534        let tmp = std::env::temp_dir().join("sema-sandbox-rel-escape");
535        let allowed = tmp.join("allowed");
536        std::fs::create_dir_all(&allowed).ok();
537        let sb = Sandbox::allow_all().with_allowed_paths(vec![allowed.clone()]);
538        let evil = format!("{}/fake/../../../etc/hosts", allowed.display());
539        let result = sb.check_path(&evil, "file/write");
540        assert!(
541            result.is_err(),
542            "relative nonexistent escape should be denied"
543        );
544        let _ = std::fs::remove_dir_all(&tmp);
545    }
546}