Skip to main content

sbom_tools/diff/changes/
vuln_grouping.rs

1//! Vulnerability grouping by root cause component.
2//!
3//! This module provides functionality to group vulnerabilities by the component
4//! that introduces them, reducing noise and showing the true scope of security issues.
5
6use crate::diff::result::VulnerabilityDetail;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Status of a vulnerability group
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum VulnGroupStatus {
13    /// Newly introduced vulnerabilities
14    Introduced,
15    /// Resolved vulnerabilities
16    Resolved,
17    /// Persistent vulnerabilities (present in both old and new)
18    Persistent,
19}
20
21impl std::fmt::Display for VulnGroupStatus {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Self::Introduced => write!(f, "Introduced"),
25            Self::Resolved => write!(f, "Resolved"),
26            Self::Persistent => write!(f, "Persistent"),
27        }
28    }
29}
30
31/// A group of vulnerabilities sharing the same root cause component
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct VulnerabilityGroup {
34    /// Root cause component ID
35    pub component_id: String,
36    /// Component name
37    pub component_name: String,
38    /// Component version (if available)
39    pub component_version: Option<String>,
40    /// Vulnerabilities in this group
41    pub vulnerabilities: Vec<VulnerabilityDetail>,
42    /// Maximum severity in the group
43    pub max_severity: String,
44    /// Maximum CVSS score in the group
45    pub max_cvss: Option<f32>,
46    /// Count by severity level
47    pub severity_counts: HashMap<String, usize>,
48    /// Group status (Introduced, Resolved, Persistent)
49    pub status: VulnGroupStatus,
50    /// Whether any vulnerability is in KEV catalog
51    pub has_kev: bool,
52    /// Whether any vulnerability is ransomware-related
53    pub has_ransomware_kev: bool,
54}
55
56impl VulnerabilityGroup {
57    /// Create a new empty group for a component
58    #[must_use] 
59    pub fn new(
60        component_id: String,
61        component_name: String,
62        status: VulnGroupStatus,
63    ) -> Self {
64        Self {
65            component_id,
66            component_name,
67            component_version: None,
68            vulnerabilities: Vec::new(),
69            max_severity: "Unknown".to_string(),
70            max_cvss: None,
71            severity_counts: HashMap::new(),
72            status,
73            has_kev: false,
74            has_ransomware_kev: false,
75        }
76    }
77
78    /// Add a vulnerability to the group
79    pub fn add_vulnerability(&mut self, vuln: VulnerabilityDetail) {
80        // Update severity counts
81        *self.severity_counts.entry(vuln.severity.clone()).or_insert(0) += 1;
82
83        // Update max severity (priority: Critical > High > Medium > Low > Unknown)
84        let vuln_priority = severity_priority(&vuln.severity);
85        let current_priority = severity_priority(&self.max_severity);
86        if vuln_priority < current_priority {
87            self.max_severity.clone_from(&vuln.severity);
88        }
89
90        // Update max CVSS
91        if let Some(score) = vuln.cvss_score {
92            self.max_cvss = Some(self.max_cvss.map_or(score, |c| c.max(score)));
93        }
94
95        // Update version from first vulnerability with version
96        if self.component_version.is_none() {
97            self.component_version.clone_from(&vuln.version);
98        }
99
100        self.vulnerabilities.push(vuln);
101    }
102
103    /// Get total vulnerability count
104    #[must_use] 
105    pub fn vuln_count(&self) -> usize {
106        self.vulnerabilities.len()
107    }
108
109    /// Check if group has any critical vulnerabilities
110    #[must_use] 
111    pub fn has_critical(&self) -> bool {
112        self.severity_counts.get("Critical").copied().unwrap_or(0) > 0
113    }
114
115    /// Check if group has any high severity vulnerabilities
116    #[must_use] 
117    pub fn has_high(&self) -> bool {
118        self.severity_counts.get("High").copied().unwrap_or(0) > 0
119    }
120
121    /// Get summary line for display
122    #[must_use] 
123    pub fn summary_line(&self) -> String {
124        let version_str = self
125            .component_version
126            .as_ref()
127            .map(|v| format!("@{v}"))
128            .unwrap_or_default();
129
130        let severity_badges: Vec<String> = ["Critical", "High", "Medium", "Low"]
131            .iter()
132            .filter_map(|sev| {
133                self.severity_counts.get(*sev).and_then(|&count| {
134                    if count > 0 {
135                        Some(format!("{}:{}", &sev[..1], count))
136                    } else {
137                        None
138                    }
139                })
140            })
141            .collect();
142
143        format!(
144            "{}{}: {} CVEs [{}]",
145            self.component_name,
146            version_str,
147            self.vuln_count(),
148            severity_badges.join(" ")
149        )
150    }
151}
152
153/// Get priority value for severity (lower = more severe)
154fn severity_priority(severity: &str) -> u8 {
155    match severity.to_lowercase().as_str() {
156        "critical" => 0,
157        "high" => 1,
158        "medium" => 2,
159        "low" => 3,
160        "info" => 4,
161        "none" => 5,
162        _ => 6,
163    }
164}
165
166/// Group vulnerabilities by component
167#[must_use] 
168pub fn group_vulnerabilities(
169    vulns: &[VulnerabilityDetail],
170    status: VulnGroupStatus,
171) -> Vec<VulnerabilityGroup> {
172    let mut groups: HashMap<String, VulnerabilityGroup> = HashMap::new();
173
174    for vuln in vulns {
175        let group = groups
176            .entry(vuln.component_id.clone())
177            .or_insert_with(|| {
178                VulnerabilityGroup::new(
179                    vuln.component_id.clone(),
180                    vuln.component_name.clone(),
181                    status,
182                )
183            });
184
185        group.add_vulnerability(vuln.clone());
186    }
187
188    // Sort groups by severity (most severe first), then by count
189    let mut result: Vec<_> = groups.into_values().collect();
190    result.sort_by(|a, b| {
191        let sev_cmp = severity_priority(&a.max_severity).cmp(&severity_priority(&b.max_severity));
192        if sev_cmp == std::cmp::Ordering::Equal {
193            b.vuln_count().cmp(&a.vuln_count())
194        } else {
195            sev_cmp
196        }
197    });
198
199    result
200}
201
202/// Grouped view of vulnerability changes
203#[derive(Debug, Clone, Default, Serialize, Deserialize)]
204pub struct VulnerabilityGroupedView {
205    /// Groups of introduced vulnerabilities
206    pub introduced_groups: Vec<VulnerabilityGroup>,
207    /// Groups of resolved vulnerabilities
208    pub resolved_groups: Vec<VulnerabilityGroup>,
209    /// Groups of persistent vulnerabilities
210    pub persistent_groups: Vec<VulnerabilityGroup>,
211}
212
213impl VulnerabilityGroupedView {
214    /// Create grouped view from vulnerability lists
215    #[must_use] 
216    pub fn from_changes(
217        introduced: &[VulnerabilityDetail],
218        resolved: &[VulnerabilityDetail],
219        persistent: &[VulnerabilityDetail],
220    ) -> Self {
221        Self {
222            introduced_groups: group_vulnerabilities(introduced, VulnGroupStatus::Introduced),
223            resolved_groups: group_vulnerabilities(resolved, VulnGroupStatus::Resolved),
224            persistent_groups: group_vulnerabilities(persistent, VulnGroupStatus::Persistent),
225        }
226    }
227
228    /// Get total group count
229    #[must_use] 
230    pub fn total_groups(&self) -> usize {
231        self.introduced_groups.len() + self.resolved_groups.len() + self.persistent_groups.len()
232    }
233
234    /// Get total vulnerability count across all groups
235    pub fn total_vulns(&self) -> usize {
236        self.introduced_groups.iter().map(VulnerabilityGroup::vuln_count).sum::<usize>()
237            + self.resolved_groups.iter().map(VulnerabilityGroup::vuln_count).sum::<usize>()
238            + self.persistent_groups.iter().map(VulnerabilityGroup::vuln_count).sum::<usize>()
239    }
240
241    /// Check if any group has KEV vulnerabilities
242    #[must_use] 
243    pub fn has_any_kev(&self) -> bool {
244        self.introduced_groups.iter().any(|g| g.has_kev)
245            || self.resolved_groups.iter().any(|g| g.has_kev)
246            || self.persistent_groups.iter().any(|g| g.has_kev)
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    fn make_vuln(id: &str, component_id: &str, severity: &str) -> VulnerabilityDetail {
255        VulnerabilityDetail {
256            id: id.to_string(),
257            source: "OSV".to_string(),
258            severity: severity.to_string(),
259            cvss_score: None,
260            component_id: component_id.to_string(),
261            component_canonical_id: None,
262            component_ref: None,
263            component_name: format!("{}-pkg", component_id),
264            version: Some("1.0.0".to_string()),
265            description: None,
266            remediation: None,
267            is_kev: false,
268            cwes: Vec::new(),
269            component_depth: None,
270            published_date: None,
271            kev_due_date: None,
272            days_since_published: None,
273            days_until_due: None,
274            vex_state: None,
275            vex_justification: None,
276            vex_impact_statement: None,
277        }
278    }
279
280    #[test]
281    fn test_group_vulnerabilities() {
282        let vulns = vec![
283            make_vuln("CVE-2024-0001", "lodash", "Critical"),
284            make_vuln("CVE-2024-0002", "lodash", "High"),
285            make_vuln("CVE-2024-0003", "lodash", "High"),
286            make_vuln("CVE-2024-0004", "express", "Medium"),
287        ];
288
289        let groups = group_vulnerabilities(&vulns, VulnGroupStatus::Introduced);
290
291        assert_eq!(groups.len(), 2);
292
293        // lodash should be first (Critical severity)
294        assert_eq!(groups[0].component_id, "lodash");
295        assert_eq!(groups[0].vuln_count(), 3);
296        assert_eq!(groups[0].max_severity, "Critical");
297        assert_eq!(groups[0].severity_counts.get("Critical"), Some(&1));
298        assert_eq!(groups[0].severity_counts.get("High"), Some(&2));
299
300        // express should be second
301        assert_eq!(groups[1].component_id, "express");
302        assert_eq!(groups[1].vuln_count(), 1);
303    }
304
305    #[test]
306    fn test_grouped_view() {
307        let introduced = vec![
308            make_vuln("CVE-2024-0001", "lodash", "High"),
309            make_vuln("CVE-2024-0002", "lodash", "Medium"),
310        ];
311        let resolved = vec![make_vuln("CVE-2024-0003", "old-dep", "Critical")];
312        let persistent = vec![];
313
314        let view = VulnerabilityGroupedView::from_changes(&introduced, &resolved, &persistent);
315
316        assert_eq!(view.total_groups(), 2);
317        assert_eq!(view.total_vulns(), 3);
318        assert_eq!(view.introduced_groups.len(), 1);
319        assert_eq!(view.resolved_groups.len(), 1);
320    }
321
322    #[test]
323    fn test_summary_line() {
324        let mut group = VulnerabilityGroup::new(
325            "lodash".to_string(),
326            "lodash".to_string(),
327            VulnGroupStatus::Introduced,
328        );
329        group.add_vulnerability(make_vuln("CVE-1", "lodash", "Critical"));
330        group.add_vulnerability(make_vuln("CVE-2", "lodash", "High"));
331        group.add_vulnerability(make_vuln("CVE-3", "lodash", "High"));
332
333        let summary = group.summary_line();
334        assert!(summary.contains("lodash"));
335        assert!(summary.contains("3 CVEs"));
336        assert!(summary.contains("C:1"));
337        assert!(summary.contains("H:2"));
338    }
339}