Skip to main content

sentri_core/
threat_model.rs

1//! Threat model defenses for Invar invariant enforcement system.
2//!
3//! This module implements comprehensive security hardening against:
4//! 1. Injection attacks (re-parsing verification)
5//! 2. Macro tampering (hash detection)
6//! 3. Uncertainty in mutation analysis (strict mode abort)
7//! 4. DSL sandbox escapes (expression validation)
8//! 5. Simulation side-effects (isolation verification)
9
10use crate::model::Expression;
11use std::collections::BTreeMap;
12
13/// Threat model security configuration.
14#[derive(Debug, Clone)]
15pub struct ThreatModelConfig {
16    /// Require strict mutation detection (abort if uncertain)
17    pub strict_mode: bool,
18    /// Validate all generated code by re-parsing
19    pub re_parse_verification: bool,
20    /// Detect tamper attempts via hash checking
21    pub tamper_detection_enabled: bool,
22    /// Validate DSL expressions for sandbox escapes
23    pub dsl_sandboxing_enabled: bool,
24    /// Verify simulation isolation (no file mutations)
25    pub isolation_verification: bool,
26}
27
28impl Default for ThreatModelConfig {
29    fn default() -> Self {
30        Self {
31            strict_mode: true,
32            re_parse_verification: true,
33            tamper_detection_enabled: true,
34            dsl_sandboxing_enabled: true,
35            isolation_verification: true,
36        }
37    }
38}
39
40/// Result type for threat model validation.
41pub type ThreatResult<T> = Result<T, ThreatModelError>;
42
43/// Threat model validation errors.
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum ThreatModelError {
46    /// Code re-parse verification failed
47    ReParseVerificationFailed(String),
48    /// Macro tampering detected
49    TamperDetected(String),
50    /// DSL sandbox escape attempt detected
51    SandboxEscapeDetected(String),
52    /// Mutation detection uncertainty in strict mode
53    MutationUncertaintyDetected(String),
54    /// Simulation isolation violation detected
55    IsolationViolationDetected(String),
56    /// Custom threat
57    Custom(String),
58}
59
60impl std::fmt::Display for ThreatModelError {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match self {
63            Self::ReParseVerificationFailed(msg) => {
64                write!(f, "re-parse verification failed: {}", msg)
65            }
66            Self::TamperDetected(msg) => write!(f, "macro tampering detected: {}", msg),
67            Self::SandboxEscapeDetected(msg) => write!(f, "DSL sandbox escape: {}", msg),
68            Self::MutationUncertaintyDetected(msg) => {
69                write!(f, "mutation uncertainty in strict mode: {}", msg)
70            }
71            Self::IsolationViolationDetected(msg) => {
72                write!(f, "simulation isolation violation: {}", msg)
73            }
74            Self::Custom(msg) => write!(f, "{}", msg),
75        }
76    }
77}
78
79/// Defense 1: Injection verification via re-parsing.
80///
81/// After generating code, re-parse it to ensure:
82/// - Syntax is valid
83/// - No injection artifacts remain
84/// - All invariants are properly placed
85pub struct InjectionVerifier;
86
87impl InjectionVerifier {
88    /// Verify that generated code contains all expected invariant checks.
89    ///
90    /// # Security Property
91    /// Ensures 100% coverage of mutating functions with invariant checks.
92    pub fn verify_coverage(generated_code: &str, expected_checks: &[String]) -> ThreatResult<()> {
93        for check in expected_checks {
94            if !generated_code.contains(&format!("// Invariant: {}", check)) {
95                return Err(ThreatModelError::ReParseVerificationFailed(format!(
96                    "invariant check not found in generated code: {}",
97                    check
98                )));
99            }
100        }
101        Ok(())
102    }
103
104    /// Verify no injected code escapes the intended scope.
105    ///
106    /// # Security Property
107    /// Prevents code injection by ensuring all injected statements stay within
108    /// invariant check blocks.
109    pub fn verify_scope_containment(generated_code: &str) -> ThreatResult<()> {
110        // Check for dangerous patterns that indicate injection escape
111        let dangerous_patterns = [
112            "unsafe",
113            "extern",
114            "use std::process",
115            "std::fs",
116            "std::net",
117        ];
118
119        for pattern in &dangerous_patterns {
120            if generated_code.contains(pattern) {
121                return Err(ThreatModelError::ReParseVerificationFailed(format!(
122                    "dangerous pattern found in generated code: {}",
123                    pattern
124                )));
125            }
126        }
127
128        Ok(())
129    }
130}
131
132/// Defense 2: Macro tamper detection.
133///
134/// Detects modifications to injected invariant checks by comparing
135/// computed hash against embedded hash in generated code.
136pub struct TamperDetector;
137
138impl TamperDetector {
139    /// Compute deterministic hash for a set of invariant checks.
140    ///
141    /// # Determinism Property
142    /// Hash is computed from sorted check list, so order doesn't matter
143    /// (prevents timing attacks on check modifications).
144    pub fn compute_hash(checks: &[String]) -> String {
145        use std::collections::hash_map::DefaultHasher;
146        use std::hash::{Hash, Hasher};
147
148        let mut hasher = DefaultHasher::new();
149        let mut sorted_checks = checks.to_vec();
150        sorted_checks.sort();
151
152        for check in sorted_checks {
153            check.hash(&mut hasher);
154        }
155
156        format!("{:016x}", hasher.finish())
157    }
158
159    /// Verify that the embedded hash matches expected checks.
160    ///
161    /// # Security Property
162    /// Detects any tampering with invariant checks after macro expansion.
163    pub fn verify_tampering(generated_code: &str, expected_checks: &[String]) -> ThreatResult<()> {
164        let expected_hash = Self::compute_hash(expected_checks);
165
166        // Extract hash from generated code (look for SENTRI_HASH: pattern)
167        let hash_pattern = format!("SENTRI_HASH: {}", expected_hash);
168
169        if !generated_code.contains(&hash_pattern) {
170            return Err(ThreatModelError::TamperDetected(
171                "hash mismatch: generated code does not contain expected SENTRI_HASH".to_string(),
172            ));
173        }
174
175        Ok(())
176    }
177}
178
179/// Defense 3: DSL sandboxing.
180///
181/// Validates that invariant expressions cannot escape the sandbox
182/// (no file I/O, no external calls, no state mutations outside checks).
183pub struct DSLSandbox;
184
185impl DSLSandbox {
186    /// Validate an expression for sandbox violations.
187    ///
188    /// # Security Property
189    /// Ensures invariant expressions:
190    /// - Don't access files
191    /// - Don't call external code
192    /// - Are deterministic (no randomness)
193    /// - Have no side effects
194    pub fn validate_expression(expr: &Expression) -> ThreatResult<()> {
195        // Check for dangerous patterns in variable names (common injection vectors)
196        let forbidden_prefixes = ["file_", "io_", "extern_", "unsafe_"];
197
198        Self::check_expression_recursive(expr, &forbidden_prefixes)
199    }
200
201    fn check_expression_recursive(
202        expr: &Expression,
203        forbidden_prefixes: &[&str],
204    ) -> ThreatResult<()> {
205        match expr {
206            Expression::Var(name) => {
207                for prefix in forbidden_prefixes {
208                    if name.to_lowercase().starts_with(prefix) {
209                        return Err(ThreatModelError::SandboxEscapeDetected(format!(
210                            "forbidden variable name: {}",
211                            name
212                        )));
213                    }
214                }
215                Ok(())
216            }
217
218            Expression::LayerVar { layer, var } => {
219                // Check both layer and variable names against forbidden prefixes
220                for prefix in forbidden_prefixes {
221                    if layer.to_lowercase().starts_with(prefix)
222                        || var.to_lowercase().starts_with(prefix)
223                    {
224                        return Err(ThreatModelError::SandboxEscapeDetected(format!(
225                            "forbidden layer/variable name: {}::{}",
226                            layer, var
227                        )));
228                    }
229                }
230                Ok(())
231            }
232
233            Expression::FunctionCall { name, args } => {
234                // Whitelist of allowed functions (purely computational, no side effects)
235                let allowed_functions = [
236                    "sum", "len", "min", "max", "abs", "mod", "div", "add", "sub", "mul", "and",
237                    "or", "not",
238                ];
239
240                if !allowed_functions.contains(&name.as_str()) {
241                    return Err(ThreatModelError::SandboxEscapeDetected(format!(
242                        "forbidden function call: {}",
243                        name
244                    )));
245                }
246
247                // Recursively check all arguments
248                for arg in args {
249                    Self::check_expression_recursive(arg, forbidden_prefixes)?;
250                }
251                Ok(())
252            }
253
254            Expression::BinaryOp { left, op: _, right } => {
255                Self::check_expression_recursive(left, forbidden_prefixes)?;
256                Self::check_expression_recursive(right, forbidden_prefixes)?;
257                Ok(())
258            }
259
260            Expression::Logical { left, op: _, right } => {
261                Self::check_expression_recursive(left, forbidden_prefixes)?;
262                Self::check_expression_recursive(right, forbidden_prefixes)?;
263                Ok(())
264            }
265
266            Expression::Not(inner) => {
267                Self::check_expression_recursive(inner, forbidden_prefixes)?;
268                Ok(())
269            }
270
271            Expression::Tuple(exprs) => {
272                for e in exprs {
273                    Self::check_expression_recursive(e, forbidden_prefixes)?;
274                }
275                Ok(())
276            }
277
278            Expression::PhaseQualifiedVar { phase, layer, var } => {
279                // Check phase, layer, and variable names against forbidden prefixes
280                for prefix in forbidden_prefixes {
281                    if phase.to_lowercase().starts_with(prefix)
282                        || layer.to_lowercase().starts_with(prefix)
283                        || var.to_lowercase().starts_with(prefix)
284                    {
285                        return Err(ThreatModelError::SandboxEscapeDetected(format!(
286                            "forbidden phase/layer/variable name: {}::{}::{}",
287                            phase, layer, var
288                        )));
289                    }
290                }
291                Ok(())
292            }
293
294            Expression::PhaseConstraint {
295                phase: _,
296                constraint,
297            } => {
298                // Check the constraint expression recursively
299                Self::check_expression_recursive(constraint, forbidden_prefixes)
300            }
301
302            Expression::CrossPhaseRelation {
303                phase1: _,
304                expr1,
305                phase2: _,
306                expr2,
307                op: _,
308            } => {
309                // Check both phase expressions
310                Self::check_expression_recursive(expr1, forbidden_prefixes)?;
311                Self::check_expression_recursive(expr2, forbidden_prefixes)?;
312                Ok(())
313            }
314
315            Expression::Boolean(_) | Expression::Int(_) => Ok(()),
316        }
317    }
318}
319
320/// Defense 4: Analyzer strict mode.
321///
322/// In strict mode, the analyzer aborts if it cannot determine with certainty
323/// whether a particular mutation will violate an invariant.
324pub struct StrictModeAnalyzer {
325    enabled: bool,
326}
327
328impl StrictModeAnalyzer {
329    /// Create a new strict mode analyzer.
330    pub fn new(enabled: bool) -> Self {
331        Self { enabled }
332    }
333
334    /// Verify that all function mutations are accounted for.
335    ///
336    /// # Security Property
337    /// In strict mode, rejects functions with uncertain mutation detection.
338    /// This prevents invariant bypass via undetected mutations.
339    pub fn verify_mutation_coverage(
340        &self,
341        _analyzed_mutations: &[String],
342        uncertainty_warnings: &[String],
343    ) -> ThreatResult<()> {
344        if !self.enabled {
345            return Ok(());
346        }
347
348        if !uncertainty_warnings.is_empty() {
349            return Err(ThreatModelError::MutationUncertaintyDetected(format!(
350                "strict mode detected {} uncertain mutations: {}",
351                uncertainty_warnings.len(),
352                uncertainty_warnings.join(", ")
353            )));
354        }
355
356        Ok(())
357    }
358}
359
360/// Defense 5: Simulation isolation verification.
361///
362/// Ensures that simulation environments cannot mutate actual state or access files.
363pub struct SimulationIsolation;
364
365impl SimulationIsolation {
366    /// Verify that a simulation context is properly isolated.
367    ///
368    /// # Security Property
369    /// Ensures simulation:
370    /// - Only uses in-memory data structures (BTreeMap, Vec)
371    /// - Makes no file system calls
372    /// - Doesn't mutate external state
373    /// - Results are deterministic
374    pub fn verify_isolation(
375        context_vars: &BTreeMap<String, String>,
376        allowed_types: &[&str],
377    ) -> ThreatResult<()> {
378        for (name, type_str) in context_vars {
379            // Validate that only allowed types are used in simulation
380            let is_allowed = allowed_types
381                .iter()
382                .any(|&allowed| type_str.contains(allowed));
383
384            if !is_allowed {
385                return Err(ThreatModelError::IsolationViolationDetected(format!(
386                    "variable '{}' has disallowed type '{}' in simulation context",
387                    name, type_str
388                )));
389            }
390        }
391
392        Ok(())
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[test]
401    fn test_injection_verification() {
402        let generated_code = r#"
403        fn transfer(from: &mut Account, to: &mut Account, amount: u64) {
404            from.balance -= amount;
405            to.balance += amount;
406            // Invariant: balance >= 0
407            // SENTRI_HASH: abcd1234
408        }
409        "#;
410
411        let checks = vec!["balance >= 0".to_string()];
412        assert!(InjectionVerifier::verify_coverage(generated_code, &checks).is_ok());
413    }
414
415    #[test]
416    fn test_injection_verification_missing_check() {
417        let generated_code = "fn transfer() { /* no checks */ }";
418        let checks = vec!["balance >= 0".to_string()];
419        assert!(InjectionVerifier::verify_coverage(generated_code, &checks).is_err());
420    }
421
422    #[test]
423    fn test_scope_containment() {
424        let safe_code = "let x = a + b; assert!(x > 0);";
425        assert!(InjectionVerifier::verify_scope_containment(safe_code).is_ok());
426
427        let unsafe_code = "use std::fs; fs::write(\"file.txt\", \"\");";
428        assert!(InjectionVerifier::verify_scope_containment(unsafe_code).is_err());
429    }
430
431    #[test]
432    fn test_tamper_hash_deterministic() {
433        let checks1 = vec!["a".to_string(), "b".to_string()];
434        let checks2 = vec!["b".to_string(), "a".to_string()];
435
436        let hash1 = TamperDetector::compute_hash(&checks1);
437        let hash2 = TamperDetector::compute_hash(&checks2);
438
439        assert_eq!(hash1, hash2);
440    }
441
442    #[test]
443    fn test_dsl_sandbox_forbidden_variable() {
444        let expr = Expression::Var("file_handle".to_string());
445        assert!(DSLSandbox::validate_expression(&expr).is_err());
446    }
447
448    #[test]
449    fn test_dsl_sandbox_allowed_variable() {
450        let expr = Expression::Var("balance".to_string());
451        assert!(DSLSandbox::validate_expression(&expr).is_ok());
452    }
453
454    #[test]
455    fn test_dsl_sandbox_forbidden_function() {
456        let expr = Expression::FunctionCall {
457            name: "system_call".to_string(),
458            args: vec![],
459        };
460        assert!(DSLSandbox::validate_expression(&expr).is_err());
461    }
462
463    #[test]
464    fn test_dsl_sandbox_allowed_function() {
465        let expr = Expression::FunctionCall {
466            name: "sum".to_string(),
467            args: vec![Expression::Var("balances".to_string())],
468        };
469        assert!(DSLSandbox::validate_expression(&expr).is_ok());
470    }
471
472    #[test]
473    fn test_strict_mode_with_uncertainty() {
474        let analyzer = StrictModeAnalyzer::new(true);
475        let mutations = vec!["balance -= amount".to_string()];
476        let warnings = vec!["mutation from function pointer call (uncertain)".to_string()];
477
478        assert!(analyzer
479            .verify_mutation_coverage(&mutations, &warnings)
480            .is_err());
481    }
482
483    #[test]
484    fn test_strict_mode_disabled() {
485        let analyzer = StrictModeAnalyzer::new(false);
486        let mutations = vec!["balance -= amount".to_string()];
487        let warnings = vec!["mutation from function pointer call (uncertain)".to_string()];
488
489        // Strict mode off, so uncertainty is allowed
490        assert!(analyzer
491            .verify_mutation_coverage(&mutations, &warnings)
492            .is_ok());
493    }
494}