Skip to main content

talon_core/expansion/
cache.rs

1//! SQL-backed cache for LLM expansion responses with millisecond-precision TTL.
2
3use rusqlite::Connection;
4
5use crate::TalonError;
6use crate::indexing::change_tracking::now_ms;
7
8/// Persistent key-value cache backed by the `llm_cache` `SQLite` table.
9///
10/// Entries expire after the TTL supplied to [`put`](LlmCache::put) elapses.
11/// Expired entries are invisible to [`get`](LlmCache::get) but remain in the
12/// database until [`purge_expired`](LlmCache::purge_expired) is called.
13pub struct LlmCache<'conn> {
14    conn: &'conn Connection,
15}
16
17impl std::fmt::Debug for LlmCache<'_> {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("LlmCache").finish_non_exhaustive()
20    }
21}
22
23impl<'conn> LlmCache<'conn> {
24    /// Wraps a database connection.
25    pub const fn new(conn: &'conn Connection) -> Self {
26        Self { conn }
27    }
28
29    /// Returns the cached value for `key` if it exists and has not expired.
30    ///
31    /// # Errors
32    ///
33    /// Returns [`TalonError::Sqlite`] on a database error.
34    pub fn get(&self, key: &str) -> Result<Option<String>, TalonError> {
35        let now = now_ms().cast_signed();
36        let result = self.conn.query_row(
37            "SELECT value FROM llm_cache WHERE key = ?1 AND expires_at_ms > ?2",
38            rusqlite::params![key, now],
39            |row| row.get::<_, String>(0),
40        );
41        match result {
42            Ok(value) => Ok(Some(value)),
43            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
44            Err(source) => Err(TalonError::Sqlite {
45                context: "llm_cache get",
46                source,
47            }),
48        }
49    }
50
51    /// Stores `value` under `key` with a time-to-live of `ttl_ms` milliseconds.
52    ///
53    /// Overwrites any existing entry for the same key.
54    ///
55    /// # Errors
56    ///
57    /// Returns [`TalonError::Sqlite`] on a database error.
58    pub fn put(&self, key: &str, value: &str, ttl_ms: u64) -> Result<(), TalonError> {
59        let expires_at_ms = (now_ms() + ttl_ms).cast_signed();
60        self.conn
61            .execute(
62                "INSERT OR REPLACE INTO llm_cache (key, value, expires_at_ms) \
63                 VALUES (?1, ?2, ?3)",
64                rusqlite::params![key, value, expires_at_ms],
65            )
66            .map_err(|source| TalonError::Sqlite {
67                context: "llm_cache put",
68                source,
69            })?;
70        Ok(())
71    }
72
73    /// Deletes all entries whose TTL has elapsed.
74    ///
75    /// Returns the number of rows removed.
76    ///
77    /// # Errors
78    ///
79    /// Returns [`TalonError::Sqlite`] on a database error.
80    pub fn purge_expired(&self) -> Result<usize, TalonError> {
81        let now = now_ms().cast_signed();
82        let count = self
83            .conn
84            .execute(
85                "DELETE FROM llm_cache WHERE expires_at_ms <= ?1",
86                rusqlite::params![now],
87            )
88            .map_err(|source| TalonError::Sqlite {
89                context: "llm_cache purge_expired",
90                source,
91            })?;
92        Ok(count)
93    }
94}
95
96#[cfg(test)]
97#[allow(clippy::unwrap_used)]
98mod tests {
99    use super::*;
100    use rusqlite::Connection;
101
102    use crate::indexing::migrations::run_migrations;
103
104    fn fresh_db() -> Connection {
105        let mut conn = Connection::open_in_memory().unwrap();
106        run_migrations(&mut conn).unwrap();
107        conn
108    }
109
110    #[test]
111    fn round_trip_stores_and_retrieves_value() {
112        let conn = fresh_db();
113        let cache = LlmCache::new(&conn);
114        cache.put("k1", "hello world", 60_000).unwrap();
115        let result = cache.get("k1").unwrap();
116        assert_eq!(result.as_deref(), Some("hello world"));
117    }
118
119    #[test]
120    fn missing_key_returns_none() {
121        let conn = fresh_db();
122        let cache = LlmCache::new(&conn);
123        let result = cache.get("nonexistent").unwrap();
124        assert!(result.is_none());
125    }
126
127    #[test]
128    fn expired_entry_returns_none() {
129        let conn = fresh_db();
130        let cache = LlmCache::new(&conn);
131        // TTL of 0 means expires_at_ms == now_ms, which fails the > check.
132        cache.put("k2", "stale", 0).unwrap();
133        let result = cache.get("k2").unwrap();
134        assert!(result.is_none(), "expired entry must not be returned");
135    }
136
137    #[test]
138    fn put_overwrites_existing_key() {
139        let conn = fresh_db();
140        let cache = LlmCache::new(&conn);
141        cache.put("k3", "first", 60_000).unwrap();
142        cache.put("k3", "second", 60_000).unwrap();
143        assert_eq!(cache.get("k3").unwrap().as_deref(), Some("second"));
144    }
145
146    #[test]
147    fn purge_expired_removes_stale_rows() {
148        let conn = fresh_db();
149        let cache = LlmCache::new(&conn);
150        cache.put("stale1", "v1", 0).unwrap();
151        cache.put("stale2", "v2", 0).unwrap();
152        cache.put("live", "v3", 60_000).unwrap();
153
154        let removed = cache.purge_expired().unwrap();
155        assert_eq!(removed, 2, "two stale entries should be purged");
156
157        let count: i64 = conn
158            .query_row("SELECT COUNT(*) FROM llm_cache", [], |r| r.get(0))
159            .unwrap();
160        assert_eq!(count, 1, "only the live entry should remain");
161    }
162}