Skip to main content

sema_core/
sandbox.rs

1use std::fmt;
2use std::path::{Component, PathBuf};
3
4use crate::error::SemaError;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct Caps(u64);
8
9impl Caps {
10    pub const NONE: Caps = Caps(0);
11    pub const FS_READ: Caps = Caps(1 << 0);
12    pub const FS_WRITE: Caps = Caps(1 << 1);
13    pub const SHELL: Caps = Caps(1 << 2);
14    pub const NETWORK: Caps = Caps(1 << 3);
15    pub const ENV_READ: Caps = Caps(1 << 4);
16    pub const ENV_WRITE: Caps = Caps(1 << 5);
17    pub const PROCESS: Caps = Caps(1 << 6);
18    pub const LLM: Caps = Caps(1 << 7);
19    pub const SERIAL: Caps = Caps(1 << 8);
20
21    pub const ALL: Caps = Caps(
22        Self::FS_READ.0
23            | Self::FS_WRITE.0
24            | Self::SHELL.0
25            | Self::NETWORK.0
26            | Self::ENV_READ.0
27            | Self::ENV_WRITE.0
28            | Self::PROCESS.0
29            | Self::LLM.0
30            | Self::SERIAL.0,
31    );
32
33    pub const STRICT: Caps = Caps(
34        Self::SHELL.0
35            | Self::FS_WRITE.0
36            | Self::NETWORK.0
37            | Self::ENV_WRITE.0
38            | Self::PROCESS.0
39            | Self::LLM.0
40            | Self::SERIAL.0,
41    );
42
43    pub fn contains(self, other: Caps) -> bool {
44        self.0 & other.0 == other.0
45    }
46
47    pub fn union(self, other: Caps) -> Caps {
48        Caps(self.0 | other.0)
49    }
50
51    pub fn name(self) -> &'static str {
52        match self {
53            Caps::NONE => "none",
54            Caps::FS_READ => "fs-read",
55            Caps::FS_WRITE => "fs-write",
56            Caps::SHELL => "shell",
57            Caps::NETWORK => "network",
58            Caps::ENV_READ => "env-read",
59            Caps::ENV_WRITE => "env-write",
60            Caps::PROCESS => "process",
61            Caps::LLM => "llm",
62            Caps::SERIAL => "serial",
63            Caps::ALL => "all",
64            Caps::STRICT => "strict",
65            _ => "unknown",
66        }
67    }
68
69    pub fn from_name(s: &str) -> Option<Self> {
70        match s {
71            "none" => Some(Caps::NONE),
72            "fs-read" => Some(Caps::FS_READ),
73            "fs-write" => Some(Caps::FS_WRITE),
74            "shell" => Some(Caps::SHELL),
75            "network" => Some(Caps::NETWORK),
76            "env-read" => Some(Caps::ENV_READ),
77            "env-write" => Some(Caps::ENV_WRITE),
78            "process" => Some(Caps::PROCESS),
79            "llm" => Some(Caps::LLM),
80            "serial" => Some(Caps::SERIAL),
81            "all" => Some(Caps::ALL),
82            "strict" => Some(Caps::STRICT),
83            _ => None,
84        }
85    }
86}
87
88impl fmt::Display for Caps {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        f.write_str(self.name())
91    }
92}
93
94#[derive(Debug, Clone)]
95pub struct Sandbox {
96    pub denied: Caps,
97    allowed_paths: Option<Vec<PathBuf>>,
98}
99
100fn normalize_lexical(path: &std::path::Path) -> PathBuf {
101    let mut result = PathBuf::new();
102    for component in path.components() {
103        match component {
104            Component::ParentDir => {
105                result.pop();
106            }
107            Component::CurDir => {}
108            other => result.push(other),
109        }
110    }
111    result
112}
113
114impl Sandbox {
115    pub fn allow_all() -> Self {
116        Sandbox {
117            denied: Caps::NONE,
118            allowed_paths: None,
119        }
120    }
121
122    pub fn deny(caps: Caps) -> Self {
123        Sandbox {
124            denied: caps,
125            allowed_paths: None,
126        }
127    }
128
129    pub fn with_allowed_paths(mut self, paths: Vec<PathBuf>) -> Self {
130        self.allowed_paths = Some(
131            paths
132                .into_iter()
133                .map(|p| std::fs::canonicalize(&p).unwrap_or(p))
134                .collect(),
135        );
136        self
137    }
138
139    pub fn is_unrestricted(&self) -> bool {
140        self.denied == Caps::NONE && self.allowed_paths.is_none()
141    }
142
143    pub fn check(&self, required: Caps, fn_name: &str) -> Result<(), SemaError> {
144        // Requesting zero capabilities always succeeds. Without this guard,
145        // `denied.contains(Caps::NONE)` is vacuously true, so a restricted
146        // sandbox would wrongly deny a `Caps::NONE` request (CORE-3).
147        if required == Caps::NONE {
148            return Ok(());
149        }
150        if self.denied.contains(required) {
151            Err(SemaError::PermissionDenied {
152                function: fn_name.to_string(),
153                capability: required.name().to_string(),
154            })
155        } else {
156            Ok(())
157        }
158    }
159
160    // NOTE: check_path validates the path before the file operation, not during it.
161    // This means a TOCTOU (time-of-check-to-time-of-use) window exists where an external
162    // process could swap a symlink between the check and the actual fs operation. Mitigating
163    // this properly requires OS-specific secure open patterns (openat with O_NOFOLLOW, dirfds)
164    // which is a significantly larger change. The current approach is on par with most
165    // scripting language sandboxes.
166    pub fn check_path(&self, path: &str, fn_name: &str) -> Result<(), SemaError> {
167        let allowed = match &self.allowed_paths {
168            Some(paths) => paths,
169            None => return Ok(()),
170        };
171        let p = std::path::Path::new(path);
172        let canonical = std::fs::canonicalize(p).unwrap_or_else(|_| {
173            if let Some(parent) = p.parent() {
174                if let Ok(canon_parent) = std::fs::canonicalize(parent) {
175                    return canon_parent.join(p.file_name().unwrap_or_default());
176                }
177            }
178            let abs = if p.is_absolute() {
179                p.to_path_buf()
180            } else {
181                std::env::current_dir()
182                    .unwrap_or_else(|_| PathBuf::from("."))
183                    .join(p)
184            };
185            normalize_lexical(&abs)
186        });
187        for allowed_path in allowed {
188            if canonical.starts_with(allowed_path) {
189                return Ok(());
190            }
191        }
192        Err(SemaError::PathDenied {
193            function: fn_name.to_string(),
194            path: canonical.display().to_string(),
195        })
196    }
197
198    pub fn parse_allowed_paths(value: &str) -> Vec<PathBuf> {
199        value
200            .split(',')
201            .map(|s| s.trim())
202            .filter(|s| !s.is_empty())
203            .map(|s| {
204                let p = PathBuf::from(s);
205                std::fs::canonicalize(&p).unwrap_or(p)
206            })
207            .collect()
208    }
209
210    pub fn parse_cli(value: &str) -> Result<Self, String> {
211        match value {
212            "strict" => Ok(Sandbox::deny(Caps::STRICT)),
213            "all" => Ok(Sandbox::deny(Caps::ALL)),
214            other => {
215                let mut denied = Caps::NONE;
216                for part in other.split(',') {
217                    let part = part.trim();
218                    if part.is_empty() {
219                        continue;
220                    }
221                    let name = part.strip_prefix("no-").unwrap_or(part);
222                    match Caps::from_name(name) {
223                        Some(cap) => denied = denied.union(cap),
224                        None => return Err(format!("unknown capability: {name}")),
225                    }
226                }
227                Ok(Sandbox::deny(denied))
228            }
229        }
230    }
231}
232
233impl Default for Sandbox {
234    fn default() -> Self {
235        Sandbox::allow_all()
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn test_caps_contains() {
245        let all = Caps::FS_READ.union(Caps::FS_WRITE).union(Caps::SHELL);
246        assert!(all.contains(Caps::FS_READ));
247        assert!(all.contains(Caps::FS_WRITE));
248        assert!(all.contains(Caps::SHELL));
249        assert!(!all.contains(Caps::NETWORK));
250        assert!(!all.contains(Caps::LLM));
251    }
252
253    #[test]
254    fn test_caps_contains_none_is_always_true() {
255        // NONE.contains(NONE) = true since (0 & 0 == 0), but this should not
256        // cause false positives in Sandbox::check because check guards against NONE.
257        assert!(Caps::NONE.contains(Caps::NONE));
258        assert!(Caps::ALL.contains(Caps::NONE));
259    }
260
261    #[test]
262    fn test_caps_union() {
263        let combined = Caps::SHELL.union(Caps::NETWORK);
264        assert!(combined.contains(Caps::SHELL));
265        assert!(combined.contains(Caps::NETWORK));
266        assert!(!combined.contains(Caps::FS_READ));
267    }
268
269    #[test]
270    fn test_caps_all_contains_every_cap() {
271        assert!(Caps::ALL.contains(Caps::FS_READ));
272        assert!(Caps::ALL.contains(Caps::FS_WRITE));
273        assert!(Caps::ALL.contains(Caps::SHELL));
274        assert!(Caps::ALL.contains(Caps::NETWORK));
275        assert!(Caps::ALL.contains(Caps::ENV_READ));
276        assert!(Caps::ALL.contains(Caps::ENV_WRITE));
277        assert!(Caps::ALL.contains(Caps::PROCESS));
278        assert!(Caps::ALL.contains(Caps::LLM));
279        assert!(Caps::ALL.contains(Caps::SERIAL));
280    }
281
282    #[test]
283    fn test_caps_strict_preset() {
284        assert!(Caps::STRICT.contains(Caps::SHELL));
285        assert!(Caps::STRICT.contains(Caps::FS_WRITE));
286        assert!(Caps::STRICT.contains(Caps::NETWORK));
287        assert!(Caps::STRICT.contains(Caps::ENV_WRITE));
288        assert!(Caps::STRICT.contains(Caps::PROCESS));
289        assert!(Caps::STRICT.contains(Caps::LLM));
290        assert!(Caps::STRICT.contains(Caps::SERIAL));
291        // strict does NOT deny read-only operations
292        assert!(!Caps::STRICT.contains(Caps::FS_READ));
293        assert!(!Caps::STRICT.contains(Caps::ENV_READ));
294    }
295
296    #[test]
297    fn test_caps_name_roundtrip() {
298        let caps = [
299            Caps::FS_READ,
300            Caps::FS_WRITE,
301            Caps::SHELL,
302            Caps::NETWORK,
303            Caps::ENV_READ,
304            Caps::ENV_WRITE,
305            Caps::PROCESS,
306            Caps::LLM,
307            Caps::SERIAL,
308        ];
309        for cap in caps {
310            let name = cap.name();
311            assert_eq!(
312                Caps::from_name(name),
313                Some(cap),
314                "roundtrip failed for {name}"
315            );
316        }
317    }
318
319    #[test]
320    fn test_caps_from_name_unknown() {
321        assert_eq!(Caps::from_name("garbage"), None);
322        assert_eq!(Caps::from_name(""), None);
323    }
324
325    #[test]
326    fn test_caps_display() {
327        assert_eq!(format!("{}", Caps::SHELL), "shell");
328        assert_eq!(format!("{}", Caps::NETWORK), "network");
329        assert_eq!(format!("{}", Caps::FS_READ), "fs-read");
330        assert_eq!(format!("{}", Caps::SERIAL), "serial");
331    }
332
333    #[test]
334    fn test_sandbox_allow_all_is_unrestricted() {
335        let sb = Sandbox::allow_all();
336        assert!(sb.is_unrestricted());
337    }
338
339    #[test]
340    fn test_sandbox_deny_is_restricted() {
341        let sb = Sandbox::deny(Caps::SHELL);
342        assert!(!sb.is_unrestricted());
343    }
344
345    #[test]
346    fn test_sandbox_default_is_unrestricted() {
347        let sb = Sandbox::default();
348        assert!(sb.is_unrestricted());
349    }
350
351    #[test]
352    fn test_sandbox_check_allowed() {
353        let sb = Sandbox::deny(Caps::SHELL);
354        assert!(sb.check(Caps::NETWORK, "http/get").is_ok());
355        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
356    }
357
358    #[test]
359    fn test_sandbox_check_denied() {
360        let sb = Sandbox::deny(Caps::SHELL);
361        let err = sb.check(Caps::SHELL, "shell").unwrap_err();
362        assert!(err.to_string().contains("Permission denied"));
363        assert!(err.to_string().contains("shell"));
364    }
365
366    #[test]
367    fn test_sandbox_check_denied_error_format() {
368        let sb = Sandbox::deny(Caps::NETWORK);
369        let err = sb.check(Caps::NETWORK, "http/get").unwrap_err();
370        let msg = err.to_string();
371        assert!(
372            msg.contains("http/get"),
373            "should contain function name: {msg}"
374        );
375        assert!(
376            msg.contains("network"),
377            "should contain capability name: {msg}"
378        );
379    }
380
381    #[test]
382    fn test_sandbox_check_multiple_denied() {
383        let sb = Sandbox::deny(Caps::SHELL.union(Caps::NETWORK));
384        assert!(sb.check(Caps::SHELL, "shell").is_err());
385        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
386        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
387    }
388
389    #[test]
390    fn test_sandbox_parse_cli_strict() {
391        let sb = Sandbox::parse_cli("strict").unwrap();
392        assert!(sb.check(Caps::SHELL, "shell").is_err());
393        assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
394        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
395        assert!(sb.check(Caps::SERIAL, "serial/list").is_err());
396        // strict allows reads
397        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
398        assert!(sb.check(Caps::ENV_READ, "env").is_ok());
399    }
400
401    #[test]
402    fn test_sandbox_parse_cli_all() {
403        let sb = Sandbox::parse_cli("all").unwrap();
404        assert!(sb.check(Caps::SHELL, "shell").is_err());
405        assert!(sb.check(Caps::FS_READ, "file/read").is_err());
406        assert!(sb.check(Caps::ENV_READ, "env").is_err());
407        assert!(sb.check(Caps::SERIAL, "serial/list").is_err());
408    }
409
410    #[test]
411    fn test_sandbox_parse_cli_no_prefix() {
412        let sb = Sandbox::parse_cli("no-shell,no-network").unwrap();
413        assert!(sb.check(Caps::SHELL, "shell").is_err());
414        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
415        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
416    }
417
418    #[test]
419    fn test_sandbox_parse_cli_without_no_prefix() {
420        let sb = Sandbox::parse_cli("shell,network").unwrap();
421        assert!(sb.check(Caps::SHELL, "shell").is_err());
422        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
423        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
424    }
425
426    #[test]
427    fn test_sandbox_parse_cli_single() {
428        let sb = Sandbox::parse_cli("no-fs-write").unwrap();
429        assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
430        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
431    }
432
433    #[test]
434    fn test_sandbox_parse_cli_with_spaces() {
435        let sb = Sandbox::parse_cli("no-shell, no-network").unwrap();
436        assert!(sb.check(Caps::SHELL, "shell").is_err());
437        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
438    }
439
440    #[test]
441    fn test_sandbox_parse_cli_empty_parts() {
442        let sb = Sandbox::parse_cli("no-shell,,no-network").unwrap();
443        assert!(sb.check(Caps::SHELL, "shell").is_err());
444        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
445    }
446
447    #[test]
448    fn test_sandbox_parse_cli_invalid() {
449        assert!(Sandbox::parse_cli("no-bogus").is_err());
450        assert!(Sandbox::parse_cli("no-shell,no-bogus").is_err());
451    }
452
453    #[test]
454    fn test_check_path_none_allows_everything() {
455        let sb = Sandbox::allow_all();
456        assert!(sb.check_path("/etc/passwd", "file/read").is_ok());
457        assert!(sb.check_path("relative.txt", "file/read").is_ok());
458    }
459
460    #[test]
461    fn test_check_path_inside_allowed_dir() {
462        let tmp = std::env::temp_dir();
463        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
464        let test_path = tmp.join("sema-test-file.txt");
465        std::fs::write(&test_path, "test").ok();
466        assert!(sb
467            .check_path(test_path.to_str().unwrap(), "file/read")
468            .is_ok());
469        let _ = std::fs::remove_file(&test_path);
470    }
471
472    #[test]
473    fn test_check_path_outside_allowed_dir() {
474        let tmp = std::env::temp_dir().join("sema-sandbox-test-dir");
475        std::fs::create_dir_all(&tmp).ok();
476        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
477        let result = sb.check_path("/etc/hosts", "file/read");
478        assert!(result.is_err());
479        let err = result.unwrap_err();
480        assert!(err.to_string().contains("Permission denied"), "{err}");
481        assert!(
482            err.to_string().contains("outside allowed directories"),
483            "{err}"
484        );
485        let _ = std::fs::remove_dir_all(&tmp);
486    }
487
488    #[test]
489    fn test_check_path_traversal_attempt() {
490        let tmp = std::env::temp_dir().join("sema-sandbox-traverse");
491        std::fs::create_dir_all(&tmp).ok();
492        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
493        let evil = format!("{}/../../../etc/passwd", tmp.display());
494        let result = sb.check_path(&evil, "file/read");
495        assert!(result.is_err(), "path traversal should be denied");
496        let _ = std::fs::remove_dir_all(&tmp);
497    }
498
499    #[test]
500    fn test_check_path_multiple_allowed() {
501        let dir_a = std::env::temp_dir().join("sema-sandbox-a");
502        let dir_b = std::env::temp_dir().join("sema-sandbox-b");
503        std::fs::create_dir_all(&dir_a).ok();
504        std::fs::create_dir_all(&dir_b).ok();
505        let sb = Sandbox::allow_all().with_allowed_paths(vec![dir_a.clone(), dir_b.clone()]);
506        let file_a = dir_a.join("ok.txt");
507        std::fs::write(&file_a, "a").ok();
508        let file_b = dir_b.join("ok.txt");
509        std::fs::write(&file_b, "b").ok();
510        assert!(sb.check_path(file_a.to_str().unwrap(), "file/read").is_ok());
511        assert!(sb.check_path(file_b.to_str().unwrap(), "file/read").is_ok());
512        assert!(sb.check_path("/etc/hosts", "file/read").is_err());
513        let _ = std::fs::remove_dir_all(&dir_a);
514        let _ = std::fs::remove_dir_all(&dir_b);
515    }
516
517    #[test]
518    fn test_parse_allowed_paths() {
519        let paths = Sandbox::parse_allowed_paths("/tmp, /var");
520        assert_eq!(paths.len(), 2);
521    }
522
523    #[test]
524    fn test_parse_allowed_paths_empty_parts() {
525        let paths = Sandbox::parse_allowed_paths("/tmp,,/var,");
526        assert_eq!(paths.len(), 2);
527    }
528
529    #[test]
530    fn test_with_allowed_paths_makes_restricted() {
531        let sb = Sandbox::allow_all().with_allowed_paths(vec![std::path::PathBuf::from("/tmp")]);
532        assert!(!sb.is_unrestricted());
533    }
534
535    #[test]
536    fn test_check_path_nonexistent_component_escape() {
537        let tmp = std::env::temp_dir().join("sema-sandbox-escape");
538        std::fs::create_dir_all(&tmp).ok();
539        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
540        let evil = format!("{}/nonexistent/../../etc/passwd", tmp.display());
541        let result = sb.check_path(&evil, "file/write");
542        assert!(
543            result.is_err(),
544            "nonexistent component escape should be denied"
545        );
546        let _ = std::fs::remove_dir_all(&tmp);
547    }
548
549    #[test]
550    fn test_check_path_relative_nonexistent_escape() {
551        let tmp = std::env::temp_dir().join("sema-sandbox-rel-escape");
552        let allowed = tmp.join("allowed");
553        std::fs::create_dir_all(&allowed).ok();
554        let sb = Sandbox::allow_all().with_allowed_paths(vec![allowed.clone()]);
555        let evil = format!("{}/fake/../../../etc/hosts", allowed.display());
556        let result = sb.check_path(&evil, "file/write");
557        assert!(
558            result.is_err(),
559            "relative nonexistent escape should be denied"
560        );
561        let _ = std::fs::remove_dir_all(&tmp);
562    }
563}