1use crate::error::{AdvisoryError, Result};
13use chrono::{DateTime, NaiveDate, Utc};
14use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
15use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::time::Duration;
19use tracing::{debug, info};
20
21pub const EPSS_API_URL: &str = "https://api.first.org/data/v1/epss";
23
24const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
26const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
28
29pub struct EpssSource {
34 client: ClientWithMiddleware,
35}
36
37impl EpssSource {
38 pub fn new() -> Self {
40 let raw_client = reqwest::Client::builder()
41 .timeout(REQUEST_TIMEOUT)
42 .connect_timeout(CONNECT_TIMEOUT)
43 .build()
44 .unwrap_or_default();
45
46 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
48 let client = ClientBuilder::new(raw_client)
49 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
50 .build();
51
52 Self { client }
53 }
54
55 pub async fn fetch_scores(&self, cve_ids: &[&str]) -> Result<HashMap<String, EpssScore>> {
65 if cve_ids.is_empty() {
66 return Ok(HashMap::new());
67 }
68
69 let cve_param = cve_ids.join(",");
71 let url = format!("{}?cve={}", EPSS_API_URL, cve_param);
72
73 debug!("Fetching EPSS scores for {} CVEs", cve_ids.len());
74
75 let response = self.client.get(&url).send().await?;
76
77 if !response.status().is_success() {
78 return Err(AdvisoryError::source_fetch(
79 "EPSS",
80 format!("HTTP {}", response.status()),
81 ));
82 }
83
84 let epss_response: EpssResponse = response.json().await?;
85
86 let scores: HashMap<String, EpssScore> = epss_response
87 .data
88 .into_iter()
89 .map(|s| (s.cve.clone(), s))
90 .collect();
91
92 debug!("Retrieved {} EPSS scores", scores.len());
93 Ok(scores)
94 }
95
96 pub async fn fetch_score(&self, cve_id: &str) -> Result<Option<EpssScore>> {
98 let scores = self.fetch_scores(&[cve_id]).await?;
99 Ok(scores.get(cve_id).cloned())
100 }
101
102 pub async fn fetch_high_risk(
109 &self,
110 min_epss: f64,
111 limit: Option<u32>,
112 ) -> Result<Vec<EpssScore>> {
113 let limit = limit.unwrap_or(100);
114 let url = format!("{}?epss-gt={}&limit={}", EPSS_API_URL, min_epss, limit);
115
116 info!("Fetching CVEs with EPSS > {}", min_epss);
117
118 let response = self.client.get(&url).send().await?;
119
120 if !response.status().is_success() {
121 return Err(AdvisoryError::source_fetch(
122 "EPSS",
123 format!("HTTP {}", response.status()),
124 ));
125 }
126
127 let epss_response: EpssResponse = response.json().await?;
128 info!("Found {} high-risk CVEs", epss_response.data.len());
129
130 Ok(epss_response.data)
131 }
132
133 pub async fn fetch_top_percentile(
140 &self,
141 min_percentile: f64,
142 limit: Option<u32>,
143 ) -> Result<Vec<EpssScore>> {
144 let limit = limit.unwrap_or(100);
145 let url = format!(
146 "{}?percentile-gt={}&limit={}",
147 EPSS_API_URL, min_percentile, limit
148 );
149
150 info!(
151 "Fetching CVEs in top {} percentile",
152 (1.0 - min_percentile) * 100.0
153 );
154
155 let response = self.client.get(&url).send().await?;
156
157 if !response.status().is_success() {
158 return Err(AdvisoryError::source_fetch(
159 "EPSS",
160 format!("HTTP {}", response.status()),
161 ));
162 }
163
164 let epss_response: EpssResponse = response.json().await?;
165 Ok(epss_response.data)
166 }
167
168 pub async fn fetch_scores_batch(
173 &self,
174 cve_ids: &[String],
175 batch_size: usize,
176 ) -> Result<HashMap<String, EpssScore>> {
177 let mut all_scores = HashMap::new();
178
179 for chunk in cve_ids.chunks(batch_size) {
180 let refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect();
181 let scores = self.fetch_scores(&refs).await?;
182 all_scores.extend(scores);
183 }
184
185 Ok(all_scores)
186 }
187}
188
189impl Default for EpssSource {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195#[derive(Debug, Clone, Deserialize)]
197pub struct EpssResponse {
198 pub status: String,
200 #[serde(rename = "status-code")]
202 pub status_code: Option<i32>,
203 pub version: Option<String>,
205 pub total: Option<u64>,
207 pub offset: Option<u64>,
209 pub limit: Option<u64>,
211 pub data: Vec<EpssScore>,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct EpssScore {
218 pub cve: String,
220 #[serde(deserialize_with = "deserialize_f64_from_string")]
223 pub epss: f64,
224 #[serde(deserialize_with = "deserialize_f64_from_string")]
227 pub percentile: f64,
228 #[serde(default)]
230 pub date: Option<String>,
231}
232
233impl EpssScore {
234 pub fn is_top_percentile(&self, threshold: f64) -> bool {
236 self.percentile >= threshold
237 }
238
239 pub fn risk_category(&self) -> EpssRiskCategory {
241 match self.epss {
242 s if s >= 0.7 => EpssRiskCategory::Critical,
243 s if s >= 0.4 => EpssRiskCategory::High,
244 s if s >= 0.1 => EpssRiskCategory::Medium,
245 _ => EpssRiskCategory::Low,
246 }
247 }
248
249 pub fn date_utc(&self) -> Option<DateTime<Utc>> {
251 self.date.as_ref().and_then(|d| {
252 NaiveDate::parse_from_str(d, "%Y-%m-%d")
253 .ok()
254 .map(|nd| nd.and_hms_opt(0, 0, 0).unwrap().and_utc())
255 })
256 }
257}
258
259#[derive(Debug, Clone, Copy, PartialEq, Eq)]
261pub enum EpssRiskCategory {
262 Low,
264 Medium,
266 High,
268 Critical,
270}
271
272fn deserialize_f64_from_string<'de, D>(deserializer: D) -> std::result::Result<f64, D::Error>
274where
275 D: serde::Deserializer<'de>,
276{
277 let s: String = String::deserialize(deserializer)?;
278 s.parse().map_err(serde::de::Error::custom)
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_epss_risk_category() {
287 let score = EpssScore {
288 cve: "CVE-2024-1234".to_string(),
289 epss: 0.75,
290 percentile: 0.98,
291 date: None,
292 };
293
294 assert_eq!(score.risk_category(), EpssRiskCategory::Critical);
295 assert!(score.is_top_percentile(0.95));
296 }
297
298 #[test]
299 fn test_epss_low_risk() {
300 let score = EpssScore {
301 cve: "CVE-2024-5678".to_string(),
302 epss: 0.05,
303 percentile: 0.3,
304 date: None,
305 };
306
307 assert_eq!(score.risk_category(), EpssRiskCategory::Low);
308 assert!(!score.is_top_percentile(0.95));
309 }
310}