sandbox_rs/isolation/
seccomp.rs

1//! Seccomp filter building and management
2
3use crate::errors::{Result, SandboxError};
4use std::collections::HashSet;
5
6/// Seccomp filter profile
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum SeccompProfile {
9    /// Minimal profile - only essential syscalls
10    Minimal,
11    /// IO-heavy profile - includes file operations
12    IoHeavy,
13    /// Compute profile - includes memory operations
14    Compute,
15    /// Network profile - includes socket operations
16    Network,
17    /// Unrestricted - allow most syscalls
18    Unrestricted,
19}
20
21impl SeccompProfile {
22    /// Get all profiles
23    pub fn all() -> Vec<Self> {
24        vec![
25            SeccompProfile::Minimal,
26            SeccompProfile::IoHeavy,
27            SeccompProfile::Compute,
28            SeccompProfile::Network,
29            SeccompProfile::Unrestricted,
30        ]
31    }
32
33    /// Get description of profile
34    pub fn description(&self) -> &'static str {
35        match self {
36            SeccompProfile::Minimal => "Minimal syscalls only",
37            SeccompProfile::IoHeavy => "With file I/O operations",
38            SeccompProfile::Compute => "With memory operations",
39            SeccompProfile::Network => "With socket operations",
40            SeccompProfile::Unrestricted => "Allow most syscalls",
41        }
42    }
43}
44
45/// Seccomp filter builder
46#[derive(Debug, Clone)]
47pub struct SeccompFilter {
48    allowed: HashSet<String>,
49    blocked: HashSet<String>,
50    kill_on_violation: bool,
51    profile: SeccompProfile,
52}
53
54impl SeccompFilter {
55    /// Create filter from profile
56    pub fn from_profile(profile: SeccompProfile) -> Self {
57        let allowed = Self::syscalls_for_profile(&profile);
58        Self {
59            allowed,
60            blocked: HashSet::new(),
61            kill_on_violation: true,
62            profile,
63        }
64    }
65
66    /// Create minimal filter
67    pub fn minimal() -> Self {
68        Self::from_profile(SeccompProfile::Minimal)
69    }
70
71    /// Get syscalls for a profile
72    fn syscalls_for_profile(profile: &SeccompProfile) -> HashSet<String> {
73        let mut syscalls = HashSet::new();
74
75        // Always allowed
76        let always_allowed = vec![
77            // Process management
78            "exit",
79            "exit_group",
80            "clone",
81            "fork",
82            "vfork",
83            // Signal handling
84            "rt_sigaction",
85            "rt_sigprocmask",
86            "rt_sigpending",
87            "rt_sigtimedwait",
88            "rt_sigqueueinfo",
89            "rt_sigreturn",
90            "kill",
91            "tkill",
92            "tgkill",
93            "sigaltstack",
94            // Basic I/O
95            "read",
96            "write",
97            "readv",
98            "writev",
99            "pread64",
100            "pwrite64",
101            "access",
102            "faccessat",
103            // File operations
104            "open",
105            "openat",
106            "close",
107            "stat",
108            "fstat",
109            "lstat",
110            "fcntl",
111            "ioctl",
112            // Memory
113            "mmap",
114            "munmap",
115            "mremap",
116            "mprotect",
117            "madvise",
118            "brk",
119            "mlock",
120            "munlock",
121            "mlockall",
122            "munlockall",
123            // Process execution
124            "execve",
125            "execveat",
126            // Waiting
127            "wait4",
128            "waitpid",
129            "waitid",
130            // File descriptors
131            "dup",
132            "dup2",
133            "dup3",
134            "pipe",
135            "pipe2",
136            // Getting time
137            "clock_gettime",
138            "clock_getres",
139            "gettimeofday",
140            "time",
141            // Process info
142            "getpid",
143            "getppid",
144            "getuid",
145            "geteuid",
146            "getgid",
147            "getegid",
148            "uname",
149            "getpgrp",
150            "getpgid",
151            "setpgid",
152            "getsid",
153            "setsid",
154            // Limits
155            "getrlimit",
156            "setrlimit",
157            "getrusage",
158            // Misc allowed
159            "futex",
160            "rt_sigpending",
161            "set_tid_address",
162            "set_robust_list",
163            "get_robust_list",
164            "pselect6",
165            "ppoll",
166            "epoll_create1",
167            "epoll_ctl",
168            "epoll_wait",
169            "poll",
170            "select",
171            "getcwd",
172            "chdir",
173            "fchdir",
174            "getdents",
175            "getdents64",
176            "prctl",
177            "arch_prctl",
178            "rseq",
179            "newfstatat",
180            "getrandom",
181            "statx",
182            "prlimit64",
183        ];
184
185        for syscall in always_allowed {
186            syscalls.insert(syscall.to_string());
187        }
188
189        // Profile-specific syscalls
190        match profile {
191            SeccompProfile::Minimal => {
192                // Just the basics above
193            }
194            SeccompProfile::IoHeavy => {
195                for syscall in &[
196                    "mkdir",
197                    "mkdirat",
198                    "rmdir",
199                    "unlink",
200                    "unlinkat",
201                    "rename",
202                    "renameat",
203                    "link",
204                    "linkat",
205                    "symlink",
206                    "symlinkat",
207                    "readlink",
208                    "readlinkat",
209                    "chmod",
210                    "fchmod",
211                    "fchmodat",
212                    "chown",
213                    "fchown",
214                    "fchownat",
215                    "lchown",
216                    "utimes",
217                    "futimesat",
218                    "utime",
219                    "utimensat",
220                    "truncate",
221                    "ftruncate",
222                    "fallocate",
223                    "access",
224                    "faccessat",
225                    "sendfile",
226                    "splice",
227                    "tee",
228                    "vmsplice",
229                    "statfs",
230                    "fstatfs",
231                    "fsync",
232                    "fdatasync",
233                ] {
234                    syscalls.insert(syscall.to_string());
235                }
236            }
237            SeccompProfile::Compute => {
238                for syscall in &[
239                    "sigaltstack",
240                    "sched_yield",
241                    "sched_getscheduler",
242                    "sched_setscheduler",
243                    "sched_getparam",
244                    "sched_setparam",
245                    "sched_get_priority_max",
246                    "sched_get_priority_min",
247                    "sched_rr_get_interval",
248                    "sched_getaffinity",
249                    "sched_setaffinity",
250                    "mbind",
251                    "get_mempolicy",
252                    "set_mempolicy",
253                    "migrate_pages",
254                    "move_pages",
255                    "membarrier",
256                ] {
257                    syscalls.insert(syscall.to_string());
258                }
259            }
260            SeccompProfile::Network => {
261                for syscall in &[
262                    "socket",
263                    "socketpair",
264                    "bind",
265                    "listen",
266                    "accept",
267                    "accept4",
268                    "connect",
269                    "shutdown",
270                    "sendto",
271                    "recvfrom",
272                    "sendmsg",
273                    "recvmsg",
274                    "sendmmsg",
275                    "recvmmsg",
276                    "setsockopt",
277                    "getsockopt",
278                    "getsockname",
279                    "getpeername",
280                ] {
281                    syscalls.insert(syscall.to_string());
282                }
283                // Also include IoHeavy syscalls
284                for syscall in &["open", "openat", "read", "write", "close"] {
285                    syscalls.insert(syscall.to_string());
286                }
287            }
288            SeccompProfile::Unrestricted => {
289                // Add many more syscalls for unrestricted
290                for syscall in &[
291                    "ptrace",
292                    "process_vm_readv",
293                    "process_vm_writev",
294                    "perf_event_open",
295                    "bpf",
296                    "seccomp",
297                    "mount",
298                    "umount2",
299                    "pivot_root",
300                    "capget",
301                    "capset",
302                    "setuid",
303                    "setgid",
304                    "setreuid",
305                    "setregid",
306                    "setresuid",
307                    "setresgid",
308                    "getgroups",
309                    "setgroups",
310                    "setfsgid",
311                    "setfsuid",
312                ] {
313                    syscalls.insert(syscall.to_string());
314                }
315            }
316        }
317
318        syscalls
319    }
320
321    /// Add syscall to whitelist
322    pub fn allow_syscall(&mut self, name: impl Into<String>) {
323        self.allowed.insert(name.into());
324    }
325
326    /// Block a syscall (deny even if in whitelist)
327    pub fn block_syscall(&mut self, name: impl Into<String>) {
328        self.blocked.insert(name.into());
329    }
330
331    /// Check if syscall is allowed
332    pub fn is_allowed(&self, name: &str) -> bool {
333        if self.blocked.contains(name) {
334            return false;
335        }
336        self.allowed.contains(name)
337    }
338
339    /// Get allowed syscalls
340    pub fn allowed_syscalls(&self) -> &HashSet<String> {
341        &self.allowed
342    }
343
344    /// Get blocked syscalls
345    pub fn blocked_syscalls(&self) -> &HashSet<String> {
346        &self.blocked
347    }
348
349    /// Count allowed syscalls
350    pub fn allowed_count(&self) -> usize {
351        self.allowed.len() - self.blocked.len()
352    }
353
354    /// Check if killing on violation
355    pub fn is_kill_on_violation(&self) -> bool {
356        self.kill_on_violation
357    }
358
359    /// Set kill on violation
360    pub fn set_kill_on_violation(&mut self, kill: bool) {
361        self.kill_on_violation = kill;
362    }
363
364    /// Get the profile used to create this filter
365    pub fn profile(&self) -> SeccompProfile {
366        self.profile.clone()
367    }
368
369    /// Validate that filter is correct
370    pub fn validate(&self) -> Result<()> {
371        if self.allowed.is_empty() && self.profile != SeccompProfile::Unrestricted {
372            return Err(SandboxError::Seccomp(
373                "Filter has no allowed syscalls".to_string(),
374            ));
375        }
376        Ok(())
377    }
378
379    /// Export as BPF program (simplified - just returns syscall names)
380    pub fn export(&self) -> Result<Vec<String>> {
381        self.validate()?;
382        let mut list: Vec<_> = self.allowed.iter().cloned().collect();
383        list.sort();
384        Ok(list)
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_seccomp_profile_all() {
394        let profiles = SeccompProfile::all();
395        assert_eq!(profiles.len(), 5);
396    }
397
398    #[test]
399    fn test_seccomp_profile_description() {
400        assert!(!SeccompProfile::Minimal.description().is_empty());
401        assert_ne!(
402            SeccompProfile::Minimal.description(),
403            SeccompProfile::Network.description()
404        );
405    }
406
407    #[test]
408    fn test_seccomp_filter_minimal() {
409        let filter = SeccompFilter::minimal();
410        assert!(filter.is_allowed("read"));
411        assert!(filter.is_allowed("write"));
412        assert!(filter.is_allowed("exit"));
413        assert!(!filter.is_allowed("ptrace"));
414        assert!(filter.allowed_count() > 20);
415    }
416
417    #[test]
418    fn test_seccomp_filter_io_heavy() {
419        let filter = SeccompFilter::from_profile(SeccompProfile::IoHeavy);
420        assert!(filter.is_allowed("read"));
421        assert!(filter.is_allowed("mkdir"));
422        assert!(filter.is_allowed("unlink"));
423        let io_count = filter.allowed_count();
424
425        let minimal = SeccompFilter::minimal();
426        assert!(io_count > minimal.allowed_count());
427    }
428
429    #[test]
430    fn test_seccomp_filter_network() {
431        let filter = SeccompFilter::from_profile(SeccompProfile::Network);
432        assert!(filter.is_allowed("socket"));
433        assert!(filter.is_allowed("connect"));
434        assert!(filter.is_allowed("bind"));
435    }
436
437    #[test]
438    fn test_seccomp_filter_allow_syscall() {
439        let mut filter = SeccompFilter::minimal();
440        filter.allow_syscall("custom_syscall");
441        assert!(filter.is_allowed("custom_syscall"));
442    }
443
444    #[test]
445    fn test_seccomp_filter_block_syscall() {
446        let mut filter = SeccompFilter::minimal();
447        filter.block_syscall("read");
448        assert!(!filter.is_allowed("read"));
449    }
450
451    #[test]
452    fn test_seccomp_filter_block_overrides_allow() {
453        let mut filter = SeccompFilter::minimal();
454        assert!(filter.is_allowed("write"));
455        filter.block_syscall("write");
456        assert!(!filter.is_allowed("write"));
457    }
458
459    #[test]
460    fn test_seccomp_filter_validate() {
461        let filter = SeccompFilter::minimal();
462        assert!(filter.validate().is_ok());
463
464        let empty_filter = SeccompFilter {
465            allowed: HashSet::new(),
466            blocked: HashSet::new(),
467            kill_on_violation: true,
468            profile: SeccompProfile::Minimal,
469        };
470        assert!(empty_filter.validate().is_err());
471    }
472
473    #[test]
474    fn test_seccomp_filter_export() {
475        let filter = SeccompFilter::minimal();
476        let syscalls = filter.export().unwrap();
477        assert!(!syscalls.is_empty());
478        assert!(syscalls.contains(&"read".to_string()));
479
480        // Should be sorted
481        let mut sorted = syscalls.clone();
482        sorted.sort();
483        assert_eq!(syscalls, sorted);
484    }
485
486    #[test]
487    fn test_seccomp_kill_on_violation() {
488        let mut filter = SeccompFilter::minimal();
489        assert!(filter.is_kill_on_violation());
490
491        filter.set_kill_on_violation(false);
492        assert!(!filter.is_kill_on_violation());
493    }
494
495    #[test]
496    fn test_seccomp_filter_comparison() {
497        let minimal = SeccompFilter::minimal();
498        let compute = SeccompFilter::from_profile(SeccompProfile::Compute);
499
500        // Compute should have at least all minimal syscalls
501        for syscall in minimal.allowed_syscalls() {
502            if !minimal.blocked_syscalls().contains(syscall) {
503                // Most minimal should be in compute, but let's just check it doesn't error
504                let _ = compute.is_allowed(syscall);
505            }
506        }
507    }
508
509    #[test]
510    fn test_validate_unrestricted_with_no_allowed() {
511        let filter = SeccompFilter {
512            allowed: HashSet::new(),
513            blocked: HashSet::new(),
514            kill_on_violation: true,
515            profile: SeccompProfile::Unrestricted,
516        };
517        assert!(filter.validate().is_ok());
518    }
519}