sandbox_rs/execution/
init.rs

1//! Minimal init process for sandbox
2
3use nix::sys::signal::{SigHandler, Signal, signal};
4use nix::unistd::execv;
5use std::ffi::CString;
6use std::process::exit;
7
8/// Simple init process that manages sandbox
9pub struct SandboxInit {
10    /// Arguments to pass to user program
11    pub program: String,
12    pub args: Vec<String>,
13}
14
15impl SandboxInit {
16    /// Create new init process
17    pub fn new(program: String, args: Vec<String>) -> Self {
18        Self { program, args }
19    }
20
21    /// Run init process
22    /// This becomes PID 1 inside the sandbox
23    pub fn run(&self) -> ! {
24        // Setup signal handlers
25        Self::setup_signals();
26
27        Self::mount_procfs();
28        Self::mount_sysfs();
29
30        // Execute user program
31        self.exec_user_program();
32    }
33
34    /// Setup signal handlers for init
35    fn setup_signals() {
36        // Ignore SIGCHLD so we don't become zombie
37        unsafe {
38            let _ = signal(Signal::SIGCHLD, SigHandler::SigIgn);
39            let _ = signal(Signal::SIGTERM, SigHandler::SigDfl);
40        }
41    }
42
43    fn mount_procfs() {
44        use std::ffi::CString;
45
46        let _ = std::fs::create_dir("/proc");
47
48        // Use mount syscall instead of external mount command
49        let source = CString::new("proc").unwrap();
50        let target = CString::new("/proc").unwrap();
51        let fstype = CString::new("proc").unwrap();
52
53        unsafe {
54            libc::mount(
55                source.as_ptr(),
56                target.as_ptr(),
57                fstype.as_ptr(),
58                0,
59                std::ptr::null(),
60            );
61        }
62    }
63
64    fn mount_sysfs() {
65        use std::ffi::CString;
66
67        let _ = std::fs::create_dir("/sys");
68
69        // Use mount syscall instead of external mount command
70        let source = CString::new("sysfs").unwrap();
71        let target = CString::new("/sys").unwrap();
72        let fstype = CString::new("sysfs").unwrap();
73
74        unsafe {
75            libc::mount(
76                source.as_ptr(),
77                target.as_ptr(),
78                fstype.as_ptr(),
79                0,
80                std::ptr::null(),
81            );
82        }
83    }
84
85    /// Execute user program
86    fn exec_user_program(&self) -> ! {
87        let program_cstr = match CString::new(self.program.clone()) {
88            Ok(s) => s,
89            Err(_) => {
90                eprintln!("Invalid program name");
91                exit(1);
92            }
93        };
94
95        let args_cstr: Vec<CString> = self
96            .args
97            .iter()
98            .map(|arg| CString::new(arg.clone()).unwrap_or_else(|_| CString::new("").unwrap()))
99            .collect();
100
101        let args_refs: Vec<&CString> = vec![&program_cstr]
102            .into_iter()
103            .chain(args_cstr.iter())
104            .collect();
105
106        match execv(&program_cstr, &args_refs) {
107            Ok(_) => {
108                // execv replaces process, never returns on success
109                exit(0);
110            }
111            Err(e) => {
112                eprintln!("Failed to execute program: {}", e);
113                exit(1);
114            }
115        }
116    }
117
118    /// Reap zombie children
119    pub fn reap_children() {
120        use nix::sys::wait::{WaitStatus, waitpid};
121        use nix::unistd::Pid;
122
123        loop {
124            match waitpid(
125                Pid::from_raw(-1),
126                Some(nix::sys::wait::WaitPidFlag::WNOHANG),
127            ) {
128                Ok(WaitStatus::Exited(pid, _status)) => {
129                    eprintln!("[init] Child {} exited", pid);
130                }
131                Ok(WaitStatus::Signaled(pid, signal, _core)) => {
132                    eprintln!("[init] Child {} killed by {:?}", pid, signal);
133                }
134                Ok(WaitStatus::StillAlive) => break,
135                Ok(_) => continue,
136                Err(_) => break,
137            }
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_init_creation() {
148        let init = SandboxInit::new("/bin/echo".to_string(), vec!["hello".to_string()]);
149
150        assert_eq!(init.program, "/bin/echo");
151        assert_eq!(init.args.len(), 1);
152        assert_eq!(init.args[0], "hello");
153    }
154
155    #[test]
156    fn test_init_with_multiple_args() {
157        let init = SandboxInit::new(
158            "/bin/echo".to_string(),
159            vec![
160                "hello".to_string(),
161                "world".to_string(),
162                "from".to_string(),
163                "init".to_string(),
164            ],
165        );
166
167        assert_eq!(init.args.len(), 4);
168    }
169
170    #[test]
171    fn test_init_empty_args() {
172        let init = SandboxInit::new("/bin/sh".to_string(), Vec::new());
173
174        assert!(init.args.is_empty());
175    }
176
177    #[test]
178    fn test_mount_helpers_are_best_effort() {
179        SandboxInit::mount_procfs();
180        SandboxInit::mount_sysfs();
181    }
182
183    #[test]
184    fn test_setup_signals_runs() {
185        // Store original handlers so we can restore them
186        let original_sigchld = unsafe { signal(Signal::SIGCHLD, SigHandler::SigDfl) };
187
188        // Test the setup
189        SandboxInit::setup_signals();
190
191        // Restore original handlers to not affect other tests
192        unsafe {
193            let _ = signal(
194                Signal::SIGCHLD,
195                original_sigchld.unwrap_or(SigHandler::SigDfl),
196            );
197        }
198    }
199
200    #[test]
201    fn test_init_program_stored_correctly() {
202        let program = "/usr/bin/python3".to_string();
203        let init = SandboxInit::new(program.clone(), vec![]);
204
205        assert_eq!(init.program, program);
206    }
207
208    #[test]
209    fn test_init_args_stored_correctly() {
210        let args = vec!["arg1".to_string(), "arg2".to_string(), "arg3".to_string()];
211        let init = SandboxInit::new("/bin/test".to_string(), args.clone());
212
213        assert_eq!(init.args, args);
214    }
215
216    #[test]
217    fn test_init_clone() {
218        let init1 = SandboxInit::new("/bin/echo".to_string(), vec!["test".to_string()]);
219
220        let init2 = SandboxInit::new(init1.program.clone(), init1.args.clone());
221
222        assert_eq!(init1.program, init2.program);
223        assert_eq!(init1.args, init2.args);
224    }
225}