1use crate::config::OssIndexConfig;
38use crate::error::AdvisoryError;
39use crate::models::{
40 Advisory, Affected, Event, Package, Range, RangeType, Reference, ReferenceType, Severity,
41};
42use crate::purl::Purl;
43use anyhow::Result;
44use reqwest::Client;
45use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
46use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
47use serde::{Deserialize, Serialize};
48use std::collections::HashSet;
49use std::env;
50use std::sync::Arc;
51use tokio::sync::Semaphore;
52use tracing::{debug, warn};
53
54const MAX_BATCH_SIZE: usize = 128;
56
57const DEFAULT_CONCURRENCY: usize = 4;
59
60const API_BASE_URL: &str = "https://ossindex.sonatype.org/api/v3";
62
63#[derive(Debug, Serialize)]
65struct ComponentReportRequest {
66 coordinates: Vec<String>,
67}
68
69#[derive(Debug, Deserialize)]
71pub struct ComponentReport {
72 pub coordinates: String,
73 #[serde(default)]
74 pub description: Option<String>,
75 #[serde(default)]
76 pub reference: Option<String>,
77 #[serde(default)]
78 pub vulnerabilities: Vec<OssVulnerability>,
79}
80
81#[derive(Debug, Deserialize, Clone)]
83pub struct OssVulnerability {
84 pub id: String,
85 #[serde(rename = "displayName")]
86 pub display_name: Option<String>,
87 pub title: String,
88 pub description: String,
89 #[serde(rename = "cvssScore")]
90 pub cvss_score: Option<f64>,
91 #[serde(rename = "cvssVector")]
92 pub cvss_vector: Option<String>,
93 #[serde(default)]
94 pub cwe: Option<String>,
95 #[serde(default)]
96 pub cve: Option<String>,
97 pub reference: String,
98 #[serde(rename = "versionRanges")]
99 pub version_ranges: Option<Vec<String>>,
100 #[serde(rename = "externalReferences")]
101 pub external_references: Option<Vec<String>>,
102}
103
104pub struct OssIndexSource {
109 client: ClientWithMiddleware,
110 config: OssIndexConfig,
111 semaphore: Arc<Semaphore>,
112}
113
114impl OssIndexSource {
115 pub fn new(config: Option<OssIndexConfig>) -> Result<Self> {
119 let config = config.unwrap_or_else(Self::config_from_env);
120
121 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
122 let client = ClientBuilder::new(Client::new())
123 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
124 .build();
125
126 Ok(Self {
127 client,
128 semaphore: Arc::new(Semaphore::new(DEFAULT_CONCURRENCY)),
129 config,
130 })
131 }
132
133 fn config_from_env() -> OssIndexConfig {
135 OssIndexConfig {
136 user: env::var("OSSINDEX_USER").ok(),
137 token: env::var("OSSINDEX_TOKEN").ok(),
138 batch_size: 128,
139 }
140 }
141
142 pub fn with_concurrency(config: Option<OssIndexConfig>, concurrency: usize) -> Result<Self> {
144 let mut source = Self::new(config)?;
145 source.semaphore = Arc::new(Semaphore::new(concurrency));
146 Ok(source)
147 }
148
149 pub async fn query_advisories(&self, purls: &[String]) -> Result<Vec<Advisory>> {
164 let reports = self.query_batch(purls).await?;
165 Ok(self.convert_reports_to_advisories(&reports))
166 }
167
168 pub async fn query_components(&self, purls: &[String]) -> Result<Vec<ComponentReport>> {
172 self.query_batch(purls).await
173 }
174
175 async fn query_batch(&self, purls: &[String]) -> Result<Vec<ComponentReport>> {
177 if purls.is_empty() {
178 return Ok(Vec::new());
179 }
180
181 let chunks: Vec<_> = purls.chunks(MAX_BATCH_SIZE).collect();
182 let mut handles = Vec::with_capacity(chunks.len());
183
184 for chunk in chunks {
185 let chunk_vec: Vec<String> = chunk.to_vec();
186 let client = self.client.clone();
187 let config = self.config.clone();
188 let semaphore = self.semaphore.clone();
189
190 handles.push(tokio::spawn(async move {
191 let _permit =
192 semaphore
193 .acquire()
194 .await
195 .map_err(|e| AdvisoryError::SourceFetch {
196 source_name: "ossindex".to_string(),
197 message: format!("Semaphore error: {}", e),
198 })?;
199
200 Self::query_chunk(&client, &config, &chunk_vec).await
201 }));
202 }
203
204 let mut all_reports = Vec::new();
205 for handle in handles {
206 match handle.await {
207 Ok(Ok(reports)) => all_reports.extend(reports),
208 Ok(Err(e)) => {
209 warn!("OSS Index batch query failed: {}", e);
210 return Err(e);
211 }
212 Err(e) => {
213 warn!("OSS Index task panicked: {}", e);
214 return Err(AdvisoryError::SourceFetch {
215 source_name: "ossindex".to_string(),
216 message: format!("Task panicked: {}", e),
217 }
218 .into());
219 }
220 }
221 }
222
223 Ok(all_reports)
224 }
225
226 async fn query_chunk(
228 client: &ClientWithMiddleware,
229 config: &OssIndexConfig,
230 purls: &[String],
231 ) -> Result<Vec<ComponentReport>> {
232 let url = format!("{}/component-report", API_BASE_URL);
233
234 let request = ComponentReportRequest {
235 coordinates: purls.to_vec(),
236 };
237
238 let mut req_builder = client
239 .post(&url)
240 .header("Content-Type", "application/json")
241 .header("Accept", "application/json");
242
243 if let (Some(user), Some(token)) = (&config.user, &config.token) {
245 req_builder = req_builder.basic_auth(user, Some(token));
246 }
247
248 let response = req_builder
249 .body(serde_json::to_string(&request)?)
250 .send()
251 .await
252 .map_err(|e| AdvisoryError::SourceFetch {
253 source_name: "ossindex".to_string(),
254 message: format!("Request failed: {}", e),
255 })?;
256
257 let status = response.status();
258
259 if status == reqwest::StatusCode::UNAUTHORIZED {
261 return Err(AdvisoryError::SourceFetch {
262 source_name: "ossindex".to_string(),
263 message: "Authentication required. Set OSSINDEX_USER and OSSINDEX_TOKEN environment variables.".to_string(),
264 }.into());
265 }
266
267 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
268 return Err(AdvisoryError::SourceFetch {
269 source_name: "ossindex".to_string(),
270 message: "Rate limited by OSS Index. Please retry later.".to_string(),
271 }
272 .into());
273 }
274
275 if !status.is_success() {
276 let body = response.text().await.unwrap_or_default();
277 return Err(AdvisoryError::SourceFetch {
278 source_name: "ossindex".to_string(),
279 message: format!("HTTP {}: {}", status, body),
280 }
281 .into());
282 }
283
284 let reports: Vec<ComponentReport> =
285 response
286 .json()
287 .await
288 .map_err(|e| AdvisoryError::SourceFetch {
289 source_name: "ossindex".to_string(),
290 message: format!("Failed to parse response: {}", e),
291 })?;
292
293 debug!("OSS Index returned {} reports", reports.len());
294 Ok(reports)
295 }
296
297 fn convert_reports_to_advisories(&self, reports: &[ComponentReport]) -> Vec<Advisory> {
299 let mut advisories = Vec::new();
300 let mut seen_ids: HashSet<String> = HashSet::new();
301
302 for report in reports {
303 for vuln in &report.vulnerabilities {
304 let advisory_id = self.generate_advisory_id(vuln);
306
307 if seen_ids.contains(&advisory_id) {
309 if let Some(advisory) = advisories
311 .iter_mut()
312 .find(|a: &&mut Advisory| a.id == advisory_id)
313 {
314 if let Some(affected) = self.extract_affected(&report.coordinates, vuln) {
315 advisory.affected.push(affected);
316 }
317 }
318 continue;
319 }
320
321 seen_ids.insert(advisory_id.clone());
322
323 let advisory = self.convert_vulnerability(vuln, &report.coordinates);
324 advisories.push(advisory);
325 }
326 }
327
328 advisories
329 }
330
331 fn generate_advisory_id(&self, vuln: &OssVulnerability) -> String {
333 if let Some(ref cve) = vuln.cve {
335 if !cve.is_empty() {
336 return cve.clone();
337 }
338 }
339
340 if let Some(ref name) = vuln.display_name {
342 if name.starts_with("CVE-") {
343 return name.clone();
344 }
345 }
346
347 if let Some(cve) = Self::extract_cve_from_url(&vuln.reference) {
349 return cve;
350 }
351
352 vuln.id.clone()
354 }
355
356 fn extract_cve_from_url(url: &str) -> Option<String> {
358 let parts: Vec<&str> = url.split('/').collect();
360 parts
361 .last()
362 .filter(|id| id.starts_with("CVE-"))
363 .map(|s| s.to_string())
364 }
365
366 fn convert_vulnerability(&self, vuln: &OssVulnerability, coordinates: &str) -> Advisory {
368 let mut affected = Vec::new();
369
370 if let Some(aff) = self.extract_affected(coordinates, vuln) {
371 affected.push(aff);
372 }
373
374 let mut aliases = Vec::new();
376 if let Some(ref cve) = vuln.cve {
377 if !cve.is_empty() && !cve.starts_with("CVE-") {
378 aliases.push(format!("CVE-{}", cve));
379 } else if !cve.is_empty() {
380 aliases.push(cve.clone());
381 }
382 }
383
384 let advisory_id = self.generate_advisory_id(vuln);
386 if advisory_id.starts_with("CVE-") && !vuln.id.starts_with("CVE-") {
387 aliases.push(vuln.id.clone());
388 }
389
390 let mut references = vec![Reference {
392 reference_type: ReferenceType::Advisory,
393 url: vuln.reference.clone(),
394 }];
395 if let Some(ref ext_refs) = vuln.external_references {
396 for url in ext_refs {
397 references.push(Reference {
398 reference_type: ReferenceType::Web,
399 url: url.clone(),
400 });
401 }
402 }
403
404 let mut db_specific = serde_json::Map::new();
406 if let Some(score) = vuln.cvss_score {
407 db_specific.insert("cvss_score".to_string(), serde_json::json!(score));
408 db_specific.insert(
409 "severity".to_string(),
410 serde_json::json!(Self::cvss_to_severity(score)),
411 );
412 }
413 if let Some(ref vector) = vuln.cvss_vector {
414 db_specific.insert("cvss_vector".to_string(), serde_json::json!(vector));
415 }
416 if let Some(ref cwe) = vuln.cwe {
417 db_specific.insert("cwe_ids".to_string(), serde_json::json!([cwe]));
418 }
419 db_specific.insert("source".to_string(), serde_json::json!("ossindex"));
420
421 Advisory {
422 id: advisory_id,
423 summary: Some(vuln.title.clone()),
424 details: Some(vuln.description.clone()),
425 affected,
426 references,
427 published: None,
428 modified: None,
429 aliases: if aliases.is_empty() {
430 None
431 } else {
432 Some(aliases)
433 },
434 database_specific: Some(serde_json::Value::Object(db_specific)),
435 enrichment: None,
436 }
437 }
438
439 fn extract_affected(&self, coordinates: &str, vuln: &OssVulnerability) -> Option<Affected> {
441 let purl = Purl::parse(coordinates).ok()?;
442
443 let ranges = vuln
444 .version_ranges
445 .as_ref()
446 .map(|ranges| {
447 ranges
448 .iter()
449 .filter_map(|r| Self::parse_version_range(r))
450 .collect()
451 })
452 .unwrap_or_default();
453
454 Some(Affected {
455 package: Package {
456 ecosystem: purl.ecosystem(),
457 name: purl.name.clone(),
458 purl: Some(coordinates.to_string()),
459 },
460 ranges,
461 versions: Vec::new(),
462 ecosystem_specific: None,
463 database_specific: None,
464 })
465 }
466
467 fn parse_version_range(range: &str) -> Option<Range> {
475 let range = range.trim();
476 if range.is_empty() {
477 return None;
478 }
479
480 if !range.contains(',') && !range.starts_with('[') && !range.starts_with('(') {
482 return Some(Range {
483 range_type: RangeType::Semver,
484 events: vec![Event::LastAffected(range.to_string())],
485 repo: None,
486 });
487 }
488
489 let start_inclusive = range.starts_with('[');
491 let end_inclusive = range.ends_with(']');
492
493 let inner = range
495 .trim_start_matches(['[', '('])
496 .trim_end_matches([']', ')']);
497
498 let parts: Vec<&str> = inner.split(',').collect();
499 if parts.len() != 2 {
500 return None;
501 }
502
503 let start = parts[0].trim();
504 let end = parts[1].trim();
505
506 let mut events = Vec::new();
507
508 if !start.is_empty() {
510 if start_inclusive {
511 events.push(Event::Introduced(start.to_string()));
512 } else {
513 events.push(Event::Introduced(start.to_string()));
516 }
517 } else {
518 events.push(Event::Introduced("0".to_string()));
520 }
521
522 if !end.is_empty() {
524 if end_inclusive {
525 events.push(Event::LastAffected(end.to_string()));
528 } else {
529 events.push(Event::Fixed(end.to_string()));
531 }
532 }
533
534 Some(Range {
535 range_type: RangeType::Semver,
536 events,
537 repo: None,
538 })
539 }
540
541 fn cvss_to_severity(score: f64) -> &'static str {
543 match score {
544 s if s >= 9.0 => "CRITICAL",
545 s if s >= 7.0 => "HIGH",
546 s if s >= 4.0 => "MEDIUM",
547 s if s > 0.0 => "LOW",
548 _ => "NONE",
549 }
550 }
551
552 pub fn score_to_severity(score: f64) -> Severity {
554 Severity::from_cvss_score(score)
555 }
556
557 pub fn name(&self) -> &'static str {
559 "ossindex"
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_parse_version_range_standard() {
569 let range = OssIndexSource::parse_version_range("[1.0.0,2.0.0)");
570 assert!(range.is_some());
571 let range = range.unwrap();
572 assert_eq!(range.range_type, RangeType::Semver);
573 assert_eq!(range.events.len(), 2);
574 assert!(matches!(&range.events[0], Event::Introduced(v) if v == "1.0.0"));
575 assert!(matches!(&range.events[1], Event::Fixed(v) if v == "2.0.0"));
576 }
577
578 #[test]
579 fn test_parse_version_range_inclusive_end() {
580 let range = OssIndexSource::parse_version_range("[1.0.0,2.0.0]");
581 assert!(range.is_some());
582 let range = range.unwrap();
583 assert_eq!(range.events.len(), 2);
584 assert!(matches!(&range.events[0], Event::Introduced(v) if v == "1.0.0"));
585 assert!(matches!(&range.events[1], Event::LastAffected(v) if v == "2.0.0"));
586 }
587
588 #[test]
589 fn test_parse_version_range_open_start() {
590 let range = OssIndexSource::parse_version_range("(,1.0.0)");
591 assert!(range.is_some());
592 let range = range.unwrap();
593 assert_eq!(range.events.len(), 2);
594 assert!(matches!(&range.events[0], Event::Introduced(v) if v == "0"));
595 assert!(matches!(&range.events[1], Event::Fixed(v) if v == "1.0.0"));
596 }
597
598 #[test]
599 fn test_parse_version_range_open_end() {
600 let range = OssIndexSource::parse_version_range("[1.0.0,)");
601 assert!(range.is_some());
602 let range = range.unwrap();
603 assert_eq!(range.events.len(), 1);
604 assert!(matches!(&range.events[0], Event::Introduced(v) if v == "1.0.0"));
605 }
606
607 #[test]
608 fn test_parse_version_range_exact() {
609 let range = OssIndexSource::parse_version_range("1.0.0");
610 assert!(range.is_some());
611 let range = range.unwrap();
612 assert_eq!(range.events.len(), 1);
613 assert!(matches!(&range.events[0], Event::LastAffected(v) if v == "1.0.0"));
614 }
615
616 #[test]
617 fn test_cvss_to_severity() {
618 assert_eq!(OssIndexSource::cvss_to_severity(9.5), "CRITICAL");
619 assert_eq!(OssIndexSource::cvss_to_severity(7.5), "HIGH");
620 assert_eq!(OssIndexSource::cvss_to_severity(5.0), "MEDIUM");
621 assert_eq!(OssIndexSource::cvss_to_severity(2.0), "LOW");
622 assert_eq!(OssIndexSource::cvss_to_severity(0.0), "NONE");
623 }
624
625 #[test]
626 fn test_extract_cve_from_url() {
627 assert_eq!(
628 OssIndexSource::extract_cve_from_url(
629 "https://ossindex.sonatype.org/vulnerability/CVE-2021-23337"
630 ),
631 Some("CVE-2021-23337".to_string())
632 );
633 assert_eq!(
634 OssIndexSource::extract_cve_from_url(
635 "https://ossindex.sonatype.org/vulnerability/sonatype-2020-1234"
636 ),
637 None
638 );
639 }
640
641 #[test]
642 fn test_purl_integration() {
643 let purl = Purl::new("npm", "lodash").with_version("4.17.20");
644 assert_eq!(purl.to_string(), "pkg:npm/lodash@4.17.20");
645 assert_eq!(purl.ecosystem(), "npm");
646 assert_eq!(purl.name, "lodash");
647 assert_eq!(purl.version, Some("4.17.20".to_string()));
648 }
649
650 #[test]
651 fn test_score_to_severity() {
652 assert_eq!(OssIndexSource::score_to_severity(9.5), Severity::Critical);
653 assert_eq!(OssIndexSource::score_to_severity(7.5), Severity::High);
654 assert_eq!(OssIndexSource::score_to_severity(5.0), Severity::Medium);
655 assert_eq!(OssIndexSource::score_to_severity(2.0), Severity::Low);
656 assert_eq!(OssIndexSource::score_to_severity(0.0), Severity::None);
657 }
658}