vulnera_advisor/sources/
nvd.rs

1use super::AdvisorySource;
2use crate::error::Result;
3use crate::models::{Advisory, Reference, ReferenceType};
4use async_trait::async_trait;
5use chrono::{DateTime, NaiveDateTime, Utc};
6use cpe::cpe::Cpe;
7use governor::clock::DefaultClock;
8use governor::middleware::NoOpMiddleware;
9use governor::state::{InMemoryState, NotKeyed};
10use governor::{Quota, RateLimiter};
11use once_cell::sync::Lazy;
12use regex_lite::Regex;
13use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
14use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
15use serde::{Deserialize, Deserializer};
16use std::collections::HashSet;
17use std::num::NonZeroU32;
18use std::sync::Arc;
19use tracing::{debug, info, warn};
20
21/// Custom deserializer for NVD datetime format (e.g., "2024-01-15T10:30:00.000")
22fn deserialize_nvd_datetime<'de, D>(deserializer: D) -> std::result::Result<DateTime<Utc>, D::Error>
23where
24    D: Deserializer<'de>,
25{
26    let s = String::deserialize(deserializer)?;
27
28    // O1
29    if let Ok(naive) = NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S%.3f") {
30        return Ok(naive.and_utc());
31    }
32
33    // O2
34    if let Ok(naive) = NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S") {
35        return Ok(naive.and_utc());
36    }
37
38    // O3
39    if let Ok(dt) = DateTime::parse_from_rfc3339(&s) {
40        return Ok(dt.with_timezone(&Utc));
41    }
42
43    Err(serde::de::Error::custom(format!(
44        "Failed to parse NVD datetime: {}",
45        s
46    )))
47}
48
49static GHSA_REGEX: Lazy<Regex> =
50    Lazy::new(|| Regex::new(r"(?i)(GHSA-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4})").unwrap());
51static OSV_REGEX: Lazy<Regex> =
52    Lazy::new(|| Regex::new(r"(?i)osv\.dev/vulnerability/([^/?#]+)").unwrap());
53static CVE_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)(CVE-\d{4}-\d{4,})").unwrap());
54
55pub struct NVDSource {
56    api_key: Option<String>,
57    client: ClientWithMiddleware,
58    limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>,
59    /// Maximum number of CVEs to fetch (None = unlimited)
60    max_results: Option<u32>,
61    /// Optional API URL (useful for tests / mocks)
62    api_url: Option<String>,
63}
64
65impl NVDSource {
66    pub fn new(api_key: Option<String>) -> Self {
67        Self::with_max_results(api_key, None)
68    }
69
70    /// Create a new NVD source with a maximum result limit.
71    ///
72    /// Use `None` for unlimited results (will fetch all ~320k CVEs on full sync).
73    /// Use `Some(n)` to limit to n results (useful for testing).
74    pub fn with_max_results(api_key: Option<String>, max_results: Option<u32>) -> Self {
75        // Build raw client with timeout
76        let raw_client = reqwest::Client::builder()
77            .timeout(std::time::Duration::from_secs(60))
78            .connect_timeout(std::time::Duration::from_secs(30))
79            .build()
80            .unwrap_or_default();
81
82        // Retry policy: 3 retries with exponential backoff
83        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
84        let client = ClientBuilder::new(raw_client)
85            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
86            .build();
87
88        // Rate limiter: 50 req / 30 sec (with key) or 5 req / 30 sec (without)
89        let (requests, seconds) = if api_key.is_some() { (50, 30) } else { (5, 30) };
90
91        let quota = Quota::with_period(std::time::Duration::from_secs(seconds))
92            .unwrap()
93            .allow_burst(NonZeroU32::new(requests).unwrap());
94
95        let limiter = Arc::new(RateLimiter::direct(quota));
96
97        Self {
98            api_key,
99            client,
100            limiter,
101            max_results,
102            api_url: None,
103        }
104    }
105
106    /// Override the API base URL (useful for mock servers in tests)
107    pub fn with_api_url(mut self, api_url: impl Into<String>) -> Self {
108        self.api_url = Some(api_url.into());
109        self
110    }
111}
112
113#[async_trait]
114impl AdvisorySource for NVDSource {
115    async fn fetch(&self, since: Option<DateTime<Utc>>) -> Result<Vec<Advisory>> {
116        let base_url = self
117            .api_url
118            .as_deref()
119            .unwrap_or("https://services.nvd.nist.gov/rest/json/cves/2.0");
120        let mut advisories = Vec::new();
121        let mut start_index = 0;
122        let results_per_page = 2000; // Max allowed by NVD
123
124        loop {
125            let mut url = format!(
126                "{}?startIndex={}&resultsPerPage={}",
127                base_url, start_index, results_per_page
128            );
129
130            if let Some(since) = since {
131                // NVD has a 120-day maximum range restriction
132                let now = Utc::now();
133                let duration = now.signed_duration_since(since);
134                let max_days = 120;
135
136                // NVD requires ISO 8601 format: YYYY-MM-DDTHH:MM:SS.sss
137                let format_nvd_date = |dt: DateTime<Utc>| -> String {
138                    dt.format("%Y-%m-%dT%H:%M:%S%.3f").to_string()
139                };
140
141                if duration.num_days() > max_days {
142                    // If range exceeds 120 days, we need to chunk
143                    warn!(
144                        "NVD sync: Last sync was {} days ago (max: {}). Only fetching last {} days.",
145                        duration.num_days(),
146                        max_days,
147                        max_days
148                    );
149                    let start = now - chrono::Duration::days(max_days);
150                    url.push_str(&format!(
151                        "&lastModStartDate={}&lastModEndDate={}",
152                        format_nvd_date(start),
153                        format_nvd_date(now)
154                    ));
155                } else {
156                    // Normal case: range is within limit
157                    url.push_str(&format!(
158                        "&lastModStartDate={}&lastModEndDate={}",
159                        format_nvd_date(since),
160                        format_nvd_date(now)
161                    ));
162                }
163            }
164            // Wait for rate limiter
165            self.limiter.until_ready().await;
166
167            debug!("Fetching NVD data from startIndex={}", start_index);
168
169            let mut request = self.client.get(&url);
170            if let Some(key) = &self.api_key {
171                request = request.header("apiKey", key);
172            }
173
174            let response = request.send().await?;
175            if !response.status().is_success() {
176                let status = response.status();
177                let body = response.text().await.unwrap_or_default();
178                return Err(crate::error::AdvisoryError::source_fetch(
179                    "NVD",
180                    format!(
181                        "HTTP {}: {}",
182                        status,
183                        body.chars().take(200).collect::<String>()
184                    ),
185                ));
186            }
187
188            let nvd_response: NvdResponse = response.json().await?;
189            let total_results = nvd_response.total_results;
190            let count = nvd_response.vulnerabilities.len();
191
192            for item in nvd_response.vulnerabilities {
193                let cve = item.cve;
194
195                let mut affected = Vec::new();
196
197                // Parse configurations to find CPEs
198                if let Some(configurations) = cve.configurations {
199                    for config in configurations {
200                        for node in config.nodes {
201                            for cpe_match in node.cpe_match {
202                                if cpe_match.vulnerable {
203                                    if let Ok(cpe_uri) = cpe::uri::Uri::parse(&cpe_match.criteria) {
204                                        let vendor = cpe_uri.vendor().to_string();
205                                        let product = cpe_uri.product().to_string();
206                                        let version = cpe_uri.version().to_string();
207
208                                        // Very basic heuristic
209                                        let ecosystem = if vendor == "apache" {
210                                            "maven"
211                                        } else if vendor == "npm" {
212                                            "npm"
213                                        } else {
214                                            "generic"
215                                        };
216
217                                        let purl = packageurl::PackageUrl::new(ecosystem, &product)
218                                            .ok()
219                                            .map(|mut p| {
220                                                if !version.is_empty() && version != "*" {
221                                                    let _ = p.with_version(version.clone());
222                                                }
223                                                if ecosystem == "maven" {
224                                                    let _ = p.with_namespace(vendor.clone());
225                                                }
226                                                p.to_string()
227                                            });
228
229                                        affected.push(crate::models::Affected {
230                                            package: crate::models::Package {
231                                                ecosystem: ecosystem.to_string(),
232                                                name: product,
233                                                purl,
234                                            },
235                                            ranges: vec![], // NVD ranges are complex, skipping for now
236                                            versions: vec![version],
237                                            ecosystem_specific: None,
238                                            database_specific: Some(serde_json::json!({
239                                                "cpe": cpe_match.criteria
240                                            })),
241                                        });
242                                    }
243                                }
244                            }
245                        }
246                    }
247                }
248
249                let references = cve
250                    .references
251                    .iter()
252                    .map(|r| Reference {
253                        reference_type: ReferenceType::Web,
254                        url: r.url.clone(),
255                    })
256                    .collect();
257
258                // Build alias set from references (e.g., GHSA / OSV IDs) and dedupe
259                let mut alias_set: HashSet<String> = HashSet::new();
260                for r in &cve.references {
261                    // GHSA: https://github.com/advisories/GHSA-xxxx-xxxx-xxxx
262                    if let Some(caps) = GHSA_REGEX.captures(&r.url) {
263                        alias_set.insert(caps[1].to_uppercase());
264                    }
265
266                    // OSV: https://osv.dev/vulnerability/<id>
267                    if let Some(caps) = OSV_REGEX.captures(&r.url) {
268                        let osv_id = caps[1].to_string();
269                        // If the OSV id looks like a CVE, don't add it here (CVE already present)
270                        if CVE_REGEX.captures(&osv_id).is_none() {
271                            alias_set.insert(osv_id);
272                        }
273                    }
274                }
275
276                let aliases_field = if alias_set.is_empty() {
277                    None
278                } else {
279                    Some(alias_set.into_iter().collect())
280                };
281
282                advisories.push(Advisory {
283                    id: cve.id,
284                    summary: None,
285                    details: cve.descriptions.first().map(|d| d.value.clone()),
286                    affected,
287                    references,
288                    published: Some(cve.published),
289                    modified: Some(cve.last_modified),
290                    aliases: aliases_field,
291                    database_specific: Some(serde_json::json!({
292                        "source": "NVD",
293                        "metrics": cve.metrics,
294                    })),
295                    enrichment: None,
296                });
297            }
298
299            start_index += count as u32;
300            if start_index >= total_results {
301                break;
302            }
303
304            // Optional limit on results (useful for testing or incremental loading)
305            if let Some(max) = self.max_results {
306                if start_index >= max {
307                    info!(
308                        "Stopping NVD sync at configured limit (fetched {} of {} items)",
309                        start_index, total_results
310                    );
311                    break;
312                }
313            }
314        }
315
316        Ok(advisories)
317    }
318
319    fn name(&self) -> &str {
320        "NVD"
321    }
322}
323
324// Minimal NVD Structs
325#[derive(Deserialize)]
326#[serde(rename_all = "camelCase")]
327struct NvdResponse {
328    total_results: u32,
329    vulnerabilities: Vec<NvdItem>,
330}
331
332#[derive(Deserialize)]
333struct NvdItem {
334    cve: Cve,
335}
336
337#[derive(Deserialize)]
338#[serde(rename_all = "camelCase")]
339struct Cve {
340    id: String,
341    #[serde(deserialize_with = "deserialize_nvd_datetime")]
342    published: DateTime<Utc>,
343    #[serde(deserialize_with = "deserialize_nvd_datetime")]
344    last_modified: DateTime<Utc>,
345    descriptions: Vec<Description>,
346    #[serde(default)]
347    references: Vec<NvdReference>,
348    #[serde(default)]
349    metrics: serde_json::Value,
350    #[serde(default)]
351    configurations: Option<Vec<Configuration>>,
352    // Ignored fields: cveTags, sourceIdentifier, vulnStatus, weaknesses
353}
354
355#[derive(Deserialize)]
356struct Configuration {
357    nodes: Vec<Node>,
358}
359
360#[derive(Deserialize)]
361#[serde(rename_all = "camelCase")]
362struct Node {
363    cpe_match: Vec<CpeMatch>,
364    // Ignored: negate, operator
365}
366
367#[derive(Deserialize)]
368struct CpeMatch {
369    vulnerable: bool,
370    criteria: String,
371}
372
373#[derive(Deserialize)]
374struct Description {
375    value: String,
376}
377
378#[derive(Deserialize)]
379struct NvdReference {
380    url: String,
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use serde_json::json;
387    use wiremock::matchers::{method, path};
388    use wiremock::{Mock, MockServer, ResponseTemplate};
389
390    #[tokio::test]
391    async fn test_nvd_parses_ghsa_and_osv_aliases() {
392        let mock_server = MockServer::start().await;
393        let source = NVDSource::with_max_results(None, Some(1)).with_api_url(mock_server.uri());
394
395        let response_body = json!({
396            "totalResults": 1,
397            "vulnerabilities": [
398                {
399                    "cve": {
400                        "id": "CVE-2024-12345",
401                        "published": "2024-06-30T12:00:00.000",
402                        "lastModified": "2024-06-30T12:00:00.000",
403                        "descriptions": [ { "value": "This is a description" } ],
404                        "references": [
405                            { "url": "https://github.com/advisories/GHSA-1111-2222-3333" },
406                            { "url": "https://osv.dev/vulnerability/OSV-2024-1234" }
407                        ],
408                        "metrics": {},
409                        "configurations": []
410                    }
411                }
412            ]
413        });
414
415        Mock::given(method("GET"))
416            .and(path("/"))
417            .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
418            .mount(&mock_server)
419            .await;
420
421        let advisories = source.fetch(None).await.unwrap();
422        assert_eq!(advisories.len(), 1);
423        let adv = &advisories[0];
424        assert_eq!(adv.id, "CVE-2024-12345");
425        let aliases = adv.aliases.as_ref().unwrap();
426        assert!(
427            aliases
428                .iter()
429                .any(|a| a.eq_ignore_ascii_case("GHSA-1111-2222-3333"))
430        );
431        assert!(aliases.iter().any(|a| a == "OSV-2024-1234"));
432    }
433
434    #[tokio::test]
435    async fn test_nvd_no_aliases_none() {
436        let mock_server = MockServer::start().await;
437        let source = NVDSource::with_max_results(None, Some(1)).with_api_url(mock_server.uri());
438
439        let response_body = json!({
440            "totalResults": 1,
441            "vulnerabilities": [
442                {
443                    "cve": {
444                        "id": "CVE-2024-22222",
445                        "published": "2024-06-30T12:00:00.000",
446                        "lastModified": "2024-06-30T12:00:00.000",
447                        "descriptions": [ { "value": "No aliases here" } ],
448                        "references": [],
449                        "metrics": {},
450                        "configurations": []
451                    }
452                }
453            ]
454        });
455
456        Mock::given(method("GET"))
457            .and(path("/"))
458            .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
459            .mount(&mock_server)
460            .await;
461
462        let advisories = source.fetch(None).await.unwrap();
463        assert_eq!(advisories.len(), 1);
464        assert!(advisories[0].aliases.is_none());
465    }
466}