1use std::collections::{HashMap, HashSet};
19use std::net::IpAddr;
20use std::sync::{Arc, RwLock};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum SyscallCategory {
29 File,
31 Network,
33 Process,
35 Memory,
37}
38
39#[derive(Debug, Clone)]
45pub struct SyscallEvent {
46 pub syscall: String,
48 pub category: SyscallCategory,
50 pub pid: u32,
52 pub parent_pid: Option<u32>,
54 pub path: Option<String>,
56 pub host: Option<IpAddr>,
58 pub port: Option<u16>,
60 pub size: Option<u64>,
62 pub argv: Option<Vec<String>>,
64 pub denied: bool,
66}
67
68impl SyscallEvent {
69 pub fn path_contains(&self, s: &str) -> bool {
71 self.path.as_ref().map_or(false, |p| p.contains(s))
72 }
73
74 pub fn argv_contains(&self, s: &str) -> bool {
76 self.argv.as_ref().map_or(false, |args| args.iter().any(|a| a.contains(s)))
77 }
78}
79
80#[derive(Debug, Clone)]
89pub struct LivePolicy {
90 pub allowed_ips: HashSet<IpAddr>,
92 pub max_memory_bytes: u64,
94 pub max_processes: u32,
96}
97
98pub struct PolicyContext {
108 live: Arc<RwLock<LivePolicy>>,
109 ceiling: LivePolicy,
110 restricted: HashSet<&'static str>,
111 pid_overrides: Arc<RwLock<HashMap<u32, HashSet<IpAddr>>>>,
112 denied_paths: Arc<RwLock<HashSet<String>>>,
113}
114
115impl PolicyContext {
116 pub(crate) fn new(
117 live: Arc<RwLock<LivePolicy>>,
118 ceiling: LivePolicy,
119 pid_overrides: Arc<RwLock<HashMap<u32, HashSet<IpAddr>>>>,
120 denied_paths: Arc<RwLock<HashSet<String>>>,
121 ) -> Self {
122 Self {
123 live,
124 ceiling,
125 restricted: HashSet::new(),
126 pid_overrides,
127 denied_paths,
128 }
129 }
130
131 pub fn current(&self) -> LivePolicy {
133 self.live.read().unwrap().clone()
134 }
135
136 pub fn ceiling(&self) -> &LivePolicy {
138 &self.ceiling
139 }
140
141 pub fn grant_network(&mut self, ips: &[IpAddr]) -> Result<(), PolicyFnError> {
145 self.check_not_restricted("allowed_ips")?;
146 let mut live = self.live.write().unwrap();
147 for ip in ips {
148 if self.ceiling.allowed_ips.contains(ip) {
149 live.allowed_ips.insert(*ip);
150 }
151 }
152 Ok(())
153 }
154
155 pub fn grant_max_memory(&mut self, bytes: u64) -> Result<(), PolicyFnError> {
157 self.check_not_restricted("max_memory_bytes")?;
158 let mut live = self.live.write().unwrap();
159 live.max_memory_bytes = bytes.min(self.ceiling.max_memory_bytes);
160 Ok(())
161 }
162
163 pub fn grant_max_processes(&mut self, n: u32) -> Result<(), PolicyFnError> {
165 self.check_not_restricted("max_processes")?;
166 let mut live = self.live.write().unwrap();
167 live.max_processes = n.min(self.ceiling.max_processes);
168 Ok(())
169 }
170
171 pub fn restrict_network(&mut self, ips: &[IpAddr]) {
175 self.restricted.insert("allowed_ips");
176 let mut live = self.live.write().unwrap();
177 live.allowed_ips = ips.iter().copied().collect();
178 }
179
180 pub fn restrict_max_memory(&mut self, bytes: u64) {
182 self.restricted.insert("max_memory_bytes");
183 let mut live = self.live.write().unwrap();
184 live.max_memory_bytes = bytes;
185 }
186
187 pub fn restrict_max_processes(&mut self, n: u32) {
189 self.restricted.insert("max_processes");
190 let mut live = self.live.write().unwrap();
191 live.max_processes = n;
192 }
193
194 pub fn restrict_pid_network(&self, pid: u32, ips: &[IpAddr]) {
198 let mut overrides = self.pid_overrides.write().unwrap();
199 overrides.insert(pid, ips.iter().copied().collect());
200 }
201
202 pub fn clear_pid_override(&self, pid: u32) {
204 let mut overrides = self.pid_overrides.write().unwrap();
205 overrides.remove(&pid);
206 }
207
208 pub fn deny_path(&self, path: &str) {
213 let mut denied = self.denied_paths.write().unwrap();
214 denied.insert(path.to_string());
215 }
216
217 pub fn allow_path(&self, path: &str) {
219 let mut denied = self.denied_paths.write().unwrap();
220 denied.remove(path);
221 }
222
223 fn check_not_restricted(&self, field: &str) -> Result<(), PolicyFnError> {
226 if self.restricted.contains(field) {
227 Err(PolicyFnError::FieldRestricted(field.to_string()))
228 } else {
229 Ok(())
230 }
231 }
232}
233
234#[derive(Debug, thiserror::Error)]
240pub enum PolicyFnError {
241 #[error("cannot grant restricted field: {0}")]
242 FieldRestricted(String),
243}
244
245#[derive(Debug, Clone, PartialEq, Eq)]
251pub enum Verdict {
252 Allow,
254 Audit,
256 Deny,
258 DenyWith(i32),
260}
261
262impl Default for Verdict {
263 fn default() -> Self { Verdict::Allow }
264}
265
266pub type PolicyCallback = Arc<dyn Fn(SyscallEvent, &mut PolicyContext) -> Verdict + Send + Sync + 'static>;
276
277pub struct PolicyEvent {
283 pub event: SyscallEvent,
284 pub gate: Option<tokio::sync::oneshot::Sender<Verdict>>,
288}
289
290pub(crate) fn spawn_policy_fn(
298 callback: PolicyCallback,
299 live: Arc<RwLock<LivePolicy>>,
300 ceiling: LivePolicy,
301 pid_overrides: Arc<RwLock<HashMap<u32, HashSet<IpAddr>>>>,
302 denied_paths: Arc<RwLock<HashSet<String>>>,
303) -> tokio::sync::mpsc::UnboundedSender<PolicyEvent> {
304 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<PolicyEvent>();
305
306 std::thread::Builder::new()
307 .name("sandlock-policy-fn".to_string())
308 .spawn(move || {
309 let mut ctx = PolicyContext::new(live, ceiling, pid_overrides, denied_paths);
310
311 while let Some(pe) = rx.blocking_recv() {
312 let verdict = callback(pe.event, &mut ctx);
313
314 if let Some(gate) = pe.gate {
317 let _ = gate.send(verdict);
318 }
319 }
320 })
321 .expect("failed to spawn policy-fn thread");
322
323 tx
324}
325
326#[cfg(test)]
331mod tests {
332 use super::*;
333
334 fn test_live() -> LivePolicy {
335 LivePolicy {
336 allowed_ips: ["127.0.0.1", "10.0.0.1"]
337 .iter()
338 .map(|s| s.parse().unwrap())
339 .collect(),
340 max_memory_bytes: 1024 * 1024 * 1024,
341 max_processes: 64,
342 }
343 }
344
345 #[test]
346 fn test_grant_within_ceiling() {
347 let live = Arc::new(RwLock::new(LivePolicy {
348 allowed_ips: HashSet::new(),
349 max_memory_bytes: 0,
350 max_processes: 0,
351 }));
352 let ceiling = test_live();
353 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
354 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
355 let mut ctx = PolicyContext::new(live.clone(), ceiling, pid_overrides, denied_paths);
356
357 let ip: IpAddr = "127.0.0.1".parse().unwrap();
358 ctx.grant_network(&[ip]).unwrap();
359 assert!(live.read().unwrap().allowed_ips.contains(&ip));
360 }
361
362 #[test]
363 fn test_grant_capped_to_ceiling() {
364 let live = Arc::new(RwLock::new(LivePolicy {
365 allowed_ips: HashSet::new(),
366 max_memory_bytes: 0,
367 max_processes: 0,
368 }));
369 let ceiling = test_live();
370 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
371 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
372 let mut ctx = PolicyContext::new(live.clone(), ceiling, pid_overrides, denied_paths);
373
374 let foreign: IpAddr = "8.8.8.8".parse().unwrap();
376 ctx.grant_network(&[foreign]).unwrap();
377 assert!(!live.read().unwrap().allowed_ips.contains(&foreign));
378 }
379
380 #[test]
381 fn test_restrict_then_grant_fails() {
382 let live = Arc::new(RwLock::new(test_live()));
383 let ceiling = test_live();
384 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
385 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
386 let mut ctx = PolicyContext::new(live, ceiling, pid_overrides, denied_paths);
387
388 ctx.restrict_network(&[]);
389 let ip: IpAddr = "127.0.0.1".parse().unwrap();
390 assert!(ctx.grant_network(&[ip]).is_err());
391 }
392
393 #[test]
394 fn test_restrict_max_memory() {
395 let live = Arc::new(RwLock::new(test_live()));
396 let ceiling = test_live();
397 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
398 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
399 let mut ctx = PolicyContext::new(live.clone(), ceiling, pid_overrides, denied_paths);
400
401 ctx.restrict_max_memory(256 * 1024 * 1024);
402 assert_eq!(live.read().unwrap().max_memory_bytes, 256 * 1024 * 1024);
403 }
404
405 #[test]
406 fn test_pid_override() {
407 let live = Arc::new(RwLock::new(test_live()));
408 let ceiling = test_live();
409 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
410 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
411 let ctx = PolicyContext::new(live, ceiling, pid_overrides.clone(), denied_paths);
412
413 let localhost: IpAddr = "127.0.0.1".parse().unwrap();
414 ctx.restrict_pid_network(1234, &[localhost]);
415
416 let overrides = pid_overrides.read().unwrap();
417 let pid_ips = overrides.get(&1234).unwrap();
418 assert!(pid_ips.contains(&localhost));
419 assert_eq!(pid_ips.len(), 1);
420 }
421
422 #[test]
423 fn test_clear_pid_override() {
424 let live = Arc::new(RwLock::new(test_live()));
425 let ceiling = test_live();
426 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
427 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
428 let ctx = PolicyContext::new(live, ceiling, pid_overrides.clone(), denied_paths);
429
430 let localhost: IpAddr = "127.0.0.1".parse().unwrap();
431 ctx.restrict_pid_network(1234, &[localhost]);
432 ctx.clear_pid_override(1234);
433 assert!(!pid_overrides.read().unwrap().contains_key(&1234));
434 }
435
436 #[test]
437 fn test_event_path_contains() {
438 let event = SyscallEvent {
439 syscall: "execve".to_string(),
440 category: SyscallCategory::Process,
441 pid: 1,
442 parent_pid: Some(0),
443 path: Some("/usr/bin/python3".to_string()),
444 host: None,
445 port: None,
446 size: None,
447 argv: Some(vec!["python3".into(), "-c".into(), "print(1)".into()]),
448 denied: false,
449 };
450 assert!(event.argv_contains("python3"));
451 assert!(event.argv_contains("-c"));
452 assert!(!event.argv_contains("ruby"));
453 assert_eq!(event.category, SyscallCategory::Process);
454 assert!(event.path_contains("python"));
455 assert!(!event.path_contains("ruby"));
456 }
457}