Skip to main content

sbom_tools/diff/changes/
vuln_grouping.rs

1//! Vulnerability grouping by root cause component.
2//!
3//! This module provides functionality to group vulnerabilities by the component
4//! that introduces them, reducing noise and showing the true scope of security issues.
5
6use crate::diff::result::VulnerabilityDetail;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Status of a vulnerability group
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum VulnGroupStatus {
13    /// Newly introduced vulnerabilities
14    Introduced,
15    /// Resolved vulnerabilities
16    Resolved,
17    /// Persistent vulnerabilities (present in both old and new)
18    Persistent,
19}
20
21impl std::fmt::Display for VulnGroupStatus {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            VulnGroupStatus::Introduced => write!(f, "Introduced"),
25            VulnGroupStatus::Resolved => write!(f, "Resolved"),
26            VulnGroupStatus::Persistent => write!(f, "Persistent"),
27        }
28    }
29}
30
31/// A group of vulnerabilities sharing the same root cause component
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct VulnerabilityGroup {
34    /// Root cause component ID
35    pub component_id: String,
36    /// Component name
37    pub component_name: String,
38    /// Component version (if available)
39    pub component_version: Option<String>,
40    /// Vulnerabilities in this group
41    pub vulnerabilities: Vec<VulnerabilityDetail>,
42    /// Maximum severity in the group
43    pub max_severity: String,
44    /// Maximum CVSS score in the group
45    pub max_cvss: Option<f32>,
46    /// Count by severity level
47    pub severity_counts: HashMap<String, usize>,
48    /// Group status (Introduced, Resolved, Persistent)
49    pub status: VulnGroupStatus,
50    /// Whether any vulnerability is in KEV catalog
51    pub has_kev: bool,
52    /// Whether any vulnerability is ransomware-related
53    pub has_ransomware_kev: bool,
54}
55
56impl VulnerabilityGroup {
57    /// Create a new empty group for a component
58    pub fn new(
59        component_id: String,
60        component_name: String,
61        status: VulnGroupStatus,
62    ) -> Self {
63        Self {
64            component_id,
65            component_name,
66            component_version: None,
67            vulnerabilities: Vec::new(),
68            max_severity: "Unknown".to_string(),
69            max_cvss: None,
70            severity_counts: HashMap::new(),
71            status,
72            has_kev: false,
73            has_ransomware_kev: false,
74        }
75    }
76
77    /// Add a vulnerability to the group
78    pub fn add_vulnerability(&mut self, vuln: VulnerabilityDetail) {
79        // Update severity counts
80        *self.severity_counts.entry(vuln.severity.clone()).or_insert(0) += 1;
81
82        // Update max severity (priority: Critical > High > Medium > Low > Unknown)
83        let vuln_priority = severity_priority(&vuln.severity);
84        let current_priority = severity_priority(&self.max_severity);
85        if vuln_priority < current_priority {
86            self.max_severity = vuln.severity.clone();
87        }
88
89        // Update max CVSS
90        if let Some(score) = vuln.cvss_score {
91            self.max_cvss = Some(self.max_cvss.map(|c| c.max(score)).unwrap_or(score));
92        }
93
94        // Update version from first vulnerability with version
95        if self.component_version.is_none() {
96            self.component_version = vuln.version.clone();
97        }
98
99        self.vulnerabilities.push(vuln);
100    }
101
102    /// Get total vulnerability count
103    pub fn vuln_count(&self) -> usize {
104        self.vulnerabilities.len()
105    }
106
107    /// Check if group has any critical vulnerabilities
108    pub fn has_critical(&self) -> bool {
109        self.severity_counts.get("Critical").copied().unwrap_or(0) > 0
110    }
111
112    /// Check if group has any high severity vulnerabilities
113    pub fn has_high(&self) -> bool {
114        self.severity_counts.get("High").copied().unwrap_or(0) > 0
115    }
116
117    /// Get summary line for display
118    pub fn summary_line(&self) -> String {
119        let version_str = self
120            .component_version
121            .as_ref()
122            .map(|v| format!("@{}", v))
123            .unwrap_or_default();
124
125        let severity_badges: Vec<String> = ["Critical", "High", "Medium", "Low"]
126            .iter()
127            .filter_map(|sev| {
128                self.severity_counts.get(*sev).and_then(|&count| {
129                    if count > 0 {
130                        Some(format!("{}:{}", &sev[..1], count))
131                    } else {
132                        None
133                    }
134                })
135            })
136            .collect();
137
138        format!(
139            "{}{}: {} CVEs [{}]",
140            self.component_name,
141            version_str,
142            self.vuln_count(),
143            severity_badges.join(" ")
144        )
145    }
146}
147
148/// Get priority value for severity (lower = more severe)
149fn severity_priority(severity: &str) -> u8 {
150    match severity.to_lowercase().as_str() {
151        "critical" => 0,
152        "high" => 1,
153        "medium" => 2,
154        "low" => 3,
155        "info" => 4,
156        "none" => 5,
157        _ => 6,
158    }
159}
160
161/// Group vulnerabilities by component
162pub fn group_vulnerabilities(
163    vulns: &[VulnerabilityDetail],
164    status: VulnGroupStatus,
165) -> Vec<VulnerabilityGroup> {
166    let mut groups: HashMap<String, VulnerabilityGroup> = HashMap::new();
167
168    for vuln in vulns {
169        let group = groups
170            .entry(vuln.component_id.clone())
171            .or_insert_with(|| {
172                VulnerabilityGroup::new(
173                    vuln.component_id.clone(),
174                    vuln.component_name.clone(),
175                    status,
176                )
177            });
178
179        group.add_vulnerability(vuln.clone());
180    }
181
182    // Sort groups by severity (most severe first), then by count
183    let mut result: Vec<_> = groups.into_values().collect();
184    result.sort_by(|a, b| {
185        let sev_cmp = severity_priority(&a.max_severity).cmp(&severity_priority(&b.max_severity));
186        if sev_cmp == std::cmp::Ordering::Equal {
187            b.vuln_count().cmp(&a.vuln_count())
188        } else {
189            sev_cmp
190        }
191    });
192
193    result
194}
195
196/// Grouped view of vulnerability changes
197#[derive(Debug, Clone, Default, Serialize, Deserialize)]
198pub struct VulnerabilityGroupedView {
199    /// Groups of introduced vulnerabilities
200    pub introduced_groups: Vec<VulnerabilityGroup>,
201    /// Groups of resolved vulnerabilities
202    pub resolved_groups: Vec<VulnerabilityGroup>,
203    /// Groups of persistent vulnerabilities
204    pub persistent_groups: Vec<VulnerabilityGroup>,
205}
206
207impl VulnerabilityGroupedView {
208    /// Create grouped view from vulnerability lists
209    pub fn from_changes(
210        introduced: &[VulnerabilityDetail],
211        resolved: &[VulnerabilityDetail],
212        persistent: &[VulnerabilityDetail],
213    ) -> Self {
214        Self {
215            introduced_groups: group_vulnerabilities(introduced, VulnGroupStatus::Introduced),
216            resolved_groups: group_vulnerabilities(resolved, VulnGroupStatus::Resolved),
217            persistent_groups: group_vulnerabilities(persistent, VulnGroupStatus::Persistent),
218        }
219    }
220
221    /// Get total group count
222    pub fn total_groups(&self) -> usize {
223        self.introduced_groups.len() + self.resolved_groups.len() + self.persistent_groups.len()
224    }
225
226    /// Get total vulnerability count across all groups
227    pub fn total_vulns(&self) -> usize {
228        self.introduced_groups.iter().map(|g| g.vuln_count()).sum::<usize>()
229            + self.resolved_groups.iter().map(|g| g.vuln_count()).sum::<usize>()
230            + self.persistent_groups.iter().map(|g| g.vuln_count()).sum::<usize>()
231    }
232
233    /// Check if any group has KEV vulnerabilities
234    pub fn has_any_kev(&self) -> bool {
235        self.introduced_groups.iter().any(|g| g.has_kev)
236            || self.resolved_groups.iter().any(|g| g.has_kev)
237            || self.persistent_groups.iter().any(|g| g.has_kev)
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    fn make_vuln(id: &str, component_id: &str, severity: &str) -> VulnerabilityDetail {
246        VulnerabilityDetail {
247            id: id.to_string(),
248            source: "OSV".to_string(),
249            severity: severity.to_string(),
250            cvss_score: None,
251            component_id: component_id.to_string(),
252            component_canonical_id: None,
253            component_ref: None,
254            component_name: format!("{}-pkg", component_id),
255            version: Some("1.0.0".to_string()),
256            description: None,
257            remediation: None,
258            is_kev: false,
259            cwes: Vec::new(),
260            component_depth: None,
261            published_date: None,
262            kev_due_date: None,
263            days_since_published: None,
264            days_until_due: None,
265            vex_state: None,
266        }
267    }
268
269    #[test]
270    fn test_group_vulnerabilities() {
271        let vulns = vec![
272            make_vuln("CVE-2024-0001", "lodash", "Critical"),
273            make_vuln("CVE-2024-0002", "lodash", "High"),
274            make_vuln("CVE-2024-0003", "lodash", "High"),
275            make_vuln("CVE-2024-0004", "express", "Medium"),
276        ];
277
278        let groups = group_vulnerabilities(&vulns, VulnGroupStatus::Introduced);
279
280        assert_eq!(groups.len(), 2);
281
282        // lodash should be first (Critical severity)
283        assert_eq!(groups[0].component_id, "lodash");
284        assert_eq!(groups[0].vuln_count(), 3);
285        assert_eq!(groups[0].max_severity, "Critical");
286        assert_eq!(groups[0].severity_counts.get("Critical"), Some(&1));
287        assert_eq!(groups[0].severity_counts.get("High"), Some(&2));
288
289        // express should be second
290        assert_eq!(groups[1].component_id, "express");
291        assert_eq!(groups[1].vuln_count(), 1);
292    }
293
294    #[test]
295    fn test_grouped_view() {
296        let introduced = vec![
297            make_vuln("CVE-2024-0001", "lodash", "High"),
298            make_vuln("CVE-2024-0002", "lodash", "Medium"),
299        ];
300        let resolved = vec![make_vuln("CVE-2024-0003", "old-dep", "Critical")];
301        let persistent = vec![];
302
303        let view = VulnerabilityGroupedView::from_changes(&introduced, &resolved, &persistent);
304
305        assert_eq!(view.total_groups(), 2);
306        assert_eq!(view.total_vulns(), 3);
307        assert_eq!(view.introduced_groups.len(), 1);
308        assert_eq!(view.resolved_groups.len(), 1);
309    }
310
311    #[test]
312    fn test_summary_line() {
313        let mut group = VulnerabilityGroup::new(
314            "lodash".to_string(),
315            "lodash".to_string(),
316            VulnGroupStatus::Introduced,
317        );
318        group.add_vulnerability(make_vuln("CVE-1", "lodash", "Critical"));
319        group.add_vulnerability(make_vuln("CVE-2", "lodash", "High"));
320        group.add_vulnerability(make_vuln("CVE-3", "lodash", "High"));
321
322        let summary = group.summary_line();
323        assert!(summary.contains("lodash"));
324        assert!(summary.contains("3 CVEs"));
325        assert!(summary.contains("C:1"));
326        assert!(summary.contains("H:2"));
327    }
328}