1use crate::config::StoreConfig;
7use crate::ecosystem::normalize_package_key;
8use crate::error::{AdvisoryError, Result};
9use crate::models::Advisory;
10use async_stream::try_stream;
11use async_trait::async_trait;
12use futures_util::Stream;
13use redis::AsyncCommands;
14use serde::{Deserialize, Serialize};
15use std::io::Write;
16use std::pin::Pin;
17use std::time::Instant;
18use tracing::{info, instrument};
19
20#[async_trait]
22pub trait AdvisoryStore: Send + Sync {
23 async fn upsert_batch(&self, advisories: &[Advisory], source: &str) -> Result<()>;
25
26 async fn get(&self, id: &str) -> Result<Option<Advisory>>;
28
29 async fn get_by_package(&self, ecosystem: &str, package: &str) -> Result<Vec<Advisory>>;
31
32 async fn last_sync(&self, source: &str) -> Result<Option<String>>;
34
35 async fn health_check(&self) -> Result<HealthStatus>;
37
38 async fn get_by_package_stream(
40 &self,
41 ecosystem: &str,
42 package: &str,
43 ) -> Result<Pin<Box<dyn Stream<Item = Result<Advisory>> + Send + '_>>>;
44
45 async fn get_batch(&self, ids: &[String]) -> Result<Vec<Advisory>>;
47
48 async fn store_enrichment(&self, cve_id: &str, data: &EnrichmentData) -> Result<()>;
50
51 async fn get_enrichment(&self, cve_id: &str) -> Result<Option<EnrichmentData>>;
53
54 async fn get_enrichment_batch(
56 &self,
57 cve_ids: &[String],
58 ) -> Result<Vec<(String, EnrichmentData)>>;
59
60 async fn update_sync_timestamp(&self, source: &str) -> Result<()>;
62
63 async fn reset_sync_timestamp(&self, source: &str) -> Result<()>;
65
66 async fn advisory_count(&self) -> Result<u64>;
68
69 async fn store_ossindex_cache(&self, purl: &str, cache: &OssIndexCache) -> Result<()>;
76
77 async fn get_ossindex_cache(&self, purl: &str) -> Result<Option<OssIndexCache>>;
81
82 async fn invalidate_ossindex_cache(&self, purl: &str) -> Result<()>;
84
85 async fn invalidate_all_ossindex_cache(&self) -> Result<u64>;
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct HealthStatus {
92 pub connected: bool,
94 pub latency_ms: u64,
96 pub advisory_count: u64,
98 pub server_info: Option<String>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct EnrichmentData {
105 pub epss_score: Option<f64>,
107 pub epss_percentile: Option<f64>,
109 pub is_kev: bool,
111 pub kev_due_date: Option<String>,
113 pub kev_date_added: Option<String>,
115 pub kev_ransomware: Option<bool>,
117 pub updated_at: String,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct OssIndexCache {
127 pub advisories: Vec<crate::models::Advisory>,
129 pub cached_at: chrono::DateTime<chrono::Utc>,
131 pub ttl_seconds: u64,
133}
134
135const DEFAULT_OSSINDEX_CACHE_TTL: u64 = 3600;
137
138impl OssIndexCache {
139 pub fn new(advisories: Vec<crate::models::Advisory>) -> Self {
141 Self {
142 advisories,
143 cached_at: chrono::Utc::now(),
144 ttl_seconds: DEFAULT_OSSINDEX_CACHE_TTL,
145 }
146 }
147
148 pub fn with_ttl(advisories: Vec<crate::models::Advisory>, ttl_seconds: u64) -> Self {
150 Self {
151 advisories,
152 cached_at: chrono::Utc::now(),
153 ttl_seconds,
154 }
155 }
156
157 pub fn is_valid(&self) -> bool {
159 !self.is_expired()
160 }
161
162 pub fn is_expired(&self) -> bool {
164 let age = chrono::Utc::now().signed_duration_since(self.cached_at);
165 age.num_seconds() >= self.ttl_seconds as i64
166 }
167
168 pub fn remaining_ttl(&self) -> i64 {
170 let age = chrono::Utc::now().signed_duration_since(self.cached_at);
171 (self.ttl_seconds as i64) - age.num_seconds()
172 }
173}
174
175pub struct DragonflyStore {
177 client: redis::Client,
178 config: StoreConfig,
179}
180
181impl DragonflyStore {
182 pub fn new(url: &str) -> Result<Self> {
184 Self::with_config(url, StoreConfig::default())
185 }
186
187 pub fn with_config(url: &str, config: StoreConfig) -> Result<Self> {
189 let client = redis::Client::open(url)?;
190 Ok(Self { client, config })
191 }
192
193 pub fn key_prefix(&self) -> &str {
195 &self.config.key_prefix
196 }
197
198 fn key(&self, suffix: &str) -> String {
200 format!("{}:{}", self.config.key_prefix, suffix)
201 }
202
203 fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
204 let mut encoder =
205 zstd::stream::write::Encoder::new(Vec::new(), self.config.compression_level)?;
206 encoder.write_all(data)?;
207 encoder
208 .finish()
209 .map_err(|e| AdvisoryError::compression(e.to_string()))
210 }
211
212 fn decompress(data: &[u8]) -> Result<Vec<u8>> {
213 let mut decoder = zstd::stream::read::Decoder::new(data)?;
214 let mut decoded = Vec::new();
215 std::io::Read::read_to_end(&mut decoder, &mut decoded)?;
216 Ok(decoded)
217 }
218
219 async fn get_connection(&self) -> Result<redis::aio::MultiplexedConnection> {
220 self.client
221 .get_multiplexed_async_connection()
222 .await
223 .map_err(AdvisoryError::from)
224 }
225}
226
227#[async_trait]
228impl AdvisoryStore for DragonflyStore {
229 #[instrument(skip(self, advisories), fields(count = advisories.len()))]
230 async fn upsert_batch(&self, advisories: &[Advisory], source: &str) -> Result<()> {
231 let mut conn = self.get_connection().await?;
232 let mut pipe = redis::pipe();
233
234 for advisory in advisories {
235 let json = serde_json::to_vec(advisory)?;
236 let compressed = self.compress(&json)?;
237
238 let data_key = self.key(&format!("data:{}", advisory.id));
239
240 if let Some(ttl) = self.config.ttl_seconds {
242 pipe.cmd("SETEX").arg(&data_key).arg(ttl).arg(compressed);
243 } else {
244 pipe.set(&data_key, compressed);
245 }
246
247 for affected in &advisory.affected {
249 let (ecosystem, package) =
250 normalize_package_key(&affected.package.ecosystem, &affected.package.name);
251 let idx_key = self.key(&format!("idx:{}:{}", ecosystem, package));
252 pipe.sadd(&idx_key, &advisory.id);
253 }
254 }
255
256 pipe.query_async::<()>(&mut conn).await?;
260 info!("Upserted {} advisories from {}", advisories.len(), source);
261 Ok(())
262 }
263
264 async fn get(&self, id: &str) -> Result<Option<Advisory>> {
265 let mut conn = self.get_connection().await?;
266 let data: Option<Vec<u8>> = conn.get(self.key(&format!("data:{}", id))).await?;
267
268 match data {
269 Some(bytes) => {
270 let decompressed = Self::decompress(&bytes)?;
271 let advisory = serde_json::from_slice(&decompressed)?;
272 Ok(Some(advisory))
273 }
274 None => Ok(None),
275 }
276 }
277
278 async fn get_by_package(&self, ecosystem: &str, package: &str) -> Result<Vec<Advisory>> {
279 let (ecosystem, package) = normalize_package_key(ecosystem, package);
280 let mut conn = self.get_connection().await?;
281 let ids: Vec<String> = conn
282 .smembers(self.key(&format!("idx:{}:{}", ecosystem, package)))
283 .await?;
284
285 self.get_batch(&ids).await
287 }
288
289 async fn last_sync(&self, source: &str) -> Result<Option<String>> {
290 let mut conn = self.get_connection().await?;
291 Ok(conn.get(self.key(&format!("meta:{}", source))).await?)
292 }
293
294 async fn health_check(&self) -> Result<HealthStatus> {
295 let start = Instant::now();
296
297 let mut conn = self.get_connection().await?;
298
299 let pong: String = redis::cmd("PING").query_async(&mut conn).await?;
301 let connected = pong == "PONG";
302
303 let latency_ms = start.elapsed().as_millis() as u64;
304
305 let advisory_count = self.advisory_count().await.unwrap_or(0);
307
308 let info: std::result::Result<String, _> = redis::cmd("INFO")
310 .arg("server")
311 .query_async(&mut conn)
312 .await;
313 let server_info = info.ok().and_then(|s| {
314 s.lines()
315 .find(|l| l.starts_with("redis_version:"))
316 .map(|l| l.to_string())
317 });
318
319 Ok(HealthStatus {
320 connected,
321 latency_ms,
322 advisory_count,
323 server_info,
324 })
325 }
326
327 async fn get_by_package_stream(
328 &self,
329 ecosystem: &str,
330 package: &str,
331 ) -> Result<Pin<Box<dyn Stream<Item = Result<Advisory>> + Send + '_>>> {
332 let (ecosystem, package) = normalize_package_key(ecosystem, package);
333 let idx_key = self.key(&format!("idx:{}:{}", ecosystem, package));
334
335 let stream = try_stream! {
336 let mut conn = self.get_connection().await?;
337
338 let mut cursor = 0u64;
340 loop {
341 let (new_cursor, ids): (u64, Vec<String>) = redis::cmd("SSCAN")
342 .arg(&idx_key)
343 .arg(cursor)
344 .arg("COUNT")
345 .arg(100)
346 .query_async(&mut conn)
347 .await?;
348
349 for id in ids {
350 if let Some(advisory) = self.get(&id).await? {
351 yield advisory;
352 }
353 }
354
355 cursor = new_cursor;
356 if cursor == 0 {
357 break;
358 }
359 }
360 };
361
362 Ok(Box::pin(stream))
363 }
364
365 async fn get_batch(&self, ids: &[String]) -> Result<Vec<Advisory>> {
366 if ids.is_empty() {
367 return Ok(Vec::new());
368 }
369
370 let mut conn = self.get_connection().await?;
371 let keys: Vec<String> = ids
372 .iter()
373 .map(|id| self.key(&format!("data:{}", id)))
374 .collect();
375
376 let data: Vec<Option<Vec<u8>>> =
377 redis::cmd("MGET").arg(&keys).query_async(&mut conn).await?;
378
379 let mut advisories = Vec::new();
380 for bytes in data.into_iter().flatten() {
381 let decompressed = Self::decompress(&bytes)?;
382 let advisory: Advisory = serde_json::from_slice(&decompressed)?;
383 advisories.push(advisory);
384 }
385
386 Ok(advisories)
387 }
388
389 async fn store_enrichment(&self, cve_id: &str, data: &EnrichmentData) -> Result<()> {
390 let mut conn = self.get_connection().await?;
391 let key = self.key(&format!("enrich:{}", cve_id));
392 let json = serde_json::to_string(data)?;
393
394 if let Some(ttl) = self.config.ttl_seconds {
395 redis::cmd("SETEX")
396 .arg(&key)
397 .arg(ttl)
398 .arg(json)
399 .query_async::<()>(&mut conn)
400 .await?;
401 } else {
402 let _: () = conn.set(&key, json).await?;
403 }
404
405 Ok(())
406 }
407
408 async fn get_enrichment(&self, cve_id: &str) -> Result<Option<EnrichmentData>> {
409 let mut conn = self.get_connection().await?;
410 let key = self.key(&format!("enrich:{}", cve_id));
411 let data: Option<String> = conn.get(&key).await?;
412
413 match data {
414 Some(json) => Ok(Some(serde_json::from_str(&json)?)),
415 None => Ok(None),
416 }
417 }
418
419 async fn get_enrichment_batch(
420 &self,
421 cve_ids: &[String],
422 ) -> Result<Vec<(String, EnrichmentData)>> {
423 if cve_ids.is_empty() {
424 return Ok(Vec::new());
425 }
426
427 let mut conn = self.get_connection().await?;
428 let keys: Vec<String> = cve_ids
429 .iter()
430 .map(|id| self.key(&format!("enrich:{}", id)))
431 .collect();
432
433 let data: Vec<Option<String>> =
434 redis::cmd("MGET").arg(&keys).query_async(&mut conn).await?;
435
436 let mut results = Vec::new();
437 for (cve_id, json_opt) in cve_ids.iter().zip(data) {
438 if let Some(json) = json_opt {
439 if let Ok(enrichment) = serde_json::from_str(&json) {
440 results.push((cve_id.clone(), enrichment));
441 }
442 }
443 }
444
445 Ok(results)
446 }
447
448 async fn update_sync_timestamp(&self, source: &str) -> Result<()> {
449 let mut conn = self.get_connection().await?;
450 let _: () = conn
451 .set(
452 self.key(&format!("meta:{}", source)),
453 chrono::Utc::now().to_rfc3339(),
454 )
455 .await?;
456 Ok(())
457 }
458
459 async fn reset_sync_timestamp(&self, source: &str) -> Result<()> {
460 let mut conn = self.get_connection().await?;
461 let _: () = conn.del(self.key(&format!("meta:{}", source))).await?;
462 info!("Reset sync timestamp for {}", source);
463 Ok(())
464 }
465
466 async fn advisory_count(&self) -> Result<u64> {
467 let mut conn = self.get_connection().await?;
468 let pattern = self.key("data:*");
469
470 let mut count = 0u64;
472 let mut cursor = 0u64;
473
474 loop {
475 let (new_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
476 .arg(cursor)
477 .arg("MATCH")
478 .arg(&pattern)
479 .arg("COUNT")
480 .arg(1000)
481 .query_async(&mut conn)
482 .await?;
483
484 count += keys.len() as u64;
485 cursor = new_cursor;
486
487 if cursor == 0 {
488 break;
489 }
490 }
491
492 Ok(count)
493 }
494
495 async fn store_ossindex_cache(&self, purl: &str, cache: &OssIndexCache) -> Result<()> {
496 let mut conn = self.get_connection().await?;
497 let key = self.key(&format!("ossidx:{}", Self::hash_purl(purl)));
498 let json = serde_json::to_string(cache)?;
499
500 let ttl = cache.remaining_ttl().max(1) as u64;
502 redis::cmd("SETEX")
503 .arg(&key)
504 .arg(ttl)
505 .arg(json)
506 .query_async::<()>(&mut conn)
507 .await?;
508
509 Ok(())
510 }
511
512 async fn get_ossindex_cache(&self, purl: &str) -> Result<Option<OssIndexCache>> {
513 let mut conn = self.get_connection().await?;
514 let key = self.key(&format!("ossidx:{}", Self::hash_purl(purl)));
515 let data: Option<String> = conn.get(&key).await?;
516
517 match data {
518 Some(json) => {
519 let cache: OssIndexCache = serde_json::from_str(&json)?;
520 if cache.is_valid() {
522 Ok(Some(cache))
523 } else {
524 let _: () = conn.del(&key).await?;
526 Ok(None)
527 }
528 }
529 None => Ok(None),
530 }
531 }
532
533 async fn invalidate_ossindex_cache(&self, purl: &str) -> Result<()> {
534 let mut conn = self.get_connection().await?;
535 let key = self.key(&format!("ossidx:{}", Self::hash_purl(purl)));
536 let _: () = conn.del(&key).await?;
537 Ok(())
538 }
539
540 async fn invalidate_all_ossindex_cache(&self) -> Result<u64> {
541 let mut conn = self.get_connection().await?;
542 let pattern = self.key("ossidx:*");
543
544 let mut deleted = 0u64;
546 let mut cursor = 0u64;
547
548 loop {
549 let (new_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
550 .arg(cursor)
551 .arg("MATCH")
552 .arg(&pattern)
553 .arg("COUNT")
554 .arg(1000)
555 .query_async(&mut conn)
556 .await?;
557
558 if !keys.is_empty() {
559 let count: u64 = redis::cmd("DEL").arg(&keys).query_async(&mut conn).await?;
560 deleted += count;
561 }
562
563 cursor = new_cursor;
564 if cursor == 0 {
565 break;
566 }
567 }
568
569 Ok(deleted)
570 }
571}
572
573impl DragonflyStore {
574 fn hash_purl(purl: &str) -> String {
576 use std::collections::hash_map::DefaultHasher;
577 use std::hash::{Hash, Hasher};
578
579 let mut hasher = DefaultHasher::new();
580 purl.hash(&mut hasher);
581 format!("{:x}", hasher.finish())
582 }
583}