1use chrono::{Duration, Utc};
12use rusqlite::{params, Connection};
13use sha2::{Digest, Sha256};
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use tokio::sync::Mutex;
17
18use crate::types::{MemoryError, MemoryResult};
19
20pub struct ResponseCache {
22 conn: Arc<Mutex<Connection>>,
23 #[allow(dead_code)]
24 db_path: PathBuf,
25 ttl_minutes: i64,
26 max_entries: usize,
27}
28
29impl ResponseCache {
30 pub async fn new(db_dir: &Path, ttl_minutes: u32, max_entries: usize) -> MemoryResult<Self> {
32 tokio::fs::create_dir_all(db_dir)
33 .await
34 .map_err(MemoryError::Io)?;
35
36 let db_path = db_dir.join("response_cache.db");
37
38 let conn = Connection::open(&db_path)?;
39 conn.execute_batch(
40 "PRAGMA journal_mode = WAL;
41 PRAGMA synchronous = NORMAL;
42 PRAGMA temp_store = MEMORY;",
43 )?;
44
45 conn.execute_batch(
46 "CREATE TABLE IF NOT EXISTS response_cache (
47 prompt_hash TEXT PRIMARY KEY,
48 model TEXT NOT NULL,
49 response TEXT NOT NULL,
50 token_count INTEGER NOT NULL DEFAULT 0,
51 created_at TEXT NOT NULL,
52 accessed_at TEXT NOT NULL,
53 hit_count INTEGER NOT NULL DEFAULT 0
54 );
55 CREATE INDEX IF NOT EXISTS idx_rc_accessed ON response_cache(accessed_at);
56 CREATE INDEX IF NOT EXISTS idx_rc_created ON response_cache(created_at);",
57 )?;
58
59 Ok(Self {
60 conn: Arc::new(Mutex::new(conn)),
61 db_path,
62 ttl_minutes: i64::from(ttl_minutes),
63 max_entries,
64 })
65 }
66
67 pub fn cache_key(model: &str, system_prompt: Option<&str>, user_prompt: &str) -> String {
69 let mut hasher = Sha256::new();
70 hasher.update(model.as_bytes());
71 hasher.update(b"|");
72 if let Some(sys) = system_prompt {
73 hasher.update(sys.as_bytes());
74 }
75 hasher.update(b"|");
76 hasher.update(user_prompt.as_bytes());
77 format!("{:064x}", hasher.finalize())
78 }
79
80 pub async fn get(&self, key: &str) -> MemoryResult<Option<String>> {
82 let conn = self.conn.lock().await;
83 let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
84
85 let result: Option<String> = conn
86 .query_row(
87 "SELECT response FROM response_cache
88 WHERE prompt_hash = ?1 AND created_at > ?2",
89 params![key, cutoff],
90 |row| row.get(0),
91 )
92 .ok();
93
94 if result.is_some() {
95 let now = Utc::now().to_rfc3339();
96 conn.execute(
97 "UPDATE response_cache
98 SET accessed_at = ?1, hit_count = hit_count + 1
99 WHERE prompt_hash = ?2",
100 params![now, key],
101 )?;
102 }
103
104 Ok(result)
105 }
106
107 pub async fn put(
109 &self,
110 key: &str,
111 model: &str,
112 response: &str,
113 token_count: u32,
114 ) -> MemoryResult<()> {
115 let conn = self.conn.lock().await;
116 let now = Utc::now().to_rfc3339();
117
118 conn.execute(
119 "INSERT OR REPLACE INTO response_cache
120 (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count)
121 VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0)",
122 params![key, model, response, token_count, now, now],
123 )?;
124
125 let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
127 conn.execute(
128 "DELETE FROM response_cache WHERE created_at <= ?1",
129 params![cutoff],
130 )?;
131
132 #[allow(clippy::cast_possible_wrap)]
134 let max = self.max_entries as i64;
135 conn.execute(
136 "DELETE FROM response_cache WHERE prompt_hash IN (
137 SELECT prompt_hash FROM response_cache
138 ORDER BY accessed_at ASC
139 LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
140 )",
141 params![max],
142 )?;
143
144 Ok(())
145 }
146
147 pub async fn stats(&self) -> MemoryResult<(usize, u64, u64)> {
149 let conn = self.conn.lock().await;
150
151 let count: i64 =
152 conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
153
154 let hits: i64 = conn.query_row(
155 "SELECT COALESCE(SUM(hit_count), 0) FROM response_cache",
156 [],
157 |row| row.get(0),
158 )?;
159
160 let tokens_saved: i64 = conn.query_row(
161 "SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache",
162 [],
163 |row| row.get(0),
164 )?;
165
166 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
167 Ok((count as usize, hits as u64, tokens_saved as u64))
168 }
169
170 pub async fn clear(&self) -> MemoryResult<usize> {
172 let conn = self.conn.lock().await;
173 let affected = conn.execute("DELETE FROM response_cache", [])?;
174 Ok(affected)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use tempfile::TempDir;
182
183 async fn temp_cache(ttl_minutes: u32) -> (TempDir, ResponseCache) {
184 let tmp = TempDir::new().unwrap();
185 let cache = ResponseCache::new(tmp.path(), ttl_minutes, 1000)
186 .await
187 .unwrap();
188 (tmp, cache)
189 }
190
191 #[tokio::test]
192 async fn cache_key_is_deterministic() {
193 let k1 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
194 let k2 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
195 assert_eq!(k1, k2);
196 assert_eq!(k1.len(), 64);
197 }
198
199 #[tokio::test]
200 async fn cache_key_varies_by_model() {
201 let k1 = ResponseCache::cache_key("gpt-4", None, "hello");
202 let k2 = ResponseCache::cache_key("claude-3", None, "hello");
203 assert_ne!(k1, k2);
204 }
205
206 #[tokio::test]
207 async fn put_and_get_roundtrip() {
208 let (_tmp, cache) = temp_cache(60).await;
209 let key = ResponseCache::cache_key("gpt-4", None, "What is Rust?");
210 cache
211 .put(&key, "gpt-4", "Rust is a systems programming language.", 25)
212 .await
213 .unwrap();
214 let result = cache.get(&key).await.unwrap();
215 assert_eq!(
216 result.as_deref(),
217 Some("Rust is a systems programming language.")
218 );
219 }
220
221 #[tokio::test]
222 async fn miss_returns_none() {
223 let (_tmp, cache) = temp_cache(60).await;
224 let result = cache.get("nonexistent").await.unwrap();
225 assert!(result.is_none());
226 }
227
228 #[tokio::test]
229 async fn expired_entry_returns_none() {
230 let (_tmp, cache) = temp_cache(0).await; let key = ResponseCache::cache_key("gpt-4", None, "test");
232 cache.put(&key, "gpt-4", "response", 10).await.unwrap();
233 let result = cache.get(&key).await.unwrap();
234 assert!(result.is_none());
235 }
236
237 #[tokio::test]
238 async fn stats_tracks_hits_and_tokens() {
239 let (_tmp, cache) = temp_cache(60).await;
240 let key = ResponseCache::cache_key("gpt-4", None, "explain rust");
241 cache.put(&key, "gpt-4", "Rust is...", 100).await.unwrap();
242 for _ in 0..5 {
243 let _ = cache.get(&key).await.unwrap();
244 }
245 let (_, hits, tokens) = cache.stats().await.unwrap();
246 assert_eq!(hits, 5);
247 assert_eq!(tokens, 500);
248 }
249
250 #[tokio::test]
251 async fn lru_eviction_respects_max_entries() {
252 let tmp = TempDir::new().unwrap();
253 let cache = ResponseCache::new(tmp.path(), 60, 3).await.unwrap();
254 for i in 0..5 {
255 let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
256 cache
257 .put(&key, "gpt-4", &format!("response {i}"), 10)
258 .await
259 .unwrap();
260 }
261 let (count, _, _) = cache.stats().await.unwrap();
262 assert!(count <= 3, "cache must not exceed max_entries");
263 }
264}