Skip to main content

sbom_tools/enrichment/
cache.rs

1//! File-based cache for vulnerability data.
2
3use crate::error::Result;
4use crate::model::VulnerabilityRef;
5use sha2::{Digest, Sha256};
6use std::fs;
7use std::path::PathBuf;
8use std::time::Duration;
9
10/// Cache key for vulnerability lookups.
11#[derive(Debug, Clone, Hash, PartialEq, Eq)]
12pub struct CacheKey {
13    /// Package URL (preferred)
14    pub purl: Option<String>,
15    /// Component name
16    pub name: String,
17    /// Ecosystem (npm, pypi, etc.)
18    pub ecosystem: Option<String>,
19    /// Version
20    pub version: Option<String>,
21}
22
23impl CacheKey {
24    /// Create a cache key from component data.
25    #[must_use]
26    pub const fn new(
27        purl: Option<String>,
28        name: String,
29        ecosystem: Option<String>,
30        version: Option<String>,
31    ) -> Self {
32        Self {
33            purl,
34            name,
35            ecosystem,
36            version,
37        }
38    }
39
40    /// Convert to a filesystem-safe filename using SHA256 hash.
41    #[must_use]
42    pub fn to_filename(&self) -> String {
43        let mut hasher = Sha256::new();
44        hasher.update(format!(
45            "purl:{:?}|name:{}|eco:{:?}|ver:{:?}",
46            self.purl, self.name, self.ecosystem, self.version
47        ));
48        let hash = hasher.finalize();
49        format!("{hash:x}.json")
50    }
51
52    /// Check if this key can be used for an OSV query.
53    #[must_use]
54    pub const fn is_queryable(&self) -> bool {
55        // Need either a PURL or name + ecosystem + version
56        self.purl.is_some() || (self.ecosystem.is_some() && self.version.is_some())
57    }
58}
59
60/// File-based cache with TTL support.
61pub struct FileCache {
62    /// Cache directory
63    cache_dir: PathBuf,
64    /// Time-to-live for cached entries
65    ttl: Duration,
66}
67
68impl FileCache {
69    /// Create a new file cache.
70    pub fn new(cache_dir: PathBuf, ttl: Duration) -> Result<Self> {
71        // Ensure cache directory exists
72        if !cache_dir.exists() {
73            fs::create_dir_all(&cache_dir)?;
74        }
75        Ok(Self { cache_dir, ttl })
76    }
77
78    /// Get cached vulnerabilities for a key.
79    ///
80    /// Returns None if not cached or cache is expired.
81    #[must_use]
82    pub fn get(&self, key: &CacheKey) -> Option<Vec<VulnerabilityRef>> {
83        let path = self.cache_dir.join(key.to_filename());
84
85        // Check if file exists
86        let metadata = fs::metadata(&path).ok()?;
87
88        // Check TTL
89        let modified = metadata.modified().ok()?;
90        let age = modified.elapsed().ok()?;
91        if age > self.ttl {
92            // Cache expired, remove it
93            let _ = fs::remove_file(&path);
94            return None;
95        }
96
97        // Read and parse
98        let data = fs::read_to_string(&path).ok()?;
99        serde_json::from_str(&data).ok()
100    }
101
102    /// Store vulnerabilities in the cache.
103    pub fn set(&self, key: &CacheKey, vulns: &[VulnerabilityRef]) -> Result<()> {
104        let path = self.cache_dir.join(key.to_filename());
105        let data = serde_json::to_string(vulns)?;
106        fs::write(path, data)?;
107        Ok(())
108    }
109
110    /// Remove a cached entry.
111    pub fn remove(&self, key: &CacheKey) -> Result<()> {
112        let path = self.cache_dir.join(key.to_filename());
113        if path.exists() {
114            fs::remove_file(path)?;
115        }
116        Ok(())
117    }
118
119    /// Clear all cached entries.
120    pub fn clear(&self) -> Result<()> {
121        if self.cache_dir.exists() {
122            for entry in fs::read_dir(&self.cache_dir)? {
123                let entry = entry?;
124                if entry.path().extension().is_some_and(|e| e == "json") {
125                    let _ = fs::remove_file(entry.path());
126                }
127            }
128        }
129        Ok(())
130    }
131
132    /// Get cache statistics.
133    #[must_use]
134    pub fn stats(&self) -> CacheStats {
135        let mut stats = CacheStats::default();
136
137        if let Ok(entries) = fs::read_dir(&self.cache_dir) {
138            for entry in entries.flatten() {
139                if entry.path().extension().is_some_and(|e| e == "json") {
140                    stats.total_entries += 1;
141                    if let Ok(metadata) = entry.metadata() {
142                        stats.total_size += metadata.len();
143
144                        // Check if expired
145                        if let Ok(modified) = metadata.modified()
146                            && let Ok(age) = modified.elapsed()
147                            && age > self.ttl
148                        {
149                            stats.expired_entries += 1;
150                        }
151                    }
152                }
153            }
154        }
155
156        stats
157    }
158}
159
160/// Cache statistics.
161#[derive(Debug, Default)]
162pub struct CacheStats {
163    /// Total number of cached entries
164    pub total_entries: usize,
165    /// Number of expired entries
166    pub expired_entries: usize,
167    /// Total size in bytes
168    pub total_size: u64,
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    fn make_key(purl: Option<&str>, name: &str, eco: Option<&str>, ver: Option<&str>) -> CacheKey {
176        CacheKey::new(
177            purl.map(String::from),
178            name.to_string(),
179            eco.map(String::from),
180            ver.map(String::from),
181        )
182    }
183
184    #[test]
185    fn test_cache_key_filename_deterministic() {
186        let key = make_key(Some("pkg:npm/foo@1.0"), "foo", Some("npm"), Some("1.0"));
187        let f1 = key.to_filename();
188        let f2 = key.to_filename();
189        assert_eq!(f1, f2);
190        assert!(f1.ends_with(".json"));
191    }
192
193    #[test]
194    fn test_cache_key_filename_different() {
195        let k1 = make_key(Some("pkg:npm/foo@1.0"), "foo", Some("npm"), Some("1.0"));
196        let k2 = make_key(Some("pkg:npm/bar@1.0"), "bar", Some("npm"), Some("1.0"));
197        assert_ne!(k1.to_filename(), k2.to_filename());
198    }
199
200    #[test]
201    fn test_cache_key_is_queryable_purl() {
202        let key = make_key(Some("pkg:npm/foo@1.0"), "foo", None, None);
203        assert!(key.is_queryable());
204    }
205
206    #[test]
207    fn test_cache_key_is_queryable_eco_ver() {
208        let key = make_key(None, "foo", Some("npm"), Some("1.0"));
209        assert!(key.is_queryable());
210    }
211
212    #[test]
213    fn test_cache_key_is_queryable_name_only() {
214        let key = make_key(None, "foo", None, None);
215        assert!(!key.is_queryable());
216    }
217
218    #[test]
219    fn test_file_cache_new_creates_dir() {
220        let tmp = tempfile::tempdir().unwrap();
221        let cache_dir = tmp.path().join("vuln_cache");
222        assert!(!cache_dir.exists());
223        let _cache = FileCache::new(cache_dir.clone(), Duration::from_secs(3600)).unwrap();
224        assert!(cache_dir.exists());
225    }
226
227    #[test]
228    fn test_file_cache_set_get_roundtrip() {
229        let tmp = tempfile::tempdir().unwrap();
230        let cache = FileCache::new(tmp.path().to_path_buf(), Duration::from_secs(3600)).unwrap();
231        let key = make_key(Some("pkg:npm/foo@1.0"), "foo", Some("npm"), Some("1.0"));
232
233        let vulns = vec![VulnerabilityRef::new(
234            "CVE-2024-0001".to_string(),
235            crate::model::VulnerabilitySource::Osv,
236        )];
237
238        cache.set(&key, &vulns).unwrap();
239        let result = cache.get(&key);
240        assert!(result.is_some());
241        let retrieved = result.unwrap();
242        assert_eq!(retrieved.len(), 1);
243        assert_eq!(retrieved[0].id, "CVE-2024-0001");
244    }
245
246    #[test]
247    fn test_file_cache_get_miss() {
248        let tmp = tempfile::tempdir().unwrap();
249        let cache = FileCache::new(tmp.path().to_path_buf(), Duration::from_secs(3600)).unwrap();
250        let key = make_key(Some("pkg:npm/nope@1.0"), "nope", Some("npm"), Some("1.0"));
251        assert!(cache.get(&key).is_none());
252    }
253
254    #[test]
255    fn test_file_cache_remove() {
256        let tmp = tempfile::tempdir().unwrap();
257        let cache = FileCache::new(tmp.path().to_path_buf(), Duration::from_secs(3600)).unwrap();
258        let key = make_key(Some("pkg:npm/rm@1.0"), "rm", Some("npm"), Some("1.0"));
259
260        cache.set(&key, &[]).unwrap();
261        assert!(cache.get(&key).is_some());
262        cache.remove(&key).unwrap();
263        assert!(cache.get(&key).is_none());
264    }
265
266    #[test]
267    fn test_file_cache_clear() {
268        let tmp = tempfile::tempdir().unwrap();
269        let cache = FileCache::new(tmp.path().to_path_buf(), Duration::from_secs(3600)).unwrap();
270
271        for i in 0..3 {
272            let key = make_key(None, &format!("pkg{i}"), Some("npm"), Some("1.0"));
273            cache.set(&key, &[]).unwrap();
274        }
275
276        assert_eq!(cache.stats().total_entries, 3);
277        cache.clear().unwrap();
278        assert_eq!(cache.stats().total_entries, 0);
279    }
280
281    #[test]
282    fn test_file_cache_stats_counts() {
283        let tmp = tempfile::tempdir().unwrap();
284        let cache = FileCache::new(tmp.path().to_path_buf(), Duration::from_secs(3600)).unwrap();
285
286        for i in 0..3 {
287            let key = make_key(None, &format!("stats{i}"), Some("npm"), Some("1.0"));
288            cache.set(&key, &[]).unwrap();
289        }
290
291        let stats = cache.stats();
292        assert_eq!(stats.total_entries, 3);
293        assert_eq!(stats.expired_entries, 0);
294    }
295}