1use crate::diff::result::VulnerabilityDetail;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum VulnGroupStatus {
13 Introduced,
15 Resolved,
17 Persistent,
19}
20
21impl std::fmt::Display for VulnGroupStatus {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 Self::Introduced => write!(f, "Introduced"),
25 Self::Resolved => write!(f, "Resolved"),
26 Self::Persistent => write!(f, "Persistent"),
27 }
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct VulnerabilityGroup {
34 pub component_id: String,
36 pub component_name: String,
38 pub component_version: Option<String>,
40 pub vulnerabilities: Vec<VulnerabilityDetail>,
42 pub max_severity: String,
44 pub max_cvss: Option<f32>,
46 pub severity_counts: HashMap<String, usize>,
48 pub status: VulnGroupStatus,
50 pub has_kev: bool,
52 pub has_ransomware_kev: bool,
54}
55
56impl VulnerabilityGroup {
57 #[must_use]
59 pub fn new(
60 component_id: String,
61 component_name: String,
62 status: VulnGroupStatus,
63 ) -> Self {
64 Self {
65 component_id,
66 component_name,
67 component_version: None,
68 vulnerabilities: Vec::new(),
69 max_severity: "Unknown".to_string(),
70 max_cvss: None,
71 severity_counts: HashMap::new(),
72 status,
73 has_kev: false,
74 has_ransomware_kev: false,
75 }
76 }
77
78 pub fn add_vulnerability(&mut self, vuln: VulnerabilityDetail) {
80 *self.severity_counts.entry(vuln.severity.clone()).or_insert(0) += 1;
82
83 let vuln_priority = severity_priority(&vuln.severity);
85 let current_priority = severity_priority(&self.max_severity);
86 if vuln_priority < current_priority {
87 self.max_severity.clone_from(&vuln.severity);
88 }
89
90 if let Some(score) = vuln.cvss_score {
92 self.max_cvss = Some(self.max_cvss.map_or(score, |c| c.max(score)));
93 }
94
95 if self.component_version.is_none() {
97 self.component_version.clone_from(&vuln.version);
98 }
99
100 self.vulnerabilities.push(vuln);
101 }
102
103 #[must_use]
105 pub fn vuln_count(&self) -> usize {
106 self.vulnerabilities.len()
107 }
108
109 #[must_use]
111 pub fn has_critical(&self) -> bool {
112 self.severity_counts.get("Critical").copied().unwrap_or(0) > 0
113 }
114
115 #[must_use]
117 pub fn has_high(&self) -> bool {
118 self.severity_counts.get("High").copied().unwrap_or(0) > 0
119 }
120
121 #[must_use]
123 pub fn summary_line(&self) -> String {
124 let version_str = self
125 .component_version
126 .as_ref()
127 .map(|v| format!("@{v}"))
128 .unwrap_or_default();
129
130 let severity_badges: Vec<String> = ["Critical", "High", "Medium", "Low"]
131 .iter()
132 .filter_map(|sev| {
133 self.severity_counts.get(*sev).and_then(|&count| {
134 if count > 0 {
135 Some(format!("{}:{}", &sev[..1], count))
136 } else {
137 None
138 }
139 })
140 })
141 .collect();
142
143 format!(
144 "{}{}: {} CVEs [{}]",
145 self.component_name,
146 version_str,
147 self.vuln_count(),
148 severity_badges.join(" ")
149 )
150 }
151}
152
153fn severity_priority(severity: &str) -> u8 {
155 match severity.to_lowercase().as_str() {
156 "critical" => 0,
157 "high" => 1,
158 "medium" => 2,
159 "low" => 3,
160 "info" => 4,
161 "none" => 5,
162 _ => 6,
163 }
164}
165
166#[must_use]
168pub fn group_vulnerabilities(
169 vulns: &[VulnerabilityDetail],
170 status: VulnGroupStatus,
171) -> Vec<VulnerabilityGroup> {
172 let mut groups: HashMap<String, VulnerabilityGroup> = HashMap::new();
173
174 for vuln in vulns {
175 let group = groups
176 .entry(vuln.component_id.clone())
177 .or_insert_with(|| {
178 VulnerabilityGroup::new(
179 vuln.component_id.clone(),
180 vuln.component_name.clone(),
181 status,
182 )
183 });
184
185 group.add_vulnerability(vuln.clone());
186 }
187
188 let mut result: Vec<_> = groups.into_values().collect();
190 result.sort_by(|a, b| {
191 let sev_cmp = severity_priority(&a.max_severity).cmp(&severity_priority(&b.max_severity));
192 if sev_cmp == std::cmp::Ordering::Equal {
193 b.vuln_count().cmp(&a.vuln_count())
194 } else {
195 sev_cmp
196 }
197 });
198
199 result
200}
201
202#[derive(Debug, Clone, Default, Serialize, Deserialize)]
204pub struct VulnerabilityGroupedView {
205 pub introduced_groups: Vec<VulnerabilityGroup>,
207 pub resolved_groups: Vec<VulnerabilityGroup>,
209 pub persistent_groups: Vec<VulnerabilityGroup>,
211}
212
213impl VulnerabilityGroupedView {
214 #[must_use]
216 pub fn from_changes(
217 introduced: &[VulnerabilityDetail],
218 resolved: &[VulnerabilityDetail],
219 persistent: &[VulnerabilityDetail],
220 ) -> Self {
221 Self {
222 introduced_groups: group_vulnerabilities(introduced, VulnGroupStatus::Introduced),
223 resolved_groups: group_vulnerabilities(resolved, VulnGroupStatus::Resolved),
224 persistent_groups: group_vulnerabilities(persistent, VulnGroupStatus::Persistent),
225 }
226 }
227
228 #[must_use]
230 pub fn total_groups(&self) -> usize {
231 self.introduced_groups.len() + self.resolved_groups.len() + self.persistent_groups.len()
232 }
233
234 pub fn total_vulns(&self) -> usize {
236 self.introduced_groups.iter().map(VulnerabilityGroup::vuln_count).sum::<usize>()
237 + self.resolved_groups.iter().map(VulnerabilityGroup::vuln_count).sum::<usize>()
238 + self.persistent_groups.iter().map(VulnerabilityGroup::vuln_count).sum::<usize>()
239 }
240
241 #[must_use]
243 pub fn has_any_kev(&self) -> bool {
244 self.introduced_groups.iter().any(|g| g.has_kev)
245 || self.resolved_groups.iter().any(|g| g.has_kev)
246 || self.persistent_groups.iter().any(|g| g.has_kev)
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 fn make_vuln(id: &str, component_id: &str, severity: &str) -> VulnerabilityDetail {
255 VulnerabilityDetail {
256 id: id.to_string(),
257 source: "OSV".to_string(),
258 severity: severity.to_string(),
259 cvss_score: None,
260 component_id: component_id.to_string(),
261 component_canonical_id: None,
262 component_ref: None,
263 component_name: format!("{}-pkg", component_id),
264 version: Some("1.0.0".to_string()),
265 description: None,
266 remediation: None,
267 is_kev: false,
268 cwes: Vec::new(),
269 component_depth: None,
270 published_date: None,
271 kev_due_date: None,
272 days_since_published: None,
273 days_until_due: None,
274 vex_state: None,
275 vex_justification: None,
276 vex_impact_statement: None,
277 }
278 }
279
280 #[test]
281 fn test_group_vulnerabilities() {
282 let vulns = vec![
283 make_vuln("CVE-2024-0001", "lodash", "Critical"),
284 make_vuln("CVE-2024-0002", "lodash", "High"),
285 make_vuln("CVE-2024-0003", "lodash", "High"),
286 make_vuln("CVE-2024-0004", "express", "Medium"),
287 ];
288
289 let groups = group_vulnerabilities(&vulns, VulnGroupStatus::Introduced);
290
291 assert_eq!(groups.len(), 2);
292
293 assert_eq!(groups[0].component_id, "lodash");
295 assert_eq!(groups[0].vuln_count(), 3);
296 assert_eq!(groups[0].max_severity, "Critical");
297 assert_eq!(groups[0].severity_counts.get("Critical"), Some(&1));
298 assert_eq!(groups[0].severity_counts.get("High"), Some(&2));
299
300 assert_eq!(groups[1].component_id, "express");
302 assert_eq!(groups[1].vuln_count(), 1);
303 }
304
305 #[test]
306 fn test_grouped_view() {
307 let introduced = vec![
308 make_vuln("CVE-2024-0001", "lodash", "High"),
309 make_vuln("CVE-2024-0002", "lodash", "Medium"),
310 ];
311 let resolved = vec![make_vuln("CVE-2024-0003", "old-dep", "Critical")];
312 let persistent = vec![];
313
314 let view = VulnerabilityGroupedView::from_changes(&introduced, &resolved, &persistent);
315
316 assert_eq!(view.total_groups(), 2);
317 assert_eq!(view.total_vulns(), 3);
318 assert_eq!(view.introduced_groups.len(), 1);
319 assert_eq!(view.resolved_groups.len(), 1);
320 }
321
322 #[test]
323 fn test_summary_line() {
324 let mut group = VulnerabilityGroup::new(
325 "lodash".to_string(),
326 "lodash".to_string(),
327 VulnGroupStatus::Introduced,
328 );
329 group.add_vulnerability(make_vuln("CVE-1", "lodash", "Critical"));
330 group.add_vulnerability(make_vuln("CVE-2", "lodash", "High"));
331 group.add_vulnerability(make_vuln("CVE-3", "lodash", "High"));
332
333 let summary = group.summary_line();
334 assert!(summary.contains("lodash"));
335 assert!(summary.contains("3 CVEs"));
336 assert!(summary.contains("C:1"));
337 assert!(summary.contains("H:2"));
338 }
339}