Skip to main content

sentri_core/
security_validator.rs

1//! Pre-build security validation using attack pattern detection.
2//!
3//! This module validates code before build to prevent known vulnerabilities.
4
5use crate::attack_patterns::{AttackPattern, AttackPatternDB};
6use std::path::Path;
7
8/// Security validation report.
9#[derive(Debug, Clone)]
10pub struct SecurityReport {
11    /// Critical vulnerabilities found.
12    pub critical_issues: Vec<SecurityIssue>,
13    /// High-risk issues found.
14    pub high_issues: Vec<SecurityIssue>,
15    /// Medium-risk issues found.
16    pub medium_issues: Vec<SecurityIssue>,
17    /// Low-risk issues found.
18    pub low_issues: Vec<SecurityIssue>,
19    /// Pass/fail status.
20    pub passed: bool,
21    /// Overall risk score (0-100).
22    pub risk_score: u32,
23}
24
25/// A detected security issue.
26#[derive(Debug, Clone)]
27pub struct SecurityIssue {
28    /// Attack pattern involved.
29    pub attack_pattern: String,
30    /// Location in code (file:line).
31    pub location: String,
32    /// Description of the issue.
33    pub description: String,
34    /// Suggested fix.
35    pub suggested_fix: String,
36    /// Severity level.
37    pub severity: IssueSeverity,
38}
39
40/// Issue severity level.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
42pub enum IssueSeverity {
43    /// Can cause total loss of funds.
44    Critical = 4,
45    /// Can cause significant fund loss.
46    High = 3,
47    /// Could enable attacks under certain conditions.
48    Medium = 2,
49    /// Minor risk or best practice violation.
50    Low = 1,
51}
52
53impl std::fmt::Display for IssueSeverity {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            Self::Critical => write!(f, "CRITICAL"),
57            Self::High => write!(f, "HIGH"),
58            Self::Medium => write!(f, "MEDIUM"),
59            Self::Low => write!(f, "LOW"),
60        }
61    }
62}
63
64/// Security validator for code before building.
65pub struct SecurityValidator {
66    attack_db: AttackPatternDB,
67}
68
69impl SecurityValidator {
70    /// Create a new security validator.
71    pub fn new() -> Self {
72        Self {
73            attack_db: AttackPatternDB::new(),
74        }
75    }
76
77    /// Validate code from a file.
78    pub fn validate_file(&self, path: &Path, chain: &str) -> Result<SecurityReport, String> {
79        let code =
80            std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
81        self.validate_code(&code, path.to_string_lossy().as_ref(), chain)
82    }
83
84    /// Validate code content.
85    pub fn validate_code(
86        &self,
87        code: &str,
88        file_path: &str,
89        chain: &str,
90    ) -> Result<SecurityReport, String> {
91        let mut critical_issues = Vec::new();
92        let mut high_issues = Vec::new();
93        let mut medium_issues = Vec::new();
94        let mut low_issues = Vec::new();
95
96        // Check each pattern relevant to the chain
97        let patterns = self.attack_db.patterns_for_chain(chain);
98
99        for pattern in patterns {
100            let issues = self.check_pattern(code, file_path, pattern);
101            for issue in issues {
102                match issue.severity {
103                    IssueSeverity::Critical => critical_issues.push(issue),
104                    IssueSeverity::High => high_issues.push(issue),
105                    IssueSeverity::Medium => medium_issues.push(issue),
106                    IssueSeverity::Low => low_issues.push(issue),
107                }
108            }
109        }
110
111        // Calculate risk score
112        let risk_score = (critical_issues.len() as u32 * 25
113            + high_issues.len() as u32 * 15
114            + medium_issues.len() as u32 * 8
115            + low_issues.len() as u32 * 3)
116            .min(100);
117
118        let passed = critical_issues.is_empty() && high_issues.is_empty();
119
120        Ok(SecurityReport {
121            critical_issues,
122            high_issues,
123            medium_issues,
124            low_issues,
125            passed,
126            risk_score,
127        })
128    }
129
130    /// Check code against a specific attack pattern.
131    fn check_pattern(
132        &self,
133        code: &str,
134        file_path: &str,
135        pattern: &AttackPattern,
136    ) -> Vec<SecurityIssue> {
137        let mut issues = Vec::new();
138
139        // Special handling for reentrancy: need to check state update AFTER external call
140        if pattern.id == "reentrancy" {
141            issues.extend(self.check_reentrancy(code, file_path, pattern));
142        } else {
143            // Generic pattern matching for other attacks
144            for (line_num, line) in code.lines().enumerate() {
145                for vulnerable_pattern in &pattern.vulnerable_patterns {
146                    if line.contains(vulnerable_pattern.as_str()) {
147                        let severity = match pattern.cvss_score {
148                            s if s >= 9.0 => IssueSeverity::Critical,
149                            s if s >= 7.0 => IssueSeverity::High,
150                            s if s >= 5.0 => IssueSeverity::Medium,
151                            _ => IssueSeverity::Low,
152                        };
153
154                        issues.push(SecurityIssue {
155                            attack_pattern: pattern.name.clone(),
156                            location: format!("{}:{}", file_path, line_num + 1),
157                            description: format!(
158                                "Potential {} vulnerability detected. {}",
159                                pattern.name, pattern.description
160                            ),
161                            suggested_fix: format!(
162                                "Apply defensive invariant: {}",
163                                pattern
164                                    .defensive_invariants
165                                    .first()
166                                    .unwrap_or(&"Review code".to_string())
167                            ),
168                            severity,
169                        });
170                    }
171                }
172            }
173        }
174        issues
175    }
176
177    /// Check for reentrancy by analyzing state update order.
178    fn check_reentrancy(
179        &self,
180        code: &str,
181        file_path: &str,
182        pattern: &AttackPattern,
183    ) -> Vec<SecurityIssue> {
184        let mut issues = Vec::new();
185        let lines: Vec<&str> = code.lines().collect();
186
187        // Find external calls (transfer, call, etc.)
188        for (line_num, line) in lines.iter().enumerate() {
189            // Skip if line has state update protection
190            if line.contains("nonReentrant") {
191                continue;
192            }
193
194            // Check if line has external call
195            let has_external_call =
196                line.contains("transfer(") || line.contains(".call(") || line.contains(".send(");
197
198            if !has_external_call {
199                continue;
200            }
201
202            // Look back up to 50 lines to find state updates
203            let mut has_state_update_before = false;
204            let search_start = line_num.saturating_sub(50);
205
206            for prev_line in lines.iter().take(line_num).skip(search_start) {
207                // Look for state updates: balance[X] = Y or balance = Z patterns
208                if (prev_line.contains("balances[") && prev_line.contains("= 0"))
209                    || (prev_line.contains("balance =") && prev_line.contains("= 0"))
210                {
211                    has_state_update_before = true;
212                    break;
213                }
214            }
215
216            // If NO state update before the external call, it's vulnerable
217            if !has_state_update_before {
218                let severity = IssueSeverity::Critical;
219
220                issues.push(SecurityIssue {
221                    attack_pattern: pattern.name.clone(),
222                    location: format!("{}:{}", file_path, line_num + 1),
223                    description: format!(
224                        "Potential {} vulnerability detected. {}",
225                        pattern.name, pattern.description
226                    ),
227                    suggested_fix: "Apply defensive invariant: state_update_before_external_call"
228                        .to_string(),
229                    severity,
230                });
231            }
232        }
233        issues
234    }
235}
236
237impl Default for SecurityValidator {
238    fn default() -> Self {
239        Self::new()
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_security_validator_creation() {
249        let validator = SecurityValidator::new();
250        assert_eq!(validator.attack_db.all_patterns().len(), 8);
251    }
252
253    #[test]
254    fn test_vulnerable_code_detection() {
255        let validator = SecurityValidator::new();
256        let code = "fn transfer() { transfer_funds(); /* state update after */ }";
257        let report = validator.validate_code(code, "test.rs", "evm").unwrap();
258        assert!(!report.passed);
259        assert!(!report.critical_issues.is_empty());
260    }
261
262    #[test]
263    fn test_safe_code_passes() {
264        let validator = SecurityValidator::new();
265        let code = "fn safe_code() { let x = 1 + 1; println!(\"{}\", x); }";
266        let report = validator.validate_code(code, "test.rs", "evm").unwrap();
267        assert!(report.passed);
268        assert_eq!(report.critical_issues.len(), 0);
269    }
270
271    #[test]
272    fn test_risk_score_calculation() {
273        let validator = SecurityValidator::new();
274        let code = "fn risky() { payable(msg.sender).transfer(amount); balances[msg.sender] = 0; }";
275        let report = validator.validate_code(code, "test.rs", "evm").unwrap();
276        assert!(report.risk_score > 0);
277    }
278
279    #[test]
280    fn test_chain_specific_validation() {
281        let validator = SecurityValidator::new();
282        let code = "fn access() { require(is_owner()); }";
283
284        let evm_report = validator.validate_code(code, "test.sol", "evm").unwrap();
285        let solana_report = validator.validate_code(code, "test.rs", "solana").unwrap();
286
287        // Both chains should detect access control patterns
288        assert!(evm_report.passed || solana_report.passed);
289    }
290}