vulnera_advisor/sources/
epss.rs

1//! FIRST EPSS (Exploit Prediction Scoring System) source.
2//!
3//! This module fetches EPSS scores which predict the probability that a
4//! vulnerability will be exploited in the next 30 days.
5//!
6//! # Data Source
7//!
8//! - API: <https://api.first.org/data/v1/epss>
9//! - Documentation: <https://www.first.org/epss/api>
10//! - License: Free to use
11
12use crate::error::{AdvisoryError, Result};
13use chrono::{DateTime, NaiveDate, Utc};
14use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
15use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::time::Duration;
19use tracing::{debug, info};
20
21/// Base URL for the FIRST EPSS API.
22pub const EPSS_API_URL: &str = "https://api.first.org/data/v1/epss";
23
24/// Request timeout
25const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
26/// Connection timeout
27const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
28
29/// EPSS data source.
30///
31/// Provides exploit probability scores for CVEs. These scores help prioritize
32/// vulnerabilities based on likelihood of exploitation.
33pub struct EpssSource {
34    client: ClientWithMiddleware,
35}
36
37impl EpssSource {
38    /// Create a new EPSS source.
39    pub fn new() -> Self {
40        let raw_client = reqwest::Client::builder()
41            .timeout(REQUEST_TIMEOUT)
42            .connect_timeout(CONNECT_TIMEOUT)
43            .build()
44            .unwrap_or_default();
45
46        // Retry policy: 3 retries with exponential backoff
47        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
48        let client = ClientBuilder::new(raw_client)
49            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
50            .build();
51
52        Self { client }
53    }
54
55    /// Fetch EPSS scores for specific CVE IDs.
56    ///
57    /// # Arguments
58    ///
59    /// * `cve_ids` - List of CVE IDs to look up (e.g., ["CVE-2024-1234", "CVE-2024-5678"])
60    ///
61    /// # Returns
62    ///
63    /// A map of CVE ID to EPSS score data.
64    pub async fn fetch_scores(&self, cve_ids: &[&str]) -> Result<HashMap<String, EpssScore>> {
65        if cve_ids.is_empty() {
66            return Ok(HashMap::new());
67        }
68
69        // API accepts comma-separated CVE IDs
70        let cve_param = cve_ids.join(",");
71        let url = format!("{}?cve={}", EPSS_API_URL, cve_param);
72
73        debug!("Fetching EPSS scores for {} CVEs", cve_ids.len());
74
75        let response = self.client.get(&url).send().await?;
76
77        if !response.status().is_success() {
78            return Err(AdvisoryError::source_fetch(
79                "EPSS",
80                format!("HTTP {}", response.status()),
81            ));
82        }
83
84        let epss_response: EpssResponse = response.json().await?;
85
86        let scores: HashMap<String, EpssScore> = epss_response
87            .data
88            .into_iter()
89            .map(|s| (s.cve.clone(), s))
90            .collect();
91
92        debug!("Retrieved {} EPSS scores", scores.len());
93        Ok(scores)
94    }
95
96    /// Fetch a single CVE's EPSS score.
97    pub async fn fetch_score(&self, cve_id: &str) -> Result<Option<EpssScore>> {
98        let scores = self.fetch_scores(&[cve_id]).await?;
99        Ok(scores.get(cve_id).cloned())
100    }
101
102    /// Fetch all CVEs with EPSS score above a threshold.
103    ///
104    /// # Arguments
105    ///
106    /// * `min_epss` - Minimum EPSS probability (0.0 - 1.0)
107    /// * `limit` - Maximum number of results (default: 100)
108    pub async fn fetch_high_risk(
109        &self,
110        min_epss: f64,
111        limit: Option<u32>,
112    ) -> Result<Vec<EpssScore>> {
113        let limit = limit.unwrap_or(100);
114        let url = format!("{}?epss-gt={}&limit={}", EPSS_API_URL, min_epss, limit);
115
116        info!("Fetching CVEs with EPSS > {}", min_epss);
117
118        let response = self.client.get(&url).send().await?;
119
120        if !response.status().is_success() {
121            return Err(AdvisoryError::source_fetch(
122                "EPSS",
123                format!("HTTP {}", response.status()),
124            ));
125        }
126
127        let epss_response: EpssResponse = response.json().await?;
128        info!("Found {} high-risk CVEs", epss_response.data.len());
129
130        Ok(epss_response.data)
131    }
132
133    /// Fetch CVEs with EPSS percentile above a threshold.
134    ///
135    /// # Arguments
136    ///
137    /// * `min_percentile` - Minimum percentile (0.0 - 1.0, e.g., 0.95 for top 5%)
138    /// * `limit` - Maximum number of results
139    pub async fn fetch_top_percentile(
140        &self,
141        min_percentile: f64,
142        limit: Option<u32>,
143    ) -> Result<Vec<EpssScore>> {
144        let limit = limit.unwrap_or(100);
145        let url = format!(
146            "{}?percentile-gt={}&limit={}",
147            EPSS_API_URL, min_percentile, limit
148        );
149
150        info!(
151            "Fetching CVEs in top {} percentile",
152            (1.0 - min_percentile) * 100.0
153        );
154
155        let response = self.client.get(&url).send().await?;
156
157        if !response.status().is_success() {
158            return Err(AdvisoryError::source_fetch(
159                "EPSS",
160                format!("HTTP {}", response.status()),
161            ));
162        }
163
164        let epss_response: EpssResponse = response.json().await?;
165        Ok(epss_response.data)
166    }
167
168    /// Fetch EPSS scores in batches for a large list of CVEs.
169    ///
170    /// The API can handle many CVEs in a single request, but we batch
171    /// to avoid URL length limits.
172    pub async fn fetch_scores_batch(
173        &self,
174        cve_ids: &[String],
175        batch_size: usize,
176    ) -> Result<HashMap<String, EpssScore>> {
177        let mut all_scores = HashMap::new();
178
179        for chunk in cve_ids.chunks(batch_size) {
180            let refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect();
181            let scores = self.fetch_scores(&refs).await?;
182            all_scores.extend(scores);
183        }
184
185        Ok(all_scores)
186    }
187}
188
189impl Default for EpssSource {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195/// Response from the EPSS API.
196#[derive(Debug, Clone, Deserialize)]
197pub struct EpssResponse {
198    /// Status of the request.
199    pub status: String,
200    /// API version.
201    #[serde(rename = "status-code")]
202    pub status_code: Option<i32>,
203    /// API version string.
204    pub version: Option<String>,
205    /// Total number of CVEs with EPSS scores.
206    pub total: Option<u64>,
207    /// Offset for pagination.
208    pub offset: Option<u64>,
209    /// Limit used in the request.
210    pub limit: Option<u64>,
211    /// The EPSS score data.
212    pub data: Vec<EpssScore>,
213}
214
215/// EPSS score for a single CVE.
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct EpssScore {
218    /// CVE identifier.
219    pub cve: String,
220    /// EPSS probability score (0.0 - 1.0).
221    /// Represents the probability of exploitation in the next 30 days.
222    #[serde(deserialize_with = "deserialize_f64_from_string")]
223    pub epss: f64,
224    /// Percentile ranking (0.0 - 1.0).
225    /// Indicates how this CVE ranks compared to all others.
226    #[serde(deserialize_with = "deserialize_f64_from_string")]
227    pub percentile: f64,
228    /// Date when the score was calculated.
229    #[serde(default)]
230    pub date: Option<String>,
231}
232
233impl EpssScore {
234    /// Check if this CVE is in the top N percentile.
235    pub fn is_top_percentile(&self, threshold: f64) -> bool {
236        self.percentile >= threshold
237    }
238
239    /// Get a risk category based on EPSS score.
240    pub fn risk_category(&self) -> EpssRiskCategory {
241        match self.epss {
242            s if s >= 0.7 => EpssRiskCategory::Critical,
243            s if s >= 0.4 => EpssRiskCategory::High,
244            s if s >= 0.1 => EpssRiskCategory::Medium,
245            _ => EpssRiskCategory::Low,
246        }
247    }
248
249    /// Get the date as a parsed DateTime if available.
250    pub fn date_utc(&self) -> Option<DateTime<Utc>> {
251        self.date.as_ref().and_then(|d| {
252            NaiveDate::parse_from_str(d, "%Y-%m-%d")
253                .ok()
254                .map(|nd| nd.and_hms_opt(0, 0, 0).unwrap().and_utc())
255        })
256    }
257}
258
259/// Risk categories based on EPSS scores.
260#[derive(Debug, Clone, Copy, PartialEq, Eq)]
261pub enum EpssRiskCategory {
262    /// EPSS < 0.1 (low likelihood of exploitation)
263    Low,
264    /// EPSS 0.1 - 0.4
265    Medium,
266    /// EPSS 0.4 - 0.7
267    High,
268    /// EPSS >= 0.7 (very likely to be exploited)
269    Critical,
270}
271
272/// Deserialize f64 from string (EPSS API returns numbers as strings).
273fn deserialize_f64_from_string<'de, D>(deserializer: D) -> std::result::Result<f64, D::Error>
274where
275    D: serde::Deserializer<'de>,
276{
277    let s: String = String::deserialize(deserializer)?;
278    s.parse().map_err(serde::de::Error::custom)
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_epss_risk_category() {
287        let score = EpssScore {
288            cve: "CVE-2024-1234".to_string(),
289            epss: 0.75,
290            percentile: 0.98,
291            date: None,
292        };
293
294        assert_eq!(score.risk_category(), EpssRiskCategory::Critical);
295        assert!(score.is_top_percentile(0.95));
296    }
297
298    #[test]
299    fn test_epss_low_risk() {
300        let score = EpssScore {
301            cve: "CVE-2024-5678".to_string(),
302            epss: 0.05,
303            percentile: 0.3,
304            date: None,
305        };
306
307        assert_eq!(score.risk_category(), EpssRiskCategory::Low);
308        assert!(!score.is_top_percentile(0.95));
309    }
310}