Skip to main content

sema_core/
sandbox.rs

1use std::fmt;
2use std::path::{Component, PathBuf};
3
4use crate::error::SemaError;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct Caps(u64);
8
9impl Caps {
10    pub const NONE: Caps = Caps(0);
11    pub const FS_READ: Caps = Caps(1 << 0);
12    pub const FS_WRITE: Caps = Caps(1 << 1);
13    pub const SHELL: Caps = Caps(1 << 2);
14    pub const NETWORK: Caps = Caps(1 << 3);
15    pub const ENV_READ: Caps = Caps(1 << 4);
16    pub const ENV_WRITE: Caps = Caps(1 << 5);
17    pub const PROCESS: Caps = Caps(1 << 6);
18    pub const LLM: Caps = Caps(1 << 7);
19    pub const SERIAL: Caps = Caps(1 << 8);
20
21    pub const ALL: Caps = Caps(
22        Self::FS_READ.0
23            | Self::FS_WRITE.0
24            | Self::SHELL.0
25            | Self::NETWORK.0
26            | Self::ENV_READ.0
27            | Self::ENV_WRITE.0
28            | Self::PROCESS.0
29            | Self::LLM.0
30            | Self::SERIAL.0,
31    );
32
33    pub const STRICT: Caps = Caps(
34        Self::SHELL.0
35            | Self::FS_WRITE.0
36            | Self::NETWORK.0
37            | Self::ENV_WRITE.0
38            | Self::PROCESS.0
39            | Self::LLM.0
40            | Self::SERIAL.0,
41    );
42
43    pub fn contains(self, other: Caps) -> bool {
44        self.0 & other.0 == other.0
45    }
46
47    pub fn union(self, other: Caps) -> Caps {
48        Caps(self.0 | other.0)
49    }
50
51    pub fn name(self) -> &'static str {
52        match self {
53            Caps::NONE => "none",
54            Caps::FS_READ => "fs-read",
55            Caps::FS_WRITE => "fs-write",
56            Caps::SHELL => "shell",
57            Caps::NETWORK => "network",
58            Caps::ENV_READ => "env-read",
59            Caps::ENV_WRITE => "env-write",
60            Caps::PROCESS => "process",
61            Caps::LLM => "llm",
62            Caps::SERIAL => "serial",
63            Caps::ALL => "all",
64            Caps::STRICT => "strict",
65            _ => "unknown",
66        }
67    }
68
69    pub fn from_name(s: &str) -> Option<Self> {
70        match s {
71            "none" => Some(Caps::NONE),
72            "fs-read" => Some(Caps::FS_READ),
73            "fs-write" => Some(Caps::FS_WRITE),
74            "shell" => Some(Caps::SHELL),
75            "network" => Some(Caps::NETWORK),
76            "env-read" => Some(Caps::ENV_READ),
77            "env-write" => Some(Caps::ENV_WRITE),
78            "process" => Some(Caps::PROCESS),
79            "llm" => Some(Caps::LLM),
80            "serial" => Some(Caps::SERIAL),
81            "all" => Some(Caps::ALL),
82            "strict" => Some(Caps::STRICT),
83            _ => None,
84        }
85    }
86}
87
88impl fmt::Display for Caps {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        f.write_str(self.name())
91    }
92}
93
94#[derive(Debug, Clone)]
95pub struct Sandbox {
96    pub denied: Caps,
97    allowed_paths: Option<Vec<PathBuf>>,
98}
99
100fn normalize_lexical(path: &std::path::Path) -> PathBuf {
101    let mut result = PathBuf::new();
102    for component in path.components() {
103        match component {
104            Component::ParentDir => {
105                result.pop();
106            }
107            Component::CurDir => {}
108            other => result.push(other),
109        }
110    }
111    result
112}
113
114impl Sandbox {
115    pub fn allow_all() -> Self {
116        Sandbox {
117            denied: Caps::NONE,
118            allowed_paths: None,
119        }
120    }
121
122    pub fn deny(caps: Caps) -> Self {
123        Sandbox {
124            denied: caps,
125            allowed_paths: None,
126        }
127    }
128
129    pub fn with_allowed_paths(mut self, paths: Vec<PathBuf>) -> Self {
130        self.allowed_paths = Some(
131            paths
132                .into_iter()
133                .map(|p| std::fs::canonicalize(&p).unwrap_or(p))
134                .collect(),
135        );
136        self
137    }
138
139    pub fn is_unrestricted(&self) -> bool {
140        self.denied == Caps::NONE && self.allowed_paths.is_none()
141    }
142
143    pub fn check(&self, required: Caps, fn_name: &str) -> Result<(), SemaError> {
144        if self.denied.contains(required) {
145            Err(SemaError::PermissionDenied {
146                function: fn_name.to_string(),
147                capability: required.name().to_string(),
148            })
149        } else {
150            Ok(())
151        }
152    }
153
154    // NOTE: check_path validates the path before the file operation, not during it.
155    // This means a TOCTOU (time-of-check-to-time-of-use) window exists where an external
156    // process could swap a symlink between the check and the actual fs operation. Mitigating
157    // this properly requires OS-specific secure open patterns (openat with O_NOFOLLOW, dirfds)
158    // which is a significantly larger change. The current approach is on par with most
159    // scripting language sandboxes.
160    pub fn check_path(&self, path: &str, fn_name: &str) -> Result<(), SemaError> {
161        let allowed = match &self.allowed_paths {
162            Some(paths) => paths,
163            None => return Ok(()),
164        };
165        let p = std::path::Path::new(path);
166        let canonical = std::fs::canonicalize(p).unwrap_or_else(|_| {
167            if let Some(parent) = p.parent() {
168                if let Ok(canon_parent) = std::fs::canonicalize(parent) {
169                    return canon_parent.join(p.file_name().unwrap_or_default());
170                }
171            }
172            let abs = if p.is_absolute() {
173                p.to_path_buf()
174            } else {
175                std::env::current_dir()
176                    .unwrap_or_else(|_| PathBuf::from("."))
177                    .join(p)
178            };
179            normalize_lexical(&abs)
180        });
181        for allowed_path in allowed {
182            if canonical.starts_with(allowed_path) {
183                return Ok(());
184            }
185        }
186        Err(SemaError::PathDenied {
187            function: fn_name.to_string(),
188            path: canonical.display().to_string(),
189        })
190    }
191
192    pub fn parse_allowed_paths(value: &str) -> Vec<PathBuf> {
193        value
194            .split(',')
195            .map(|s| s.trim())
196            .filter(|s| !s.is_empty())
197            .map(|s| {
198                let p = PathBuf::from(s);
199                std::fs::canonicalize(&p).unwrap_or(p)
200            })
201            .collect()
202    }
203
204    pub fn parse_cli(value: &str) -> Result<Self, String> {
205        match value {
206            "strict" => Ok(Sandbox::deny(Caps::STRICT)),
207            "all" => Ok(Sandbox::deny(Caps::ALL)),
208            other => {
209                let mut denied = Caps::NONE;
210                for part in other.split(',') {
211                    let part = part.trim();
212                    if part.is_empty() {
213                        continue;
214                    }
215                    let name = part.strip_prefix("no-").unwrap_or(part);
216                    match Caps::from_name(name) {
217                        Some(cap) => denied = denied.union(cap),
218                        None => return Err(format!("unknown capability: {name}")),
219                    }
220                }
221                Ok(Sandbox::deny(denied))
222            }
223        }
224    }
225}
226
227impl Default for Sandbox {
228    fn default() -> Self {
229        Sandbox::allow_all()
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_caps_contains() {
239        let all = Caps::FS_READ.union(Caps::FS_WRITE).union(Caps::SHELL);
240        assert!(all.contains(Caps::FS_READ));
241        assert!(all.contains(Caps::FS_WRITE));
242        assert!(all.contains(Caps::SHELL));
243        assert!(!all.contains(Caps::NETWORK));
244        assert!(!all.contains(Caps::LLM));
245    }
246
247    #[test]
248    fn test_caps_contains_none_is_always_true() {
249        // NONE.contains(NONE) = true since (0 & 0 == 0), but this should not
250        // cause false positives in Sandbox::check because check guards against NONE.
251        assert!(Caps::NONE.contains(Caps::NONE));
252        assert!(Caps::ALL.contains(Caps::NONE));
253    }
254
255    #[test]
256    fn test_caps_union() {
257        let combined = Caps::SHELL.union(Caps::NETWORK);
258        assert!(combined.contains(Caps::SHELL));
259        assert!(combined.contains(Caps::NETWORK));
260        assert!(!combined.contains(Caps::FS_READ));
261    }
262
263    #[test]
264    fn test_caps_all_contains_every_cap() {
265        assert!(Caps::ALL.contains(Caps::FS_READ));
266        assert!(Caps::ALL.contains(Caps::FS_WRITE));
267        assert!(Caps::ALL.contains(Caps::SHELL));
268        assert!(Caps::ALL.contains(Caps::NETWORK));
269        assert!(Caps::ALL.contains(Caps::ENV_READ));
270        assert!(Caps::ALL.contains(Caps::ENV_WRITE));
271        assert!(Caps::ALL.contains(Caps::PROCESS));
272        assert!(Caps::ALL.contains(Caps::LLM));
273        assert!(Caps::ALL.contains(Caps::SERIAL));
274    }
275
276    #[test]
277    fn test_caps_strict_preset() {
278        assert!(Caps::STRICT.contains(Caps::SHELL));
279        assert!(Caps::STRICT.contains(Caps::FS_WRITE));
280        assert!(Caps::STRICT.contains(Caps::NETWORK));
281        assert!(Caps::STRICT.contains(Caps::ENV_WRITE));
282        assert!(Caps::STRICT.contains(Caps::PROCESS));
283        assert!(Caps::STRICT.contains(Caps::LLM));
284        assert!(Caps::STRICT.contains(Caps::SERIAL));
285        // strict does NOT deny read-only operations
286        assert!(!Caps::STRICT.contains(Caps::FS_READ));
287        assert!(!Caps::STRICT.contains(Caps::ENV_READ));
288    }
289
290    #[test]
291    fn test_caps_name_roundtrip() {
292        let caps = [
293            Caps::FS_READ,
294            Caps::FS_WRITE,
295            Caps::SHELL,
296            Caps::NETWORK,
297            Caps::ENV_READ,
298            Caps::ENV_WRITE,
299            Caps::PROCESS,
300            Caps::LLM,
301            Caps::SERIAL,
302        ];
303        for cap in caps {
304            let name = cap.name();
305            assert_eq!(
306                Caps::from_name(name),
307                Some(cap),
308                "roundtrip failed for {name}"
309            );
310        }
311    }
312
313    #[test]
314    fn test_caps_from_name_unknown() {
315        assert_eq!(Caps::from_name("garbage"), None);
316        assert_eq!(Caps::from_name(""), None);
317    }
318
319    #[test]
320    fn test_caps_display() {
321        assert_eq!(format!("{}", Caps::SHELL), "shell");
322        assert_eq!(format!("{}", Caps::NETWORK), "network");
323        assert_eq!(format!("{}", Caps::FS_READ), "fs-read");
324        assert_eq!(format!("{}", Caps::SERIAL), "serial");
325    }
326
327    #[test]
328    fn test_sandbox_allow_all_is_unrestricted() {
329        let sb = Sandbox::allow_all();
330        assert!(sb.is_unrestricted());
331    }
332
333    #[test]
334    fn test_sandbox_deny_is_restricted() {
335        let sb = Sandbox::deny(Caps::SHELL);
336        assert!(!sb.is_unrestricted());
337    }
338
339    #[test]
340    fn test_sandbox_default_is_unrestricted() {
341        let sb = Sandbox::default();
342        assert!(sb.is_unrestricted());
343    }
344
345    #[test]
346    fn test_sandbox_check_allowed() {
347        let sb = Sandbox::deny(Caps::SHELL);
348        assert!(sb.check(Caps::NETWORK, "http/get").is_ok());
349        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
350    }
351
352    #[test]
353    fn test_sandbox_check_denied() {
354        let sb = Sandbox::deny(Caps::SHELL);
355        let err = sb.check(Caps::SHELL, "shell").unwrap_err();
356        assert!(err.to_string().contains("Permission denied"));
357        assert!(err.to_string().contains("shell"));
358    }
359
360    #[test]
361    fn test_sandbox_check_denied_error_format() {
362        let sb = Sandbox::deny(Caps::NETWORK);
363        let err = sb.check(Caps::NETWORK, "http/get").unwrap_err();
364        let msg = err.to_string();
365        assert!(
366            msg.contains("http/get"),
367            "should contain function name: {msg}"
368        );
369        assert!(
370            msg.contains("network"),
371            "should contain capability name: {msg}"
372        );
373    }
374
375    #[test]
376    fn test_sandbox_check_multiple_denied() {
377        let sb = Sandbox::deny(Caps::SHELL.union(Caps::NETWORK));
378        assert!(sb.check(Caps::SHELL, "shell").is_err());
379        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
380        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
381    }
382
383    #[test]
384    fn test_sandbox_parse_cli_strict() {
385        let sb = Sandbox::parse_cli("strict").unwrap();
386        assert!(sb.check(Caps::SHELL, "shell").is_err());
387        assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
388        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
389        assert!(sb.check(Caps::SERIAL, "serial/list").is_err());
390        // strict allows reads
391        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
392        assert!(sb.check(Caps::ENV_READ, "env").is_ok());
393    }
394
395    #[test]
396    fn test_sandbox_parse_cli_all() {
397        let sb = Sandbox::parse_cli("all").unwrap();
398        assert!(sb.check(Caps::SHELL, "shell").is_err());
399        assert!(sb.check(Caps::FS_READ, "file/read").is_err());
400        assert!(sb.check(Caps::ENV_READ, "env").is_err());
401        assert!(sb.check(Caps::SERIAL, "serial/list").is_err());
402    }
403
404    #[test]
405    fn test_sandbox_parse_cli_no_prefix() {
406        let sb = Sandbox::parse_cli("no-shell,no-network").unwrap();
407        assert!(sb.check(Caps::SHELL, "shell").is_err());
408        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
409        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
410    }
411
412    #[test]
413    fn test_sandbox_parse_cli_without_no_prefix() {
414        let sb = Sandbox::parse_cli("shell,network").unwrap();
415        assert!(sb.check(Caps::SHELL, "shell").is_err());
416        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
417        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
418    }
419
420    #[test]
421    fn test_sandbox_parse_cli_single() {
422        let sb = Sandbox::parse_cli("no-fs-write").unwrap();
423        assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
424        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
425    }
426
427    #[test]
428    fn test_sandbox_parse_cli_with_spaces() {
429        let sb = Sandbox::parse_cli("no-shell, no-network").unwrap();
430        assert!(sb.check(Caps::SHELL, "shell").is_err());
431        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
432    }
433
434    #[test]
435    fn test_sandbox_parse_cli_empty_parts() {
436        let sb = Sandbox::parse_cli("no-shell,,no-network").unwrap();
437        assert!(sb.check(Caps::SHELL, "shell").is_err());
438        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
439    }
440
441    #[test]
442    fn test_sandbox_parse_cli_invalid() {
443        assert!(Sandbox::parse_cli("no-bogus").is_err());
444        assert!(Sandbox::parse_cli("no-shell,no-bogus").is_err());
445    }
446
447    #[test]
448    fn test_check_path_none_allows_everything() {
449        let sb = Sandbox::allow_all();
450        assert!(sb.check_path("/etc/passwd", "file/read").is_ok());
451        assert!(sb.check_path("relative.txt", "file/read").is_ok());
452    }
453
454    #[test]
455    fn test_check_path_inside_allowed_dir() {
456        let tmp = std::env::temp_dir();
457        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
458        let test_path = tmp.join("sema-test-file.txt");
459        std::fs::write(&test_path, "test").ok();
460        assert!(sb
461            .check_path(test_path.to_str().unwrap(), "file/read")
462            .is_ok());
463        let _ = std::fs::remove_file(&test_path);
464    }
465
466    #[test]
467    fn test_check_path_outside_allowed_dir() {
468        let tmp = std::env::temp_dir().join("sema-sandbox-test-dir");
469        std::fs::create_dir_all(&tmp).ok();
470        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
471        let result = sb.check_path("/etc/hosts", "file/read");
472        assert!(result.is_err());
473        let err = result.unwrap_err();
474        assert!(err.to_string().contains("Permission denied"), "{err}");
475        assert!(
476            err.to_string().contains("outside allowed directories"),
477            "{err}"
478        );
479        let _ = std::fs::remove_dir_all(&tmp);
480    }
481
482    #[test]
483    fn test_check_path_traversal_attempt() {
484        let tmp = std::env::temp_dir().join("sema-sandbox-traverse");
485        std::fs::create_dir_all(&tmp).ok();
486        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
487        let evil = format!("{}/../../../etc/passwd", tmp.display());
488        let result = sb.check_path(&evil, "file/read");
489        assert!(result.is_err(), "path traversal should be denied");
490        let _ = std::fs::remove_dir_all(&tmp);
491    }
492
493    #[test]
494    fn test_check_path_multiple_allowed() {
495        let dir_a = std::env::temp_dir().join("sema-sandbox-a");
496        let dir_b = std::env::temp_dir().join("sema-sandbox-b");
497        std::fs::create_dir_all(&dir_a).ok();
498        std::fs::create_dir_all(&dir_b).ok();
499        let sb = Sandbox::allow_all().with_allowed_paths(vec![dir_a.clone(), dir_b.clone()]);
500        let file_a = dir_a.join("ok.txt");
501        std::fs::write(&file_a, "a").ok();
502        let file_b = dir_b.join("ok.txt");
503        std::fs::write(&file_b, "b").ok();
504        assert!(sb.check_path(file_a.to_str().unwrap(), "file/read").is_ok());
505        assert!(sb.check_path(file_b.to_str().unwrap(), "file/read").is_ok());
506        assert!(sb.check_path("/etc/hosts", "file/read").is_err());
507        let _ = std::fs::remove_dir_all(&dir_a);
508        let _ = std::fs::remove_dir_all(&dir_b);
509    }
510
511    #[test]
512    fn test_parse_allowed_paths() {
513        let paths = Sandbox::parse_allowed_paths("/tmp, /var");
514        assert_eq!(paths.len(), 2);
515    }
516
517    #[test]
518    fn test_parse_allowed_paths_empty_parts() {
519        let paths = Sandbox::parse_allowed_paths("/tmp,,/var,");
520        assert_eq!(paths.len(), 2);
521    }
522
523    #[test]
524    fn test_with_allowed_paths_makes_restricted() {
525        let sb = Sandbox::allow_all().with_allowed_paths(vec![std::path::PathBuf::from("/tmp")]);
526        assert!(!sb.is_unrestricted());
527    }
528
529    #[test]
530    fn test_check_path_nonexistent_component_escape() {
531        let tmp = std::env::temp_dir().join("sema-sandbox-escape");
532        std::fs::create_dir_all(&tmp).ok();
533        let sb = Sandbox::allow_all().with_allowed_paths(vec![tmp.clone()]);
534        let evil = format!("{}/nonexistent/../../etc/passwd", tmp.display());
535        let result = sb.check_path(&evil, "file/write");
536        assert!(
537            result.is_err(),
538            "nonexistent component escape should be denied"
539        );
540        let _ = std::fs::remove_dir_all(&tmp);
541    }
542
543    #[test]
544    fn test_check_path_relative_nonexistent_escape() {
545        let tmp = std::env::temp_dir().join("sema-sandbox-rel-escape");
546        let allowed = tmp.join("allowed");
547        std::fs::create_dir_all(&allowed).ok();
548        let sb = Sandbox::allow_all().with_allowed_paths(vec![allowed.clone()]);
549        let evil = format!("{}/fake/../../../etc/hosts", allowed.display());
550        let result = sb.check_path(&evil, "file/write");
551        assert!(
552            result.is_err(),
553            "relative nonexistent escape should be denied"
554        );
555        let _ = std::fs::remove_dir_all(&tmp);
556    }
557}