Skip to main content

sema_core/
sandbox.rs

1use std::fmt;
2
3use crate::error::SemaError;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub struct Caps(u64);
7
8impl Caps {
9    pub const NONE: Caps = Caps(0);
10    pub const FS_READ: Caps = Caps(1 << 0);
11    pub const FS_WRITE: Caps = Caps(1 << 1);
12    pub const SHELL: Caps = Caps(1 << 2);
13    pub const NETWORK: Caps = Caps(1 << 3);
14    pub const ENV_READ: Caps = Caps(1 << 4);
15    pub const ENV_WRITE: Caps = Caps(1 << 5);
16    pub const PROCESS: Caps = Caps(1 << 6);
17    pub const LLM: Caps = Caps(1 << 7);
18
19    pub const ALL: Caps = Caps(
20        Self::FS_READ.0
21            | Self::FS_WRITE.0
22            | Self::SHELL.0
23            | Self::NETWORK.0
24            | Self::ENV_READ.0
25            | Self::ENV_WRITE.0
26            | Self::PROCESS.0
27            | Self::LLM.0,
28    );
29
30    pub const STRICT: Caps = Caps(
31        Self::SHELL.0
32            | Self::FS_WRITE.0
33            | Self::NETWORK.0
34            | Self::ENV_WRITE.0
35            | Self::PROCESS.0
36            | Self::LLM.0,
37    );
38
39    pub fn contains(self, other: Caps) -> bool {
40        self.0 & other.0 == other.0
41    }
42
43    pub fn union(self, other: Caps) -> Caps {
44        Caps(self.0 | other.0)
45    }
46
47    pub fn name(self) -> &'static str {
48        match self {
49            Caps::NONE => "none",
50            Caps::FS_READ => "fs-read",
51            Caps::FS_WRITE => "fs-write",
52            Caps::SHELL => "shell",
53            Caps::NETWORK => "network",
54            Caps::ENV_READ => "env-read",
55            Caps::ENV_WRITE => "env-write",
56            Caps::PROCESS => "process",
57            Caps::LLM => "llm",
58            Caps::ALL => "all",
59            Caps::STRICT => "strict",
60            _ => "unknown",
61        }
62    }
63
64    pub fn from_name(s: &str) -> Option<Self> {
65        match s {
66            "none" => Some(Caps::NONE),
67            "fs-read" => Some(Caps::FS_READ),
68            "fs-write" => Some(Caps::FS_WRITE),
69            "shell" => Some(Caps::SHELL),
70            "network" => Some(Caps::NETWORK),
71            "env-read" => Some(Caps::ENV_READ),
72            "env-write" => Some(Caps::ENV_WRITE),
73            "process" => Some(Caps::PROCESS),
74            "llm" => Some(Caps::LLM),
75            "all" => Some(Caps::ALL),
76            "strict" => Some(Caps::STRICT),
77            _ => None,
78        }
79    }
80}
81
82impl fmt::Display for Caps {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        f.write_str(self.name())
85    }
86}
87
88#[derive(Debug, Clone)]
89pub struct Sandbox {
90    pub denied: Caps,
91}
92
93impl Sandbox {
94    pub fn allow_all() -> Self {
95        Sandbox { denied: Caps::NONE }
96    }
97
98    pub fn deny(caps: Caps) -> Self {
99        Sandbox { denied: caps }
100    }
101
102    pub fn is_unrestricted(&self) -> bool {
103        self.denied == Caps::NONE
104    }
105
106    pub fn check(&self, required: Caps, fn_name: &str) -> Result<(), SemaError> {
107        if self.denied.contains(required) {
108            Err(SemaError::PermissionDenied {
109                function: fn_name.to_string(),
110                capability: required.name().to_string(),
111            })
112        } else {
113            Ok(())
114        }
115    }
116
117    pub fn parse_cli(value: &str) -> Result<Self, String> {
118        match value {
119            "strict" => Ok(Sandbox::deny(Caps::STRICT)),
120            "all" => Ok(Sandbox::deny(Caps::ALL)),
121            other => {
122                let mut denied = Caps::NONE;
123                for part in other.split(',') {
124                    let part = part.trim();
125                    if part.is_empty() {
126                        continue;
127                    }
128                    let name = part.strip_prefix("no-").unwrap_or(part);
129                    match Caps::from_name(name) {
130                        Some(cap) => denied = denied.union(cap),
131                        None => return Err(format!("unknown capability: {name}")),
132                    }
133                }
134                Ok(Sandbox::deny(denied))
135            }
136        }
137    }
138}
139
140impl Default for Sandbox {
141    fn default() -> Self {
142        Sandbox::allow_all()
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn test_caps_contains() {
152        let all = Caps::FS_READ.union(Caps::FS_WRITE).union(Caps::SHELL);
153        assert!(all.contains(Caps::FS_READ));
154        assert!(all.contains(Caps::FS_WRITE));
155        assert!(all.contains(Caps::SHELL));
156        assert!(!all.contains(Caps::NETWORK));
157        assert!(!all.contains(Caps::LLM));
158    }
159
160    #[test]
161    fn test_caps_contains_none_is_always_true() {
162        // NONE.contains(NONE) = true since (0 & 0 == 0), but this should not
163        // cause false positives in Sandbox::check because check guards against NONE.
164        assert!(Caps::NONE.contains(Caps::NONE));
165        assert!(Caps::ALL.contains(Caps::NONE));
166    }
167
168    #[test]
169    fn test_caps_union() {
170        let combined = Caps::SHELL.union(Caps::NETWORK);
171        assert!(combined.contains(Caps::SHELL));
172        assert!(combined.contains(Caps::NETWORK));
173        assert!(!combined.contains(Caps::FS_READ));
174    }
175
176    #[test]
177    fn test_caps_all_contains_every_cap() {
178        assert!(Caps::ALL.contains(Caps::FS_READ));
179        assert!(Caps::ALL.contains(Caps::FS_WRITE));
180        assert!(Caps::ALL.contains(Caps::SHELL));
181        assert!(Caps::ALL.contains(Caps::NETWORK));
182        assert!(Caps::ALL.contains(Caps::ENV_READ));
183        assert!(Caps::ALL.contains(Caps::ENV_WRITE));
184        assert!(Caps::ALL.contains(Caps::PROCESS));
185        assert!(Caps::ALL.contains(Caps::LLM));
186    }
187
188    #[test]
189    fn test_caps_strict_preset() {
190        assert!(Caps::STRICT.contains(Caps::SHELL));
191        assert!(Caps::STRICT.contains(Caps::FS_WRITE));
192        assert!(Caps::STRICT.contains(Caps::NETWORK));
193        assert!(Caps::STRICT.contains(Caps::ENV_WRITE));
194        assert!(Caps::STRICT.contains(Caps::PROCESS));
195        assert!(Caps::STRICT.contains(Caps::LLM));
196        // strict does NOT deny read-only operations
197        assert!(!Caps::STRICT.contains(Caps::FS_READ));
198        assert!(!Caps::STRICT.contains(Caps::ENV_READ));
199    }
200
201    #[test]
202    fn test_caps_name_roundtrip() {
203        let caps = [
204            Caps::FS_READ,
205            Caps::FS_WRITE,
206            Caps::SHELL,
207            Caps::NETWORK,
208            Caps::ENV_READ,
209            Caps::ENV_WRITE,
210            Caps::PROCESS,
211            Caps::LLM,
212        ];
213        for cap in caps {
214            let name = cap.name();
215            assert_eq!(
216                Caps::from_name(name),
217                Some(cap),
218                "roundtrip failed for {name}"
219            );
220        }
221    }
222
223    #[test]
224    fn test_caps_from_name_unknown() {
225        assert_eq!(Caps::from_name("garbage"), None);
226        assert_eq!(Caps::from_name(""), None);
227    }
228
229    #[test]
230    fn test_caps_display() {
231        assert_eq!(format!("{}", Caps::SHELL), "shell");
232        assert_eq!(format!("{}", Caps::NETWORK), "network");
233        assert_eq!(format!("{}", Caps::FS_READ), "fs-read");
234    }
235
236    #[test]
237    fn test_sandbox_allow_all_is_unrestricted() {
238        let sb = Sandbox::allow_all();
239        assert!(sb.is_unrestricted());
240    }
241
242    #[test]
243    fn test_sandbox_deny_is_restricted() {
244        let sb = Sandbox::deny(Caps::SHELL);
245        assert!(!sb.is_unrestricted());
246    }
247
248    #[test]
249    fn test_sandbox_default_is_unrestricted() {
250        let sb = Sandbox::default();
251        assert!(sb.is_unrestricted());
252    }
253
254    #[test]
255    fn test_sandbox_check_allowed() {
256        let sb = Sandbox::deny(Caps::SHELL);
257        assert!(sb.check(Caps::NETWORK, "http/get").is_ok());
258        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
259    }
260
261    #[test]
262    fn test_sandbox_check_denied() {
263        let sb = Sandbox::deny(Caps::SHELL);
264        let err = sb.check(Caps::SHELL, "shell").unwrap_err();
265        assert!(err.to_string().contains("Permission denied"));
266        assert!(err.to_string().contains("shell"));
267    }
268
269    #[test]
270    fn test_sandbox_check_denied_error_format() {
271        let sb = Sandbox::deny(Caps::NETWORK);
272        let err = sb.check(Caps::NETWORK, "http/get").unwrap_err();
273        let msg = err.to_string();
274        assert!(
275            msg.contains("http/get"),
276            "should contain function name: {msg}"
277        );
278        assert!(
279            msg.contains("network"),
280            "should contain capability name: {msg}"
281        );
282    }
283
284    #[test]
285    fn test_sandbox_check_multiple_denied() {
286        let sb = Sandbox::deny(Caps::SHELL.union(Caps::NETWORK));
287        assert!(sb.check(Caps::SHELL, "shell").is_err());
288        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
289        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
290    }
291
292    #[test]
293    fn test_sandbox_parse_cli_strict() {
294        let sb = Sandbox::parse_cli("strict").unwrap();
295        assert!(sb.check(Caps::SHELL, "shell").is_err());
296        assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
297        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
298        // strict allows reads
299        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
300        assert!(sb.check(Caps::ENV_READ, "env").is_ok());
301    }
302
303    #[test]
304    fn test_sandbox_parse_cli_all() {
305        let sb = Sandbox::parse_cli("all").unwrap();
306        assert!(sb.check(Caps::SHELL, "shell").is_err());
307        assert!(sb.check(Caps::FS_READ, "file/read").is_err());
308        assert!(sb.check(Caps::ENV_READ, "env").is_err());
309    }
310
311    #[test]
312    fn test_sandbox_parse_cli_no_prefix() {
313        let sb = Sandbox::parse_cli("no-shell,no-network").unwrap();
314        assert!(sb.check(Caps::SHELL, "shell").is_err());
315        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
316        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
317    }
318
319    #[test]
320    fn test_sandbox_parse_cli_without_no_prefix() {
321        let sb = Sandbox::parse_cli("shell,network").unwrap();
322        assert!(sb.check(Caps::SHELL, "shell").is_err());
323        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
324        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
325    }
326
327    #[test]
328    fn test_sandbox_parse_cli_single() {
329        let sb = Sandbox::parse_cli("no-fs-write").unwrap();
330        assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
331        assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
332    }
333
334    #[test]
335    fn test_sandbox_parse_cli_with_spaces() {
336        let sb = Sandbox::parse_cli("no-shell, no-network").unwrap();
337        assert!(sb.check(Caps::SHELL, "shell").is_err());
338        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
339    }
340
341    #[test]
342    fn test_sandbox_parse_cli_empty_parts() {
343        let sb = Sandbox::parse_cli("no-shell,,no-network").unwrap();
344        assert!(sb.check(Caps::SHELL, "shell").is_err());
345        assert!(sb.check(Caps::NETWORK, "http/get").is_err());
346    }
347
348    #[test]
349    fn test_sandbox_parse_cli_invalid() {
350        assert!(Sandbox::parse_cli("no-bogus").is_err());
351        assert!(Sandbox::parse_cli("no-shell,no-bogus").is_err());
352    }
353}