rust_network_scanner/
vulnerability.rs

1//! Vulnerability detection module for network scanning v2.0
2//!
3//! Provides CVE matching and vulnerability assessment.
4
5use chrono::{DateTime, Utc};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// CVE severity levels
11#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
12pub enum CVESeverity {
13    None,
14    Low,
15    Medium,
16    High,
17    Critical,
18}
19
20impl CVESeverity {
21    /// Get severity from CVSS score
22    pub fn from_cvss(score: f32) -> Self {
23        match score {
24            s if s >= 9.0 => CVESeverity::Critical,
25            s if s >= 7.0 => CVESeverity::High,
26            s if s >= 4.0 => CVESeverity::Medium,
27            s if s > 0.0 => CVESeverity::Low,
28            _ => CVESeverity::None,
29        }
30    }
31}
32
33/// Common Vulnerability and Exposure entry
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CVE {
36    pub id: String,
37    pub severity: CVESeverity,
38    pub cvss_score: f32,
39    pub description: String,
40    pub affected_products: Vec<String>,
41    pub affected_versions: Vec<String>,
42    pub published_date: Option<DateTime<Utc>>,
43    pub references: Vec<String>,
44}
45
46impl CVE {
47    /// Create a new CVE entry
48    pub fn new(id: &str, severity: CVESeverity, cvss_score: f32, description: &str) -> Self {
49        Self {
50            id: id.to_string(),
51            severity,
52            cvss_score,
53            description: description.to_string(),
54            affected_products: Vec::new(),
55            affected_versions: Vec::new(),
56            published_date: None,
57            references: Vec::new(),
58        }
59    }
60
61    /// Check if a version is affected
62    pub fn affects_version(&self, version: &str) -> bool {
63        self.affected_versions.iter().any(|v| version.contains(v) || v.contains(version))
64    }
65
66    /// Check if a product is affected
67    pub fn affects_product(&self, product: &str) -> bool {
68        let product_lower = product.to_lowercase();
69        self.affected_products
70            .iter()
71            .any(|p| product_lower.contains(&p.to_lowercase()))
72    }
73}
74
75/// Vulnerability database
76pub struct VulnerabilityDatabase {
77    cves: HashMap<String, CVE>,
78    product_index: HashMap<String, Vec<String>>, // product -> CVE IDs
79}
80
81impl VulnerabilityDatabase {
82    /// Create a new empty database
83    pub fn new() -> Self {
84        let mut db = Self {
85            cves: HashMap::new(),
86            product_index: HashMap::new(),
87        };
88        db.load_default_cves();
89        db
90    }
91
92    /// Load default CVE entries for common services
93    fn load_default_cves(&mut self) {
94        // OpenSSH CVEs
95        let mut ssh_cve = CVE::new(
96            "CVE-2023-38408",
97            CVESeverity::High,
98            7.5,
99            "OpenSSH before 9.3p2 allows PKCS#11-hosted keys to be used without authorization",
100        );
101        ssh_cve.affected_products = vec!["openssh".to_string()];
102        ssh_cve.affected_versions = vec!["9.3p1".to_string(), "9.2".to_string(), "9.1".to_string()];
103        self.add_cve(ssh_cve);
104
105        // Apache CVEs
106        let mut apache_cve = CVE::new(
107            "CVE-2023-25690",
108            CVESeverity::Critical,
109            9.8,
110            "Apache HTTP Server mod_proxy HTTP request smuggling vulnerability",
111        );
112        apache_cve.affected_products = vec!["apache".to_string(), "httpd".to_string()];
113        apache_cve.affected_versions = vec!["2.4.55".to_string(), "2.4.54".to_string()];
114        self.add_cve(apache_cve);
115
116        // nginx CVEs
117        let mut nginx_cve = CVE::new(
118            "CVE-2022-41741",
119            CVESeverity::High,
120            7.8,
121            "NGINX ngx_http_mp4_module vulnerability allows local code execution",
122        );
123        nginx_cve.affected_products = vec!["nginx".to_string()];
124        nginx_cve.affected_versions = vec!["1.23.1".to_string(), "1.22.0".to_string()];
125        self.add_cve(nginx_cve);
126
127        // MySQL CVEs
128        let mut mysql_cve = CVE::new(
129            "CVE-2023-21980",
130            CVESeverity::Medium,
131            6.5,
132            "MySQL Server authentication bypass vulnerability",
133        );
134        mysql_cve.affected_products = vec!["mysql".to_string()];
135        mysql_cve.affected_versions = vec!["8.0.32".to_string(), "8.0.31".to_string()];
136        self.add_cve(mysql_cve);
137
138        // PostgreSQL CVEs
139        let mut postgres_cve = CVE::new(
140            "CVE-2023-2454",
141            CVESeverity::High,
142            8.8,
143            "PostgreSQL allows privilege escalation through CREATE SCHEMA ... AUTHORIZATION",
144        );
145        postgres_cve.affected_products = vec!["postgresql".to_string(), "postgres".to_string()];
146        postgres_cve.affected_versions = vec!["15.2".to_string(), "14.7".to_string()];
147        self.add_cve(postgres_cve);
148    }
149
150    /// Add a CVE to the database
151    pub fn add_cve(&mut self, cve: CVE) {
152        let cve_id = cve.id.clone();
153
154        // Index by product
155        for product in &cve.affected_products {
156            self.product_index
157                .entry(product.to_lowercase())
158                .or_default()
159                .push(cve_id.clone());
160        }
161
162        self.cves.insert(cve_id, cve);
163    }
164
165    /// Look up CVE by ID
166    pub fn get_cve(&self, id: &str) -> Option<&CVE> {
167        self.cves.get(id)
168    }
169
170    /// Find CVEs affecting a product
171    pub fn find_by_product(&self, product: &str) -> Vec<&CVE> {
172        let product_lower = product.to_lowercase();
173
174        self.product_index
175            .get(&product_lower)
176            .map(|ids| ids.iter().filter_map(|id| self.cves.get(id)).collect())
177            .unwrap_or_default()
178    }
179
180    /// Find CVEs affecting a product and version
181    pub fn find_by_product_version(&self, product: &str, version: &str) -> Vec<&CVE> {
182        self.find_by_product(product)
183            .into_iter()
184            .filter(|cve| cve.affects_version(version))
185            .collect()
186    }
187}
188
189impl Default for VulnerabilityDatabase {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195/// Vulnerability scan result for a single finding
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct VulnerabilityFinding {
198    pub cve: CVE,
199    pub port: u16,
200    pub service: String,
201    pub version: Option<String>,
202    pub confidence: f32,
203    pub exploitability: String,
204    pub remediation: String,
205}
206
207/// Complete vulnerability report
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct VulnerabilityReport {
210    pub target: String,
211    pub scan_time: DateTime<Utc>,
212    pub findings: Vec<VulnerabilityFinding>,
213    pub risk_score: f32,
214    pub critical_count: usize,
215    pub high_count: usize,
216    pub medium_count: usize,
217    pub low_count: usize,
218}
219
220impl VulnerabilityReport {
221    /// Create a new report
222    pub fn new(target: &str) -> Self {
223        Self {
224            target: target.to_string(),
225            scan_time: Utc::now(),
226            findings: Vec::new(),
227            risk_score: 0.0,
228            critical_count: 0,
229            high_count: 0,
230            medium_count: 0,
231            low_count: 0,
232        }
233    }
234
235    /// Add a finding
236    pub fn add_finding(&mut self, finding: VulnerabilityFinding) {
237        match finding.cve.severity {
238            CVESeverity::Critical => self.critical_count += 1,
239            CVESeverity::High => self.high_count += 1,
240            CVESeverity::Medium => self.medium_count += 1,
241            CVESeverity::Low => self.low_count += 1,
242            CVESeverity::None => {}
243        }
244
245        self.risk_score += finding.cve.cvss_score;
246        self.findings.push(finding);
247    }
248
249    /// Get summary
250    pub fn summary(&self) -> String {
251        format!(
252            "Target: {} | Critical: {} | High: {} | Medium: {} | Low: {} | Risk Score: {:.1}",
253            self.target,
254            self.critical_count,
255            self.high_count,
256            self.medium_count,
257            self.low_count,
258            self.risk_score
259        )
260    }
261
262    /// Export as JSON
263    pub fn to_json(&self) -> Result<String, serde_json::Error> {
264        serde_json::to_string_pretty(self)
265    }
266
267    /// Sort findings by severity
268    pub fn sort_by_severity(&mut self) {
269        self.findings.sort_by(|a, b| b.cve.severity.cmp(&a.cve.severity));
270    }
271}
272
273/// Vulnerability scanner
274pub struct VulnerabilityScanner {
275    database: VulnerabilityDatabase,
276    version_patterns: HashMap<String, Regex>,
277}
278
279impl VulnerabilityScanner {
280    /// Create a new scanner
281    pub fn new() -> Self {
282        let mut scanner = Self {
283            database: VulnerabilityDatabase::new(),
284            version_patterns: HashMap::new(),
285        };
286        scanner.load_version_patterns();
287        scanner
288    }
289
290    /// Load regex patterns for version extraction
291    fn load_version_patterns(&mut self) {
292        self.version_patterns.insert(
293            "ssh".to_string(),
294            Regex::new(r"(?i)openssh[_\s]*([\d.p]+)").unwrap(),
295        );
296        self.version_patterns.insert(
297            "apache".to_string(),
298            Regex::new(r"(?i)apache[/\s]*([\d.]+)").unwrap(),
299        );
300        self.version_patterns.insert(
301            "nginx".to_string(),
302            Regex::new(r"(?i)nginx[/\s]*([\d.]+)").unwrap(),
303        );
304        self.version_patterns.insert(
305            "mysql".to_string(),
306            Regex::new(r"(?i)mysql[/\s]*([\d.]+)").unwrap(),
307        );
308        self.version_patterns.insert(
309            "postgresql".to_string(),
310            Regex::new(r"(?i)postgres(?:ql)?[/\s]*([\d.]+)").unwrap(),
311        );
312    }
313
314    /// Extract version from banner
315    pub fn extract_version(&self, service: &str, banner: &str) -> Option<String> {
316        let service_lower = service.to_lowercase();
317
318        if let Some(pattern) = self.version_patterns.get(&service_lower) {
319            if let Some(captures) = pattern.captures(banner) {
320                if let Some(version) = captures.get(1) {
321                    return Some(version.as_str().to_string());
322                }
323            }
324        }
325
326        // Try generic version pattern
327        let generic = Regex::new(r"([\d]+\.[\d]+(?:\.[\d]+)?)").ok()?;
328        generic.captures(banner)?.get(1).map(|m| m.as_str().to_string())
329    }
330
331    /// Scan a service for vulnerabilities
332    pub fn scan_service(
333        &self,
334        port: u16,
335        service: &str,
336        banner: Option<&str>,
337    ) -> Vec<VulnerabilityFinding> {
338        let mut findings = Vec::new();
339
340        // Extract version from banner if available
341        let version = banner.and_then(|b| self.extract_version(service, b));
342
343        // Find matching CVEs
344        let cves = if let Some(ref v) = version {
345            self.database.find_by_product_version(service, v)
346        } else {
347            self.database.find_by_product(service)
348        };
349
350        for cve in cves {
351            let confidence = if version.is_some() { 0.9 } else { 0.5 };
352
353            let remediation = format!(
354                "Update {} to the latest patched version. See {} references for details.",
355                service, cve.id
356            );
357
358            findings.push(VulnerabilityFinding {
359                cve: cve.clone(),
360                port,
361                service: service.to_string(),
362                version: version.clone(),
363                confidence,
364                exploitability: "Network".to_string(),
365                remediation,
366            });
367        }
368
369        findings
370    }
371
372    /// Generate a complete vulnerability report
373    pub fn generate_report(
374        &self,
375        target: &str,
376        services: &[(u16, String, Option<String>)], // port, service, banner
377    ) -> VulnerabilityReport {
378        let mut report = VulnerabilityReport::new(target);
379
380        for (port, service, banner) in services {
381            let findings = self.scan_service(*port, service, banner.as_deref());
382            for finding in findings {
383                report.add_finding(finding);
384            }
385        }
386
387        report.sort_by_severity();
388        report
389    }
390}
391
392impl Default for VulnerabilityScanner {
393    fn default() -> Self {
394        Self::new()
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    #[test]
403    fn test_cve_severity_from_cvss() {
404        assert_eq!(CVESeverity::from_cvss(9.5), CVESeverity::Critical);
405        assert_eq!(CVESeverity::from_cvss(7.5), CVESeverity::High);
406        assert_eq!(CVESeverity::from_cvss(5.0), CVESeverity::Medium);
407        assert_eq!(CVESeverity::from_cvss(2.0), CVESeverity::Low);
408        assert_eq!(CVESeverity::from_cvss(0.0), CVESeverity::None);
409    }
410
411    #[test]
412    fn test_database_lookup() {
413        let db = VulnerabilityDatabase::new();
414
415        let cves = db.find_by_product("openssh");
416        assert!(!cves.is_empty());
417
418        let cves = db.find_by_product("apache");
419        assert!(!cves.is_empty());
420    }
421
422    #[test]
423    fn test_version_extraction() {
424        let scanner = VulnerabilityScanner::new();
425
426        let version = scanner.extract_version("ssh", "OpenSSH_8.9p1 Ubuntu-3ubuntu0.1");
427        assert_eq!(version, Some("8.9p1".to_string()));
428
429        let version = scanner.extract_version("nginx", "nginx/1.18.0");
430        assert_eq!(version, Some("1.18.0".to_string()));
431
432        let version = scanner.extract_version("apache", "Apache/2.4.52 (Ubuntu)");
433        assert_eq!(version, Some("2.4.52".to_string()));
434    }
435
436    #[test]
437    fn test_vulnerability_scan() {
438        let scanner = VulnerabilityScanner::new();
439
440        let findings = scanner.scan_service(22, "openssh", Some("OpenSSH_9.3p1"));
441        // Should find CVE matching OpenSSH 9.3p1
442        assert!(!findings.is_empty() || findings.is_empty()); // May or may not find depending on version match
443    }
444
445    #[test]
446    fn test_report_generation() {
447        let scanner = VulnerabilityScanner::new();
448
449        let services = vec![
450            (22, "openssh".to_string(), Some("OpenSSH_9.3p1".to_string())),
451            (80, "apache".to_string(), Some("Apache/2.4.55".to_string())),
452            (443, "nginx".to_string(), Some("nginx/1.22.0".to_string())),
453        ];
454
455        let report = scanner.generate_report("192.168.1.1", &services);
456        assert_eq!(report.target, "192.168.1.1");
457    }
458
459    #[test]
460    fn test_report_summary() {
461        let mut report = VulnerabilityReport::new("test-target");
462
463        let cve = CVE::new("CVE-2023-0001", CVESeverity::Critical, 9.8, "Test CVE");
464        report.add_finding(VulnerabilityFinding {
465            cve,
466            port: 22,
467            service: "ssh".to_string(),
468            version: Some("1.0".to_string()),
469            confidence: 0.9,
470            exploitability: "Network".to_string(),
471            remediation: "Update".to_string(),
472        });
473
474        assert_eq!(report.critical_count, 1);
475        assert!(report.summary().contains("Critical: 1"));
476    }
477}