talon_core/expansion/
cache.rs1use rusqlite::Connection;
4
5use crate::TalonError;
6use crate::indexing::change_tracking::now_ms;
7
8pub struct LlmCache<'conn> {
14 conn: &'conn Connection,
15}
16
17impl std::fmt::Debug for LlmCache<'_> {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("LlmCache").finish_non_exhaustive()
20 }
21}
22
23impl<'conn> LlmCache<'conn> {
24 pub const fn new(conn: &'conn Connection) -> Self {
26 Self { conn }
27 }
28
29 pub fn get(&self, key: &str) -> Result<Option<String>, TalonError> {
35 let now = now_ms().cast_signed();
36 let result = self.conn.query_row(
37 "SELECT value FROM llm_cache WHERE key = ?1 AND expires_at_ms > ?2",
38 rusqlite::params![key, now],
39 |row| row.get::<_, String>(0),
40 );
41 match result {
42 Ok(value) => Ok(Some(value)),
43 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
44 Err(source) => Err(TalonError::Sqlite {
45 context: "llm_cache get",
46 source,
47 }),
48 }
49 }
50
51 pub fn put(&self, key: &str, value: &str, ttl_ms: u64) -> Result<(), TalonError> {
59 let expires_at_ms = (now_ms() + ttl_ms).cast_signed();
60 self.conn
61 .execute(
62 "INSERT OR REPLACE INTO llm_cache (key, value, expires_at_ms) \
63 VALUES (?1, ?2, ?3)",
64 rusqlite::params![key, value, expires_at_ms],
65 )
66 .map_err(|source| TalonError::Sqlite {
67 context: "llm_cache put",
68 source,
69 })?;
70 Ok(())
71 }
72
73 pub fn purge_expired(&self) -> Result<usize, TalonError> {
81 let now = now_ms().cast_signed();
82 let count = self
83 .conn
84 .execute(
85 "DELETE FROM llm_cache WHERE expires_at_ms <= ?1",
86 rusqlite::params![now],
87 )
88 .map_err(|source| TalonError::Sqlite {
89 context: "llm_cache purge_expired",
90 source,
91 })?;
92 Ok(count)
93 }
94}
95
96#[cfg(test)]
97#[allow(clippy::unwrap_used)]
98mod tests {
99 use super::*;
100 use rusqlite::Connection;
101
102 use crate::indexing::migrations::run_migrations;
103
104 fn fresh_db() -> Connection {
105 let mut conn = Connection::open_in_memory().unwrap();
106 run_migrations(&mut conn).unwrap();
107 conn
108 }
109
110 #[test]
111 fn round_trip_stores_and_retrieves_value() {
112 let conn = fresh_db();
113 let cache = LlmCache::new(&conn);
114 cache.put("k1", "hello world", 60_000).unwrap();
115 let result = cache.get("k1").unwrap();
116 assert_eq!(result.as_deref(), Some("hello world"));
117 }
118
119 #[test]
120 fn missing_key_returns_none() {
121 let conn = fresh_db();
122 let cache = LlmCache::new(&conn);
123 let result = cache.get("nonexistent").unwrap();
124 assert!(result.is_none());
125 }
126
127 #[test]
128 fn expired_entry_returns_none() {
129 let conn = fresh_db();
130 let cache = LlmCache::new(&conn);
131 cache.put("k2", "stale", 0).unwrap();
133 let result = cache.get("k2").unwrap();
134 assert!(result.is_none(), "expired entry must not be returned");
135 }
136
137 #[test]
138 fn put_overwrites_existing_key() {
139 let conn = fresh_db();
140 let cache = LlmCache::new(&conn);
141 cache.put("k3", "first", 60_000).unwrap();
142 cache.put("k3", "second", 60_000).unwrap();
143 assert_eq!(cache.get("k3").unwrap().as_deref(), Some("second"));
144 }
145
146 #[test]
147 fn purge_expired_removes_stale_rows() {
148 let conn = fresh_db();
149 let cache = LlmCache::new(&conn);
150 cache.put("stale1", "v1", 0).unwrap();
151 cache.put("stale2", "v2", 0).unwrap();
152 cache.put("live", "v3", 60_000).unwrap();
153
154 let removed = cache.purge_expired().unwrap();
155 assert_eq!(removed, 2, "two stale entries should be purged");
156
157 let count: i64 = conn
158 .query_row("SELECT COUNT(*) FROM llm_cache", [], |r| r.get(0))
159 .unwrap();
160 assert_eq!(count, 1, "only the live entry should remain");
161 }
162}