1use crate::diff::result::VulnerabilityDetail;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum VulnGroupStatus {
13 Introduced,
15 Resolved,
17 Persistent,
19}
20
21impl std::fmt::Display for VulnGroupStatus {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 Self::Introduced => write!(f, "Introduced"),
25 Self::Resolved => write!(f, "Resolved"),
26 Self::Persistent => write!(f, "Persistent"),
27 }
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct VulnerabilityGroup {
34 pub component_id: String,
36 pub component_name: String,
38 pub component_version: Option<String>,
40 pub vulnerabilities: Vec<VulnerabilityDetail>,
42 pub max_severity: String,
44 pub max_cvss: Option<f32>,
46 pub severity_counts: HashMap<String, usize>,
48 pub status: VulnGroupStatus,
50 pub has_kev: bool,
52 pub has_ransomware_kev: bool,
54}
55
56impl VulnerabilityGroup {
57 #[must_use]
59 pub fn new(component_id: String, component_name: String, status: VulnGroupStatus) -> Self {
60 Self {
61 component_id,
62 component_name,
63 component_version: None,
64 vulnerabilities: Vec::new(),
65 max_severity: "Unknown".to_string(),
66 max_cvss: None,
67 severity_counts: HashMap::new(),
68 status,
69 has_kev: false,
70 has_ransomware_kev: false,
71 }
72 }
73
74 pub fn add_vulnerability(&mut self, vuln: VulnerabilityDetail) {
76 *self
78 .severity_counts
79 .entry(vuln.severity.clone())
80 .or_insert(0) += 1;
81
82 let vuln_priority = severity_priority(&vuln.severity);
84 let current_priority = severity_priority(&self.max_severity);
85 if vuln_priority < current_priority {
86 self.max_severity.clone_from(&vuln.severity);
87 }
88
89 if let Some(score) = vuln.cvss_score {
91 self.max_cvss = Some(self.max_cvss.map_or(score, |c| c.max(score)));
92 }
93
94 if self.component_version.is_none() {
96 self.component_version.clone_from(&vuln.version);
97 }
98
99 if vuln.is_kev {
101 self.has_kev = true;
102 }
103
104 self.vulnerabilities.push(vuln);
105 }
106
107 #[must_use]
109 pub fn vuln_count(&self) -> usize {
110 self.vulnerabilities.len()
111 }
112
113 #[must_use]
115 pub fn has_critical(&self) -> bool {
116 self.severity_counts.get("Critical").copied().unwrap_or(0) > 0
117 }
118
119 #[must_use]
121 pub fn has_high(&self) -> bool {
122 self.severity_counts.get("High").copied().unwrap_or(0) > 0
123 }
124
125 #[must_use]
127 pub fn summary_line(&self) -> String {
128 let version_str = self
129 .component_version
130 .as_ref()
131 .map(|v| format!("@{v}"))
132 .unwrap_or_default();
133
134 let severity_badges: Vec<String> = ["Critical", "High", "Medium", "Low"]
135 .iter()
136 .filter_map(|sev| {
137 self.severity_counts.get(*sev).and_then(|&count| {
138 if count > 0 {
139 Some(format!("{}:{}", &sev[..1], count))
140 } else {
141 None
142 }
143 })
144 })
145 .collect();
146
147 format!(
148 "{}{}: {} CVEs [{}]",
149 self.component_name,
150 version_str,
151 self.vuln_count(),
152 severity_badges.join(" ")
153 )
154 }
155}
156
157fn severity_priority(severity: &str) -> u8 {
159 match severity.to_lowercase().as_str() {
160 "critical" => 0,
161 "high" => 1,
162 "medium" => 2,
163 "low" => 3,
164 "info" => 4,
165 "none" => 5,
166 _ => 6,
167 }
168}
169
170#[must_use]
172pub fn group_vulnerabilities(
173 vulns: &[VulnerabilityDetail],
174 status: VulnGroupStatus,
175) -> Vec<VulnerabilityGroup> {
176 let mut groups: HashMap<String, VulnerabilityGroup> = HashMap::new();
177
178 for vuln in vulns {
179 let group = groups.entry(vuln.component_id.clone()).or_insert_with(|| {
180 VulnerabilityGroup::new(
181 vuln.component_id.clone(),
182 vuln.component_name.clone(),
183 status,
184 )
185 });
186
187 group.add_vulnerability(vuln.clone());
188 }
189
190 let mut result: Vec<_> = groups.into_values().collect();
192 result.sort_by(|a, b| {
193 let sev_cmp = severity_priority(&a.max_severity).cmp(&severity_priority(&b.max_severity));
194 if sev_cmp == std::cmp::Ordering::Equal {
195 b.vuln_count().cmp(&a.vuln_count())
196 } else {
197 sev_cmp
198 }
199 });
200
201 result
202}
203
204#[derive(Debug, Clone, Default, Serialize, Deserialize)]
206pub struct VulnerabilityGroupedView {
207 pub introduced_groups: Vec<VulnerabilityGroup>,
209 pub resolved_groups: Vec<VulnerabilityGroup>,
211 pub persistent_groups: Vec<VulnerabilityGroup>,
213}
214
215impl VulnerabilityGroupedView {
216 #[must_use]
218 pub fn from_changes(
219 introduced: &[VulnerabilityDetail],
220 resolved: &[VulnerabilityDetail],
221 persistent: &[VulnerabilityDetail],
222 ) -> Self {
223 Self {
224 introduced_groups: group_vulnerabilities(introduced, VulnGroupStatus::Introduced),
225 resolved_groups: group_vulnerabilities(resolved, VulnGroupStatus::Resolved),
226 persistent_groups: group_vulnerabilities(persistent, VulnGroupStatus::Persistent),
227 }
228 }
229
230 #[must_use]
232 pub fn total_groups(&self) -> usize {
233 self.introduced_groups.len() + self.resolved_groups.len() + self.persistent_groups.len()
234 }
235
236 pub fn total_vulns(&self) -> usize {
238 self.introduced_groups
239 .iter()
240 .map(VulnerabilityGroup::vuln_count)
241 .sum::<usize>()
242 + self
243 .resolved_groups
244 .iter()
245 .map(VulnerabilityGroup::vuln_count)
246 .sum::<usize>()
247 + self
248 .persistent_groups
249 .iter()
250 .map(VulnerabilityGroup::vuln_count)
251 .sum::<usize>()
252 }
253
254 #[must_use]
256 pub fn has_any_kev(&self) -> bool {
257 self.introduced_groups.iter().any(|g| g.has_kev)
258 || self.resolved_groups.iter().any(|g| g.has_kev)
259 || self.persistent_groups.iter().any(|g| g.has_kev)
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 fn make_vuln(id: &str, component_id: &str, severity: &str) -> VulnerabilityDetail {
268 VulnerabilityDetail {
269 id: id.to_string(),
270 source: "OSV".to_string(),
271 severity: severity.to_string(),
272 cvss_score: None,
273 component_id: component_id.to_string(),
274 component_canonical_id: None,
275 component_ref: None,
276 component_name: format!("{}-pkg", component_id),
277 version: Some("1.0.0".to_string()),
278 description: None,
279 remediation: None,
280 is_kev: false,
281 epss_score: None,
282 cwes: Vec::new(),
283 component_depth: None,
284 published_date: None,
285 kev_due_date: None,
286 days_since_published: None,
287 days_until_due: None,
288 vex_state: None,
289 vex_justification: None,
290 vex_impact_statement: None,
291 }
292 }
293
294 #[test]
295 fn test_group_vulnerabilities() {
296 let vulns = vec![
297 make_vuln("CVE-2024-0001", "lodash", "Critical"),
298 make_vuln("CVE-2024-0002", "lodash", "High"),
299 make_vuln("CVE-2024-0003", "lodash", "High"),
300 make_vuln("CVE-2024-0004", "express", "Medium"),
301 ];
302
303 let groups = group_vulnerabilities(&vulns, VulnGroupStatus::Introduced);
304
305 assert_eq!(groups.len(), 2);
306
307 assert_eq!(groups[0].component_id, "lodash");
309 assert_eq!(groups[0].vuln_count(), 3);
310 assert_eq!(groups[0].max_severity, "Critical");
311 assert_eq!(groups[0].severity_counts.get("Critical"), Some(&1));
312 assert_eq!(groups[0].severity_counts.get("High"), Some(&2));
313
314 assert_eq!(groups[1].component_id, "express");
316 assert_eq!(groups[1].vuln_count(), 1);
317 }
318
319 #[test]
320 fn test_group_propagates_kev_flag() {
321 let mut kev_vuln = make_vuln("CVE-2021-44228", "log4j", "Critical");
322 kev_vuln.is_kev = true;
323 let vulns = vec![kev_vuln, make_vuln("CVE-2024-0009", "lodash", "High")];
324
325 let groups = group_vulnerabilities(&vulns, VulnGroupStatus::Introduced);
326
327 let log4j = groups
328 .iter()
329 .find(|g| g.component_id == "log4j")
330 .expect("log4j group present");
331 assert!(log4j.has_kev, "group with a KEV vuln must report has_kev");
332
333 let lodash = groups
334 .iter()
335 .find(|g| g.component_id == "lodash")
336 .expect("lodash group present");
337 assert!(
338 !lodash.has_kev,
339 "group without KEV vulns must not report has_kev"
340 );
341 }
342
343 #[test]
344 fn test_grouped_view() {
345 let introduced = vec![
346 make_vuln("CVE-2024-0001", "lodash", "High"),
347 make_vuln("CVE-2024-0002", "lodash", "Medium"),
348 ];
349 let resolved = vec![make_vuln("CVE-2024-0003", "old-dep", "Critical")];
350 let persistent = vec![];
351
352 let view = VulnerabilityGroupedView::from_changes(&introduced, &resolved, &persistent);
353
354 assert_eq!(view.total_groups(), 2);
355 assert_eq!(view.total_vulns(), 3);
356 assert_eq!(view.introduced_groups.len(), 1);
357 assert_eq!(view.resolved_groups.len(), 1);
358 }
359
360 #[test]
361 fn test_summary_line() {
362 let mut group = VulnerabilityGroup::new(
363 "lodash".to_string(),
364 "lodash".to_string(),
365 VulnGroupStatus::Introduced,
366 );
367 group.add_vulnerability(make_vuln("CVE-1", "lodash", "Critical"));
368 group.add_vulnerability(make_vuln("CVE-2", "lodash", "High"));
369 group.add_vulnerability(make_vuln("CVE-3", "lodash", "High"));
370
371 let summary = group.summary_line();
372 assert!(summary.contains("lodash"));
373 assert!(summary.contains("3 CVEs"));
374 assert!(summary.contains("C:1"));
375 assert!(summary.contains("H:2"));
376 }
377}