1use crate::attack_patterns::{AttackPattern, AttackPatternDB};
6use std::path::Path;
7
8#[derive(Debug, Clone)]
10pub struct SecurityReport {
11 pub critical_issues: Vec<SecurityIssue>,
13 pub high_issues: Vec<SecurityIssue>,
15 pub medium_issues: Vec<SecurityIssue>,
17 pub low_issues: Vec<SecurityIssue>,
19 pub passed: bool,
21 pub risk_score: u32,
23}
24
25#[derive(Debug, Clone)]
27pub struct SecurityIssue {
28 pub attack_pattern: String,
30 pub location: String,
32 pub description: String,
34 pub suggested_fix: String,
36 pub severity: IssueSeverity,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
42pub enum IssueSeverity {
43 Critical = 4,
45 High = 3,
47 Medium = 2,
49 Low = 1,
51}
52
53impl std::fmt::Display for IssueSeverity {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 Self::Critical => write!(f, "CRITICAL"),
57 Self::High => write!(f, "HIGH"),
58 Self::Medium => write!(f, "MEDIUM"),
59 Self::Low => write!(f, "LOW"),
60 }
61 }
62}
63
64pub struct SecurityValidator {
66 attack_db: AttackPatternDB,
67}
68
69impl SecurityValidator {
70 pub fn new() -> Self {
72 Self {
73 attack_db: AttackPatternDB::new(),
74 }
75 }
76
77 pub fn validate_file(&self, path: &Path, chain: &str) -> Result<SecurityReport, String> {
79 let code =
80 std::fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
81 self.validate_code(&code, path.to_string_lossy().as_ref(), chain)
82 }
83
84 pub fn validate_code(
86 &self,
87 code: &str,
88 file_path: &str,
89 chain: &str,
90 ) -> Result<SecurityReport, String> {
91 let mut critical_issues = Vec::new();
92 let mut high_issues = Vec::new();
93 let mut medium_issues = Vec::new();
94 let mut low_issues = Vec::new();
95
96 let patterns = self.attack_db.patterns_for_chain(chain);
98
99 for pattern in patterns {
100 let issues = self.check_pattern(code, file_path, pattern);
101 for issue in issues {
102 match issue.severity {
103 IssueSeverity::Critical => critical_issues.push(issue),
104 IssueSeverity::High => high_issues.push(issue),
105 IssueSeverity::Medium => medium_issues.push(issue),
106 IssueSeverity::Low => low_issues.push(issue),
107 }
108 }
109 }
110
111 let risk_score = (critical_issues.len() as u32 * 25
113 + high_issues.len() as u32 * 15
114 + medium_issues.len() as u32 * 8
115 + low_issues.len() as u32 * 3)
116 .min(100);
117
118 let passed = critical_issues.is_empty() && high_issues.is_empty();
119
120 Ok(SecurityReport {
121 critical_issues,
122 high_issues,
123 medium_issues,
124 low_issues,
125 passed,
126 risk_score,
127 })
128 }
129
130 fn check_pattern(
132 &self,
133 code: &str,
134 file_path: &str,
135 pattern: &AttackPattern,
136 ) -> Vec<SecurityIssue> {
137 let mut issues = Vec::new();
138
139 if pattern.id == "reentrancy" {
141 issues.extend(self.check_reentrancy(code, file_path, pattern));
142 } else {
143 for (line_num, line) in code.lines().enumerate() {
145 for vulnerable_pattern in &pattern.vulnerable_patterns {
146 if line.contains(vulnerable_pattern.as_str()) {
147 let severity = match pattern.cvss_score {
148 s if s >= 9.0 => IssueSeverity::Critical,
149 s if s >= 7.0 => IssueSeverity::High,
150 s if s >= 5.0 => IssueSeverity::Medium,
151 _ => IssueSeverity::Low,
152 };
153
154 issues.push(SecurityIssue {
155 attack_pattern: pattern.name.clone(),
156 location: format!("{}:{}", file_path, line_num + 1),
157 description: format!(
158 "Potential {} vulnerability detected. {}",
159 pattern.name, pattern.description
160 ),
161 suggested_fix: format!(
162 "Apply defensive invariant: {}",
163 pattern
164 .defensive_invariants
165 .first()
166 .unwrap_or(&"Review code".to_string())
167 ),
168 severity,
169 });
170 }
171 }
172 }
173 }
174 issues
175 }
176
177 fn check_reentrancy(
179 &self,
180 code: &str,
181 file_path: &str,
182 pattern: &AttackPattern,
183 ) -> Vec<SecurityIssue> {
184 let mut issues = Vec::new();
185 let lines: Vec<&str> = code.lines().collect();
186
187 for (line_num, line) in lines.iter().enumerate() {
189 if line.contains("nonReentrant") {
191 continue;
192 }
193
194 let has_external_call =
196 line.contains("transfer(") || line.contains(".call(") || line.contains(".send(");
197
198 if !has_external_call {
199 continue;
200 }
201
202 let mut has_state_update_before = false;
204 let search_start = line_num.saturating_sub(50);
205
206 for prev_line in lines.iter().take(line_num).skip(search_start) {
207 if (prev_line.contains("balances[") && prev_line.contains("= 0"))
209 || (prev_line.contains("balance =") && prev_line.contains("= 0"))
210 {
211 has_state_update_before = true;
212 break;
213 }
214 }
215
216 if !has_state_update_before {
218 let severity = IssueSeverity::Critical;
219
220 issues.push(SecurityIssue {
221 attack_pattern: pattern.name.clone(),
222 location: format!("{}:{}", file_path, line_num + 1),
223 description: format!(
224 "Potential {} vulnerability detected. {}",
225 pattern.name, pattern.description
226 ),
227 suggested_fix: "Apply defensive invariant: state_update_before_external_call"
228 .to_string(),
229 severity,
230 });
231 }
232 }
233 issues
234 }
235}
236
237impl Default for SecurityValidator {
238 fn default() -> Self {
239 Self::new()
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_security_validator_creation() {
249 let validator = SecurityValidator::new();
250 assert_eq!(validator.attack_db.all_patterns().len(), 8);
251 }
252
253 #[test]
254 fn test_vulnerable_code_detection() {
255 let validator = SecurityValidator::new();
256 let code = "fn transfer() { transfer_funds(); /* state update after */ }";
257 let report = validator.validate_code(code, "test.rs", "evm").unwrap();
258 assert!(!report.passed);
259 assert!(!report.critical_issues.is_empty());
260 }
261
262 #[test]
263 fn test_safe_code_passes() {
264 let validator = SecurityValidator::new();
265 let code = "fn safe_code() { let x = 1 + 1; println!(\"{}\", x); }";
266 let report = validator.validate_code(code, "test.rs", "evm").unwrap();
267 assert!(report.passed);
268 assert_eq!(report.critical_issues.len(), 0);
269 }
270
271 #[test]
272 fn test_risk_score_calculation() {
273 let validator = SecurityValidator::new();
274 let code = "fn risky() { payable(msg.sender).transfer(amount); balances[msg.sender] = 0; }";
275 let report = validator.validate_code(code, "test.rs", "evm").unwrap();
276 assert!(report.risk_score > 0);
277 }
278
279 #[test]
280 fn test_chain_specific_validation() {
281 let validator = SecurityValidator::new();
282 let code = "fn access() { require(is_owner()); }";
283
284 let evm_report = validator.validate_code(code, "test.sol", "evm").unwrap();
285 let solana_report = validator.validate_code(code, "test.rs", "solana").unwrap();
286
287 assert!(evm_report.passed || solana_report.passed);
289 }
290}