1use super::AdvisorySource;
2use crate::error::Result;
3use crate::models::{Advisory, Reference, ReferenceType};
4use async_trait::async_trait;
5use chrono::{DateTime, NaiveDateTime, Utc};
6use cpe::cpe::Cpe;
7use governor::clock::DefaultClock;
8use governor::middleware::NoOpMiddleware;
9use governor::state::{InMemoryState, NotKeyed};
10use governor::{Quota, RateLimiter};
11use once_cell::sync::Lazy;
12use regex_lite::Regex;
13use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
14use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
15use serde::{Deserialize, Deserializer};
16use std::collections::HashSet;
17use std::num::NonZeroU32;
18use std::sync::Arc;
19use tracing::{debug, info, warn};
20
21fn deserialize_nvd_datetime<'de, D>(deserializer: D) -> std::result::Result<DateTime<Utc>, D::Error>
23where
24 D: Deserializer<'de>,
25{
26 let s = String::deserialize(deserializer)?;
27
28 if let Ok(naive) = NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S%.3f") {
30 return Ok(naive.and_utc());
31 }
32
33 if let Ok(naive) = NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S") {
35 return Ok(naive.and_utc());
36 }
37
38 if let Ok(dt) = DateTime::parse_from_rfc3339(&s) {
40 return Ok(dt.with_timezone(&Utc));
41 }
42
43 Err(serde::de::Error::custom(format!(
44 "Failed to parse NVD datetime: {}",
45 s
46 )))
47}
48
49static GHSA_REGEX: Lazy<Regex> =
50 Lazy::new(|| Regex::new(r"(?i)(GHSA-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4})").unwrap());
51static OSV_REGEX: Lazy<Regex> =
52 Lazy::new(|| Regex::new(r"(?i)osv\.dev/vulnerability/([^/?#]+)").unwrap());
53static CVE_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)(CVE-\d{4}-\d{4,})").unwrap());
54
55pub struct NVDSource {
56 api_key: Option<String>,
57 client: ClientWithMiddleware,
58 limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>,
59 max_results: Option<u32>,
61 api_url: Option<String>,
63}
64
65impl NVDSource {
66 pub fn new(api_key: Option<String>) -> Self {
67 Self::with_max_results(api_key, None)
68 }
69
70 pub fn with_max_results(api_key: Option<String>, max_results: Option<u32>) -> Self {
75 let raw_client = reqwest::Client::builder()
77 .timeout(std::time::Duration::from_secs(60))
78 .connect_timeout(std::time::Duration::from_secs(30))
79 .build()
80 .unwrap_or_default();
81
82 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
84 let client = ClientBuilder::new(raw_client)
85 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
86 .build();
87
88 let (requests, seconds) = if api_key.is_some() { (50, 30) } else { (5, 30) };
90
91 let quota = Quota::with_period(std::time::Duration::from_secs(seconds))
92 .unwrap()
93 .allow_burst(NonZeroU32::new(requests).unwrap());
94
95 let limiter = Arc::new(RateLimiter::direct(quota));
96
97 Self {
98 api_key,
99 client,
100 limiter,
101 max_results,
102 api_url: None,
103 }
104 }
105
106 pub fn with_api_url(mut self, api_url: impl Into<String>) -> Self {
108 self.api_url = Some(api_url.into());
109 self
110 }
111}
112
113#[async_trait]
114impl AdvisorySource for NVDSource {
115 async fn fetch(&self, since: Option<DateTime<Utc>>) -> Result<Vec<Advisory>> {
116 let base_url = self
117 .api_url
118 .as_deref()
119 .unwrap_or("https://services.nvd.nist.gov/rest/json/cves/2.0");
120 let mut advisories = Vec::new();
121 let mut start_index = 0;
122 let results_per_page = 2000; loop {
125 let mut url = format!(
126 "{}?startIndex={}&resultsPerPage={}",
127 base_url, start_index, results_per_page
128 );
129
130 if let Some(since) = since {
131 let now = Utc::now();
133 let duration = now.signed_duration_since(since);
134 let max_days = 120;
135
136 let format_nvd_date = |dt: DateTime<Utc>| -> String {
138 dt.format("%Y-%m-%dT%H:%M:%S%.3f").to_string()
139 };
140
141 if duration.num_days() > max_days {
142 warn!(
144 "NVD sync: Last sync was {} days ago (max: {}). Only fetching last {} days.",
145 duration.num_days(),
146 max_days,
147 max_days
148 );
149 let start = now - chrono::Duration::days(max_days);
150 url.push_str(&format!(
151 "&lastModStartDate={}&lastModEndDate={}",
152 format_nvd_date(start),
153 format_nvd_date(now)
154 ));
155 } else {
156 url.push_str(&format!(
158 "&lastModStartDate={}&lastModEndDate={}",
159 format_nvd_date(since),
160 format_nvd_date(now)
161 ));
162 }
163 }
164 self.limiter.until_ready().await;
166
167 debug!("Fetching NVD data from startIndex={}", start_index);
168
169 let mut request = self.client.get(&url);
170 if let Some(key) = &self.api_key {
171 request = request.header("apiKey", key);
172 }
173
174 let response = request.send().await?;
175 if !response.status().is_success() {
176 let status = response.status();
177 let body = response.text().await.unwrap_or_default();
178 return Err(crate::error::AdvisoryError::source_fetch(
179 "NVD",
180 format!(
181 "HTTP {}: {}",
182 status,
183 body.chars().take(200).collect::<String>()
184 ),
185 ));
186 }
187
188 let nvd_response: NvdResponse = response.json().await?;
189 let total_results = nvd_response.total_results;
190 let count = nvd_response.vulnerabilities.len();
191
192 for item in nvd_response.vulnerabilities {
193 let cve = item.cve;
194
195 let mut affected = Vec::new();
196
197 if let Some(configurations) = cve.configurations {
199 for config in configurations {
200 for node in config.nodes {
201 for cpe_match in node.cpe_match {
202 if cpe_match.vulnerable {
203 if let Ok(cpe_uri) = cpe::uri::Uri::parse(&cpe_match.criteria) {
204 let vendor = cpe_uri.vendor().to_string();
205 let product = cpe_uri.product().to_string();
206 let version = cpe_uri.version().to_string();
207
208 let ecosystem = if vendor == "apache" {
210 "maven"
211 } else if vendor == "npm" {
212 "npm"
213 } else {
214 "generic"
215 };
216
217 let purl = packageurl::PackageUrl::new(ecosystem, &product)
218 .ok()
219 .map(|mut p| {
220 if !version.is_empty() && version != "*" {
221 let _ = p.with_version(version.clone());
222 }
223 if ecosystem == "maven" {
224 let _ = p.with_namespace(vendor.clone());
225 }
226 p.to_string()
227 });
228
229 affected.push(crate::models::Affected {
230 package: crate::models::Package {
231 ecosystem: ecosystem.to_string(),
232 name: product,
233 purl,
234 },
235 ranges: vec![], versions: vec![version],
237 ecosystem_specific: None,
238 database_specific: Some(serde_json::json!({
239 "cpe": cpe_match.criteria
240 })),
241 });
242 }
243 }
244 }
245 }
246 }
247 }
248
249 let references = cve
250 .references
251 .iter()
252 .map(|r| Reference {
253 reference_type: ReferenceType::Web,
254 url: r.url.clone(),
255 })
256 .collect();
257
258 let mut alias_set: HashSet<String> = HashSet::new();
260 for r in &cve.references {
261 if let Some(caps) = GHSA_REGEX.captures(&r.url) {
263 alias_set.insert(caps[1].to_uppercase());
264 }
265
266 if let Some(caps) = OSV_REGEX.captures(&r.url) {
268 let osv_id = caps[1].to_string();
269 if CVE_REGEX.captures(&osv_id).is_none() {
271 alias_set.insert(osv_id);
272 }
273 }
274 }
275
276 let aliases_field = if alias_set.is_empty() {
277 None
278 } else {
279 Some(alias_set.into_iter().collect())
280 };
281
282 advisories.push(Advisory {
283 id: cve.id,
284 summary: None,
285 details: cve.descriptions.first().map(|d| d.value.clone()),
286 affected,
287 references,
288 published: Some(cve.published),
289 modified: Some(cve.last_modified),
290 aliases: aliases_field,
291 database_specific: Some(serde_json::json!({
292 "source": "NVD",
293 "metrics": cve.metrics,
294 })),
295 enrichment: None,
296 });
297 }
298
299 start_index += count as u32;
300 if start_index >= total_results {
301 break;
302 }
303
304 if let Some(max) = self.max_results {
306 if start_index >= max {
307 info!(
308 "Stopping NVD sync at configured limit (fetched {} of {} items)",
309 start_index, total_results
310 );
311 break;
312 }
313 }
314 }
315
316 Ok(advisories)
317 }
318
319 fn name(&self) -> &str {
320 "NVD"
321 }
322}
323
324#[derive(Deserialize)]
326#[serde(rename_all = "camelCase")]
327struct NvdResponse {
328 total_results: u32,
329 vulnerabilities: Vec<NvdItem>,
330}
331
332#[derive(Deserialize)]
333struct NvdItem {
334 cve: Cve,
335}
336
337#[derive(Deserialize)]
338#[serde(rename_all = "camelCase")]
339struct Cve {
340 id: String,
341 #[serde(deserialize_with = "deserialize_nvd_datetime")]
342 published: DateTime<Utc>,
343 #[serde(deserialize_with = "deserialize_nvd_datetime")]
344 last_modified: DateTime<Utc>,
345 descriptions: Vec<Description>,
346 #[serde(default)]
347 references: Vec<NvdReference>,
348 #[serde(default)]
349 metrics: serde_json::Value,
350 #[serde(default)]
351 configurations: Option<Vec<Configuration>>,
352 }
354
355#[derive(Deserialize)]
356struct Configuration {
357 nodes: Vec<Node>,
358}
359
360#[derive(Deserialize)]
361#[serde(rename_all = "camelCase")]
362struct Node {
363 cpe_match: Vec<CpeMatch>,
364 }
366
367#[derive(Deserialize)]
368struct CpeMatch {
369 vulnerable: bool,
370 criteria: String,
371}
372
373#[derive(Deserialize)]
374struct Description {
375 value: String,
376}
377
378#[derive(Deserialize)]
379struct NvdReference {
380 url: String,
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use serde_json::json;
387 use wiremock::matchers::{method, path};
388 use wiremock::{Mock, MockServer, ResponseTemplate};
389
390 #[tokio::test]
391 async fn test_nvd_parses_ghsa_and_osv_aliases() {
392 let mock_server = MockServer::start().await;
393 let source = NVDSource::with_max_results(None, Some(1)).with_api_url(mock_server.uri());
394
395 let response_body = json!({
396 "totalResults": 1,
397 "vulnerabilities": [
398 {
399 "cve": {
400 "id": "CVE-2024-12345",
401 "published": "2024-06-30T12:00:00.000",
402 "lastModified": "2024-06-30T12:00:00.000",
403 "descriptions": [ { "value": "This is a description" } ],
404 "references": [
405 { "url": "https://github.com/advisories/GHSA-1111-2222-3333" },
406 { "url": "https://osv.dev/vulnerability/OSV-2024-1234" }
407 ],
408 "metrics": {},
409 "configurations": []
410 }
411 }
412 ]
413 });
414
415 Mock::given(method("GET"))
416 .and(path("/"))
417 .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
418 .mount(&mock_server)
419 .await;
420
421 let advisories = source.fetch(None).await.unwrap();
422 assert_eq!(advisories.len(), 1);
423 let adv = &advisories[0];
424 assert_eq!(adv.id, "CVE-2024-12345");
425 let aliases = adv.aliases.as_ref().unwrap();
426 assert!(
427 aliases
428 .iter()
429 .any(|a| a.eq_ignore_ascii_case("GHSA-1111-2222-3333"))
430 );
431 assert!(aliases.iter().any(|a| a == "OSV-2024-1234"));
432 }
433
434 #[tokio::test]
435 async fn test_nvd_no_aliases_none() {
436 let mock_server = MockServer::start().await;
437 let source = NVDSource::with_max_results(None, Some(1)).with_api_url(mock_server.uri());
438
439 let response_body = json!({
440 "totalResults": 1,
441 "vulnerabilities": [
442 {
443 "cve": {
444 "id": "CVE-2024-22222",
445 "published": "2024-06-30T12:00:00.000",
446 "lastModified": "2024-06-30T12:00:00.000",
447 "descriptions": [ { "value": "No aliases here" } ],
448 "references": [],
449 "metrics": {},
450 "configurations": []
451 }
452 }
453 ]
454 });
455
456 Mock::given(method("GET"))
457 .and(path("/"))
458 .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
459 .mount(&mock_server)
460 .await;
461
462 let advisories = source.fetch(None).await.unwrap();
463 assert_eq!(advisories.len(), 1);
464 assert!(advisories[0].aliases.is_none());
465 }
466}