1use std::fmt;
2
3use crate::error::SemaError;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub struct Caps(u64);
7
8impl Caps {
9 pub const NONE: Caps = Caps(0);
10 pub const FS_READ: Caps = Caps(1 << 0);
11 pub const FS_WRITE: Caps = Caps(1 << 1);
12 pub const SHELL: Caps = Caps(1 << 2);
13 pub const NETWORK: Caps = Caps(1 << 3);
14 pub const ENV_READ: Caps = Caps(1 << 4);
15 pub const ENV_WRITE: Caps = Caps(1 << 5);
16 pub const PROCESS: Caps = Caps(1 << 6);
17 pub const LLM: Caps = Caps(1 << 7);
18
19 pub const ALL: Caps = Caps(
20 Self::FS_READ.0
21 | Self::FS_WRITE.0
22 | Self::SHELL.0
23 | Self::NETWORK.0
24 | Self::ENV_READ.0
25 | Self::ENV_WRITE.0
26 | Self::PROCESS.0
27 | Self::LLM.0,
28 );
29
30 pub const STRICT: Caps = Caps(
31 Self::SHELL.0
32 | Self::FS_WRITE.0
33 | Self::NETWORK.0
34 | Self::ENV_WRITE.0
35 | Self::PROCESS.0
36 | Self::LLM.0,
37 );
38
39 pub fn contains(self, other: Caps) -> bool {
40 self.0 & other.0 == other.0
41 }
42
43 pub fn union(self, other: Caps) -> Caps {
44 Caps(self.0 | other.0)
45 }
46
47 pub fn name(self) -> &'static str {
48 match self {
49 Caps::NONE => "none",
50 Caps::FS_READ => "fs-read",
51 Caps::FS_WRITE => "fs-write",
52 Caps::SHELL => "shell",
53 Caps::NETWORK => "network",
54 Caps::ENV_READ => "env-read",
55 Caps::ENV_WRITE => "env-write",
56 Caps::PROCESS => "process",
57 Caps::LLM => "llm",
58 Caps::ALL => "all",
59 Caps::STRICT => "strict",
60 _ => "unknown",
61 }
62 }
63
64 pub fn from_name(s: &str) -> Option<Self> {
65 match s {
66 "none" => Some(Caps::NONE),
67 "fs-read" => Some(Caps::FS_READ),
68 "fs-write" => Some(Caps::FS_WRITE),
69 "shell" => Some(Caps::SHELL),
70 "network" => Some(Caps::NETWORK),
71 "env-read" => Some(Caps::ENV_READ),
72 "env-write" => Some(Caps::ENV_WRITE),
73 "process" => Some(Caps::PROCESS),
74 "llm" => Some(Caps::LLM),
75 "all" => Some(Caps::ALL),
76 "strict" => Some(Caps::STRICT),
77 _ => None,
78 }
79 }
80}
81
82impl fmt::Display for Caps {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 f.write_str(self.name())
85 }
86}
87
88#[derive(Debug, Clone)]
89pub struct Sandbox {
90 pub denied: Caps,
91}
92
93impl Sandbox {
94 pub fn allow_all() -> Self {
95 Sandbox { denied: Caps::NONE }
96 }
97
98 pub fn deny(caps: Caps) -> Self {
99 Sandbox { denied: caps }
100 }
101
102 pub fn is_unrestricted(&self) -> bool {
103 self.denied == Caps::NONE
104 }
105
106 pub fn check(&self, required: Caps, fn_name: &str) -> Result<(), SemaError> {
107 if self.denied.contains(required) {
108 Err(SemaError::PermissionDenied {
109 function: fn_name.to_string(),
110 capability: required.name().to_string(),
111 })
112 } else {
113 Ok(())
114 }
115 }
116
117 pub fn parse_cli(value: &str) -> Result<Self, String> {
118 match value {
119 "strict" => Ok(Sandbox::deny(Caps::STRICT)),
120 "all" => Ok(Sandbox::deny(Caps::ALL)),
121 other => {
122 let mut denied = Caps::NONE;
123 for part in other.split(',') {
124 let part = part.trim();
125 if part.is_empty() {
126 continue;
127 }
128 let name = part.strip_prefix("no-").unwrap_or(part);
129 match Caps::from_name(name) {
130 Some(cap) => denied = denied.union(cap),
131 None => return Err(format!("unknown capability: {name}")),
132 }
133 }
134 Ok(Sandbox::deny(denied))
135 }
136 }
137 }
138}
139
140impl Default for Sandbox {
141 fn default() -> Self {
142 Sandbox::allow_all()
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_caps_contains() {
152 let all = Caps::FS_READ.union(Caps::FS_WRITE).union(Caps::SHELL);
153 assert!(all.contains(Caps::FS_READ));
154 assert!(all.contains(Caps::FS_WRITE));
155 assert!(all.contains(Caps::SHELL));
156 assert!(!all.contains(Caps::NETWORK));
157 assert!(!all.contains(Caps::LLM));
158 }
159
160 #[test]
161 fn test_caps_contains_none_is_always_true() {
162 assert!(Caps::NONE.contains(Caps::NONE));
165 assert!(Caps::ALL.contains(Caps::NONE));
166 }
167
168 #[test]
169 fn test_caps_union() {
170 let combined = Caps::SHELL.union(Caps::NETWORK);
171 assert!(combined.contains(Caps::SHELL));
172 assert!(combined.contains(Caps::NETWORK));
173 assert!(!combined.contains(Caps::FS_READ));
174 }
175
176 #[test]
177 fn test_caps_all_contains_every_cap() {
178 assert!(Caps::ALL.contains(Caps::FS_READ));
179 assert!(Caps::ALL.contains(Caps::FS_WRITE));
180 assert!(Caps::ALL.contains(Caps::SHELL));
181 assert!(Caps::ALL.contains(Caps::NETWORK));
182 assert!(Caps::ALL.contains(Caps::ENV_READ));
183 assert!(Caps::ALL.contains(Caps::ENV_WRITE));
184 assert!(Caps::ALL.contains(Caps::PROCESS));
185 assert!(Caps::ALL.contains(Caps::LLM));
186 }
187
188 #[test]
189 fn test_caps_strict_preset() {
190 assert!(Caps::STRICT.contains(Caps::SHELL));
191 assert!(Caps::STRICT.contains(Caps::FS_WRITE));
192 assert!(Caps::STRICT.contains(Caps::NETWORK));
193 assert!(Caps::STRICT.contains(Caps::ENV_WRITE));
194 assert!(Caps::STRICT.contains(Caps::PROCESS));
195 assert!(Caps::STRICT.contains(Caps::LLM));
196 assert!(!Caps::STRICT.contains(Caps::FS_READ));
198 assert!(!Caps::STRICT.contains(Caps::ENV_READ));
199 }
200
201 #[test]
202 fn test_caps_name_roundtrip() {
203 let caps = [
204 Caps::FS_READ,
205 Caps::FS_WRITE,
206 Caps::SHELL,
207 Caps::NETWORK,
208 Caps::ENV_READ,
209 Caps::ENV_WRITE,
210 Caps::PROCESS,
211 Caps::LLM,
212 ];
213 for cap in caps {
214 let name = cap.name();
215 assert_eq!(
216 Caps::from_name(name),
217 Some(cap),
218 "roundtrip failed for {name}"
219 );
220 }
221 }
222
223 #[test]
224 fn test_caps_from_name_unknown() {
225 assert_eq!(Caps::from_name("garbage"), None);
226 assert_eq!(Caps::from_name(""), None);
227 }
228
229 #[test]
230 fn test_caps_display() {
231 assert_eq!(format!("{}", Caps::SHELL), "shell");
232 assert_eq!(format!("{}", Caps::NETWORK), "network");
233 assert_eq!(format!("{}", Caps::FS_READ), "fs-read");
234 }
235
236 #[test]
237 fn test_sandbox_allow_all_is_unrestricted() {
238 let sb = Sandbox::allow_all();
239 assert!(sb.is_unrestricted());
240 }
241
242 #[test]
243 fn test_sandbox_deny_is_restricted() {
244 let sb = Sandbox::deny(Caps::SHELL);
245 assert!(!sb.is_unrestricted());
246 }
247
248 #[test]
249 fn test_sandbox_default_is_unrestricted() {
250 let sb = Sandbox::default();
251 assert!(sb.is_unrestricted());
252 }
253
254 #[test]
255 fn test_sandbox_check_allowed() {
256 let sb = Sandbox::deny(Caps::SHELL);
257 assert!(sb.check(Caps::NETWORK, "http/get").is_ok());
258 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
259 }
260
261 #[test]
262 fn test_sandbox_check_denied() {
263 let sb = Sandbox::deny(Caps::SHELL);
264 let err = sb.check(Caps::SHELL, "shell").unwrap_err();
265 assert!(err.to_string().contains("Permission denied"));
266 assert!(err.to_string().contains("shell"));
267 }
268
269 #[test]
270 fn test_sandbox_check_denied_error_format() {
271 let sb = Sandbox::deny(Caps::NETWORK);
272 let err = sb.check(Caps::NETWORK, "http/get").unwrap_err();
273 let msg = err.to_string();
274 assert!(
275 msg.contains("http/get"),
276 "should contain function name: {msg}"
277 );
278 assert!(
279 msg.contains("network"),
280 "should contain capability name: {msg}"
281 );
282 }
283
284 #[test]
285 fn test_sandbox_check_multiple_denied() {
286 let sb = Sandbox::deny(Caps::SHELL.union(Caps::NETWORK));
287 assert!(sb.check(Caps::SHELL, "shell").is_err());
288 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
289 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
290 }
291
292 #[test]
293 fn test_sandbox_parse_cli_strict() {
294 let sb = Sandbox::parse_cli("strict").unwrap();
295 assert!(sb.check(Caps::SHELL, "shell").is_err());
296 assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
297 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
298 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
300 assert!(sb.check(Caps::ENV_READ, "env").is_ok());
301 }
302
303 #[test]
304 fn test_sandbox_parse_cli_all() {
305 let sb = Sandbox::parse_cli("all").unwrap();
306 assert!(sb.check(Caps::SHELL, "shell").is_err());
307 assert!(sb.check(Caps::FS_READ, "file/read").is_err());
308 assert!(sb.check(Caps::ENV_READ, "env").is_err());
309 }
310
311 #[test]
312 fn test_sandbox_parse_cli_no_prefix() {
313 let sb = Sandbox::parse_cli("no-shell,no-network").unwrap();
314 assert!(sb.check(Caps::SHELL, "shell").is_err());
315 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
316 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
317 }
318
319 #[test]
320 fn test_sandbox_parse_cli_without_no_prefix() {
321 let sb = Sandbox::parse_cli("shell,network").unwrap();
322 assert!(sb.check(Caps::SHELL, "shell").is_err());
323 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
324 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
325 }
326
327 #[test]
328 fn test_sandbox_parse_cli_single() {
329 let sb = Sandbox::parse_cli("no-fs-write").unwrap();
330 assert!(sb.check(Caps::FS_WRITE, "file/write").is_err());
331 assert!(sb.check(Caps::FS_READ, "file/read").is_ok());
332 }
333
334 #[test]
335 fn test_sandbox_parse_cli_with_spaces() {
336 let sb = Sandbox::parse_cli("no-shell, no-network").unwrap();
337 assert!(sb.check(Caps::SHELL, "shell").is_err());
338 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
339 }
340
341 #[test]
342 fn test_sandbox_parse_cli_empty_parts() {
343 let sb = Sandbox::parse_cli("no-shell,,no-network").unwrap();
344 assert!(sb.check(Caps::SHELL, "shell").is_err());
345 assert!(sb.check(Caps::NETWORK, "http/get").is_err());
346 }
347
348 #[test]
349 fn test_sandbox_parse_cli_invalid() {
350 assert!(Sandbox::parse_cli("no-bogus").is_err());
351 assert!(Sandbox::parse_cli("no-shell,no-bogus").is_err());
352 }
353}