vulnera_advisor/
store.rs

1//! Storage backends for advisory data.
2//!
3//! This module provides the [`AdvisoryStore`] trait and implementations for
4//! persisting and querying vulnerability advisories.
5
6use crate::config::StoreConfig;
7use crate::error::{AdvisoryError, Result};
8use crate::models::Advisory;
9use async_stream::try_stream;
10use async_trait::async_trait;
11use futures_util::Stream;
12use redis::AsyncCommands;
13use serde::{Deserialize, Serialize};
14use std::io::Write;
15use std::pin::Pin;
16use std::time::Instant;
17use tracing::{info, instrument};
18
19/// Trait for advisory storage backends.
20#[async_trait]
21pub trait AdvisoryStore: Send + Sync {
22    /// Insert or update a batch of advisories.
23    async fn upsert_batch(&self, advisories: &[Advisory], source: &str) -> Result<()>;
24
25    /// Get a single advisory by ID.
26    async fn get(&self, id: &str) -> Result<Option<Advisory>>;
27
28    /// Get all advisories affecting a specific package.
29    async fn get_by_package(&self, ecosystem: &str, package: &str) -> Result<Vec<Advisory>>;
30
31    /// Get the timestamp of the last sync for a source.
32    async fn last_sync(&self, source: &str) -> Result<Option<String>>;
33
34    /// Check the health of the store connection.
35    async fn health_check(&self) -> Result<HealthStatus>;
36
37    /// Get advisories as a stream for memory-efficient processing.
38    async fn get_by_package_stream(
39        &self,
40        ecosystem: &str,
41        package: &str,
42    ) -> Result<Pin<Box<dyn Stream<Item = Result<Advisory>> + Send + '_>>>;
43
44    /// Get multiple advisories by IDs in a batch.
45    async fn get_batch(&self, ids: &[String]) -> Result<Vec<Advisory>>;
46
47    /// Store enrichment data (EPSS/KEV) for a CVE.
48    async fn store_enrichment(&self, cve_id: &str, data: &EnrichmentData) -> Result<()>;
49
50    /// Get enrichment data for a CVE.
51    async fn get_enrichment(&self, cve_id: &str) -> Result<Option<EnrichmentData>>;
52
53    /// Get enrichment data for multiple CVEs.
54    async fn get_enrichment_batch(
55        &self,
56        cve_ids: &[String],
57    ) -> Result<Vec<(String, EnrichmentData)>>;
58
59    /// Update the last sync timestamp for a source.
60    async fn update_sync_timestamp(&self, source: &str) -> Result<()>;
61
62    /// Reset (delete) the sync timestamp for a source, forcing a full re-sync.
63    async fn reset_sync_timestamp(&self, source: &str) -> Result<()>;
64
65    /// Get the count of stored advisories.
66    async fn advisory_count(&self) -> Result<u64>;
67
68    /// Store an OSS Index component report in cache.
69    ///
70    /// # Arguments
71    ///
72    /// * `purl` - The Package URL that was queried
73    /// * `cache` - The cached component report with metadata
74    async fn store_ossindex_cache(&self, purl: &str, cache: &OssIndexCache) -> Result<()>;
75
76    /// Get a cached OSS Index component report.
77    ///
78    /// Returns `None` if not cached or if the cache has expired.
79    async fn get_ossindex_cache(&self, purl: &str) -> Result<Option<OssIndexCache>>;
80
81    /// Invalidate (delete) a cached OSS Index component report.
82    async fn invalidate_ossindex_cache(&self, purl: &str) -> Result<()>;
83
84    /// Invalidate all OSS Index cache entries.
85    async fn invalidate_all_ossindex_cache(&self) -> Result<u64>;
86}
87
88/// Health status of the store.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct HealthStatus {
91    /// Whether the connection is working.
92    pub connected: bool,
93    /// Round-trip latency in milliseconds.
94    pub latency_ms: u64,
95    /// Number of advisory keys (approximate).
96    pub advisory_count: u64,
97    /// Redis server info (version, etc.).
98    pub server_info: Option<String>,
99}
100
101/// Enrichment data stored separately for CVEs.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct EnrichmentData {
104    /// EPSS score (0.0 - 1.0).
105    pub epss_score: Option<f64>,
106    /// EPSS percentile (0.0 - 1.0).
107    pub epss_percentile: Option<f64>,
108    /// Whether in CISA KEV catalog.
109    pub is_kev: bool,
110    /// KEV due date (RFC3339).
111    pub kev_due_date: Option<String>,
112    /// KEV date added (RFC3339).
113    pub kev_date_added: Option<String>,
114    /// Whether used in ransomware campaigns.
115    pub kev_ransomware: Option<bool>,
116    /// Last updated timestamp.
117    pub updated_at: String,
118}
119
120/// Cached OSS Index component report.
121///
122/// Stores advisories from OSS Index along with
123/// cache metadata for TTL management.
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct OssIndexCache {
126    /// The converted advisories from OSS Index.
127    pub advisories: Vec<crate::models::Advisory>,
128    /// When this was cached.
129    pub cached_at: chrono::DateTime<chrono::Utc>,
130    /// TTL in seconds from cache time.
131    pub ttl_seconds: u64,
132}
133
134/// Default cache TTL: 1 hour.
135const DEFAULT_OSSINDEX_CACHE_TTL: u64 = 3600;
136
137impl OssIndexCache {
138    /// Create a new cache entry with default TTL.
139    pub fn new(advisories: Vec<crate::models::Advisory>) -> Self {
140        Self {
141            advisories,
142            cached_at: chrono::Utc::now(),
143            ttl_seconds: DEFAULT_OSSINDEX_CACHE_TTL,
144        }
145    }
146
147    /// Create a new cache entry with custom TTL.
148    pub fn with_ttl(advisories: Vec<crate::models::Advisory>, ttl_seconds: u64) -> Self {
149        Self {
150            advisories,
151            cached_at: chrono::Utc::now(),
152            ttl_seconds,
153        }
154    }
155
156    /// Check if this cache entry is still valid (not expired).
157    pub fn is_valid(&self) -> bool {
158        !self.is_expired()
159    }
160
161    /// Check if this cache entry has expired.
162    pub fn is_expired(&self) -> bool {
163        let age = chrono::Utc::now().signed_duration_since(self.cached_at);
164        age.num_seconds() >= self.ttl_seconds as i64
165    }
166
167    /// Get the remaining TTL in seconds.
168    pub fn remaining_ttl(&self) -> i64 {
169        let age = chrono::Utc::now().signed_duration_since(self.cached_at);
170        (self.ttl_seconds as i64) - age.num_seconds()
171    }
172}
173
174/// Redis/DragonflyDB storage implementation.
175pub struct DragonflyStore {
176    client: redis::Client,
177    config: StoreConfig,
178}
179
180impl DragonflyStore {
181    /// Create a new store with default configuration.
182    pub fn new(url: &str) -> Result<Self> {
183        Self::with_config(url, StoreConfig::default())
184    }
185
186    /// Create a new store with custom configuration.
187    pub fn with_config(url: &str, config: StoreConfig) -> Result<Self> {
188        let client = redis::Client::open(url)?;
189        Ok(Self { client, config })
190    }
191
192    /// Get the key prefix for this store.
193    pub fn key_prefix(&self) -> &str {
194        &self.config.key_prefix
195    }
196
197    /// Build a key with the configured prefix.
198    fn key(&self, suffix: &str) -> String {
199        format!("{}:{}", self.config.key_prefix, suffix)
200    }
201
202    fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
203        let mut encoder =
204            zstd::stream::write::Encoder::new(Vec::new(), self.config.compression_level)?;
205        encoder.write_all(data)?;
206        encoder
207            .finish()
208            .map_err(|e| AdvisoryError::compression(e.to_string()))
209    }
210
211    fn decompress(data: &[u8]) -> Result<Vec<u8>> {
212        let mut decoder = zstd::stream::read::Decoder::new(data)?;
213        let mut decoded = Vec::new();
214        std::io::Read::read_to_end(&mut decoder, &mut decoded)?;
215        Ok(decoded)
216    }
217
218    async fn get_connection(&self) -> Result<redis::aio::MultiplexedConnection> {
219        self.client
220            .get_multiplexed_async_connection()
221            .await
222            .map_err(AdvisoryError::from)
223    }
224}
225
226#[async_trait]
227impl AdvisoryStore for DragonflyStore {
228    #[instrument(skip(self, advisories), fields(count = advisories.len()))]
229    async fn upsert_batch(&self, advisories: &[Advisory], source: &str) -> Result<()> {
230        let mut conn = self.get_connection().await?;
231        let mut pipe = redis::pipe();
232
233        for advisory in advisories {
234            let json = serde_json::to_vec(advisory)?;
235            let compressed = self.compress(&json)?;
236
237            let data_key = self.key(&format!("data:{}", advisory.id));
238
239            // Store data with optional TTL
240            if let Some(ttl) = self.config.ttl_seconds {
241                pipe.cmd("SETEX").arg(&data_key).arg(ttl).arg(compressed);
242            } else {
243                pipe.set(&data_key, compressed);
244            }
245
246            // Update index
247            for affected in &advisory.affected {
248                let idx_key = self.key(&format!(
249                    "idx:{}:{}",
250                    affected.package.ecosystem, affected.package.name
251                ));
252                pipe.sadd(&idx_key, &advisory.id);
253            }
254        }
255
256        // NOTE: Do NOT update meta timestamp here.
257        // The caller (sync_all) will update it explicitly after verifying success.
258
259        pipe.query_async::<()>(&mut conn).await?;
260        info!("Upserted {} advisories from {}", advisories.len(), source);
261        Ok(())
262    }
263
264    async fn get(&self, id: &str) -> Result<Option<Advisory>> {
265        let mut conn = self.get_connection().await?;
266        let data: Option<Vec<u8>> = conn.get(self.key(&format!("data:{}", id))).await?;
267
268        match data {
269            Some(bytes) => {
270                let decompressed = Self::decompress(&bytes)?;
271                let advisory = serde_json::from_slice(&decompressed)?;
272                Ok(Some(advisory))
273            }
274            None => Ok(None),
275        }
276    }
277
278    async fn get_by_package(&self, ecosystem: &str, package: &str) -> Result<Vec<Advisory>> {
279        let mut conn = self.get_connection().await?;
280        let ids: Vec<String> = conn
281            .smembers(self.key(&format!("idx:{}:{}", ecosystem, package)))
282            .await?;
283
284        let mut advisories = Vec::new();
285        for id in ids {
286            if let Some(advisory) = self.get(&id).await? {
287                advisories.push(advisory);
288            }
289        }
290        Ok(advisories)
291    }
292
293    async fn last_sync(&self, source: &str) -> Result<Option<String>> {
294        let mut conn = self.get_connection().await?;
295        Ok(conn.get(self.key(&format!("meta:{}", source))).await?)
296    }
297
298    async fn health_check(&self) -> Result<HealthStatus> {
299        let start = Instant::now();
300
301        let mut conn = self.get_connection().await?;
302
303        // Ping to check connection
304        let pong: String = redis::cmd("PING").query_async(&mut conn).await?;
305        let connected = pong == "PONG";
306
307        let latency_ms = start.elapsed().as_millis() as u64;
308
309        // Get approximate key count
310        let advisory_count = self.advisory_count().await.unwrap_or(0);
311
312        // Get server info
313        let info: std::result::Result<String, _> = redis::cmd("INFO")
314            .arg("server")
315            .query_async(&mut conn)
316            .await;
317        let server_info = info.ok().and_then(|s| {
318            s.lines()
319                .find(|l| l.starts_with("redis_version:"))
320                .map(|l| l.to_string())
321        });
322
323        Ok(HealthStatus {
324            connected,
325            latency_ms,
326            advisory_count,
327            server_info,
328        })
329    }
330
331    async fn get_by_package_stream(
332        &self,
333        ecosystem: &str,
334        package: &str,
335    ) -> Result<Pin<Box<dyn Stream<Item = Result<Advisory>> + Send + '_>>> {
336        let idx_key = self.key(&format!("idx:{}:{}", ecosystem, package));
337
338        let stream = try_stream! {
339            let mut conn = self.get_connection().await?;
340
341            // Use SSCAN for memory-efficient iteration
342            let mut cursor = 0u64;
343            loop {
344                let (new_cursor, ids): (u64, Vec<String>) = redis::cmd("SSCAN")
345                    .arg(&idx_key)
346                    .arg(cursor)
347                    .arg("COUNT")
348                    .arg(100)
349                    .query_async(&mut conn)
350                    .await?;
351
352                for id in ids {
353                    if let Some(advisory) = self.get(&id).await? {
354                        yield advisory;
355                    }
356                }
357
358                cursor = new_cursor;
359                if cursor == 0 {
360                    break;
361                }
362            }
363        };
364
365        Ok(Box::pin(stream))
366    }
367
368    async fn get_batch(&self, ids: &[String]) -> Result<Vec<Advisory>> {
369        if ids.is_empty() {
370            return Ok(Vec::new());
371        }
372
373        let mut conn = self.get_connection().await?;
374        let keys: Vec<String> = ids
375            .iter()
376            .map(|id| self.key(&format!("data:{}", id)))
377            .collect();
378
379        let data: Vec<Option<Vec<u8>>> =
380            redis::cmd("MGET").arg(&keys).query_async(&mut conn).await?;
381
382        let mut advisories = Vec::new();
383        for bytes in data.into_iter().flatten() {
384            let decompressed = Self::decompress(&bytes)?;
385            let advisory: Advisory = serde_json::from_slice(&decompressed)?;
386            advisories.push(advisory);
387        }
388
389        Ok(advisories)
390    }
391
392    async fn store_enrichment(&self, cve_id: &str, data: &EnrichmentData) -> Result<()> {
393        let mut conn = self.get_connection().await?;
394        let key = self.key(&format!("enrich:{}", cve_id));
395        let json = serde_json::to_string(data)?;
396
397        if let Some(ttl) = self.config.ttl_seconds {
398            redis::cmd("SETEX")
399                .arg(&key)
400                .arg(ttl)
401                .arg(json)
402                .query_async::<()>(&mut conn)
403                .await?;
404        } else {
405            let _: () = conn.set(&key, json).await?;
406        }
407
408        Ok(())
409    }
410
411    async fn get_enrichment(&self, cve_id: &str) -> Result<Option<EnrichmentData>> {
412        let mut conn = self.get_connection().await?;
413        let key = self.key(&format!("enrich:{}", cve_id));
414        let data: Option<String> = conn.get(&key).await?;
415
416        match data {
417            Some(json) => Ok(Some(serde_json::from_str(&json)?)),
418            None => Ok(None),
419        }
420    }
421
422    async fn get_enrichment_batch(
423        &self,
424        cve_ids: &[String],
425    ) -> Result<Vec<(String, EnrichmentData)>> {
426        if cve_ids.is_empty() {
427            return Ok(Vec::new());
428        }
429
430        let mut conn = self.get_connection().await?;
431        let keys: Vec<String> = cve_ids
432            .iter()
433            .map(|id| self.key(&format!("enrich:{}", id)))
434            .collect();
435
436        let data: Vec<Option<String>> =
437            redis::cmd("MGET").arg(&keys).query_async(&mut conn).await?;
438
439        let mut results = Vec::new();
440        for (cve_id, json_opt) in cve_ids.iter().zip(data) {
441            if let Some(json) = json_opt {
442                if let Ok(enrichment) = serde_json::from_str(&json) {
443                    results.push((cve_id.clone(), enrichment));
444                }
445            }
446        }
447
448        Ok(results)
449    }
450
451    async fn update_sync_timestamp(&self, source: &str) -> Result<()> {
452        let mut conn = self.get_connection().await?;
453        let _: () = conn
454            .set(
455                self.key(&format!("meta:{}", source)),
456                chrono::Utc::now().to_rfc3339(),
457            )
458            .await?;
459        Ok(())
460    }
461
462    async fn reset_sync_timestamp(&self, source: &str) -> Result<()> {
463        let mut conn = self.get_connection().await?;
464        let _: () = conn.del(self.key(&format!("meta:{}", source))).await?;
465        info!("Reset sync timestamp for {}", source);
466        Ok(())
467    }
468
469    async fn advisory_count(&self) -> Result<u64> {
470        let mut conn = self.get_connection().await?;
471        let pattern = self.key("data:*");
472
473        // Use SCAN to count keys matching pattern
474        let mut count = 0u64;
475        let mut cursor = 0u64;
476
477        loop {
478            let (new_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
479                .arg(cursor)
480                .arg("MATCH")
481                .arg(&pattern)
482                .arg("COUNT")
483                .arg(1000)
484                .query_async(&mut conn)
485                .await?;
486
487            count += keys.len() as u64;
488            cursor = new_cursor;
489
490            if cursor == 0 {
491                break;
492            }
493        }
494
495        Ok(count)
496    }
497
498    async fn store_ossindex_cache(&self, purl: &str, cache: &OssIndexCache) -> Result<()> {
499        let mut conn = self.get_connection().await?;
500        let key = self.key(&format!("ossidx:{}", Self::hash_purl(purl)));
501        let json = serde_json::to_string(cache)?;
502
503        // Use the remaining TTL or the configured TTL
504        let ttl = cache.remaining_ttl().max(1) as u64;
505        redis::cmd("SETEX")
506            .arg(&key)
507            .arg(ttl)
508            .arg(json)
509            .query_async::<()>(&mut conn)
510            .await?;
511
512        Ok(())
513    }
514
515    async fn get_ossindex_cache(&self, purl: &str) -> Result<Option<OssIndexCache>> {
516        let mut conn = self.get_connection().await?;
517        let key = self.key(&format!("ossidx:{}", Self::hash_purl(purl)));
518        let data: Option<String> = conn.get(&key).await?;
519
520        match data {
521            Some(json) => {
522                let cache: OssIndexCache = serde_json::from_str(&json)?;
523                // Double-check validity (Redis TTL should handle this, but be safe)
524                if cache.is_valid() {
525                    Ok(Some(cache))
526                } else {
527                    // Cache expired, delete it
528                    let _: () = conn.del(&key).await?;
529                    Ok(None)
530                }
531            }
532            None => Ok(None),
533        }
534    }
535
536    async fn invalidate_ossindex_cache(&self, purl: &str) -> Result<()> {
537        let mut conn = self.get_connection().await?;
538        let key = self.key(&format!("ossidx:{}", Self::hash_purl(purl)));
539        let _: () = conn.del(&key).await?;
540        Ok(())
541    }
542
543    async fn invalidate_all_ossindex_cache(&self) -> Result<u64> {
544        let mut conn = self.get_connection().await?;
545        let pattern = self.key("ossidx:*");
546
547        // Use SCAN to find all OSS Index cache keys
548        let mut deleted = 0u64;
549        let mut cursor = 0u64;
550
551        loop {
552            let (new_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
553                .arg(cursor)
554                .arg("MATCH")
555                .arg(&pattern)
556                .arg("COUNT")
557                .arg(1000)
558                .query_async(&mut conn)
559                .await?;
560
561            if !keys.is_empty() {
562                let count: u64 = redis::cmd("DEL").arg(&keys).query_async(&mut conn).await?;
563                deleted += count;
564            }
565
566            cursor = new_cursor;
567            if cursor == 0 {
568                break;
569            }
570        }
571
572        Ok(deleted)
573    }
574}
575
576impl DragonflyStore {
577    /// Generate a hash key for a PURL string.
578    fn hash_purl(purl: &str) -> String {
579        use std::collections::hash_map::DefaultHasher;
580        use std::hash::{Hash, Hasher};
581
582        let mut hasher = DefaultHasher::new();
583        purl.hash(&mut hasher);
584        format!("{:x}", hasher.finish())
585    }
586}