Skip to main content

shiplog_cache/
sqlite.rs

1//! SQLite-backed implementation of `ApiCache` for shiplog API responses.
2
3use anyhow::{Context, Result};
4use chrono::Duration;
5use rusqlite::{Connection, OptionalExtension, params};
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use std::path::Path;
9
10use crate::expiry::{CacheExpiryWindow, now_rfc3339};
11use crate::stats::CacheStats;
12
13/// Cache for API responses backed by a local SQLite database.
14#[derive(Debug)]
15pub struct ApiCache {
16    conn: Connection,
17    default_ttl: Duration,
18    #[allow(dead_code)]
19    max_size_bytes: Option<u64>,
20}
21
22impl ApiCache {
23    /// Open or create cache at the given path.
24    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
25        let conn = Connection::open(path).context("open cache database")?;
26
27        conn.execute(
28            "CREATE TABLE IF NOT EXISTS cache_entries (
29                key TEXT PRIMARY KEY,
30                data TEXT NOT NULL,
31                cached_at TEXT NOT NULL,
32                expires_at TEXT NOT NULL
33            )",
34            [],
35        )?;
36
37        conn.execute(
38            "CREATE INDEX IF NOT EXISTS idx_expires ON cache_entries(expires_at)",
39            [],
40        )?;
41
42        Ok(Self {
43            conn,
44            default_ttl: Duration::hours(24),
45            max_size_bytes: None,
46        })
47    }
48
49    /// Create an in-memory cache (for testing).
50    pub fn open_in_memory() -> Result<Self> {
51        let conn = Connection::open_in_memory().context("open in-memory cache")?;
52
53        conn.execute(
54            "CREATE TABLE cache_entries (
55                key TEXT PRIMARY KEY,
56                data TEXT NOT NULL,
57                cached_at TEXT NOT NULL,
58                expires_at TEXT NOT NULL
59            )",
60            [],
61        )?;
62
63        Ok(Self {
64            conn,
65            default_ttl: Duration::hours(24),
66            max_size_bytes: None,
67        })
68    }
69
70    /// Set the default TTL for cache entries.
71    pub fn with_ttl(mut self, ttl: Duration) -> Self {
72        self.default_ttl = ttl;
73        self
74    }
75
76    /// Create a cache with a maximum size limit.
77    pub fn with_max_size(mut self, max_size_bytes: u64) -> Self {
78        self.max_size_bytes = Some(max_size_bytes);
79        self
80    }
81
82    /// Get a cached value if it exists and hasn't expired.
83    pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
84        let now = now_rfc3339();
85
86        let row: Option<String> = self
87            .conn
88            .query_row(
89                "SELECT data FROM cache_entries WHERE key = ?1 AND expires_at > ?2",
90                params![key, now],
91                |row| row.get(0),
92            )
93            .optional()?;
94
95        match row {
96            Some(data) => {
97                let value: T = serde_json::from_str(&data)
98                    .with_context(|| format!("deserialize cached value for key: {key}"))?;
99                Ok(Some(value))
100            }
101            None => Ok(None),
102        }
103    }
104
105    /// Store a value in the cache.
106    pub fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
107        self.set_with_ttl(key, value, self.default_ttl)
108    }
109
110    /// Store a value with a custom TTL.
111    pub fn set_with_ttl<T: Serialize>(&self, key: &str, value: &T, ttl: Duration) -> Result<()> {
112        let window = CacheExpiryWindow::from_now(ttl);
113        let data = serde_json::to_string(value)
114            .with_context(|| format!("serialize value for key: {key}"))?;
115
116        self.conn.execute(
117            "INSERT OR REPLACE INTO cache_entries (key, data, cached_at, expires_at) VALUES (?1, ?2, ?3, ?4)",
118            params![
119                key,
120                data,
121                window.cached_at_rfc3339(),
122                window.expires_at_rfc3339(),
123            ],
124        )?;
125
126        Ok(())
127    }
128
129    /// Check if a key exists and hasn't expired.
130    pub fn contains(&self, key: &str) -> Result<bool> {
131        let now = now_rfc3339();
132
133        let count: i64 = self.conn.query_row(
134            "SELECT COUNT(*) FROM cache_entries WHERE key = ?1 AND expires_at > ?2",
135            params![key, now],
136            |row| row.get(0),
137        )?;
138
139        Ok(count > 0)
140    }
141
142    /// Remove expired entries from the cache.
143    pub fn cleanup_expired(&self) -> Result<usize> {
144        let now = now_rfc3339();
145
146        let deleted = self.conn.execute(
147            "DELETE FROM cache_entries WHERE expires_at <= ?1",
148            params![now],
149        )?;
150
151        Ok(deleted)
152    }
153
154    /// Clear all entries from the cache.
155    pub fn clear(&self) -> Result<()> {
156        self.conn.execute("DELETE FROM cache_entries", [])?;
157        Ok(())
158    }
159
160    /// Get cache statistics.
161    pub fn stats(&self) -> Result<CacheStats> {
162        let now = now_rfc3339();
163
164        let total: i64 = self
165            .conn
166            .query_row("SELECT COUNT(*) FROM cache_entries", [], |row| row.get(0))?;
167
168        let expired: i64 = self.conn.query_row(
169            "SELECT COUNT(*) FROM cache_entries WHERE expires_at <= ?1",
170            params![now],
171            |row| row.get(0),
172        )?;
173
174        let size_bytes: i64 =
175            self.conn
176                .query_row("SELECT SUM(LENGTH(data)) FROM cache_entries", [], |row| {
177                    Ok(row.get::<_, Option<i64>>(0).unwrap_or(Some(0)).unwrap_or(0))
178                })?;
179
180        Ok(CacheStats::from_raw_counts(total, expired, size_bytes))
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::CacheKey;
188
189    #[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Clone)]
190    struct TestData {
191        name: String,
192        count: u32,
193    }
194
195    #[test]
196    fn cache_basic_operations() {
197        let cache = ApiCache::open_in_memory().unwrap();
198
199        let data = TestData {
200            name: "test".to_string(),
201            count: 42,
202        };
203
204        let result: Option<TestData> = cache.get("key1").unwrap();
205        assert!(result.is_none());
206
207        cache.set("key1", &data).unwrap();
208
209        let result: Option<TestData> = cache.get("key1").unwrap();
210        assert_eq!(result, Some(data));
211    }
212
213    #[test]
214    fn cache_ttl_expiration() {
215        let cache = ApiCache::open_in_memory()
216            .unwrap()
217            .with_ttl(Duration::seconds(1));
218
219        let data = TestData {
220            name: "test".to_string(),
221            count: 42,
222        };
223
224        cache.set("key1", &data).unwrap();
225
226        let result: Option<TestData> = cache.get("key1").unwrap();
227        assert_eq!(result, Some(data.clone()));
228
229        std::thread::sleep(std::time::Duration::from_millis(1100));
230
231        let result: Option<TestData> = cache.get("key1").unwrap();
232        assert!(result.is_none());
233    }
234
235    #[test]
236    fn cache_stats() {
237        let cache = ApiCache::open_in_memory().unwrap();
238
239        let data = TestData {
240            name: "test".to_string(),
241            count: 42,
242        };
243
244        cache.set("key1", &data).unwrap();
245        cache.set("key2", &data).unwrap();
246
247        let stats = cache.stats().unwrap();
248        assert_eq!(stats.total_entries, 2);
249        assert_eq!(stats.valid_entries, 2);
250        assert_eq!(stats.expired_entries, 0);
251    }
252
253    #[test]
254    fn cache_cleanup() {
255        let cache = ApiCache::open_in_memory().unwrap();
256
257        let data = TestData {
258            name: "test".to_string(),
259            count: 42,
260        };
261
262        cache
263            .set_with_ttl("key1", &data, Duration::seconds(-1))
264            .unwrap();
265
266        let deleted = cache.cleanup_expired().unwrap();
267        assert_eq!(deleted, 1);
268
269        let stats = cache.stats().unwrap();
270        assert_eq!(stats.expired_entries, 0);
271    }
272
273    #[test]
274    fn cache_clear() {
275        let cache = ApiCache::open_in_memory().unwrap();
276
277        let data = TestData {
278            name: "test".to_string(),
279            count: 42,
280        };
281
282        cache.set("key1", &data).unwrap();
283        cache.set("key2", &data).unwrap();
284
285        cache.clear().unwrap();
286
287        let stats = cache.stats().unwrap();
288        assert_eq!(stats.total_entries, 0);
289    }
290
291    #[test]
292    fn cache_contains() {
293        let cache = ApiCache::open_in_memory().unwrap();
294
295        let data = TestData {
296            name: "test".to_string(),
297            count: 42,
298        };
299
300        assert!(!cache.contains("key1").unwrap());
301
302        cache.set("key1", &data).unwrap();
303        assert!(cache.contains("key1").unwrap());
304    }
305
306    #[test]
307    fn cache_key_reexport_matches_contract() {
308        let details = CacheKey::pr_details("https://api.github.com/repos/o/r/pulls/1");
309        let reviews = CacheKey::pr_reviews("https://api.github.com/repos/o/r/pulls/1", 2);
310        let notes = CacheKey::mr_notes(12, 34, 1);
311
312        assert_eq!(
313            details,
314            "pr:details:https://api.github.com/repos/o/r/pulls/1"
315        );
316        assert_eq!(
317            reviews,
318            "pr:reviews:https://api.github.com/repos/o/r/pulls/1:page2"
319        );
320        assert_eq!(notes, "gitlab:mr:notes:project12:mr34:page1");
321    }
322
323    #[test]
324    fn cache_stats_reexport_matches_contract() {
325        let stats = CacheStats::from_raw_counts(5, 2, 2 * 1024 * 1024 + 77);
326        assert_eq!(stats.total_entries, 5);
327        assert_eq!(stats.expired_entries, 2);
328        assert_eq!(stats.valid_entries, 3);
329        assert_eq!(stats.cache_size_mb, 2);
330    }
331}