1use crate::diff::result::VulnerabilityDetail;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum VulnGroupStatus {
13 Introduced,
15 Resolved,
17 Persistent,
19}
20
21impl std::fmt::Display for VulnGroupStatus {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 Self::Introduced => write!(f, "Introduced"),
25 Self::Resolved => write!(f, "Resolved"),
26 Self::Persistent => write!(f, "Persistent"),
27 }
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct VulnerabilityGroup {
34 pub component_id: String,
36 pub component_name: String,
38 pub component_version: Option<String>,
40 pub vulnerabilities: Vec<VulnerabilityDetail>,
42 pub max_severity: String,
44 pub max_cvss: Option<f32>,
46 pub severity_counts: HashMap<String, usize>,
48 pub status: VulnGroupStatus,
50 pub has_kev: bool,
52 pub has_ransomware_kev: bool,
54}
55
56impl VulnerabilityGroup {
57 #[must_use]
59 pub fn new(component_id: String, component_name: String, status: VulnGroupStatus) -> Self {
60 Self {
61 component_id,
62 component_name,
63 component_version: None,
64 vulnerabilities: Vec::new(),
65 max_severity: "Unknown".to_string(),
66 max_cvss: None,
67 severity_counts: HashMap::new(),
68 status,
69 has_kev: false,
70 has_ransomware_kev: false,
71 }
72 }
73
74 pub fn add_vulnerability(&mut self, vuln: VulnerabilityDetail) {
76 *self
78 .severity_counts
79 .entry(vuln.severity.clone())
80 .or_insert(0) += 1;
81
82 let vuln_priority = severity_priority(&vuln.severity);
84 let current_priority = severity_priority(&self.max_severity);
85 if vuln_priority < current_priority {
86 self.max_severity.clone_from(&vuln.severity);
87 }
88
89 if let Some(score) = vuln.cvss_score {
91 self.max_cvss = Some(self.max_cvss.map_or(score, |c| c.max(score)));
92 }
93
94 if self.component_version.is_none() {
96 self.component_version.clone_from(&vuln.version);
97 }
98
99 self.vulnerabilities.push(vuln);
100 }
101
102 #[must_use]
104 pub fn vuln_count(&self) -> usize {
105 self.vulnerabilities.len()
106 }
107
108 #[must_use]
110 pub fn has_critical(&self) -> bool {
111 self.severity_counts.get("Critical").copied().unwrap_or(0) > 0
112 }
113
114 #[must_use]
116 pub fn has_high(&self) -> bool {
117 self.severity_counts.get("High").copied().unwrap_or(0) > 0
118 }
119
120 #[must_use]
122 pub fn summary_line(&self) -> String {
123 let version_str = self
124 .component_version
125 .as_ref()
126 .map(|v| format!("@{v}"))
127 .unwrap_or_default();
128
129 let severity_badges: Vec<String> = ["Critical", "High", "Medium", "Low"]
130 .iter()
131 .filter_map(|sev| {
132 self.severity_counts.get(*sev).and_then(|&count| {
133 if count > 0 {
134 Some(format!("{}:{}", &sev[..1], count))
135 } else {
136 None
137 }
138 })
139 })
140 .collect();
141
142 format!(
143 "{}{}: {} CVEs [{}]",
144 self.component_name,
145 version_str,
146 self.vuln_count(),
147 severity_badges.join(" ")
148 )
149 }
150}
151
152fn severity_priority(severity: &str) -> u8 {
154 match severity.to_lowercase().as_str() {
155 "critical" => 0,
156 "high" => 1,
157 "medium" => 2,
158 "low" => 3,
159 "info" => 4,
160 "none" => 5,
161 _ => 6,
162 }
163}
164
165#[must_use]
167pub fn group_vulnerabilities(
168 vulns: &[VulnerabilityDetail],
169 status: VulnGroupStatus,
170) -> Vec<VulnerabilityGroup> {
171 let mut groups: HashMap<String, VulnerabilityGroup> = HashMap::new();
172
173 for vuln in vulns {
174 let group = groups.entry(vuln.component_id.clone()).or_insert_with(|| {
175 VulnerabilityGroup::new(
176 vuln.component_id.clone(),
177 vuln.component_name.clone(),
178 status,
179 )
180 });
181
182 group.add_vulnerability(vuln.clone());
183 }
184
185 let mut result: Vec<_> = groups.into_values().collect();
187 result.sort_by(|a, b| {
188 let sev_cmp = severity_priority(&a.max_severity).cmp(&severity_priority(&b.max_severity));
189 if sev_cmp == std::cmp::Ordering::Equal {
190 b.vuln_count().cmp(&a.vuln_count())
191 } else {
192 sev_cmp
193 }
194 });
195
196 result
197}
198
199#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201pub struct VulnerabilityGroupedView {
202 pub introduced_groups: Vec<VulnerabilityGroup>,
204 pub resolved_groups: Vec<VulnerabilityGroup>,
206 pub persistent_groups: Vec<VulnerabilityGroup>,
208}
209
210impl VulnerabilityGroupedView {
211 #[must_use]
213 pub fn from_changes(
214 introduced: &[VulnerabilityDetail],
215 resolved: &[VulnerabilityDetail],
216 persistent: &[VulnerabilityDetail],
217 ) -> Self {
218 Self {
219 introduced_groups: group_vulnerabilities(introduced, VulnGroupStatus::Introduced),
220 resolved_groups: group_vulnerabilities(resolved, VulnGroupStatus::Resolved),
221 persistent_groups: group_vulnerabilities(persistent, VulnGroupStatus::Persistent),
222 }
223 }
224
225 #[must_use]
227 pub fn total_groups(&self) -> usize {
228 self.introduced_groups.len() + self.resolved_groups.len() + self.persistent_groups.len()
229 }
230
231 pub fn total_vulns(&self) -> usize {
233 self.introduced_groups
234 .iter()
235 .map(VulnerabilityGroup::vuln_count)
236 .sum::<usize>()
237 + self
238 .resolved_groups
239 .iter()
240 .map(VulnerabilityGroup::vuln_count)
241 .sum::<usize>()
242 + self
243 .persistent_groups
244 .iter()
245 .map(VulnerabilityGroup::vuln_count)
246 .sum::<usize>()
247 }
248
249 #[must_use]
251 pub fn has_any_kev(&self) -> bool {
252 self.introduced_groups.iter().any(|g| g.has_kev)
253 || self.resolved_groups.iter().any(|g| g.has_kev)
254 || self.persistent_groups.iter().any(|g| g.has_kev)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 fn make_vuln(id: &str, component_id: &str, severity: &str) -> VulnerabilityDetail {
263 VulnerabilityDetail {
264 id: id.to_string(),
265 source: "OSV".to_string(),
266 severity: severity.to_string(),
267 cvss_score: None,
268 component_id: component_id.to_string(),
269 component_canonical_id: None,
270 component_ref: None,
271 component_name: format!("{}-pkg", component_id),
272 version: Some("1.0.0".to_string()),
273 description: None,
274 remediation: None,
275 is_kev: false,
276 cwes: Vec::new(),
277 component_depth: None,
278 published_date: None,
279 kev_due_date: None,
280 days_since_published: None,
281 days_until_due: None,
282 vex_state: None,
283 vex_justification: None,
284 vex_impact_statement: None,
285 }
286 }
287
288 #[test]
289 fn test_group_vulnerabilities() {
290 let vulns = vec![
291 make_vuln("CVE-2024-0001", "lodash", "Critical"),
292 make_vuln("CVE-2024-0002", "lodash", "High"),
293 make_vuln("CVE-2024-0003", "lodash", "High"),
294 make_vuln("CVE-2024-0004", "express", "Medium"),
295 ];
296
297 let groups = group_vulnerabilities(&vulns, VulnGroupStatus::Introduced);
298
299 assert_eq!(groups.len(), 2);
300
301 assert_eq!(groups[0].component_id, "lodash");
303 assert_eq!(groups[0].vuln_count(), 3);
304 assert_eq!(groups[0].max_severity, "Critical");
305 assert_eq!(groups[0].severity_counts.get("Critical"), Some(&1));
306 assert_eq!(groups[0].severity_counts.get("High"), Some(&2));
307
308 assert_eq!(groups[1].component_id, "express");
310 assert_eq!(groups[1].vuln_count(), 1);
311 }
312
313 #[test]
314 fn test_grouped_view() {
315 let introduced = vec![
316 make_vuln("CVE-2024-0001", "lodash", "High"),
317 make_vuln("CVE-2024-0002", "lodash", "Medium"),
318 ];
319 let resolved = vec![make_vuln("CVE-2024-0003", "old-dep", "Critical")];
320 let persistent = vec![];
321
322 let view = VulnerabilityGroupedView::from_changes(&introduced, &resolved, &persistent);
323
324 assert_eq!(view.total_groups(), 2);
325 assert_eq!(view.total_vulns(), 3);
326 assert_eq!(view.introduced_groups.len(), 1);
327 assert_eq!(view.resolved_groups.len(), 1);
328 }
329
330 #[test]
331 fn test_summary_line() {
332 let mut group = VulnerabilityGroup::new(
333 "lodash".to_string(),
334 "lodash".to_string(),
335 VulnGroupStatus::Introduced,
336 );
337 group.add_vulnerability(make_vuln("CVE-1", "lodash", "Critical"));
338 group.add_vulnerability(make_vuln("CVE-2", "lodash", "High"));
339 group.add_vulnerability(make_vuln("CVE-3", "lodash", "High"));
340
341 let summary = group.summary_line();
342 assert!(summary.contains("lodash"));
343 assert!(summary.contains("3 CVEs"));
344 assert!(summary.contains("C:1"));
345 assert!(summary.contains("H:2"));
346 }
347}