Skip to main content

vulnera_advisor/
store.rs

1//! Storage backends for advisory data.
2//!
3//! This module provides the [`AdvisoryStore`] trait and implementations for
4//! persisting and querying vulnerability advisories.
5
6use crate::config::StoreConfig;
7use crate::ecosystem::normalize_package_key;
8use crate::error::{AdvisoryError, Result};
9use crate::models::Advisory;
10use async_stream::try_stream;
11use async_trait::async_trait;
12use futures_util::Stream;
13use redis::AsyncCommands;
14use serde::{Deserialize, Serialize};
15use std::io::Write;
16use std::pin::Pin;
17use std::time::Instant;
18use tracing::{info, instrument};
19
20/// Trait for advisory storage backends.
21#[async_trait]
22pub trait AdvisoryStore: Send + Sync {
23    /// Insert or update a batch of advisories.
24    async fn upsert_batch(&self, advisories: &[Advisory], source: &str) -> Result<()>;
25
26    /// Get a single advisory by ID.
27    async fn get(&self, id: &str) -> Result<Option<Advisory>>;
28
29    /// Get all advisories affecting a specific package.
30    async fn get_by_package(&self, ecosystem: &str, package: &str) -> Result<Vec<Advisory>>;
31
32    /// Get the timestamp of the last sync for a source.
33    async fn last_sync(&self, source: &str) -> Result<Option<String>>;
34
35    /// Check the health of the store connection.
36    async fn health_check(&self) -> Result<HealthStatus>;
37
38    /// Get advisories as a stream for memory-efficient processing.
39    async fn get_by_package_stream(
40        &self,
41        ecosystem: &str,
42        package: &str,
43    ) -> Result<Pin<Box<dyn Stream<Item = Result<Advisory>> + Send + '_>>>;
44
45    /// Get multiple advisories by IDs in a batch.
46    async fn get_batch(&self, ids: &[String]) -> Result<Vec<Advisory>>;
47
48    /// Store enrichment data (EPSS/KEV) for a CVE.
49    async fn store_enrichment(&self, cve_id: &str, data: &EnrichmentData) -> Result<()>;
50
51    /// Get enrichment data for a CVE.
52    async fn get_enrichment(&self, cve_id: &str) -> Result<Option<EnrichmentData>>;
53
54    /// Get enrichment data for multiple CVEs.
55    async fn get_enrichment_batch(
56        &self,
57        cve_ids: &[String],
58    ) -> Result<Vec<(String, EnrichmentData)>>;
59
60    /// Update the last sync timestamp for a source.
61    async fn update_sync_timestamp(&self, source: &str) -> Result<()>;
62
63    /// Reset (delete) the sync timestamp for a source, forcing a full re-sync.
64    async fn reset_sync_timestamp(&self, source: &str) -> Result<()>;
65
66    /// Get the count of stored advisories.
67    async fn advisory_count(&self) -> Result<u64>;
68
69    /// Store an OSS Index component report in cache.
70    ///
71    /// # Arguments
72    ///
73    /// * `purl` - The Package URL that was queried
74    /// * `cache` - The cached component report with metadata
75    async fn store_ossindex_cache(&self, purl: &str, cache: &OssIndexCache) -> Result<()>;
76
77    /// Get a cached OSS Index component report.
78    ///
79    /// Returns `None` if not cached or if the cache has expired.
80    async fn get_ossindex_cache(&self, purl: &str) -> Result<Option<OssIndexCache>>;
81
82    /// Invalidate (delete) a cached OSS Index component report.
83    async fn invalidate_ossindex_cache(&self, purl: &str) -> Result<()>;
84
85    /// Invalidate all OSS Index cache entries.
86    async fn invalidate_all_ossindex_cache(&self) -> Result<u64>;
87}
88
89/// Health status of the store.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct HealthStatus {
92    /// Whether the connection is working.
93    pub connected: bool,
94    /// Round-trip latency in milliseconds.
95    pub latency_ms: u64,
96    /// Number of advisory keys (approximate).
97    pub advisory_count: u64,
98    /// Redis server info (version, etc.).
99    pub server_info: Option<String>,
100}
101
102/// Enrichment data stored separately for CVEs.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct EnrichmentData {
105    /// EPSS score (0.0 - 1.0).
106    pub epss_score: Option<f64>,
107    /// EPSS percentile (0.0 - 1.0).
108    pub epss_percentile: Option<f64>,
109    /// Whether in CISA KEV catalog.
110    pub is_kev: bool,
111    /// KEV due date (RFC3339).
112    pub kev_due_date: Option<String>,
113    /// KEV date added (RFC3339).
114    pub kev_date_added: Option<String>,
115    /// Whether used in ransomware campaigns.
116    pub kev_ransomware: Option<bool>,
117    /// Last updated timestamp.
118    pub updated_at: String,
119}
120
121/// Cached OSS Index component report.
122///
123/// Stores advisories from OSS Index along with
124/// cache metadata for TTL management.
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct OssIndexCache {
127    /// The converted advisories from OSS Index.
128    pub advisories: Vec<crate::models::Advisory>,
129    /// When this was cached.
130    pub cached_at: chrono::DateTime<chrono::Utc>,
131    /// TTL in seconds from cache time.
132    pub ttl_seconds: u64,
133}
134
135/// Default cache TTL: 1 hour.
136const DEFAULT_OSSINDEX_CACHE_TTL: u64 = 3600;
137
138impl OssIndexCache {
139    /// Create a new cache entry with default TTL.
140    pub fn new(advisories: Vec<crate::models::Advisory>) -> Self {
141        Self {
142            advisories,
143            cached_at: chrono::Utc::now(),
144            ttl_seconds: DEFAULT_OSSINDEX_CACHE_TTL,
145        }
146    }
147
148    /// Create a new cache entry with custom TTL.
149    pub fn with_ttl(advisories: Vec<crate::models::Advisory>, ttl_seconds: u64) -> Self {
150        Self {
151            advisories,
152            cached_at: chrono::Utc::now(),
153            ttl_seconds,
154        }
155    }
156
157    /// Check if this cache entry is still valid (not expired).
158    pub fn is_valid(&self) -> bool {
159        !self.is_expired()
160    }
161
162    /// Check if this cache entry has expired.
163    pub fn is_expired(&self) -> bool {
164        let age = chrono::Utc::now().signed_duration_since(self.cached_at);
165        age.num_seconds() >= self.ttl_seconds as i64
166    }
167
168    /// Get the remaining TTL in seconds.
169    pub fn remaining_ttl(&self) -> i64 {
170        let age = chrono::Utc::now().signed_duration_since(self.cached_at);
171        (self.ttl_seconds as i64) - age.num_seconds()
172    }
173}
174
175/// Redis/DragonflyDB storage implementation.
176pub struct DragonflyStore {
177    client: redis::Client,
178    config: StoreConfig,
179}
180
181impl DragonflyStore {
182    /// Create a new store with default configuration.
183    pub fn new(url: &str) -> Result<Self> {
184        Self::with_config(url, StoreConfig::default())
185    }
186
187    /// Create a new store with custom configuration.
188    pub fn with_config(url: &str, config: StoreConfig) -> Result<Self> {
189        let client = redis::Client::open(url)?;
190        Ok(Self { client, config })
191    }
192
193    /// Get the key prefix for this store.
194    pub fn key_prefix(&self) -> &str {
195        &self.config.key_prefix
196    }
197
198    /// Build a key with the configured prefix.
199    fn key(&self, suffix: &str) -> String {
200        format!("{}:{}", self.config.key_prefix, suffix)
201    }
202
203    fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
204        let mut encoder =
205            zstd::stream::write::Encoder::new(Vec::new(), self.config.compression_level)?;
206        encoder.write_all(data)?;
207        encoder
208            .finish()
209            .map_err(|e| AdvisoryError::compression(e.to_string()))
210    }
211
212    fn decompress(data: &[u8]) -> Result<Vec<u8>> {
213        let mut decoder = zstd::stream::read::Decoder::new(data)?;
214        let mut decoded = Vec::new();
215        std::io::Read::read_to_end(&mut decoder, &mut decoded)?;
216        Ok(decoded)
217    }
218
219    async fn get_connection(&self) -> Result<redis::aio::MultiplexedConnection> {
220        self.client
221            .get_multiplexed_async_connection()
222            .await
223            .map_err(AdvisoryError::from)
224    }
225}
226
227#[async_trait]
228impl AdvisoryStore for DragonflyStore {
229    #[instrument(skip(self, advisories), fields(count = advisories.len()))]
230    async fn upsert_batch(&self, advisories: &[Advisory], source: &str) -> Result<()> {
231        let mut conn = self.get_connection().await?;
232        let mut pipe = redis::pipe();
233
234        for advisory in advisories {
235            let json = serde_json::to_vec(advisory)?;
236            let compressed = self.compress(&json)?;
237
238            let data_key = self.key(&format!("data:{}", advisory.id));
239
240            // Store data with optional TTL
241            if let Some(ttl) = self.config.ttl_seconds {
242                pipe.cmd("SETEX").arg(&data_key).arg(ttl).arg(compressed);
243            } else {
244                pipe.set(&data_key, compressed);
245            }
246
247            // Update index
248            for affected in &advisory.affected {
249                let (ecosystem, package) =
250                    normalize_package_key(&affected.package.ecosystem, &affected.package.name);
251                let idx_key = self.key(&format!("idx:{}:{}", ecosystem, package));
252                pipe.sadd(&idx_key, &advisory.id);
253            }
254        }
255
256        // NOTE: Do NOT update meta timestamp here.
257        // The caller (sync_all) will update it explicitly after verifying success.
258
259        pipe.query_async::<()>(&mut conn).await?;
260        info!("Upserted {} advisories from {}", advisories.len(), source);
261        Ok(())
262    }
263
264    async fn get(&self, id: &str) -> Result<Option<Advisory>> {
265        let mut conn = self.get_connection().await?;
266        let data: Option<Vec<u8>> = conn.get(self.key(&format!("data:{}", id))).await?;
267
268        match data {
269            Some(bytes) => {
270                let decompressed = Self::decompress(&bytes)?;
271                let advisory = serde_json::from_slice(&decompressed)?;
272                Ok(Some(advisory))
273            }
274            None => Ok(None),
275        }
276    }
277
278    async fn get_by_package(&self, ecosystem: &str, package: &str) -> Result<Vec<Advisory>> {
279        let (ecosystem, package) = normalize_package_key(ecosystem, package);
280        let mut conn = self.get_connection().await?;
281        let ids: Vec<String> = conn
282            .smembers(self.key(&format!("idx:{}:{}", ecosystem, package)))
283            .await?;
284
285        // Batch fetch to avoid N+1 round trips and repeated decompress calls
286        self.get_batch(&ids).await
287    }
288
289    async fn last_sync(&self, source: &str) -> Result<Option<String>> {
290        let mut conn = self.get_connection().await?;
291        Ok(conn.get(self.key(&format!("meta:{}", source))).await?)
292    }
293
294    async fn health_check(&self) -> Result<HealthStatus> {
295        let start = Instant::now();
296
297        let mut conn = self.get_connection().await?;
298
299        // Ping to check connection
300        let pong: String = redis::cmd("PING").query_async(&mut conn).await?;
301        let connected = pong == "PONG";
302
303        let latency_ms = start.elapsed().as_millis() as u64;
304
305        // Get approximate key count
306        let advisory_count = self.advisory_count().await.unwrap_or(0);
307
308        // Get server info
309        let info: std::result::Result<String, _> = redis::cmd("INFO")
310            .arg("server")
311            .query_async(&mut conn)
312            .await;
313        let server_info = info.ok().and_then(|s| {
314            s.lines()
315                .find(|l| l.starts_with("redis_version:"))
316                .map(|l| l.to_string())
317        });
318
319        Ok(HealthStatus {
320            connected,
321            latency_ms,
322            advisory_count,
323            server_info,
324        })
325    }
326
327    async fn get_by_package_stream(
328        &self,
329        ecosystem: &str,
330        package: &str,
331    ) -> Result<Pin<Box<dyn Stream<Item = Result<Advisory>> + Send + '_>>> {
332        let (ecosystem, package) = normalize_package_key(ecosystem, package);
333        let idx_key = self.key(&format!("idx:{}:{}", ecosystem, package));
334
335        let stream = try_stream! {
336            let mut conn = self.get_connection().await?;
337
338            // Use SSCAN for memory-efficient iteration
339            let mut cursor = 0u64;
340            loop {
341                let (new_cursor, ids): (u64, Vec<String>) = redis::cmd("SSCAN")
342                    .arg(&idx_key)
343                    .arg(cursor)
344                    .arg("COUNT")
345                    .arg(100)
346                    .query_async(&mut conn)
347                    .await?;
348
349                for id in ids {
350                    if let Some(advisory) = self.get(&id).await? {
351                        yield advisory;
352                    }
353                }
354
355                cursor = new_cursor;
356                if cursor == 0 {
357                    break;
358                }
359            }
360        };
361
362        Ok(Box::pin(stream))
363    }
364
365    async fn get_batch(&self, ids: &[String]) -> Result<Vec<Advisory>> {
366        if ids.is_empty() {
367            return Ok(Vec::new());
368        }
369
370        let mut conn = self.get_connection().await?;
371        let keys: Vec<String> = ids
372            .iter()
373            .map(|id| self.key(&format!("data:{}", id)))
374            .collect();
375
376        let data: Vec<Option<Vec<u8>>> =
377            redis::cmd("MGET").arg(&keys).query_async(&mut conn).await?;
378
379        let mut advisories = Vec::new();
380        for bytes in data.into_iter().flatten() {
381            let decompressed = Self::decompress(&bytes)?;
382            let advisory: Advisory = serde_json::from_slice(&decompressed)?;
383            advisories.push(advisory);
384        }
385
386        Ok(advisories)
387    }
388
389    async fn store_enrichment(&self, cve_id: &str, data: &EnrichmentData) -> Result<()> {
390        let mut conn = self.get_connection().await?;
391        let key = self.key(&format!("enrich:{}", cve_id));
392        let json = serde_json::to_string(data)?;
393
394        if let Some(ttl) = self.config.ttl_seconds {
395            redis::cmd("SETEX")
396                .arg(&key)
397                .arg(ttl)
398                .arg(json)
399                .query_async::<()>(&mut conn)
400                .await?;
401        } else {
402            let _: () = conn.set(&key, json).await?;
403        }
404
405        Ok(())
406    }
407
408    async fn get_enrichment(&self, cve_id: &str) -> Result<Option<EnrichmentData>> {
409        let mut conn = self.get_connection().await?;
410        let key = self.key(&format!("enrich:{}", cve_id));
411        let data: Option<String> = conn.get(&key).await?;
412
413        match data {
414            Some(json) => Ok(Some(serde_json::from_str(&json)?)),
415            None => Ok(None),
416        }
417    }
418
419    async fn get_enrichment_batch(
420        &self,
421        cve_ids: &[String],
422    ) -> Result<Vec<(String, EnrichmentData)>> {
423        if cve_ids.is_empty() {
424            return Ok(Vec::new());
425        }
426
427        let mut conn = self.get_connection().await?;
428        let keys: Vec<String> = cve_ids
429            .iter()
430            .map(|id| self.key(&format!("enrich:{}", id)))
431            .collect();
432
433        let data: Vec<Option<String>> =
434            redis::cmd("MGET").arg(&keys).query_async(&mut conn).await?;
435
436        let mut results = Vec::new();
437        for (cve_id, json_opt) in cve_ids.iter().zip(data) {
438            if let Some(json) = json_opt {
439                if let Ok(enrichment) = serde_json::from_str(&json) {
440                    results.push((cve_id.clone(), enrichment));
441                }
442            }
443        }
444
445        Ok(results)
446    }
447
448    async fn update_sync_timestamp(&self, source: &str) -> Result<()> {
449        let mut conn = self.get_connection().await?;
450        let _: () = conn
451            .set(
452                self.key(&format!("meta:{}", source)),
453                chrono::Utc::now().to_rfc3339(),
454            )
455            .await?;
456        Ok(())
457    }
458
459    async fn reset_sync_timestamp(&self, source: &str) -> Result<()> {
460        let mut conn = self.get_connection().await?;
461        let _: () = conn.del(self.key(&format!("meta:{}", source))).await?;
462        info!("Reset sync timestamp for {}", source);
463        Ok(())
464    }
465
466    async fn advisory_count(&self) -> Result<u64> {
467        let mut conn = self.get_connection().await?;
468        let pattern = self.key("data:*");
469
470        // Use SCAN to count keys matching pattern
471        let mut count = 0u64;
472        let mut cursor = 0u64;
473
474        loop {
475            let (new_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
476                .arg(cursor)
477                .arg("MATCH")
478                .arg(&pattern)
479                .arg("COUNT")
480                .arg(1000)
481                .query_async(&mut conn)
482                .await?;
483
484            count += keys.len() as u64;
485            cursor = new_cursor;
486
487            if cursor == 0 {
488                break;
489            }
490        }
491
492        Ok(count)
493    }
494
495    async fn store_ossindex_cache(&self, purl: &str, cache: &OssIndexCache) -> Result<()> {
496        let mut conn = self.get_connection().await?;
497        let key = self.key(&format!("ossidx:{}", Self::hash_purl(purl)));
498        let json = serde_json::to_string(cache)?;
499
500        // Use the remaining TTL or the configured TTL
501        let ttl = cache.remaining_ttl().max(1) as u64;
502        redis::cmd("SETEX")
503            .arg(&key)
504            .arg(ttl)
505            .arg(json)
506            .query_async::<()>(&mut conn)
507            .await?;
508
509        Ok(())
510    }
511
512    async fn get_ossindex_cache(&self, purl: &str) -> Result<Option<OssIndexCache>> {
513        let mut conn = self.get_connection().await?;
514        let key = self.key(&format!("ossidx:{}", Self::hash_purl(purl)));
515        let data: Option<String> = conn.get(&key).await?;
516
517        match data {
518            Some(json) => {
519                let cache: OssIndexCache = serde_json::from_str(&json)?;
520                // Double-check validity (Redis TTL should handle this, but be safe)
521                if cache.is_valid() {
522                    Ok(Some(cache))
523                } else {
524                    // Cache expired, delete it
525                    let _: () = conn.del(&key).await?;
526                    Ok(None)
527                }
528            }
529            None => Ok(None),
530        }
531    }
532
533    async fn invalidate_ossindex_cache(&self, purl: &str) -> Result<()> {
534        let mut conn = self.get_connection().await?;
535        let key = self.key(&format!("ossidx:{}", Self::hash_purl(purl)));
536        let _: () = conn.del(&key).await?;
537        Ok(())
538    }
539
540    async fn invalidate_all_ossindex_cache(&self) -> Result<u64> {
541        let mut conn = self.get_connection().await?;
542        let pattern = self.key("ossidx:*");
543
544        // Use SCAN to find all OSS Index cache keys
545        let mut deleted = 0u64;
546        let mut cursor = 0u64;
547
548        loop {
549            let (new_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
550                .arg(cursor)
551                .arg("MATCH")
552                .arg(&pattern)
553                .arg("COUNT")
554                .arg(1000)
555                .query_async(&mut conn)
556                .await?;
557
558            if !keys.is_empty() {
559                let count: u64 = redis::cmd("DEL").arg(&keys).query_async(&mut conn).await?;
560                deleted += count;
561            }
562
563            cursor = new_cursor;
564            if cursor == 0 {
565                break;
566            }
567        }
568
569        Ok(deleted)
570    }
571}
572
573impl DragonflyStore {
574    /// Generate a hash key for a PURL string.
575    fn hash_purl(purl: &str) -> String {
576        use std::collections::hash_map::DefaultHasher;
577        use std::hash::{Hash, Hasher};
578
579        let mut hasher = DefaultHasher::new();
580        purl.hash(&mut hasher);
581        format!("{:x}", hasher.finish())
582    }
583}