Skip to main content

sbom_tools/diff/changes/
vuln_grouping.rs

1//! Vulnerability grouping by root cause component.
2//!
3//! This module provides functionality to group vulnerabilities by the component
4//! that introduces them, reducing noise and showing the true scope of security issues.
5
6use crate::diff::result::VulnerabilityDetail;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Status of a vulnerability group
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum VulnGroupStatus {
13    /// Newly introduced vulnerabilities
14    Introduced,
15    /// Resolved vulnerabilities
16    Resolved,
17    /// Persistent vulnerabilities (present in both old and new)
18    Persistent,
19}
20
21impl std::fmt::Display for VulnGroupStatus {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Self::Introduced => write!(f, "Introduced"),
25            Self::Resolved => write!(f, "Resolved"),
26            Self::Persistent => write!(f, "Persistent"),
27        }
28    }
29}
30
31/// A group of vulnerabilities sharing the same root cause component
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct VulnerabilityGroup {
34    /// Root cause component ID
35    pub component_id: String,
36    /// Component name
37    pub component_name: String,
38    /// Component version (if available)
39    pub component_version: Option<String>,
40    /// Vulnerabilities in this group
41    pub vulnerabilities: Vec<VulnerabilityDetail>,
42    /// Maximum severity in the group
43    pub max_severity: String,
44    /// Maximum CVSS score in the group
45    pub max_cvss: Option<f32>,
46    /// Count by severity level
47    pub severity_counts: HashMap<String, usize>,
48    /// Group status (Introduced, Resolved, Persistent)
49    pub status: VulnGroupStatus,
50    /// Whether any vulnerability is in KEV catalog
51    pub has_kev: bool,
52    /// Whether any vulnerability is ransomware-related
53    pub has_ransomware_kev: bool,
54}
55
56impl VulnerabilityGroup {
57    /// Create a new empty group for a component
58    #[must_use]
59    pub fn new(component_id: String, component_name: String, status: VulnGroupStatus) -> Self {
60        Self {
61            component_id,
62            component_name,
63            component_version: None,
64            vulnerabilities: Vec::new(),
65            max_severity: "Unknown".to_string(),
66            max_cvss: None,
67            severity_counts: HashMap::new(),
68            status,
69            has_kev: false,
70            has_ransomware_kev: false,
71        }
72    }
73
74    /// Add a vulnerability to the group
75    pub fn add_vulnerability(&mut self, vuln: VulnerabilityDetail) {
76        // Update severity counts
77        *self
78            .severity_counts
79            .entry(vuln.severity.clone())
80            .or_insert(0) += 1;
81
82        // Update max severity (priority: Critical > High > Medium > Low > Unknown)
83        let vuln_priority = severity_priority(&vuln.severity);
84        let current_priority = severity_priority(&self.max_severity);
85        if vuln_priority < current_priority {
86            self.max_severity.clone_from(&vuln.severity);
87        }
88
89        // Update max CVSS
90        if let Some(score) = vuln.cvss_score {
91            self.max_cvss = Some(self.max_cvss.map_or(score, |c| c.max(score)));
92        }
93
94        // Update version from first vulnerability with version
95        if self.component_version.is_none() {
96            self.component_version.clone_from(&vuln.version);
97        }
98
99        // Propagate the KEV (actively exploited) flag to the group.
100        if vuln.is_kev {
101            self.has_kev = true;
102        }
103
104        self.vulnerabilities.push(vuln);
105    }
106
107    /// Get total vulnerability count
108    #[must_use]
109    pub fn vuln_count(&self) -> usize {
110        self.vulnerabilities.len()
111    }
112
113    /// Check if group has any critical vulnerabilities
114    #[must_use]
115    pub fn has_critical(&self) -> bool {
116        self.severity_counts.get("Critical").copied().unwrap_or(0) > 0
117    }
118
119    /// Check if group has any high severity vulnerabilities
120    #[must_use]
121    pub fn has_high(&self) -> bool {
122        self.severity_counts.get("High").copied().unwrap_or(0) > 0
123    }
124
125    /// Get summary line for display
126    #[must_use]
127    pub fn summary_line(&self) -> String {
128        let version_str = self
129            .component_version
130            .as_ref()
131            .map(|v| format!("@{v}"))
132            .unwrap_or_default();
133
134        let severity_badges: Vec<String> = ["Critical", "High", "Medium", "Low"]
135            .iter()
136            .filter_map(|sev| {
137                self.severity_counts.get(*sev).and_then(|&count| {
138                    if count > 0 {
139                        Some(format!("{}:{}", &sev[..1], count))
140                    } else {
141                        None
142                    }
143                })
144            })
145            .collect();
146
147        format!(
148            "{}{}: {} CVEs [{}]",
149            self.component_name,
150            version_str,
151            self.vuln_count(),
152            severity_badges.join(" ")
153        )
154    }
155}
156
157/// Get priority value for severity (lower = more severe)
158fn severity_priority(severity: &str) -> u8 {
159    match severity.to_lowercase().as_str() {
160        "critical" => 0,
161        "high" => 1,
162        "medium" => 2,
163        "low" => 3,
164        "info" => 4,
165        "none" => 5,
166        _ => 6,
167    }
168}
169
170/// Group vulnerabilities by component
171#[must_use]
172pub fn group_vulnerabilities(
173    vulns: &[VulnerabilityDetail],
174    status: VulnGroupStatus,
175) -> Vec<VulnerabilityGroup> {
176    let mut groups: HashMap<String, VulnerabilityGroup> = HashMap::new();
177
178    for vuln in vulns {
179        let group = groups.entry(vuln.component_id.clone()).or_insert_with(|| {
180            VulnerabilityGroup::new(
181                vuln.component_id.clone(),
182                vuln.component_name.clone(),
183                status,
184            )
185        });
186
187        group.add_vulnerability(vuln.clone());
188    }
189
190    // Sort groups by severity (most severe first), then by count
191    let mut result: Vec<_> = groups.into_values().collect();
192    result.sort_by(|a, b| {
193        let sev_cmp = severity_priority(&a.max_severity).cmp(&severity_priority(&b.max_severity));
194        if sev_cmp == std::cmp::Ordering::Equal {
195            b.vuln_count().cmp(&a.vuln_count())
196        } else {
197            sev_cmp
198        }
199    });
200
201    result
202}
203
204/// Grouped view of vulnerability changes
205#[derive(Debug, Clone, Default, Serialize, Deserialize)]
206pub struct VulnerabilityGroupedView {
207    /// Groups of introduced vulnerabilities
208    pub introduced_groups: Vec<VulnerabilityGroup>,
209    /// Groups of resolved vulnerabilities
210    pub resolved_groups: Vec<VulnerabilityGroup>,
211    /// Groups of persistent vulnerabilities
212    pub persistent_groups: Vec<VulnerabilityGroup>,
213}
214
215impl VulnerabilityGroupedView {
216    /// Create grouped view from vulnerability lists
217    #[must_use]
218    pub fn from_changes(
219        introduced: &[VulnerabilityDetail],
220        resolved: &[VulnerabilityDetail],
221        persistent: &[VulnerabilityDetail],
222    ) -> Self {
223        Self {
224            introduced_groups: group_vulnerabilities(introduced, VulnGroupStatus::Introduced),
225            resolved_groups: group_vulnerabilities(resolved, VulnGroupStatus::Resolved),
226            persistent_groups: group_vulnerabilities(persistent, VulnGroupStatus::Persistent),
227        }
228    }
229
230    /// Get total group count
231    #[must_use]
232    pub fn total_groups(&self) -> usize {
233        self.introduced_groups.len() + self.resolved_groups.len() + self.persistent_groups.len()
234    }
235
236    /// Get total vulnerability count across all groups
237    pub fn total_vulns(&self) -> usize {
238        self.introduced_groups
239            .iter()
240            .map(VulnerabilityGroup::vuln_count)
241            .sum::<usize>()
242            + self
243                .resolved_groups
244                .iter()
245                .map(VulnerabilityGroup::vuln_count)
246                .sum::<usize>()
247            + self
248                .persistent_groups
249                .iter()
250                .map(VulnerabilityGroup::vuln_count)
251                .sum::<usize>()
252    }
253
254    /// Check if any group has KEV vulnerabilities
255    #[must_use]
256    pub fn has_any_kev(&self) -> bool {
257        self.introduced_groups.iter().any(|g| g.has_kev)
258            || self.resolved_groups.iter().any(|g| g.has_kev)
259            || self.persistent_groups.iter().any(|g| g.has_kev)
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    fn make_vuln(id: &str, component_id: &str, severity: &str) -> VulnerabilityDetail {
268        VulnerabilityDetail {
269            id: id.to_string(),
270            source: "OSV".to_string(),
271            severity: severity.to_string(),
272            cvss_score: None,
273            component_id: component_id.to_string(),
274            component_canonical_id: None,
275            component_ref: None,
276            component_name: format!("{}-pkg", component_id),
277            version: Some("1.0.0".to_string()),
278            description: None,
279            remediation: None,
280            is_kev: false,
281            epss_score: None,
282            cwes: Vec::new(),
283            component_depth: None,
284            published_date: None,
285            kev_due_date: None,
286            days_since_published: None,
287            days_until_due: None,
288            vex_state: None,
289            vex_justification: None,
290            vex_impact_statement: None,
291        }
292    }
293
294    #[test]
295    fn test_group_vulnerabilities() {
296        let vulns = vec![
297            make_vuln("CVE-2024-0001", "lodash", "Critical"),
298            make_vuln("CVE-2024-0002", "lodash", "High"),
299            make_vuln("CVE-2024-0003", "lodash", "High"),
300            make_vuln("CVE-2024-0004", "express", "Medium"),
301        ];
302
303        let groups = group_vulnerabilities(&vulns, VulnGroupStatus::Introduced);
304
305        assert_eq!(groups.len(), 2);
306
307        // lodash should be first (Critical severity)
308        assert_eq!(groups[0].component_id, "lodash");
309        assert_eq!(groups[0].vuln_count(), 3);
310        assert_eq!(groups[0].max_severity, "Critical");
311        assert_eq!(groups[0].severity_counts.get("Critical"), Some(&1));
312        assert_eq!(groups[0].severity_counts.get("High"), Some(&2));
313
314        // express should be second
315        assert_eq!(groups[1].component_id, "express");
316        assert_eq!(groups[1].vuln_count(), 1);
317    }
318
319    #[test]
320    fn test_group_propagates_kev_flag() {
321        let mut kev_vuln = make_vuln("CVE-2021-44228", "log4j", "Critical");
322        kev_vuln.is_kev = true;
323        let vulns = vec![kev_vuln, make_vuln("CVE-2024-0009", "lodash", "High")];
324
325        let groups = group_vulnerabilities(&vulns, VulnGroupStatus::Introduced);
326
327        let log4j = groups
328            .iter()
329            .find(|g| g.component_id == "log4j")
330            .expect("log4j group present");
331        assert!(log4j.has_kev, "group with a KEV vuln must report has_kev");
332
333        let lodash = groups
334            .iter()
335            .find(|g| g.component_id == "lodash")
336            .expect("lodash group present");
337        assert!(
338            !lodash.has_kev,
339            "group without KEV vulns must not report has_kev"
340        );
341    }
342
343    #[test]
344    fn test_grouped_view() {
345        let introduced = vec![
346            make_vuln("CVE-2024-0001", "lodash", "High"),
347            make_vuln("CVE-2024-0002", "lodash", "Medium"),
348        ];
349        let resolved = vec![make_vuln("CVE-2024-0003", "old-dep", "Critical")];
350        let persistent = vec![];
351
352        let view = VulnerabilityGroupedView::from_changes(&introduced, &resolved, &persistent);
353
354        assert_eq!(view.total_groups(), 2);
355        assert_eq!(view.total_vulns(), 3);
356        assert_eq!(view.introduced_groups.len(), 1);
357        assert_eq!(view.resolved_groups.len(), 1);
358    }
359
360    #[test]
361    fn test_summary_line() {
362        let mut group = VulnerabilityGroup::new(
363            "lodash".to_string(),
364            "lodash".to_string(),
365            VulnGroupStatus::Introduced,
366        );
367        group.add_vulnerability(make_vuln("CVE-1", "lodash", "Critical"));
368        group.add_vulnerability(make_vuln("CVE-2", "lodash", "High"));
369        group.add_vulnerability(make_vuln("CVE-3", "lodash", "High"));
370
371        let summary = group.summary_line();
372        assert!(summary.contains("lodash"));
373        assert!(summary.contains("3 CVEs"));
374        assert!(summary.contains("C:1"));
375        assert!(summary.contains("H:2"));
376    }
377}