1use crate::model::Expression;
11use std::collections::BTreeMap;
12
13#[derive(Debug, Clone)]
15pub struct ThreatModelConfig {
16 pub strict_mode: bool,
18 pub re_parse_verification: bool,
20 pub tamper_detection_enabled: bool,
22 pub dsl_sandboxing_enabled: bool,
24 pub isolation_verification: bool,
26}
27
28impl Default for ThreatModelConfig {
29 fn default() -> Self {
30 Self {
31 strict_mode: true,
32 re_parse_verification: true,
33 tamper_detection_enabled: true,
34 dsl_sandboxing_enabled: true,
35 isolation_verification: true,
36 }
37 }
38}
39
40pub type ThreatResult<T> = Result<T, ThreatModelError>;
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum ThreatModelError {
46 ReParseVerificationFailed(String),
48 TamperDetected(String),
50 SandboxEscapeDetected(String),
52 MutationUncertaintyDetected(String),
54 IsolationViolationDetected(String),
56 Custom(String),
58}
59
60impl std::fmt::Display for ThreatModelError {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Self::ReParseVerificationFailed(msg) => {
64 write!(f, "re-parse verification failed: {}", msg)
65 }
66 Self::TamperDetected(msg) => write!(f, "macro tampering detected: {}", msg),
67 Self::SandboxEscapeDetected(msg) => write!(f, "DSL sandbox escape: {}", msg),
68 Self::MutationUncertaintyDetected(msg) => {
69 write!(f, "mutation uncertainty in strict mode: {}", msg)
70 }
71 Self::IsolationViolationDetected(msg) => {
72 write!(f, "simulation isolation violation: {}", msg)
73 }
74 Self::Custom(msg) => write!(f, "{}", msg),
75 }
76 }
77}
78
79pub struct InjectionVerifier;
86
87impl InjectionVerifier {
88 pub fn verify_coverage(generated_code: &str, expected_checks: &[String]) -> ThreatResult<()> {
93 for check in expected_checks {
94 if !generated_code.contains(&format!("// Invariant: {}", check)) {
95 return Err(ThreatModelError::ReParseVerificationFailed(format!(
96 "invariant check not found in generated code: {}",
97 check
98 )));
99 }
100 }
101 Ok(())
102 }
103
104 pub fn verify_scope_containment(generated_code: &str) -> ThreatResult<()> {
110 let dangerous_patterns = [
112 "unsafe",
113 "extern",
114 "use std::process",
115 "std::fs",
116 "std::net",
117 ];
118
119 for pattern in &dangerous_patterns {
120 if generated_code.contains(pattern) {
121 return Err(ThreatModelError::ReParseVerificationFailed(format!(
122 "dangerous pattern found in generated code: {}",
123 pattern
124 )));
125 }
126 }
127
128 Ok(())
129 }
130}
131
132pub struct TamperDetector;
137
138impl TamperDetector {
139 pub fn compute_hash(checks: &[String]) -> String {
145 use std::collections::hash_map::DefaultHasher;
146 use std::hash::{Hash, Hasher};
147
148 let mut hasher = DefaultHasher::new();
149 let mut sorted_checks = checks.to_vec();
150 sorted_checks.sort();
151
152 for check in sorted_checks {
153 check.hash(&mut hasher);
154 }
155
156 format!("{:016x}", hasher.finish())
157 }
158
159 pub fn verify_tampering(generated_code: &str, expected_checks: &[String]) -> ThreatResult<()> {
164 let expected_hash = Self::compute_hash(expected_checks);
165
166 let hash_pattern = format!("SENTRI_HASH: {}", expected_hash);
168
169 if !generated_code.contains(&hash_pattern) {
170 return Err(ThreatModelError::TamperDetected(
171 "hash mismatch: generated code does not contain expected SENTRI_HASH".to_string(),
172 ));
173 }
174
175 Ok(())
176 }
177}
178
179pub struct DSLSandbox;
184
185impl DSLSandbox {
186 pub fn validate_expression(expr: &Expression) -> ThreatResult<()> {
195 let forbidden_prefixes = ["file_", "io_", "extern_", "unsafe_"];
197
198 Self::check_expression_recursive(expr, &forbidden_prefixes)
199 }
200
201 fn check_expression_recursive(
202 expr: &Expression,
203 forbidden_prefixes: &[&str],
204 ) -> ThreatResult<()> {
205 match expr {
206 Expression::Var(name) => {
207 for prefix in forbidden_prefixes {
208 if name.to_lowercase().starts_with(prefix) {
209 return Err(ThreatModelError::SandboxEscapeDetected(format!(
210 "forbidden variable name: {}",
211 name
212 )));
213 }
214 }
215 Ok(())
216 }
217
218 Expression::LayerVar { layer, var } => {
219 for prefix in forbidden_prefixes {
221 if layer.to_lowercase().starts_with(prefix)
222 || var.to_lowercase().starts_with(prefix)
223 {
224 return Err(ThreatModelError::SandboxEscapeDetected(format!(
225 "forbidden layer/variable name: {}::{}",
226 layer, var
227 )));
228 }
229 }
230 Ok(())
231 }
232
233 Expression::FunctionCall { name, args } => {
234 let allowed_functions = [
236 "sum", "len", "min", "max", "abs", "mod", "div", "add", "sub", "mul", "and",
237 "or", "not",
238 ];
239
240 if !allowed_functions.contains(&name.as_str()) {
241 return Err(ThreatModelError::SandboxEscapeDetected(format!(
242 "forbidden function call: {}",
243 name
244 )));
245 }
246
247 for arg in args {
249 Self::check_expression_recursive(arg, forbidden_prefixes)?;
250 }
251 Ok(())
252 }
253
254 Expression::BinaryOp { left, op: _, right } => {
255 Self::check_expression_recursive(left, forbidden_prefixes)?;
256 Self::check_expression_recursive(right, forbidden_prefixes)?;
257 Ok(())
258 }
259
260 Expression::Logical { left, op: _, right } => {
261 Self::check_expression_recursive(left, forbidden_prefixes)?;
262 Self::check_expression_recursive(right, forbidden_prefixes)?;
263 Ok(())
264 }
265
266 Expression::Not(inner) => {
267 Self::check_expression_recursive(inner, forbidden_prefixes)?;
268 Ok(())
269 }
270
271 Expression::Tuple(exprs) => {
272 for e in exprs {
273 Self::check_expression_recursive(e, forbidden_prefixes)?;
274 }
275 Ok(())
276 }
277
278 Expression::PhaseQualifiedVar { phase, layer, var } => {
279 for prefix in forbidden_prefixes {
281 if phase.to_lowercase().starts_with(prefix)
282 || layer.to_lowercase().starts_with(prefix)
283 || var.to_lowercase().starts_with(prefix)
284 {
285 return Err(ThreatModelError::SandboxEscapeDetected(format!(
286 "forbidden phase/layer/variable name: {}::{}::{}",
287 phase, layer, var
288 )));
289 }
290 }
291 Ok(())
292 }
293
294 Expression::PhaseConstraint {
295 phase: _,
296 constraint,
297 } => {
298 Self::check_expression_recursive(constraint, forbidden_prefixes)
300 }
301
302 Expression::CrossPhaseRelation {
303 phase1: _,
304 expr1,
305 phase2: _,
306 expr2,
307 op: _,
308 } => {
309 Self::check_expression_recursive(expr1, forbidden_prefixes)?;
311 Self::check_expression_recursive(expr2, forbidden_prefixes)?;
312 Ok(())
313 }
314
315 Expression::Boolean(_) | Expression::Int(_) => Ok(()),
316 }
317 }
318}
319
320pub struct StrictModeAnalyzer {
325 enabled: bool,
326}
327
328impl StrictModeAnalyzer {
329 pub fn new(enabled: bool) -> Self {
331 Self { enabled }
332 }
333
334 pub fn verify_mutation_coverage(
340 &self,
341 _analyzed_mutations: &[String],
342 uncertainty_warnings: &[String],
343 ) -> ThreatResult<()> {
344 if !self.enabled {
345 return Ok(());
346 }
347
348 if !uncertainty_warnings.is_empty() {
349 return Err(ThreatModelError::MutationUncertaintyDetected(format!(
350 "strict mode detected {} uncertain mutations: {}",
351 uncertainty_warnings.len(),
352 uncertainty_warnings.join(", ")
353 )));
354 }
355
356 Ok(())
357 }
358}
359
360pub struct SimulationIsolation;
364
365impl SimulationIsolation {
366 pub fn verify_isolation(
375 context_vars: &BTreeMap<String, String>,
376 allowed_types: &[&str],
377 ) -> ThreatResult<()> {
378 for (name, type_str) in context_vars {
379 let is_allowed = allowed_types
381 .iter()
382 .any(|&allowed| type_str.contains(allowed));
383
384 if !is_allowed {
385 return Err(ThreatModelError::IsolationViolationDetected(format!(
386 "variable '{}' has disallowed type '{}' in simulation context",
387 name, type_str
388 )));
389 }
390 }
391
392 Ok(())
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[test]
401 fn test_injection_verification() {
402 let generated_code = r#"
403 fn transfer(from: &mut Account, to: &mut Account, amount: u64) {
404 from.balance -= amount;
405 to.balance += amount;
406 // Invariant: balance >= 0
407 // SENTRI_HASH: abcd1234
408 }
409 "#;
410
411 let checks = vec!["balance >= 0".to_string()];
412 assert!(InjectionVerifier::verify_coverage(generated_code, &checks).is_ok());
413 }
414
415 #[test]
416 fn test_injection_verification_missing_check() {
417 let generated_code = "fn transfer() { /* no checks */ }";
418 let checks = vec!["balance >= 0".to_string()];
419 assert!(InjectionVerifier::verify_coverage(generated_code, &checks).is_err());
420 }
421
422 #[test]
423 fn test_scope_containment() {
424 let safe_code = "let x = a + b; assert!(x > 0);";
425 assert!(InjectionVerifier::verify_scope_containment(safe_code).is_ok());
426
427 let unsafe_code = "use std::fs; fs::write(\"file.txt\", \"\");";
428 assert!(InjectionVerifier::verify_scope_containment(unsafe_code).is_err());
429 }
430
431 #[test]
432 fn test_tamper_hash_deterministic() {
433 let checks1 = vec!["a".to_string(), "b".to_string()];
434 let checks2 = vec!["b".to_string(), "a".to_string()];
435
436 let hash1 = TamperDetector::compute_hash(&checks1);
437 let hash2 = TamperDetector::compute_hash(&checks2);
438
439 assert_eq!(hash1, hash2);
440 }
441
442 #[test]
443 fn test_dsl_sandbox_forbidden_variable() {
444 let expr = Expression::Var("file_handle".to_string());
445 assert!(DSLSandbox::validate_expression(&expr).is_err());
446 }
447
448 #[test]
449 fn test_dsl_sandbox_allowed_variable() {
450 let expr = Expression::Var("balance".to_string());
451 assert!(DSLSandbox::validate_expression(&expr).is_ok());
452 }
453
454 #[test]
455 fn test_dsl_sandbox_forbidden_function() {
456 let expr = Expression::FunctionCall {
457 name: "system_call".to_string(),
458 args: vec![],
459 };
460 assert!(DSLSandbox::validate_expression(&expr).is_err());
461 }
462
463 #[test]
464 fn test_dsl_sandbox_allowed_function() {
465 let expr = Expression::FunctionCall {
466 name: "sum".to_string(),
467 args: vec![Expression::Var("balances".to_string())],
468 };
469 assert!(DSLSandbox::validate_expression(&expr).is_ok());
470 }
471
472 #[test]
473 fn test_strict_mode_with_uncertainty() {
474 let analyzer = StrictModeAnalyzer::new(true);
475 let mutations = vec!["balance -= amount".to_string()];
476 let warnings = vec!["mutation from function pointer call (uncertain)".to_string()];
477
478 assert!(analyzer
479 .verify_mutation_coverage(&mutations, &warnings)
480 .is_err());
481 }
482
483 #[test]
484 fn test_strict_mode_disabled() {
485 let analyzer = StrictModeAnalyzer::new(false);
486 let mutations = vec!["balance -= amount".to_string()];
487 let warnings = vec!["mutation from function pointer call (uncertain)".to_string()];
488
489 assert!(analyzer
491 .verify_mutation_coverage(&mutations, &warnings)
492 .is_ok());
493 }
494}