Skip to main content

sbom_tools/tui/
security.rs

1//! Security analysis utilities for TUI.
2//!
3//! Provides blast radius analysis, risk indicators, and security-focused
4//! utilities for security analysts working with SBOMs.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7
8/// Component compliance data: (name, version, licenses, vulns\[(id, severity)\]).
9pub type ComplianceComponentData = (String, Option<String>, Vec<String>, Vec<(String, String)>);
10
11/// Blast radius analysis result for a component
12#[derive(Debug, Clone, Default)]
13pub struct BlastRadius {
14    /// Direct dependents (components that directly depend on this)
15    pub direct_dependents: Vec<String>,
16    /// All transitive dependents (full blast radius)
17    pub transitive_dependents: HashSet<String>,
18    /// Maximum depth of impact
19    pub max_depth: usize,
20    /// Risk level based on impact
21    pub risk_level: RiskLevel,
22    /// Critical paths (paths to important components)
23    pub critical_paths: Vec<Vec<String>>,
24}
25
26impl BlastRadius {
27    /// Total number of affected components
28    pub fn total_affected(&self) -> usize {
29        self.transitive_dependents.len()
30    }
31
32    /// Compute blast radius impact description
33    pub fn impact_description(&self) -> &'static str {
34        match self.transitive_dependents.len() {
35            0 => "No downstream impact",
36            1..=5 => "Limited impact",
37            6..=20 => "Moderate impact",
38            21..=50 => "Significant impact",
39            _ => "Critical impact - affects many components",
40        }
41    }
42}
43
44/// Risk level for a component
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
46pub enum RiskLevel {
47    #[default]
48    Low,
49    Medium,
50    High,
51    Critical,
52}
53
54impl RiskLevel {
55    pub fn as_str(&self) -> &'static str {
56        match self {
57            RiskLevel::Low => "Low",
58            RiskLevel::Medium => "Medium",
59            RiskLevel::High => "High",
60            RiskLevel::Critical => "Critical",
61        }
62    }
63
64    pub fn symbol(&self) -> &'static str {
65        match self {
66            RiskLevel::Low => "○",
67            RiskLevel::Medium => "◐",
68            RiskLevel::High => "●",
69            RiskLevel::Critical => "◉",
70        }
71    }
72}
73
74/// Risk indicators for a component
75#[derive(Debug, Clone, Default)]
76pub struct RiskIndicators {
77    /// Vulnerability count
78    pub vuln_count: usize,
79    /// Highest vulnerability severity
80    pub highest_severity: Option<String>,
81    /// Number of direct dependents
82    pub direct_dependent_count: usize,
83    /// Number of transitive dependents (blast radius)
84    pub transitive_dependent_count: usize,
85    /// License risk (unknown, copyleft, etc.)
86    pub license_risk: LicenseRisk,
87    /// Is this a direct dependency (depth 1)
88    pub is_direct_dep: bool,
89    /// Dependency depth from root
90    pub depth: usize,
91    /// Overall risk score (0-100)
92    pub risk_score: u8,
93    /// Overall risk level
94    pub risk_level: RiskLevel,
95}
96
97impl RiskIndicators {
98    /// Calculate risk score based on various factors
99    pub fn calculate_risk_score(&mut self) {
100        let mut score: u32 = 0;
101
102        // Vulnerability contribution (0-40 points)
103        score += match self.vuln_count {
104            0 => 0,
105            1 => 15,
106            2..=5 => 25,
107            _ => 40,
108        };
109
110        // Severity contribution (0-30 points)
111        if let Some(ref sev) = self.highest_severity {
112            let sev_lower = sev.to_lowercase();
113            score += if sev_lower.contains("critical") {
114                30
115            } else if sev_lower.contains("high") {
116                20
117            } else if sev_lower.contains("medium") {
118                10
119            } else {
120                5
121            };
122        }
123
124        // Blast radius contribution (0-20 points)
125        score += match self.transitive_dependent_count {
126            0 => 0,
127            1..=5 => 5,
128            6..=20 => 10,
129            21..=50 => 15,
130            _ => 20,
131        };
132
133        // License risk contribution (0-10 points)
134        score += match self.license_risk {
135            LicenseRisk::None => 0,
136            LicenseRisk::Low => 2,
137            LicenseRisk::Medium => 5,
138            LicenseRisk::High => 10,
139        };
140
141        self.risk_score = score.min(100) as u8;
142
143        // Determine risk level
144        self.risk_level = match self.risk_score {
145            0..=25 => RiskLevel::Low,
146            26..=50 => RiskLevel::Medium,
147            51..=75 => RiskLevel::High,
148            _ => RiskLevel::Critical,
149        };
150    }
151}
152
153/// License risk level
154#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
155pub enum LicenseRisk {
156    #[default]
157    None,
158    Low,      // Permissive (MIT, Apache, BSD)
159    Medium,   // Weak copyleft (LGPL, MPL)
160    High,     // Strong copyleft (GPL, AGPL) or Unknown
161}
162
163impl LicenseRisk {
164    pub fn from_license(license: &str) -> Self {
165        let lower = license.to_lowercase();
166
167        if lower.contains("unlicense")
168            || lower.contains("mit")
169            || lower.contains("apache")
170            || lower.contains("bsd")
171            || lower.contains("isc")
172            || lower.contains("cc0")
173        {
174            LicenseRisk::Low
175        } else if lower.contains("lgpl") || lower.contains("mpl") || lower.contains("cddl") {
176            LicenseRisk::Medium
177        } else if lower.contains("gpl") || lower.contains("agpl") || lower.contains("unknown") {
178            LicenseRisk::High
179        } else {
180            LicenseRisk::None
181        }
182    }
183
184    pub fn as_str(&self) -> &'static str {
185        match self {
186            LicenseRisk::None => "Unknown",
187            LicenseRisk::Low => "Permissive",
188            LicenseRisk::Medium => "Weak Copyleft",
189            LicenseRisk::High => "Copyleft/Unknown",
190        }
191    }
192}
193
194/// Flagged item for analyst follow-up
195#[derive(Debug, Clone)]
196pub struct FlaggedItem {
197    /// Component ID or name
198    pub component_id: String,
199    /// Reason for flagging
200    pub reason: String,
201    /// Optional analyst note
202    pub note: Option<String>,
203    /// Timestamp
204    pub flagged_at: std::time::Instant,
205}
206
207/// Security analysis cache for the TUI
208#[derive(Debug, Default)]
209pub struct SecurityAnalysisCache {
210    /// Cached blast radius for components
211    pub blast_radius_cache: HashMap<String, BlastRadius>,
212    /// Cached risk indicators
213    pub risk_indicators_cache: HashMap<String, RiskIndicators>,
214    /// Flagged items for follow-up
215    pub flagged_items: Vec<FlaggedItem>,
216    /// Components flagged (for quick lookup)
217    pub flagged_set: HashSet<String>,
218}
219
220impl SecurityAnalysisCache {
221    pub fn new() -> Self {
222        Self::default()
223    }
224
225    /// Compute blast radius for a component using reverse dependency graph
226    pub fn compute_blast_radius(
227        &mut self,
228        component_id: &str,
229        reverse_graph: &HashMap<String, Vec<String>>,
230    ) -> &BlastRadius {
231        if self.blast_radius_cache.contains_key(component_id) {
232            return &self.blast_radius_cache[component_id];
233        }
234
235        let mut blast = BlastRadius::default();
236
237        // Direct dependents
238        if let Some(direct) = reverse_graph.get(component_id) {
239            blast.direct_dependents = direct.clone();
240        }
241
242        // BFS to find all transitive dependents
243        let mut visited: HashSet<String> = HashSet::new();
244        let mut queue: VecDeque<(String, usize)> = VecDeque::new();
245        let mut max_depth = 0usize;
246
247        // Start with direct dependents
248        for dep in &blast.direct_dependents {
249            queue.push_back((dep.clone(), 1));
250        }
251
252        while let Some((node, depth)) = queue.pop_front() {
253            if visited.contains(&node) {
254                continue;
255            }
256            visited.insert(node.clone());
257            blast.transitive_dependents.insert(node.clone());
258            max_depth = max_depth.max(depth);
259
260            // Add this node's dependents
261            if let Some(dependents) = reverse_graph.get(&node) {
262                for dep in dependents {
263                    if !visited.contains(dep) {
264                        queue.push_back((dep.clone(), depth + 1));
265                    }
266                }
267            }
268        }
269
270        blast.max_depth = max_depth;
271
272        // Determine risk level based on blast radius
273        blast.risk_level = match blast.transitive_dependents.len() {
274            0 => RiskLevel::Low,
275            1..=5 => RiskLevel::Low,
276            6..=20 => RiskLevel::Medium,
277            21..=50 => RiskLevel::High,
278            _ => RiskLevel::Critical,
279        };
280
281        self.blast_radius_cache
282            .insert(component_id.to_string(), blast);
283        &self.blast_radius_cache[component_id]
284    }
285
286    /// Flag a component for follow-up
287    pub fn flag_component(&mut self, component_id: &str, reason: &str) {
288        if !self.flagged_set.contains(component_id) {
289            self.flagged_items.push(FlaggedItem {
290                component_id: component_id.to_string(),
291                reason: reason.to_string(),
292                note: None,
293                flagged_at: std::time::Instant::now(),
294            });
295            self.flagged_set.insert(component_id.to_string());
296        }
297    }
298
299    /// Unflag a component
300    pub fn unflag_component(&mut self, component_id: &str) {
301        self.flagged_items
302            .retain(|item| item.component_id != component_id);
303        self.flagged_set.remove(component_id);
304    }
305
306    /// Toggle flag status
307    pub fn toggle_flag(&mut self, component_id: &str, reason: &str) {
308        if self.flagged_set.contains(component_id) {
309            self.unflag_component(component_id);
310        } else {
311            self.flag_component(component_id, reason);
312        }
313    }
314
315    /// Check if a component is flagged
316    pub fn is_flagged(&self, component_id: &str) -> bool {
317        self.flagged_set.contains(component_id)
318    }
319
320    /// Add note to a flagged component
321    pub fn add_note(&mut self, component_id: &str, note: &str) {
322        for item in &mut self.flagged_items {
323            if item.component_id == component_id {
324                item.note = Some(note.to_string());
325                break;
326            }
327        }
328    }
329
330    /// Clear all caches
331    pub fn clear(&mut self) {
332        self.blast_radius_cache.clear();
333        self.risk_indicators_cache.clear();
334    }
335
336    /// Invalidate cache for a specific component
337    pub fn invalidate(&mut self, component_id: &str) {
338        self.blast_radius_cache.remove(component_id);
339        self.risk_indicators_cache.remove(component_id);
340    }
341
342    /// Get note for a flagged component
343    pub fn get_note(&self, component_id: &str) -> Option<&str> {
344        self.flagged_items
345            .iter()
346            .find(|item| item.component_id == component_id)
347            .and_then(|item| item.note.as_deref())
348    }
349}
350
351// ============================================================================
352// Vulnerability Prioritization
353// ============================================================================
354
355/// Vulnerability priority information for sorting
356#[derive(Debug, Clone)]
357pub struct VulnPriority {
358    /// Parsed CVSS score (0.0-10.0)
359    pub cvss_score: f32,
360    /// Severity level as numeric value (for sorting)
361    pub severity_rank: u8,
362    /// Fix urgency score (0-100)
363    pub fix_urgency: u8,
364    /// Blast radius of affected component
365    pub blast_radius: usize,
366    /// Whether this is a known exploited vulnerability
367    pub is_known_exploited: bool,
368}
369
370impl Default for VulnPriority {
371    fn default() -> Self {
372        Self {
373            cvss_score: 0.0,
374            severity_rank: 0,
375            fix_urgency: 0,
376            blast_radius: 0,
377            is_known_exploited: false,
378        }
379    }
380}
381
382/// Parse CVSS score from a string (e.g., "9.8", "CVSS:3.1/AV:N/AC:L/...")
383pub fn parse_cvss_score(score_str: &str) -> f32 {
384    // Try direct float parse first
385    if let Ok(score) = score_str.parse::<f32>() {
386        return score.clamp(0.0, 10.0);
387    }
388
389    // Try to extract from CVSS vector string
390    if score_str.contains("CVSS:") {
391        // Look for base score at the end or extract from metrics
392        // Common format: "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H" (no score)
393        // or with score appended
394        if let Some(last_part) = score_str.split('/').next_back() {
395            if let Ok(score) = last_part.parse::<f32>() {
396                return score.clamp(0.0, 10.0);
397            }
398        }
399    }
400
401    0.0
402}
403
404/// Convert severity string to numeric rank for sorting
405pub fn severity_to_rank(severity: &str) -> u8 {
406    let s = severity.to_lowercase();
407    if s.contains("critical") {
408        4
409    } else if s.contains("high") {
410        3
411    } else if s.contains("medium") || s.contains("moderate") {
412        2
413    } else if s.contains("low") {
414        1
415    } else {
416        0 // Unknown/None
417    }
418}
419
420/// Calculate fix urgency score (0-100) based on severity and blast radius
421pub fn calculate_fix_urgency(severity_rank: u8, blast_radius: usize, cvss_score: f32) -> u8 {
422    // Base score from severity (0-40)
423    let severity_score = (severity_rank as u32) * 10;
424
425    // CVSS contribution (0-30)
426    let cvss_contribution = (cvss_score * 3.0) as u32;
427
428    // Blast radius contribution (0-30)
429    let blast_score = match blast_radius {
430        0 => 0,
431        1..=5 => 10,
432        6..=20 => 20,
433        _ => 30,
434    };
435
436    (severity_score + cvss_contribution + blast_score).min(100) as u8
437}
438
439/// Check if a vulnerability ID might be a known exploited vulnerability (KEV)
440/// This is a heuristic - real KEV checking would need an API call
441pub fn is_likely_known_exploited(vuln_id: &str, severity: &str) -> bool {
442    // Critical CVEs with certain patterns are more likely to be exploited
443    let is_critical = severity.to_lowercase().contains("critical");
444    let is_recent_cve = vuln_id.starts_with("CVE-202"); // 2020+
445
446    // Known high-profile vulnerabilities (simplified check)
447    let known_patterns = [
448        "CVE-2021-44228", // Log4Shell
449        "CVE-2021-45046", // Log4j
450        "CVE-2022-22965", // Spring4Shell
451        "CVE-2023-44487", // HTTP/2 Rapid Reset
452        "CVE-2024-3094",  // XZ Utils
453    ];
454
455    known_patterns.iter().any(|p| vuln_id.contains(p))
456        || (is_critical && is_recent_cve)
457}
458
459// ============================================================================
460// Version Downgrade Detection
461// ============================================================================
462
463/// Result of version comparison for downgrade detection
464#[derive(Debug, Clone, PartialEq, Eq)]
465pub enum VersionChange {
466    /// Version increased (normal upgrade)
467    Upgrade,
468    /// Version decreased (potential attack)
469    Downgrade,
470    /// Same version
471    NoChange,
472    /// Cannot determine (unparseable versions)
473    Unknown,
474}
475
476/// Detect if a version change is a downgrade (potential supply chain attack)
477pub fn detect_version_downgrade(old_version: &str, new_version: &str) -> VersionChange {
478    if old_version == new_version {
479        return VersionChange::NoChange;
480    }
481
482    // Try semver parsing first
483    if let (Some(old_parts), Some(new_parts)) = (
484        parse_version_parts(old_version),
485        parse_version_parts(new_version),
486    ) {
487        // Compare major.minor.patch
488        for (old, new) in old_parts.iter().zip(new_parts.iter()) {
489            if new > old {
490                return VersionChange::Upgrade;
491            } else if new < old {
492                return VersionChange::Downgrade;
493            }
494        }
495        // If we get here, versions are equal up to the compared parts
496        if new_parts.len() < old_parts.len() {
497            return VersionChange::Downgrade; // e.g., 1.2.3 -> 1.2
498        } else if new_parts.len() > old_parts.len() {
499            return VersionChange::Upgrade; // e.g., 1.2 -> 1.2.3
500        }
501        return VersionChange::NoChange;
502    }
503
504    // Fallback: lexicographic comparison (less reliable)
505    if new_version < old_version {
506        VersionChange::Downgrade
507    } else if new_version > old_version {
508        VersionChange::Upgrade
509    } else {
510        VersionChange::Unknown
511    }
512}
513
514/// Parse version string into numeric parts
515fn parse_version_parts(version: &str) -> Option<Vec<u32>> {
516    // Remove common prefixes like 'v', 'V', 'version-'
517    let cleaned = version
518        .trim_start_matches(|c: char| !c.is_ascii_digit())
519        .split(|c: char| !c.is_ascii_digit() && c != '.')
520        .next()
521        .unwrap_or(version);
522
523    let parts: Vec<u32> = cleaned
524        .split('.')
525        .filter_map(|p| p.parse().ok())
526        .collect();
527
528    if parts.is_empty() {
529        None
530    } else {
531        Some(parts)
532    }
533}
534
535/// Check if a component change represents a security concern
536#[derive(Debug, Clone)]
537pub struct DowngradeWarning {
538    pub component_name: String,
539    pub old_version: String,
540    pub new_version: String,
541    pub severity: DowngradeSeverity,
542}
543
544#[derive(Debug, Clone, Copy, PartialEq, Eq)]
545pub enum DowngradeSeverity {
546    /// Minor version downgrade (e.g., 1.2.3 -> 1.2.2)
547    Minor,
548    /// Major version downgrade (e.g., 2.0.0 -> 1.9.0)
549    Major,
550    /// Suspicious pattern (e.g., security patch removed)
551    Suspicious,
552}
553
554impl DowngradeSeverity {
555    pub fn as_str(&self) -> &'static str {
556        match self {
557            Self::Minor => "Minor Downgrade",
558            Self::Major => "Major Downgrade",
559            Self::Suspicious => "Suspicious",
560        }
561    }
562}
563
564/// Analyze a version change for downgrade severity
565pub fn analyze_downgrade(old_version: &str, new_version: &str) -> Option<DowngradeSeverity> {
566    if detect_version_downgrade(old_version, new_version) != VersionChange::Downgrade {
567        return None;
568    }
569
570    let old_parts = parse_version_parts(old_version)?;
571    let new_parts = parse_version_parts(new_version)?;
572
573    // Check if major version decreased
574    if let (Some(&old_major), Some(&new_major)) = (old_parts.first(), new_parts.first()) {
575        if new_major < old_major {
576            return Some(DowngradeSeverity::Major);
577        }
578    }
579
580    // Check for suspicious patterns (security-related version strings)
581    let old_lower = old_version.to_lowercase();
582    let new_lower = new_version.to_lowercase();
583    if (old_lower.contains("security") || old_lower.contains("patch") || old_lower.contains("fix"))
584        && !new_lower.contains("security")
585        && !new_lower.contains("patch")
586        && !new_lower.contains("fix")
587    {
588        return Some(DowngradeSeverity::Suspicious);
589    }
590
591    Some(DowngradeSeverity::Minor)
592}
593
594/// Sanitize a vulnerability ID to contain only safe characters.
595/// Allows alphanumeric, hyphen, underscore, dot, and colon — sufficient
596/// for CVE, GHSA, RUSTSEC, PYSEC, and other standard advisory IDs.
597fn sanitize_vuln_id(id: &str) -> String {
598    id.chars()
599        .filter(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.' | ':'))
600        .collect()
601}
602
603/// Format a CVE ID as a URL for opening in browser
604pub fn cve_url(cve_id: &str) -> String {
605    let safe_id = sanitize_vuln_id(cve_id);
606    if safe_id.to_uppercase().starts_with("CVE-") {
607        format!("https://nvd.nist.gov/vuln/detail/{}", safe_id.to_uppercase())
608    } else if safe_id.to_uppercase().starts_with("GHSA-") {
609        format!("https://github.com/advisories/{}", safe_id.to_uppercase())
610    } else if safe_id.starts_with("RUSTSEC-") {
611        format!("https://rustsec.org/advisories/{}", safe_id)
612    } else if safe_id.starts_with("PYSEC-") {
613        format!("https://osv.dev/vulnerability/{}", safe_id)
614    } else {
615        // Generic OSV lookup
616        format!("https://osv.dev/vulnerability/{}", safe_id)
617    }
618}
619
620/// Validate that a URL contains only characters from RFC 3986
621/// (unreserved + reserved + percent-encoded). Rejects control characters,
622/// spaces, backticks, pipes, and other non-URL characters that could be
623/// misinterpreted by platform open commands.
624fn is_safe_url(url: &str) -> bool {
625    url.chars().all(|c| {
626        c.is_ascii_alphanumeric()
627            || matches!(
628                c,
629                ':' | '/' | '.' | '-' | '_' | '~' | '?' | '#' | '[' | ']' | '@' | '!' | '$'
630                    | '&' | '\'' | '(' | ')' | '*' | '+' | ',' | ';' | '=' | '%'
631            )
632    })
633}
634
635/// Open a URL in the default browser
636pub fn open_in_browser(url: &str) -> Result<(), String> {
637    if !is_safe_url(url) {
638        return Err("URL contains unsafe characters".to_string());
639    }
640
641    #[cfg(target_os = "macos")]
642    {
643        std::process::Command::new("open")
644            .arg(url)
645            .spawn()
646            .map_err(|e| format!("Failed to open browser: {}", e))?;
647    }
648
649    #[cfg(target_os = "linux")]
650    {
651        std::process::Command::new("xdg-open")
652            .arg(url)
653            .spawn()
654            .map_err(|e| format!("Failed to open browser: {}", e))?;
655    }
656
657    #[cfg(target_os = "windows")]
658    {
659        // Use explorer.exe instead of cmd /C start to avoid shell
660        // metacharacter interpretation (e.g. & | > would be dangerous
661        // with cmd.exe). explorer.exe receives the URL as a direct
662        // process argument with no shell involved.
663        std::process::Command::new("explorer")
664            .arg(url)
665            .spawn()
666            .map_err(|e| format!("Failed to open browser: {}", e))?;
667    }
668
669    Ok(())
670}
671
672/// Copy text to system clipboard
673pub fn copy_to_clipboard(text: &str) -> Result<(), String> {
674    #[cfg(target_os = "macos")]
675    {
676        use std::io::Write;
677        let mut child = std::process::Command::new("pbcopy")
678            .stdin(std::process::Stdio::piped())
679            .spawn()
680            .map_err(|e| format!("Failed to copy to clipboard: {}", e))?;
681
682        if let Some(stdin) = child.stdin.as_mut() {
683            stdin
684                .write_all(text.as_bytes())
685                .map_err(|e| format!("Failed to write to clipboard: {}", e))?;
686        }
687        child
688            .wait()
689            .map_err(|e| format!("Clipboard command failed: {}", e))?;
690    }
691
692    #[cfg(target_os = "linux")]
693    {
694        use std::io::Write;
695        // Try xclip first, then xsel
696        let result = std::process::Command::new("xclip")
697            .args(["-selection", "clipboard"])
698            .stdin(std::process::Stdio::piped())
699            .spawn();
700
701        let mut child = match result {
702            Ok(child) => child,
703            Err(_) => std::process::Command::new("xsel")
704                .args(["--clipboard", "--input"])
705                .stdin(std::process::Stdio::piped())
706                .spawn()
707                .map_err(|e| format!("Failed to copy to clipboard: {}", e))?,
708        };
709
710        if let Some(stdin) = child.stdin.as_mut() {
711            stdin
712                .write_all(text.as_bytes())
713                .map_err(|e| format!("Failed to write to clipboard: {}", e))?;
714        }
715        child
716            .wait()
717            .map_err(|e| format!("Clipboard command failed: {}", e))?;
718    }
719
720    #[cfg(target_os = "windows")]
721    {
722        // Use clip.exe with stdin to avoid command injection via string interpolation
723        use std::io::Write;
724        let mut child = std::process::Command::new("clip")
725            .stdin(std::process::Stdio::piped())
726            .spawn()
727            .map_err(|e| format!("Failed to copy to clipboard: {}", e))?;
728
729        if let Some(stdin) = child.stdin.as_mut() {
730            stdin
731                .write_all(text.as_bytes())
732                .map_err(|e| format!("Failed to write to clipboard: {}", e))?;
733        }
734        child
735            .wait()
736            .map_err(|e| format!("Clipboard command failed: {}", e))?;
737    }
738
739    Ok(())
740}
741
742// ============================================================================
743// Attack Path Visualization
744// ============================================================================
745
746/// An attack path from an entry point to a vulnerable component
747#[derive(Debug, Clone)]
748pub struct AttackPath {
749    /// The path of component names from entry point to target
750    pub path: Vec<String>,
751    /// The entry point (root component)
752    pub entry_point: String,
753    /// The vulnerable component (target)
754    pub target: String,
755    /// Path length (number of hops)
756    pub depth: usize,
757    /// Risk score based on path characteristics
758    pub risk_score: u8,
759}
760
761impl AttackPath {
762    /// Format the path as a readable string
763    pub fn format(&self) -> String {
764        self.path.join(" → ")
765    }
766
767    /// Get a short description of the path
768    pub fn description(&self) -> String {
769        if self.depth == 1 {
770            "Direct dependency".to_string()
771        } else {
772            format!("{} hops", self.depth)
773        }
774    }
775}
776
777/// Find attack paths from root components to a vulnerable component
778pub fn find_attack_paths(
779    target: &str,
780    forward_graph: &HashMap<String, Vec<String>>,
781    root_components: &[String],
782    max_paths: usize,
783    max_depth: usize,
784) -> Vec<AttackPath> {
785    let mut paths = Vec::new();
786
787    // BFS from each root to find paths to target
788    for root in root_components {
789        if root == target {
790            // Direct hit - root is the vulnerable component
791            paths.push(AttackPath {
792                path: vec![root.clone()],
793                entry_point: root.clone(),
794                target: target.to_string(),
795                depth: 0,
796                risk_score: 100, // Highest risk - direct exposure
797            });
798            continue;
799        }
800
801        // BFS to find path from this root to target
802        let mut visited: HashSet<String> = HashSet::new();
803        let mut queue: VecDeque<(String, Vec<String>)> = VecDeque::new();
804        queue.push_back((root.clone(), vec![root.clone()]));
805        visited.insert(root.clone());
806
807        while let Some((current, path)) = queue.pop_front() {
808            if path.len() > max_depth {
809                continue;
810            }
811
812            // Check all dependencies of current node
813            if let Some(deps) = forward_graph.get(&current) {
814                for dep in deps {
815                    if dep == target {
816                        // Found a path!
817                        let mut full_path = path.clone();
818                        full_path.push(dep.clone());
819                        let depth = full_path.len() - 1;
820
821                        // Risk score decreases with depth
822                        let risk_score = match depth {
823                            1 => 90,
824                            2 => 70,
825                            3 => 50,
826                            4 => 30,
827                            _ => 10,
828                        };
829
830                        paths.push(AttackPath {
831                            path: full_path,
832                            entry_point: root.clone(),
833                            target: target.to_string(),
834                            depth,
835                            risk_score,
836                        });
837
838                        if paths.len() >= max_paths {
839                            // Sort by risk score before returning
840                            paths.sort_by(|a, b| b.risk_score.cmp(&a.risk_score));
841                            return paths;
842                        }
843                    } else if !visited.contains(dep) {
844                        visited.insert(dep.clone());
845                        let mut new_path = path.clone();
846                        new_path.push(dep.clone());
847                        queue.push_back((dep.clone(), new_path));
848                    }
849                }
850            }
851        }
852    }
853
854    // Sort by risk score (highest first), then by depth (shortest first)
855    paths.sort_by(|a, b| {
856        b.risk_score
857            .cmp(&a.risk_score)
858            .then_with(|| a.depth.cmp(&b.depth))
859    });
860    paths
861}
862
863/// Identify root components (components with no dependents)
864pub fn find_root_components(
865    all_components: &[String],
866    reverse_graph: &HashMap<String, Vec<String>>,
867) -> Vec<String> {
868    all_components
869        .iter()
870        .filter(|comp| {
871            reverse_graph
872                .get(*comp)
873                .map(|deps| deps.is_empty())
874                .unwrap_or(true)
875        })
876        .cloned()
877        .collect()
878}
879
880// ============================================================================
881// Compliance / Policy Checking
882// ============================================================================
883
884/// A policy rule for compliance checking
885#[derive(Debug, Clone)]
886pub enum PolicyRule {
887    /// Ban specific licenses (e.g., GPL in proprietary projects)
888    BannedLicense {
889        pattern: String,
890        reason: String,
891    },
892    /// Require specific licenses
893    RequiredLicense {
894        allowed: Vec<String>,
895        reason: String,
896    },
897    /// Ban specific components by name pattern
898    BannedComponent {
899        pattern: String,
900        reason: String,
901    },
902    /// Minimum version requirement for a component
903    MinimumVersion {
904        component_pattern: String,
905        min_version: String,
906        reason: String,
907    },
908    /// No pre-release versions (0.x.x)
909    NoPreRelease {
910        reason: String,
911    },
912    /// Maximum vulnerability severity allowed
913    MaxVulnerabilitySeverity {
914        max_severity: String,
915        reason: String,
916    },
917    /// Require SBOM completeness (minimum fields)
918    RequireFields {
919        fields: Vec<String>,
920        reason: String,
921    },
922}
923
924impl PolicyRule {
925    pub fn name(&self) -> &'static str {
926        match self {
927            PolicyRule::BannedLicense { .. } => "Banned License",
928            PolicyRule::RequiredLicense { .. } => "License Allowlist",
929            PolicyRule::BannedComponent { .. } => "Banned Component",
930            PolicyRule::MinimumVersion { .. } => "Minimum Version",
931            PolicyRule::NoPreRelease { .. } => "No Pre-Release",
932            PolicyRule::MaxVulnerabilitySeverity { .. } => "Max Vulnerability Severity",
933            PolicyRule::RequireFields { .. } => "Required Fields",
934        }
935    }
936
937    pub fn severity(&self) -> PolicySeverity {
938        match self {
939            PolicyRule::BannedLicense { .. } => PolicySeverity::High,
940            PolicyRule::RequiredLicense { .. } => PolicySeverity::Medium,
941            PolicyRule::BannedComponent { .. } => PolicySeverity::Critical,
942            PolicyRule::MinimumVersion { .. } => PolicySeverity::Medium,
943            PolicyRule::NoPreRelease { .. } => PolicySeverity::Low,
944            PolicyRule::MaxVulnerabilitySeverity { .. } => PolicySeverity::High,
945            PolicyRule::RequireFields { .. } => PolicySeverity::Low,
946        }
947    }
948}
949
950/// Severity of a policy violation
951#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
952pub enum PolicySeverity {
953    Low,
954    Medium,
955    High,
956    Critical,
957}
958
959impl PolicySeverity {
960    pub fn as_str(&self) -> &'static str {
961        match self {
962            PolicySeverity::Low => "Low",
963            PolicySeverity::Medium => "Medium",
964            PolicySeverity::High => "High",
965            PolicySeverity::Critical => "Critical",
966        }
967    }
968
969    pub fn symbol(&self) -> &'static str {
970        match self {
971            PolicySeverity::Low => "○",
972            PolicySeverity::Medium => "◐",
973            PolicySeverity::High => "●",
974            PolicySeverity::Critical => "◉",
975        }
976    }
977}
978
979/// A policy violation
980#[derive(Debug, Clone)]
981pub struct PolicyViolation {
982    /// The rule that was violated
983    pub rule_name: String,
984    /// Severity of the violation
985    pub severity: PolicySeverity,
986    /// Component that violated the rule (if applicable)
987    pub component: Option<String>,
988    /// Description of what violated the rule
989    pub description: String,
990    /// Suggested remediation
991    pub remediation: String,
992}
993
994/// Security policy configuration
995#[derive(Debug, Clone, Default)]
996pub struct SecurityPolicy {
997    /// Name of this policy
998    pub name: String,
999    /// Policy rules
1000    pub rules: Vec<PolicyRule>,
1001}
1002
1003impl SecurityPolicy {
1004    /// Create a default enterprise security policy
1005    pub fn enterprise_default() -> Self {
1006        Self {
1007            name: "Enterprise Security Policy".to_string(),
1008            rules: vec![
1009                PolicyRule::BannedLicense {
1010                    pattern: "GPL".to_string(),
1011                    reason: "GPL licenses incompatible with proprietary software".to_string(),
1012                },
1013                PolicyRule::BannedLicense {
1014                    pattern: "AGPL".to_string(),
1015                    reason: "AGPL requires source disclosure for network services".to_string(),
1016                },
1017                PolicyRule::MaxVulnerabilitySeverity {
1018                    max_severity: "High".to_string(),
1019                    reason: "Critical vulnerabilities must be remediated before deployment"
1020                        .to_string(),
1021                },
1022                PolicyRule::NoPreRelease {
1023                    reason: "Pre-release versions (0.x) may have unstable APIs".to_string(),
1024                },
1025            ],
1026        }
1027    }
1028
1029    /// Create a strict security policy
1030    pub fn strict() -> Self {
1031        Self {
1032            name: "Strict Security Policy".to_string(),
1033            rules: vec![
1034                PolicyRule::BannedLicense {
1035                    pattern: "GPL".to_string(),
1036                    reason: "GPL licenses not allowed".to_string(),
1037                },
1038                PolicyRule::BannedLicense {
1039                    pattern: "AGPL".to_string(),
1040                    reason: "AGPL licenses not allowed".to_string(),
1041                },
1042                PolicyRule::BannedLicense {
1043                    pattern: "LGPL".to_string(),
1044                    reason: "LGPL licenses not allowed".to_string(),
1045                },
1046                PolicyRule::MaxVulnerabilitySeverity {
1047                    max_severity: "Medium".to_string(),
1048                    reason: "High/Critical vulnerabilities not allowed".to_string(),
1049                },
1050                PolicyRule::NoPreRelease {
1051                    reason: "Pre-release versions not allowed in production".to_string(),
1052                },
1053                PolicyRule::BannedComponent {
1054                    pattern: "lodash".to_string(),
1055                    reason: "Use native JS methods or lighter alternatives".to_string(),
1056                },
1057            ],
1058        }
1059    }
1060
1061    /// Create a permissive policy (minimal checks)
1062    pub fn permissive() -> Self {
1063        Self {
1064            name: "Permissive Policy".to_string(),
1065            rules: vec![PolicyRule::MaxVulnerabilitySeverity {
1066                max_severity: "Critical".to_string(),
1067                reason: "Critical vulnerabilities should be reviewed".to_string(),
1068            }],
1069        }
1070    }
1071}
1072
1073/// Result of a compliance check
1074#[derive(Debug, Clone, Default)]
1075pub struct ComplianceResult {
1076    /// Policy name that was checked
1077    pub policy_name: String,
1078    /// Total components checked
1079    pub components_checked: usize,
1080    /// Violations found
1081    pub violations: Vec<PolicyViolation>,
1082    /// Compliance score (0-100)
1083    pub score: u8,
1084    /// Whether the SBOM passes the policy
1085    pub passes: bool,
1086}
1087
1088impl ComplianceResult {
1089    /// Count violations by severity
1090    pub fn count_by_severity(&self, severity: PolicySeverity) -> usize {
1091        self.violations
1092            .iter()
1093            .filter(|v| v.severity == severity)
1094            .count()
1095    }
1096
1097    /// Get summary of violations
1098    pub fn summary(&self) -> String {
1099        if self.violations.is_empty() {
1100            "All checks passed".to_string()
1101        } else {
1102            let critical = self.count_by_severity(PolicySeverity::Critical);
1103            let high = self.count_by_severity(PolicySeverity::High);
1104            let medium = self.count_by_severity(PolicySeverity::Medium);
1105            let low = self.count_by_severity(PolicySeverity::Low);
1106            format!(
1107                "{} critical, {} high, {} medium, {} low",
1108                critical, high, medium, low
1109            )
1110        }
1111    }
1112}
1113
1114/// Check compliance of components against a policy
1115pub fn check_compliance(
1116    policy: &SecurityPolicy,
1117    components: &[ComplianceComponentData],
1118) -> ComplianceResult {
1119    let mut result = ComplianceResult {
1120        policy_name: policy.name.clone(),
1121        components_checked: components.len(),
1122        violations: Vec::new(),
1123        score: 100,
1124        passes: true,
1125    };
1126
1127    for (name, version, licenses, vulns) in components {
1128        for rule in &policy.rules {
1129            match rule {
1130                PolicyRule::BannedLicense { pattern, reason } => {
1131                    for license in licenses {
1132                        if license.to_uppercase().contains(&pattern.to_uppercase()) {
1133                            result.violations.push(PolicyViolation {
1134                                rule_name: rule.name().to_string(),
1135                                severity: rule.severity(),
1136                                component: Some(name.clone()),
1137                                description: format!(
1138                                    "License '{}' matches banned pattern '{}'",
1139                                    license, pattern
1140                                ),
1141                                remediation: format!(
1142                                    "Replace with component using permissive license. {}",
1143                                    reason
1144                                ),
1145                            });
1146                        }
1147                    }
1148                }
1149                PolicyRule::RequiredLicense { allowed, reason } => {
1150                    let has_allowed = licenses
1151                        .iter()
1152                        .any(|l| allowed.iter().any(|a| l.to_uppercase().contains(&a.to_uppercase())));
1153                    if !licenses.is_empty() && !has_allowed {
1154                        result.violations.push(PolicyViolation {
1155                            rule_name: rule.name().to_string(),
1156                            severity: rule.severity(),
1157                            component: Some(name.clone()),
1158                            description: format!(
1159                                "License '{}' not in allowed list",
1160                                licenses.join(", ")
1161                            ),
1162                            remediation: format!(
1163                                "Use component with allowed license: {}. {}",
1164                                allowed.join(", "),
1165                                reason
1166                            ),
1167                        });
1168                    }
1169                }
1170                PolicyRule::BannedComponent { pattern, reason } => {
1171                    if name.to_lowercase().contains(&pattern.to_lowercase()) {
1172                        result.violations.push(PolicyViolation {
1173                            rule_name: rule.name().to_string(),
1174                            severity: rule.severity(),
1175                            component: Some(name.clone()),
1176                            description: format!(
1177                                "Component '{}' matches banned pattern '{}'",
1178                                name, pattern
1179                            ),
1180                            remediation: reason.clone(),
1181                        });
1182                    }
1183                }
1184                PolicyRule::MinimumVersion {
1185                    component_pattern,
1186                    min_version,
1187                    reason,
1188                } => {
1189                    if name.to_lowercase().contains(&component_pattern.to_lowercase()) {
1190                        if let Some(ver) = version {
1191                            if let (Some(current), Some(min)) =
1192                                (parse_version_parts(ver), parse_version_parts(min_version))
1193                            {
1194                                let is_below = current
1195                                    .iter()
1196                                    .zip(min.iter())
1197                                    .any(|(c, m)| c < m);
1198                                if is_below {
1199                                    result.violations.push(PolicyViolation {
1200                                        rule_name: rule.name().to_string(),
1201                                        severity: rule.severity(),
1202                                        component: Some(name.clone()),
1203                                        description: format!(
1204                                            "Version '{}' below minimum '{}'",
1205                                            ver, min_version
1206                                        ),
1207                                        remediation: format!(
1208                                            "Upgrade to version {} or higher. {}",
1209                                            min_version, reason
1210                                        ),
1211                                    });
1212                                }
1213                            }
1214                        }
1215                    }
1216                }
1217                PolicyRule::NoPreRelease { reason } => {
1218                    if let Some(ver) = version {
1219                        if let Some(parts) = parse_version_parts(ver) {
1220                            if parts.first() == Some(&0) {
1221                                result.violations.push(PolicyViolation {
1222                                    rule_name: rule.name().to_string(),
1223                                    severity: rule.severity(),
1224                                    component: Some(name.clone()),
1225                                    description: format!("Pre-release version '{}' (0.x.x)", ver),
1226                                    remediation: format!(
1227                                        "Upgrade to stable version (1.0+). {}",
1228                                        reason
1229                                    ),
1230                                });
1231                            }
1232                        }
1233                    }
1234                }
1235                PolicyRule::MaxVulnerabilitySeverity { max_severity, reason } => {
1236                    let max_rank = severity_to_rank(max_severity);
1237                    for (vuln_id, vuln_sev) in vulns {
1238                        let vuln_rank = severity_to_rank(vuln_sev);
1239                        if vuln_rank > max_rank {
1240                            result.violations.push(PolicyViolation {
1241                                rule_name: rule.name().to_string(),
1242                                severity: PolicySeverity::Critical,
1243                                component: Some(name.clone()),
1244                                description: format!(
1245                                    "{} has {} severity (max allowed: {})",
1246                                    vuln_id, vuln_sev, max_severity
1247                                ),
1248                                remediation: format!(
1249                                    "Remediate {} or upgrade component. {}",
1250                                    vuln_id, reason
1251                                ),
1252                            });
1253                        }
1254                    }
1255                }
1256                PolicyRule::RequireFields { .. } => {
1257                    // This would check SBOM-level fields, not per-component
1258                    // Handled separately
1259                }
1260            }
1261        }
1262    }
1263
1264    // Calculate score
1265    let violation_penalty: u32 = result
1266        .violations
1267        .iter()
1268        .map(|v| match v.severity {
1269            PolicySeverity::Critical => 25,
1270            PolicySeverity::High => 15,
1271            PolicySeverity::Medium => 8,
1272            PolicySeverity::Low => 3,
1273        })
1274        .sum();
1275
1276    result.score = 100u8.saturating_sub(violation_penalty.min(100) as u8);
1277    result.passes = result.count_by_severity(PolicySeverity::Critical) == 0
1278        && result.count_by_severity(PolicySeverity::High) == 0;
1279
1280    result
1281}
1282
1283#[cfg(test)]
1284mod tests {
1285    use super::*;
1286
1287    #[test]
1288    fn test_blast_radius_impact_description() {
1289        let mut blast = BlastRadius::default();
1290        assert_eq!(blast.impact_description(), "No downstream impact");
1291
1292        blast.transitive_dependents.insert("a".to_string());
1293        assert_eq!(blast.impact_description(), "Limited impact");
1294
1295        for i in 0..25 {
1296            blast.transitive_dependents.insert(format!("comp_{}", i));
1297        }
1298        assert_eq!(blast.impact_description(), "Significant impact");
1299    }
1300
1301    #[test]
1302    fn test_risk_indicators_score() {
1303        let mut indicators = RiskIndicators::default();
1304        indicators.vuln_count = 3;
1305        indicators.highest_severity = Some("High".to_string());
1306        indicators.transitive_dependent_count = 15;
1307        indicators.calculate_risk_score();
1308
1309        assert!(indicators.risk_score > 0);
1310        assert!(indicators.risk_level != RiskLevel::Low);
1311    }
1312
1313    #[test]
1314    fn test_license_risk() {
1315        assert_eq!(LicenseRisk::from_license("MIT"), LicenseRisk::Low);
1316        assert_eq!(LicenseRisk::from_license("Apache-2.0"), LicenseRisk::Low);
1317        assert_eq!(LicenseRisk::from_license("LGPL-3.0"), LicenseRisk::Medium);
1318        assert_eq!(LicenseRisk::from_license("GPL-3.0"), LicenseRisk::High);
1319    }
1320
1321    #[test]
1322    fn test_cve_url() {
1323        assert!(cve_url("CVE-2021-44228").contains("nvd.nist.gov"));
1324        assert!(cve_url("GHSA-abcd-1234-efgh").contains("github.com"));
1325        assert!(cve_url("RUSTSEC-2021-0001").contains("rustsec.org"));
1326    }
1327
1328    #[test]
1329    fn test_sanitize_vuln_id_strips_shell_metacharacters() {
1330        // Normal IDs pass through unchanged
1331        assert_eq!(sanitize_vuln_id("CVE-2021-44228"), "CVE-2021-44228");
1332        assert_eq!(sanitize_vuln_id("GHSA-abcd-1234-efgh"), "GHSA-abcd-1234-efgh");
1333
1334        // Shell metacharacters are stripped
1335        assert_eq!(sanitize_vuln_id("CVE-2021&whoami"), "CVE-2021whoami");
1336        assert_eq!(sanitize_vuln_id("CVE|calc.exe"), "CVEcalc.exe");
1337        assert_eq!(sanitize_vuln_id("id;rm -rf /"), "idrm-rf");
1338        assert_eq!(sanitize_vuln_id("$(malicious)"), "malicious");
1339        assert_eq!(sanitize_vuln_id("foo`bar`"), "foobar");
1340    }
1341
1342    #[test]
1343    fn test_cve_url_with_injected_id() {
1344        // Ensure shell metacharacters in vuln IDs don't appear in the URL
1345        let url = cve_url("CVE-2021-44228&calc");
1346        assert!(!url.contains('&'));
1347        // sanitize_vuln_id strips '&', cve_url uppercases CVE IDs
1348        assert!(url.contains("CVE-2021-44228CALC"));
1349    }
1350
1351    #[test]
1352    fn test_is_safe_url() {
1353        assert!(is_safe_url("https://nvd.nist.gov/vuln/detail/CVE-2021-44228"));
1354        assert!(is_safe_url("https://example.com/path?q=1&a=2"));
1355        // Shell injection attempts
1356        assert!(!is_safe_url("https://evil.com\"; rm -rf /"));
1357        assert!(!is_safe_url("https://x.com\nmalicious"));
1358        // Backtick and pipe are not valid URL characters
1359        assert!(!is_safe_url("url`calc`"));
1360        assert!(!is_safe_url("url|cmd"));
1361    }
1362
1363    #[test]
1364    fn test_security_cache_flagging() {
1365        let mut cache = SecurityAnalysisCache::new();
1366
1367        assert!(!cache.is_flagged("comp1"));
1368        cache.flag_component("comp1", "Suspicious activity");
1369        assert!(cache.is_flagged("comp1"));
1370
1371        cache.toggle_flag("comp1", "test");
1372        assert!(!cache.is_flagged("comp1"));
1373    }
1374}