1use crate::config::OssIndexConfig;
38use crate::error::AdvisoryError;
39use crate::models::{
40 Advisory, Affected, Event, Package, Range, RangeType, Reference, ReferenceType, Severity,
41};
42use crate::purl::Purl;
43use anyhow::Result;
44use reqwest::Client;
45use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
46use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
47use serde::{Deserialize, Serialize};
48use std::collections::HashSet;
49use std::env;
50use std::sync::Arc;
51use std::time::Duration;
52use tokio::sync::Semaphore;
53use tracing::{debug, warn};
54
55const MAX_BATCH_SIZE: usize = 128;
57
58const DEFAULT_CONCURRENCY: usize = 4;
60
61const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
63
64const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
66
67const API_BASE_URL: &str = "https://ossindex.sonatype.org/api/v3";
69
70#[derive(Debug, Serialize)]
72struct ComponentReportRequest {
73 coordinates: Vec<String>,
74}
75
76#[derive(Debug, Deserialize)]
78pub struct ComponentReport {
79 pub coordinates: String,
80 #[serde(default)]
81 pub description: Option<String>,
82 #[serde(default)]
83 pub reference: Option<String>,
84 #[serde(default)]
85 pub vulnerabilities: Vec<OssVulnerability>,
86}
87
88#[derive(Debug, Deserialize, Clone)]
90pub struct OssVulnerability {
91 pub id: String,
92 #[serde(rename = "displayName")]
93 pub display_name: Option<String>,
94 pub title: String,
95 pub description: String,
96 #[serde(rename = "cvssScore")]
97 pub cvss_score: Option<f64>,
98 #[serde(rename = "cvssVector")]
99 pub cvss_vector: Option<String>,
100 #[serde(default)]
101 pub cwe: Option<String>,
102 #[serde(default)]
103 pub cve: Option<String>,
104 pub reference: String,
105 #[serde(rename = "versionRanges")]
106 pub version_ranges: Option<Vec<String>>,
107 #[serde(rename = "externalReferences")]
108 pub external_references: Option<Vec<String>>,
109}
110
111pub struct OssIndexSource {
116 client: ClientWithMiddleware,
117 config: OssIndexConfig,
118 semaphore: Arc<Semaphore>,
119}
120
121impl OssIndexSource {
122 pub fn new(config: Option<OssIndexConfig>) -> Result<Self> {
126 let config = config.unwrap_or_else(Self::config_from_env);
127
128 let raw_client = Client::builder()
129 .timeout(REQUEST_TIMEOUT)
130 .connect_timeout(CONNECT_TIMEOUT)
131 .build()
132 .unwrap_or_default();
133
134 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
135 let client = ClientBuilder::new(raw_client)
136 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
137 .build();
138
139 Ok(Self {
140 client,
141 semaphore: Arc::new(Semaphore::new(DEFAULT_CONCURRENCY)),
142 config,
143 })
144 }
145
146 fn config_from_env() -> OssIndexConfig {
148 OssIndexConfig {
149 user: env::var("OSSINDEX_USER").ok(),
150 token: env::var("OSSINDEX_TOKEN").ok(),
151 batch_size: 128,
152 }
153 }
154
155 pub fn with_concurrency(config: Option<OssIndexConfig>, concurrency: usize) -> Result<Self> {
157 let mut source = Self::new(config)?;
158 source.semaphore = Arc::new(Semaphore::new(concurrency));
159 Ok(source)
160 }
161
162 pub async fn query_advisories(&self, purls: &[String]) -> Result<Vec<Advisory>> {
177 let reports = self.query_batch(purls).await?;
178 Ok(self.convert_reports_to_advisories(&reports))
179 }
180
181 pub async fn query_components(&self, purls: &[String]) -> Result<Vec<ComponentReport>> {
185 self.query_batch(purls).await
186 }
187
188 async fn query_batch(&self, purls: &[String]) -> Result<Vec<ComponentReport>> {
190 if purls.is_empty() {
191 return Ok(Vec::new());
192 }
193
194 let chunks: Vec<_> = purls.chunks(MAX_BATCH_SIZE).collect();
195 let mut handles = Vec::with_capacity(chunks.len());
196
197 for chunk in chunks {
198 let chunk_vec: Vec<String> = chunk.to_vec();
199 let client = self.client.clone();
200 let config = self.config.clone();
201 let semaphore = self.semaphore.clone();
202
203 handles.push(tokio::spawn(async move {
204 let _permit =
205 semaphore
206 .acquire()
207 .await
208 .map_err(|e| AdvisoryError::SourceFetch {
209 source_name: "ossindex".to_string(),
210 message: format!("Semaphore error: {}", e),
211 })?;
212
213 Self::query_chunk(&client, &config, &chunk_vec).await
214 }));
215 }
216
217 let mut all_reports = Vec::new();
218 for handle in handles {
219 match handle.await {
220 Ok(Ok(reports)) => all_reports.extend(reports),
221 Ok(Err(e)) => {
222 warn!("OSS Index batch query failed: {}", e);
223 return Err(e);
224 }
225 Err(e) => {
226 warn!("OSS Index task panicked: {}", e);
227 return Err(AdvisoryError::SourceFetch {
228 source_name: "ossindex".to_string(),
229 message: format!("Task panicked: {}", e),
230 }
231 .into());
232 }
233 }
234 }
235
236 Ok(all_reports)
237 }
238
239 async fn query_chunk(
241 client: &ClientWithMiddleware,
242 config: &OssIndexConfig,
243 purls: &[String],
244 ) -> Result<Vec<ComponentReport>> {
245 let url = format!("{}/component-report", API_BASE_URL);
246
247 let request = ComponentReportRequest {
248 coordinates: purls.to_vec(),
249 };
250
251 let mut req_builder = client
252 .post(&url)
253 .header("Content-Type", "application/json")
254 .header("Accept", "application/json");
255
256 if let (Some(user), Some(token)) = (&config.user, &config.token) {
258 req_builder = req_builder.basic_auth(user, Some(token));
259 }
260
261 let response = req_builder
262 .body(serde_json::to_string(&request)?)
263 .send()
264 .await
265 .map_err(|e| AdvisoryError::SourceFetch {
266 source_name: "ossindex".to_string(),
267 message: format!("Request failed: {}", e),
268 })?;
269
270 let status = response.status();
271
272 if status == reqwest::StatusCode::UNAUTHORIZED {
274 return Err(AdvisoryError::SourceFetch {
275 source_name: "ossindex".to_string(),
276 message: "Authentication required. Set OSSINDEX_USER and OSSINDEX_TOKEN environment variables.".to_string(),
277 }.into());
278 }
279
280 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
281 return Err(AdvisoryError::SourceFetch {
282 source_name: "ossindex".to_string(),
283 message: "Rate limited by OSS Index. Please retry later.".to_string(),
284 }
285 .into());
286 }
287
288 if !status.is_success() {
289 let body = response.text().await.unwrap_or_default();
290 return Err(AdvisoryError::SourceFetch {
291 source_name: "ossindex".to_string(),
292 message: format!("HTTP {}: {}", status, body),
293 }
294 .into());
295 }
296
297 let reports: Vec<ComponentReport> =
298 response
299 .json()
300 .await
301 .map_err(|e| AdvisoryError::SourceFetch {
302 source_name: "ossindex".to_string(),
303 message: format!("Failed to parse response: {}", e),
304 })?;
305
306 debug!("OSS Index returned {} reports", reports.len());
307 Ok(reports)
308 }
309
310 fn convert_reports_to_advisories(&self, reports: &[ComponentReport]) -> Vec<Advisory> {
312 let mut advisories = Vec::new();
313 let mut seen_ids: HashSet<String> = HashSet::new();
314
315 for report in reports {
316 for vuln in &report.vulnerabilities {
317 let advisory_id = self.generate_advisory_id(vuln);
319
320 if seen_ids.contains(&advisory_id) {
322 if let Some(advisory) = advisories
324 .iter_mut()
325 .find(|a: &&mut Advisory| a.id == advisory_id)
326 {
327 if let Some(affected) = self.extract_affected(&report.coordinates, vuln) {
328 advisory.affected.push(affected);
329 }
330 }
331 continue;
332 }
333
334 seen_ids.insert(advisory_id.clone());
335
336 let advisory = self.convert_vulnerability(vuln, &report.coordinates);
337 advisories.push(advisory);
338 }
339 }
340
341 advisories
342 }
343
344 fn generate_advisory_id(&self, vuln: &OssVulnerability) -> String {
346 if let Some(ref cve) = vuln.cve {
348 if !cve.is_empty() {
349 return cve.clone();
350 }
351 }
352
353 if let Some(ref name) = vuln.display_name {
355 if name.starts_with("CVE-") {
356 return name.clone();
357 }
358 }
359
360 if let Some(cve) = Self::extract_cve_from_url(&vuln.reference) {
362 return cve;
363 }
364
365 vuln.id.clone()
367 }
368
369 fn extract_cve_from_url(url: &str) -> Option<String> {
371 let parts: Vec<&str> = url.split('/').collect();
373 parts
374 .last()
375 .filter(|id| id.starts_with("CVE-"))
376 .map(|s| s.to_string())
377 }
378
379 fn convert_vulnerability(&self, vuln: &OssVulnerability, coordinates: &str) -> Advisory {
381 let mut affected = Vec::new();
382
383 if let Some(aff) = self.extract_affected(coordinates, vuln) {
384 affected.push(aff);
385 }
386
387 let mut aliases = Vec::new();
389 if let Some(ref cve) = vuln.cve {
390 if !cve.is_empty() && !cve.starts_with("CVE-") {
391 aliases.push(format!("CVE-{}", cve));
392 } else if !cve.is_empty() {
393 aliases.push(cve.clone());
394 }
395 }
396
397 let advisory_id = self.generate_advisory_id(vuln);
399 if advisory_id.starts_with("CVE-") && !vuln.id.starts_with("CVE-") {
400 aliases.push(vuln.id.clone());
401 }
402
403 let mut references = vec![Reference {
405 reference_type: ReferenceType::Advisory,
406 url: vuln.reference.clone(),
407 }];
408 if let Some(ref ext_refs) = vuln.external_references {
409 for url in ext_refs {
410 references.push(Reference {
411 reference_type: ReferenceType::Web,
412 url: url.clone(),
413 });
414 }
415 }
416
417 let mut db_specific = serde_json::Map::new();
419 if let Some(score) = vuln.cvss_score {
420 db_specific.insert("cvss_score".to_string(), serde_json::json!(score));
421 db_specific.insert(
422 "severity".to_string(),
423 serde_json::json!(Self::cvss_to_severity(score)),
424 );
425 }
426 if let Some(ref vector) = vuln.cvss_vector {
427 db_specific.insert("cvss_vector".to_string(), serde_json::json!(vector));
428 }
429 if let Some(ref cwe) = vuln.cwe {
430 db_specific.insert("cwe_ids".to_string(), serde_json::json!([cwe]));
431 }
432 db_specific.insert("source".to_string(), serde_json::json!("ossindex"));
433
434 Advisory {
435 id: advisory_id,
436 summary: Some(vuln.title.clone()),
437 details: Some(vuln.description.clone()),
438 affected,
439 references,
440 published: None,
441 modified: None,
442 aliases: if aliases.is_empty() {
443 None
444 } else {
445 Some(aliases)
446 },
447 database_specific: Some(serde_json::Value::Object(db_specific)),
448 enrichment: None,
449 }
450 }
451
452 fn extract_affected(&self, coordinates: &str, vuln: &OssVulnerability) -> Option<Affected> {
454 let purl = Purl::parse(coordinates).ok()?;
455
456 let ranges = vuln
457 .version_ranges
458 .as_ref()
459 .map(|ranges| {
460 ranges
461 .iter()
462 .filter_map(|r| Self::parse_version_range(r))
463 .collect()
464 })
465 .unwrap_or_default();
466
467 Some(Affected {
468 package: Package {
469 ecosystem: purl.ecosystem(),
470 name: purl.name.clone(),
471 purl: Some(coordinates.to_string()),
472 },
473 ranges,
474 versions: Vec::new(),
475 ecosystem_specific: None,
476 database_specific: None,
477 })
478 }
479
480 fn parse_version_range(range: &str) -> Option<Range> {
488 let range = range.trim();
489 if range.is_empty() {
490 return None;
491 }
492
493 if !range.contains(',') && !range.starts_with('[') && !range.starts_with('(') {
495 return Some(Range {
496 range_type: RangeType::Semver,
497 events: vec![Event::LastAffected(range.to_string())],
498 repo: None,
499 });
500 }
501
502 let start_inclusive = range.starts_with('[');
504 let end_inclusive = range.ends_with(']');
505
506 let inner = range
508 .trim_start_matches(['[', '('])
509 .trim_end_matches([']', ')']);
510
511 let parts: Vec<&str> = inner.split(',').collect();
512 if parts.len() != 2 {
513 return None;
514 }
515
516 let start = parts[0].trim();
517 let end = parts[1].trim();
518
519 let mut events = Vec::new();
520
521 if !start.is_empty() {
523 if start_inclusive {
524 events.push(Event::Introduced(start.to_string()));
525 } else {
526 events.push(Event::Introduced(start.to_string()));
529 }
530 } else {
531 events.push(Event::Introduced("0".to_string()));
533 }
534
535 if !end.is_empty() {
537 if end_inclusive {
538 events.push(Event::LastAffected(end.to_string()));
541 } else {
542 events.push(Event::Fixed(end.to_string()));
544 }
545 }
546
547 Some(Range {
548 range_type: RangeType::Semver,
549 events,
550 repo: None,
551 })
552 }
553
554 fn cvss_to_severity(score: f64) -> &'static str {
556 match score {
557 s if s >= 9.0 => "CRITICAL",
558 s if s >= 7.0 => "HIGH",
559 s if s >= 4.0 => "MEDIUM",
560 s if s > 0.0 => "LOW",
561 _ => "NONE",
562 }
563 }
564
565 pub fn score_to_severity(score: f64) -> Severity {
567 Severity::from_cvss_score(score)
568 }
569
570 pub fn name(&self) -> &'static str {
572 "ossindex"
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579
580 #[test]
581 fn test_parse_version_range_standard() {
582 let range = OssIndexSource::parse_version_range("[1.0.0,2.0.0)");
583 assert!(range.is_some());
584 let range = range.unwrap();
585 assert_eq!(range.range_type, RangeType::Semver);
586 assert_eq!(range.events.len(), 2);
587 assert!(matches!(&range.events[0], Event::Introduced(v) if v == "1.0.0"));
588 assert!(matches!(&range.events[1], Event::Fixed(v) if v == "2.0.0"));
589 }
590
591 #[test]
592 fn test_parse_version_range_inclusive_end() {
593 let range = OssIndexSource::parse_version_range("[1.0.0,2.0.0]");
594 assert!(range.is_some());
595 let range = range.unwrap();
596 assert_eq!(range.events.len(), 2);
597 assert!(matches!(&range.events[0], Event::Introduced(v) if v == "1.0.0"));
598 assert!(matches!(&range.events[1], Event::LastAffected(v) if v == "2.0.0"));
599 }
600
601 #[test]
602 fn test_parse_version_range_open_start() {
603 let range = OssIndexSource::parse_version_range("(,1.0.0)");
604 assert!(range.is_some());
605 let range = range.unwrap();
606 assert_eq!(range.events.len(), 2);
607 assert!(matches!(&range.events[0], Event::Introduced(v) if v == "0"));
608 assert!(matches!(&range.events[1], Event::Fixed(v) if v == "1.0.0"));
609 }
610
611 #[test]
612 fn test_parse_version_range_open_end() {
613 let range = OssIndexSource::parse_version_range("[1.0.0,)");
614 assert!(range.is_some());
615 let range = range.unwrap();
616 assert_eq!(range.events.len(), 1);
617 assert!(matches!(&range.events[0], Event::Introduced(v) if v == "1.0.0"));
618 }
619
620 #[test]
621 fn test_parse_version_range_exact() {
622 let range = OssIndexSource::parse_version_range("1.0.0");
623 assert!(range.is_some());
624 let range = range.unwrap();
625 assert_eq!(range.events.len(), 1);
626 assert!(matches!(&range.events[0], Event::LastAffected(v) if v == "1.0.0"));
627 }
628
629 #[test]
630 fn test_cvss_to_severity() {
631 assert_eq!(OssIndexSource::cvss_to_severity(9.5), "CRITICAL");
632 assert_eq!(OssIndexSource::cvss_to_severity(7.5), "HIGH");
633 assert_eq!(OssIndexSource::cvss_to_severity(5.0), "MEDIUM");
634 assert_eq!(OssIndexSource::cvss_to_severity(2.0), "LOW");
635 assert_eq!(OssIndexSource::cvss_to_severity(0.0), "NONE");
636 }
637
638 #[test]
639 fn test_extract_cve_from_url() {
640 assert_eq!(
641 OssIndexSource::extract_cve_from_url(
642 "https://ossindex.sonatype.org/vulnerability/CVE-2021-23337"
643 ),
644 Some("CVE-2021-23337".to_string())
645 );
646 assert_eq!(
647 OssIndexSource::extract_cve_from_url(
648 "https://ossindex.sonatype.org/vulnerability/sonatype-2020-1234"
649 ),
650 None
651 );
652 }
653
654 #[test]
655 fn test_purl_integration() {
656 let purl = Purl::new("npm", "lodash").with_version("4.17.20");
657 assert_eq!(purl.to_string(), "pkg:npm/lodash@4.17.20");
658 assert_eq!(purl.ecosystem(), "npm");
659 assert_eq!(purl.name, "lodash");
660 assert_eq!(purl.version, Some("4.17.20".to_string()));
661 }
662
663 #[test]
664 fn test_score_to_severity() {
665 assert_eq!(OssIndexSource::score_to_severity(9.5), Severity::Critical);
666 assert_eq!(OssIndexSource::score_to_severity(7.5), Severity::High);
667 assert_eq!(OssIndexSource::score_to_severity(5.0), Severity::Medium);
668 assert_eq!(OssIndexSource::score_to_severity(2.0), Severity::Low);
669 assert_eq!(OssIndexSource::score_to_severity(0.0), Severity::None);
670 }
671}