1use crate::diff::result::VulnerabilityDetail;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum VulnGroupStatus {
13 Introduced,
15 Resolved,
17 Persistent,
19}
20
21impl std::fmt::Display for VulnGroupStatus {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 VulnGroupStatus::Introduced => write!(f, "Introduced"),
25 VulnGroupStatus::Resolved => write!(f, "Resolved"),
26 VulnGroupStatus::Persistent => write!(f, "Persistent"),
27 }
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct VulnerabilityGroup {
34 pub component_id: String,
36 pub component_name: String,
38 pub component_version: Option<String>,
40 pub vulnerabilities: Vec<VulnerabilityDetail>,
42 pub max_severity: String,
44 pub max_cvss: Option<f32>,
46 pub severity_counts: HashMap<String, usize>,
48 pub status: VulnGroupStatus,
50 pub has_kev: bool,
52 pub has_ransomware_kev: bool,
54}
55
56impl VulnerabilityGroup {
57 pub fn new(
59 component_id: String,
60 component_name: String,
61 status: VulnGroupStatus,
62 ) -> Self {
63 Self {
64 component_id,
65 component_name,
66 component_version: None,
67 vulnerabilities: Vec::new(),
68 max_severity: "Unknown".to_string(),
69 max_cvss: None,
70 severity_counts: HashMap::new(),
71 status,
72 has_kev: false,
73 has_ransomware_kev: false,
74 }
75 }
76
77 pub fn add_vulnerability(&mut self, vuln: VulnerabilityDetail) {
79 *self.severity_counts.entry(vuln.severity.clone()).or_insert(0) += 1;
81
82 let vuln_priority = severity_priority(&vuln.severity);
84 let current_priority = severity_priority(&self.max_severity);
85 if vuln_priority < current_priority {
86 self.max_severity = vuln.severity.clone();
87 }
88
89 if let Some(score) = vuln.cvss_score {
91 self.max_cvss = Some(self.max_cvss.map(|c| c.max(score)).unwrap_or(score));
92 }
93
94 if self.component_version.is_none() {
96 self.component_version = vuln.version.clone();
97 }
98
99 self.vulnerabilities.push(vuln);
100 }
101
102 pub fn vuln_count(&self) -> usize {
104 self.vulnerabilities.len()
105 }
106
107 pub fn has_critical(&self) -> bool {
109 self.severity_counts.get("Critical").copied().unwrap_or(0) > 0
110 }
111
112 pub fn has_high(&self) -> bool {
114 self.severity_counts.get("High").copied().unwrap_or(0) > 0
115 }
116
117 pub fn summary_line(&self) -> String {
119 let version_str = self
120 .component_version
121 .as_ref()
122 .map(|v| format!("@{}", v))
123 .unwrap_or_default();
124
125 let severity_badges: Vec<String> = ["Critical", "High", "Medium", "Low"]
126 .iter()
127 .filter_map(|sev| {
128 self.severity_counts.get(*sev).and_then(|&count| {
129 if count > 0 {
130 Some(format!("{}:{}", &sev[..1], count))
131 } else {
132 None
133 }
134 })
135 })
136 .collect();
137
138 format!(
139 "{}{}: {} CVEs [{}]",
140 self.component_name,
141 version_str,
142 self.vuln_count(),
143 severity_badges.join(" ")
144 )
145 }
146}
147
148fn severity_priority(severity: &str) -> u8 {
150 match severity.to_lowercase().as_str() {
151 "critical" => 0,
152 "high" => 1,
153 "medium" => 2,
154 "low" => 3,
155 "info" => 4,
156 "none" => 5,
157 _ => 6,
158 }
159}
160
161pub fn group_vulnerabilities(
163 vulns: &[VulnerabilityDetail],
164 status: VulnGroupStatus,
165) -> Vec<VulnerabilityGroup> {
166 let mut groups: HashMap<String, VulnerabilityGroup> = HashMap::new();
167
168 for vuln in vulns {
169 let group = groups
170 .entry(vuln.component_id.clone())
171 .or_insert_with(|| {
172 VulnerabilityGroup::new(
173 vuln.component_id.clone(),
174 vuln.component_name.clone(),
175 status,
176 )
177 });
178
179 group.add_vulnerability(vuln.clone());
180 }
181
182 let mut result: Vec<_> = groups.into_values().collect();
184 result.sort_by(|a, b| {
185 let sev_cmp = severity_priority(&a.max_severity).cmp(&severity_priority(&b.max_severity));
186 if sev_cmp == std::cmp::Ordering::Equal {
187 b.vuln_count().cmp(&a.vuln_count())
188 } else {
189 sev_cmp
190 }
191 });
192
193 result
194}
195
196#[derive(Debug, Clone, Default, Serialize, Deserialize)]
198pub struct VulnerabilityGroupedView {
199 pub introduced_groups: Vec<VulnerabilityGroup>,
201 pub resolved_groups: Vec<VulnerabilityGroup>,
203 pub persistent_groups: Vec<VulnerabilityGroup>,
205}
206
207impl VulnerabilityGroupedView {
208 pub fn from_changes(
210 introduced: &[VulnerabilityDetail],
211 resolved: &[VulnerabilityDetail],
212 persistent: &[VulnerabilityDetail],
213 ) -> Self {
214 Self {
215 introduced_groups: group_vulnerabilities(introduced, VulnGroupStatus::Introduced),
216 resolved_groups: group_vulnerabilities(resolved, VulnGroupStatus::Resolved),
217 persistent_groups: group_vulnerabilities(persistent, VulnGroupStatus::Persistent),
218 }
219 }
220
221 pub fn total_groups(&self) -> usize {
223 self.introduced_groups.len() + self.resolved_groups.len() + self.persistent_groups.len()
224 }
225
226 pub fn total_vulns(&self) -> usize {
228 self.introduced_groups.iter().map(|g| g.vuln_count()).sum::<usize>()
229 + self.resolved_groups.iter().map(|g| g.vuln_count()).sum::<usize>()
230 + self.persistent_groups.iter().map(|g| g.vuln_count()).sum::<usize>()
231 }
232
233 pub fn has_any_kev(&self) -> bool {
235 self.introduced_groups.iter().any(|g| g.has_kev)
236 || self.resolved_groups.iter().any(|g| g.has_kev)
237 || self.persistent_groups.iter().any(|g| g.has_kev)
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 fn make_vuln(id: &str, component_id: &str, severity: &str) -> VulnerabilityDetail {
246 VulnerabilityDetail {
247 id: id.to_string(),
248 source: "OSV".to_string(),
249 severity: severity.to_string(),
250 cvss_score: None,
251 component_id: component_id.to_string(),
252 component_canonical_id: None,
253 component_ref: None,
254 component_name: format!("{}-pkg", component_id),
255 version: Some("1.0.0".to_string()),
256 description: None,
257 remediation: None,
258 is_kev: false,
259 cwes: Vec::new(),
260 component_depth: None,
261 published_date: None,
262 kev_due_date: None,
263 days_since_published: None,
264 days_until_due: None,
265 vex_state: None,
266 }
267 }
268
269 #[test]
270 fn test_group_vulnerabilities() {
271 let vulns = vec![
272 make_vuln("CVE-2024-0001", "lodash", "Critical"),
273 make_vuln("CVE-2024-0002", "lodash", "High"),
274 make_vuln("CVE-2024-0003", "lodash", "High"),
275 make_vuln("CVE-2024-0004", "express", "Medium"),
276 ];
277
278 let groups = group_vulnerabilities(&vulns, VulnGroupStatus::Introduced);
279
280 assert_eq!(groups.len(), 2);
281
282 assert_eq!(groups[0].component_id, "lodash");
284 assert_eq!(groups[0].vuln_count(), 3);
285 assert_eq!(groups[0].max_severity, "Critical");
286 assert_eq!(groups[0].severity_counts.get("Critical"), Some(&1));
287 assert_eq!(groups[0].severity_counts.get("High"), Some(&2));
288
289 assert_eq!(groups[1].component_id, "express");
291 assert_eq!(groups[1].vuln_count(), 1);
292 }
293
294 #[test]
295 fn test_grouped_view() {
296 let introduced = vec![
297 make_vuln("CVE-2024-0001", "lodash", "High"),
298 make_vuln("CVE-2024-0002", "lodash", "Medium"),
299 ];
300 let resolved = vec![make_vuln("CVE-2024-0003", "old-dep", "Critical")];
301 let persistent = vec![];
302
303 let view = VulnerabilityGroupedView::from_changes(&introduced, &resolved, &persistent);
304
305 assert_eq!(view.total_groups(), 2);
306 assert_eq!(view.total_vulns(), 3);
307 assert_eq!(view.introduced_groups.len(), 1);
308 assert_eq!(view.resolved_groups.len(), 1);
309 }
310
311 #[test]
312 fn test_summary_line() {
313 let mut group = VulnerabilityGroup::new(
314 "lodash".to_string(),
315 "lodash".to_string(),
316 VulnGroupStatus::Introduced,
317 );
318 group.add_vulnerability(make_vuln("CVE-1", "lodash", "Critical"));
319 group.add_vulnerability(make_vuln("CVE-2", "lodash", "High"));
320 group.add_vulnerability(make_vuln("CVE-3", "lodash", "High"));
321
322 let summary = group.summary_line();
323 assert!(summary.contains("lodash"));
324 assert!(summary.contains("3 CVEs"));
325 assert!(summary.contains("C:1"));
326 assert!(summary.contains("H:2"));
327 }
328}