Skip to main content

sbom_tools/diff/changes/
vuln_grouping.rs

1//! Vulnerability grouping by root cause component.
2//!
3//! This module provides functionality to group vulnerabilities by the component
4//! that introduces them, reducing noise and showing the true scope of security issues.
5
6use crate::diff::result::VulnerabilityDetail;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Status of a vulnerability group
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum VulnGroupStatus {
13    /// Newly introduced vulnerabilities
14    Introduced,
15    /// Resolved vulnerabilities
16    Resolved,
17    /// Persistent vulnerabilities (present in both old and new)
18    Persistent,
19}
20
21impl std::fmt::Display for VulnGroupStatus {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Self::Introduced => write!(f, "Introduced"),
25            Self::Resolved => write!(f, "Resolved"),
26            Self::Persistent => write!(f, "Persistent"),
27        }
28    }
29}
30
31/// A group of vulnerabilities sharing the same root cause component
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct VulnerabilityGroup {
34    /// Root cause component ID
35    pub component_id: String,
36    /// Component name
37    pub component_name: String,
38    /// Component version (if available)
39    pub component_version: Option<String>,
40    /// Vulnerabilities in this group
41    pub vulnerabilities: Vec<VulnerabilityDetail>,
42    /// Maximum severity in the group
43    pub max_severity: String,
44    /// Maximum CVSS score in the group
45    pub max_cvss: Option<f32>,
46    /// Count by severity level
47    pub severity_counts: HashMap<String, usize>,
48    /// Group status (Introduced, Resolved, Persistent)
49    pub status: VulnGroupStatus,
50    /// Whether any vulnerability is in KEV catalog
51    pub has_kev: bool,
52    /// Whether any vulnerability is ransomware-related
53    pub has_ransomware_kev: bool,
54}
55
56impl VulnerabilityGroup {
57    /// Create a new empty group for a component
58    #[must_use]
59    pub fn new(component_id: String, component_name: String, status: VulnGroupStatus) -> Self {
60        Self {
61            component_id,
62            component_name,
63            component_version: None,
64            vulnerabilities: Vec::new(),
65            max_severity: "Unknown".to_string(),
66            max_cvss: None,
67            severity_counts: HashMap::new(),
68            status,
69            has_kev: false,
70            has_ransomware_kev: false,
71        }
72    }
73
74    /// Add a vulnerability to the group
75    pub fn add_vulnerability(&mut self, vuln: VulnerabilityDetail) {
76        // Update severity counts
77        *self
78            .severity_counts
79            .entry(vuln.severity.clone())
80            .or_insert(0) += 1;
81
82        // Update max severity (priority: Critical > High > Medium > Low > Unknown)
83        let vuln_priority = severity_priority(&vuln.severity);
84        let current_priority = severity_priority(&self.max_severity);
85        if vuln_priority < current_priority {
86            self.max_severity.clone_from(&vuln.severity);
87        }
88
89        // Update max CVSS
90        if let Some(score) = vuln.cvss_score {
91            self.max_cvss = Some(self.max_cvss.map_or(score, |c| c.max(score)));
92        }
93
94        // Update version from first vulnerability with version
95        if self.component_version.is_none() {
96            self.component_version.clone_from(&vuln.version);
97        }
98
99        self.vulnerabilities.push(vuln);
100    }
101
102    /// Get total vulnerability count
103    #[must_use]
104    pub fn vuln_count(&self) -> usize {
105        self.vulnerabilities.len()
106    }
107
108    /// Check if group has any critical vulnerabilities
109    #[must_use]
110    pub fn has_critical(&self) -> bool {
111        self.severity_counts.get("Critical").copied().unwrap_or(0) > 0
112    }
113
114    /// Check if group has any high severity vulnerabilities
115    #[must_use]
116    pub fn has_high(&self) -> bool {
117        self.severity_counts.get("High").copied().unwrap_or(0) > 0
118    }
119
120    /// Get summary line for display
121    #[must_use]
122    pub fn summary_line(&self) -> String {
123        let version_str = self
124            .component_version
125            .as_ref()
126            .map(|v| format!("@{v}"))
127            .unwrap_or_default();
128
129        let severity_badges: Vec<String> = ["Critical", "High", "Medium", "Low"]
130            .iter()
131            .filter_map(|sev| {
132                self.severity_counts.get(*sev).and_then(|&count| {
133                    if count > 0 {
134                        Some(format!("{}:{}", &sev[..1], count))
135                    } else {
136                        None
137                    }
138                })
139            })
140            .collect();
141
142        format!(
143            "{}{}: {} CVEs [{}]",
144            self.component_name,
145            version_str,
146            self.vuln_count(),
147            severity_badges.join(" ")
148        )
149    }
150}
151
152/// Get priority value for severity (lower = more severe)
153fn severity_priority(severity: &str) -> u8 {
154    match severity.to_lowercase().as_str() {
155        "critical" => 0,
156        "high" => 1,
157        "medium" => 2,
158        "low" => 3,
159        "info" => 4,
160        "none" => 5,
161        _ => 6,
162    }
163}
164
165/// Group vulnerabilities by component
166#[must_use]
167pub fn group_vulnerabilities(
168    vulns: &[VulnerabilityDetail],
169    status: VulnGroupStatus,
170) -> Vec<VulnerabilityGroup> {
171    let mut groups: HashMap<String, VulnerabilityGroup> = HashMap::new();
172
173    for vuln in vulns {
174        let group = groups.entry(vuln.component_id.clone()).or_insert_with(|| {
175            VulnerabilityGroup::new(
176                vuln.component_id.clone(),
177                vuln.component_name.clone(),
178                status,
179            )
180        });
181
182        group.add_vulnerability(vuln.clone());
183    }
184
185    // Sort groups by severity (most severe first), then by count
186    let mut result: Vec<_> = groups.into_values().collect();
187    result.sort_by(|a, b| {
188        let sev_cmp = severity_priority(&a.max_severity).cmp(&severity_priority(&b.max_severity));
189        if sev_cmp == std::cmp::Ordering::Equal {
190            b.vuln_count().cmp(&a.vuln_count())
191        } else {
192            sev_cmp
193        }
194    });
195
196    result
197}
198
199/// Grouped view of vulnerability changes
200#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201pub struct VulnerabilityGroupedView {
202    /// Groups of introduced vulnerabilities
203    pub introduced_groups: Vec<VulnerabilityGroup>,
204    /// Groups of resolved vulnerabilities
205    pub resolved_groups: Vec<VulnerabilityGroup>,
206    /// Groups of persistent vulnerabilities
207    pub persistent_groups: Vec<VulnerabilityGroup>,
208}
209
210impl VulnerabilityGroupedView {
211    /// Create grouped view from vulnerability lists
212    #[must_use]
213    pub fn from_changes(
214        introduced: &[VulnerabilityDetail],
215        resolved: &[VulnerabilityDetail],
216        persistent: &[VulnerabilityDetail],
217    ) -> Self {
218        Self {
219            introduced_groups: group_vulnerabilities(introduced, VulnGroupStatus::Introduced),
220            resolved_groups: group_vulnerabilities(resolved, VulnGroupStatus::Resolved),
221            persistent_groups: group_vulnerabilities(persistent, VulnGroupStatus::Persistent),
222        }
223    }
224
225    /// Get total group count
226    #[must_use]
227    pub fn total_groups(&self) -> usize {
228        self.introduced_groups.len() + self.resolved_groups.len() + self.persistent_groups.len()
229    }
230
231    /// Get total vulnerability count across all groups
232    pub fn total_vulns(&self) -> usize {
233        self.introduced_groups
234            .iter()
235            .map(VulnerabilityGroup::vuln_count)
236            .sum::<usize>()
237            + self
238                .resolved_groups
239                .iter()
240                .map(VulnerabilityGroup::vuln_count)
241                .sum::<usize>()
242            + self
243                .persistent_groups
244                .iter()
245                .map(VulnerabilityGroup::vuln_count)
246                .sum::<usize>()
247    }
248
249    /// Check if any group has KEV vulnerabilities
250    #[must_use]
251    pub fn has_any_kev(&self) -> bool {
252        self.introduced_groups.iter().any(|g| g.has_kev)
253            || self.resolved_groups.iter().any(|g| g.has_kev)
254            || self.persistent_groups.iter().any(|g| g.has_kev)
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    fn make_vuln(id: &str, component_id: &str, severity: &str) -> VulnerabilityDetail {
263        VulnerabilityDetail {
264            id: id.to_string(),
265            source: "OSV".to_string(),
266            severity: severity.to_string(),
267            cvss_score: None,
268            component_id: component_id.to_string(),
269            component_canonical_id: None,
270            component_ref: None,
271            component_name: format!("{}-pkg", component_id),
272            version: Some("1.0.0".to_string()),
273            description: None,
274            remediation: None,
275            is_kev: false,
276            cwes: Vec::new(),
277            component_depth: None,
278            published_date: None,
279            kev_due_date: None,
280            days_since_published: None,
281            days_until_due: None,
282            vex_state: None,
283            vex_justification: None,
284            vex_impact_statement: None,
285        }
286    }
287
288    #[test]
289    fn test_group_vulnerabilities() {
290        let vulns = vec![
291            make_vuln("CVE-2024-0001", "lodash", "Critical"),
292            make_vuln("CVE-2024-0002", "lodash", "High"),
293            make_vuln("CVE-2024-0003", "lodash", "High"),
294            make_vuln("CVE-2024-0004", "express", "Medium"),
295        ];
296
297        let groups = group_vulnerabilities(&vulns, VulnGroupStatus::Introduced);
298
299        assert_eq!(groups.len(), 2);
300
301        // lodash should be first (Critical severity)
302        assert_eq!(groups[0].component_id, "lodash");
303        assert_eq!(groups[0].vuln_count(), 3);
304        assert_eq!(groups[0].max_severity, "Critical");
305        assert_eq!(groups[0].severity_counts.get("Critical"), Some(&1));
306        assert_eq!(groups[0].severity_counts.get("High"), Some(&2));
307
308        // express should be second
309        assert_eq!(groups[1].component_id, "express");
310        assert_eq!(groups[1].vuln_count(), 1);
311    }
312
313    #[test]
314    fn test_grouped_view() {
315        let introduced = vec![
316            make_vuln("CVE-2024-0001", "lodash", "High"),
317            make_vuln("CVE-2024-0002", "lodash", "Medium"),
318        ];
319        let resolved = vec![make_vuln("CVE-2024-0003", "old-dep", "Critical")];
320        let persistent = vec![];
321
322        let view = VulnerabilityGroupedView::from_changes(&introduced, &resolved, &persistent);
323
324        assert_eq!(view.total_groups(), 2);
325        assert_eq!(view.total_vulns(), 3);
326        assert_eq!(view.introduced_groups.len(), 1);
327        assert_eq!(view.resolved_groups.len(), 1);
328    }
329
330    #[test]
331    fn test_summary_line() {
332        let mut group = VulnerabilityGroup::new(
333            "lodash".to_string(),
334            "lodash".to_string(),
335            VulnGroupStatus::Introduced,
336        );
337        group.add_vulnerability(make_vuln("CVE-1", "lodash", "Critical"));
338        group.add_vulnerability(make_vuln("CVE-2", "lodash", "High"));
339        group.add_vulnerability(make_vuln("CVE-3", "lodash", "High"));
340
341        let summary = group.summary_line();
342        assert!(summary.contains("lodash"));
343        assert!(summary.contains("3 CVEs"));
344        assert!(summary.contains("C:1"));
345        assert!(summary.contains("H:2"));
346    }
347}