1use std::collections::{HashMap, HashSet};
31use std::net::IpAddr;
32use std::sync::{Arc, RwLock};
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub enum SyscallCategory {
41 File,
43 Network,
45 Process,
47 Memory,
49}
50
51#[derive(Debug, Clone)]
77pub struct SyscallEvent {
78 pub syscall: String,
80 pub category: SyscallCategory,
82 pub pid: u32,
84 pub parent_pid: Option<u32>,
86 pub host: Option<IpAddr>,
88 pub port: Option<u16>,
90 pub size: Option<u64>,
92 pub argv: Option<Vec<String>>,
95 pub denied: bool,
97}
98
99impl SyscallEvent {
100 pub fn argv_contains(&self, s: &str) -> bool {
103 self.argv.as_ref().map_or(false, |args| args.iter().any(|a| a.contains(s)))
104 }
105}
106
107#[derive(Debug, Clone)]
116pub struct LivePolicy {
117 pub allowed_ips: HashSet<IpAddr>,
119 pub max_memory_bytes: u64,
121 pub max_processes: u32,
123}
124
125pub struct PolicyContext {
135 live: Arc<RwLock<LivePolicy>>,
136 ceiling: LivePolicy,
137 restricted: HashSet<&'static str>,
138 pid_overrides: Arc<RwLock<HashMap<u32, HashSet<IpAddr>>>>,
139 denied_paths: Arc<RwLock<HashSet<String>>>,
140}
141
142impl PolicyContext {
143 pub(crate) fn new(
144 live: Arc<RwLock<LivePolicy>>,
145 ceiling: LivePolicy,
146 pid_overrides: Arc<RwLock<HashMap<u32, HashSet<IpAddr>>>>,
147 denied_paths: Arc<RwLock<HashSet<String>>>,
148 ) -> Self {
149 Self {
150 live,
151 ceiling,
152 restricted: HashSet::new(),
153 pid_overrides,
154 denied_paths,
155 }
156 }
157
158 pub fn current(&self) -> LivePolicy {
160 self.live.read().unwrap().clone()
161 }
162
163 pub fn ceiling(&self) -> &LivePolicy {
165 &self.ceiling
166 }
167
168 pub fn grant_network(&mut self, ips: &[IpAddr]) -> Result<(), PolicyFnError> {
172 self.check_not_restricted("allowed_ips")?;
173 let mut live = self.live.write().unwrap();
174 for ip in ips {
175 if self.ceiling.allowed_ips.contains(ip) {
176 live.allowed_ips.insert(*ip);
177 }
178 }
179 Ok(())
180 }
181
182 pub fn grant_max_memory(&mut self, bytes: u64) -> Result<(), PolicyFnError> {
184 self.check_not_restricted("max_memory_bytes")?;
185 let mut live = self.live.write().unwrap();
186 live.max_memory_bytes = bytes.min(self.ceiling.max_memory_bytes);
187 Ok(())
188 }
189
190 pub fn grant_max_processes(&mut self, n: u32) -> Result<(), PolicyFnError> {
192 self.check_not_restricted("max_processes")?;
193 let mut live = self.live.write().unwrap();
194 live.max_processes = n.min(self.ceiling.max_processes);
195 Ok(())
196 }
197
198 pub fn restrict_network(&mut self, ips: &[IpAddr]) {
202 self.restricted.insert("allowed_ips");
203 let mut live = self.live.write().unwrap();
204 live.allowed_ips = ips.iter().copied().collect();
205 }
206
207 pub fn restrict_max_memory(&mut self, bytes: u64) {
209 self.restricted.insert("max_memory_bytes");
210 let mut live = self.live.write().unwrap();
211 live.max_memory_bytes = bytes;
212 }
213
214 pub fn restrict_max_processes(&mut self, n: u32) {
216 self.restricted.insert("max_processes");
217 let mut live = self.live.write().unwrap();
218 live.max_processes = n;
219 }
220
221 pub fn restrict_pid_network(&self, pid: u32, ips: &[IpAddr]) {
225 let mut overrides = self.pid_overrides.write().unwrap();
226 overrides.insert(pid, ips.iter().copied().collect());
227 }
228
229 pub fn clear_pid_override(&self, pid: u32) {
231 let mut overrides = self.pid_overrides.write().unwrap();
232 overrides.remove(&pid);
233 }
234
235 pub fn deny_path(&self, path: &str) {
240 let mut denied = self.denied_paths.write().unwrap();
241 denied.insert(path.to_string());
242 }
243
244 pub fn allow_path(&self, path: &str) {
246 let mut denied = self.denied_paths.write().unwrap();
247 denied.remove(path);
248 }
249
250 fn check_not_restricted(&self, field: &str) -> Result<(), PolicyFnError> {
253 if self.restricted.contains(field) {
254 Err(PolicyFnError::FieldRestricted(field.to_string()))
255 } else {
256 Ok(())
257 }
258 }
259}
260
261#[derive(Debug, thiserror::Error)]
267pub enum PolicyFnError {
268 #[error("cannot grant restricted field: {0}")]
269 FieldRestricted(String),
270}
271
272#[derive(Debug, Clone, PartialEq, Eq)]
278pub enum Verdict {
279 Allow,
281 Audit,
283 Deny,
285 DenyWith(i32),
287}
288
289impl Default for Verdict {
290 fn default() -> Self { Verdict::Allow }
291}
292
293pub type PolicyCallback = Arc<dyn Fn(SyscallEvent, &mut PolicyContext) -> Verdict + Send + Sync + 'static>;
303
304pub struct PolicyEvent {
310 pub event: SyscallEvent,
311 pub gate: Option<tokio::sync::oneshot::Sender<Verdict>>,
315}
316
317pub(crate) fn spawn_policy_fn(
325 callback: PolicyCallback,
326 live: Arc<RwLock<LivePolicy>>,
327 ceiling: LivePolicy,
328 pid_overrides: Arc<RwLock<HashMap<u32, HashSet<IpAddr>>>>,
329 denied_paths: Arc<RwLock<HashSet<String>>>,
330) -> tokio::sync::mpsc::UnboundedSender<PolicyEvent> {
331 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<PolicyEvent>();
332
333 std::thread::Builder::new()
334 .name("sandlock-policy-fn".to_string())
335 .spawn(move || {
336 let mut ctx = PolicyContext::new(live, ceiling, pid_overrides, denied_paths);
337
338 while let Some(pe) = rx.blocking_recv() {
339 let verdict = callback(pe.event, &mut ctx);
340
341 if let Some(gate) = pe.gate {
344 let _ = gate.send(verdict);
345 }
346 }
347 })
348 .expect("failed to spawn policy-fn thread");
349
350 tx
351}
352
353#[cfg(test)]
358mod tests {
359 use super::*;
360
361 fn test_live() -> LivePolicy {
362 LivePolicy {
363 allowed_ips: ["127.0.0.1", "10.0.0.1"]
364 .iter()
365 .map(|s| s.parse().unwrap())
366 .collect(),
367 max_memory_bytes: 1024 * 1024 * 1024,
368 max_processes: 64,
369 }
370 }
371
372 #[test]
373 fn test_grant_within_ceiling() {
374 let live = Arc::new(RwLock::new(LivePolicy {
375 allowed_ips: HashSet::new(),
376 max_memory_bytes: 0,
377 max_processes: 0,
378 }));
379 let ceiling = test_live();
380 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
381 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
382 let mut ctx = PolicyContext::new(live.clone(), ceiling, pid_overrides, denied_paths);
383
384 let ip: IpAddr = "127.0.0.1".parse().unwrap();
385 ctx.grant_network(&[ip]).unwrap();
386 assert!(live.read().unwrap().allowed_ips.contains(&ip));
387 }
388
389 #[test]
390 fn test_grant_capped_to_ceiling() {
391 let live = Arc::new(RwLock::new(LivePolicy {
392 allowed_ips: HashSet::new(),
393 max_memory_bytes: 0,
394 max_processes: 0,
395 }));
396 let ceiling = test_live();
397 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
398 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
399 let mut ctx = PolicyContext::new(live.clone(), ceiling, pid_overrides, denied_paths);
400
401 let foreign: IpAddr = "8.8.8.8".parse().unwrap();
403 ctx.grant_network(&[foreign]).unwrap();
404 assert!(!live.read().unwrap().allowed_ips.contains(&foreign));
405 }
406
407 #[test]
408 fn test_restrict_then_grant_fails() {
409 let live = Arc::new(RwLock::new(test_live()));
410 let ceiling = test_live();
411 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
412 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
413 let mut ctx = PolicyContext::new(live, ceiling, pid_overrides, denied_paths);
414
415 ctx.restrict_network(&[]);
416 let ip: IpAddr = "127.0.0.1".parse().unwrap();
417 assert!(ctx.grant_network(&[ip]).is_err());
418 }
419
420 #[test]
421 fn test_restrict_max_memory() {
422 let live = Arc::new(RwLock::new(test_live()));
423 let ceiling = test_live();
424 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
425 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
426 let mut ctx = PolicyContext::new(live.clone(), ceiling, pid_overrides, denied_paths);
427
428 ctx.restrict_max_memory(256 * 1024 * 1024);
429 assert_eq!(live.read().unwrap().max_memory_bytes, 256 * 1024 * 1024);
430 }
431
432 #[test]
433 fn test_pid_override() {
434 let live = Arc::new(RwLock::new(test_live()));
435 let ceiling = test_live();
436 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
437 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
438 let ctx = PolicyContext::new(live, ceiling, pid_overrides.clone(), denied_paths);
439
440 let localhost: IpAddr = "127.0.0.1".parse().unwrap();
441 ctx.restrict_pid_network(1234, &[localhost]);
442
443 let overrides = pid_overrides.read().unwrap();
444 let pid_ips = overrides.get(&1234).unwrap();
445 assert!(pid_ips.contains(&localhost));
446 assert_eq!(pid_ips.len(), 1);
447 }
448
449 #[test]
450 fn test_clear_pid_override() {
451 let live = Arc::new(RwLock::new(test_live()));
452 let ceiling = test_live();
453 let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
454 let denied_paths = Arc::new(RwLock::new(HashSet::new()));
455 let ctx = PolicyContext::new(live, ceiling, pid_overrides.clone(), denied_paths);
456
457 let localhost: IpAddr = "127.0.0.1".parse().unwrap();
458 ctx.restrict_pid_network(1234, &[localhost]);
459 ctx.clear_pid_override(1234);
460 assert!(!pid_overrides.read().unwrap().contains_key(&1234));
461 }
462
463 #[test]
464 fn test_event_argv_contains() {
465 let event = SyscallEvent {
466 syscall: "execve".to_string(),
467 category: SyscallCategory::Process,
468 pid: 1,
469 parent_pid: Some(0),
470 host: None,
471 port: None,
472 size: None,
473 argv: Some(vec!["python3".into(), "-c".into(), "print(1)".into()]),
474 denied: false,
475 };
476 assert!(event.argv_contains("python3"));
477 assert!(event.argv_contains("-c"));
478 assert!(!event.argv_contains("ruby"));
479 assert_eq!(event.category, SyscallCategory::Process);
480 }
481
482 #[test]
483 fn test_event_argv_contains_none() {
484 let event = SyscallEvent {
485 syscall: "openat".to_string(),
486 category: SyscallCategory::File,
487 pid: 1,
488 parent_pid: None,
489 host: None,
490 port: None,
491 size: None,
492 argv: None,
493 denied: false,
494 };
495 assert!(!event.argv_contains("anything"));
496 }
497}