Skip to main content

tandem_memory/
response_cache.rs

1//! LLM Response Cache — avoid burning tokens on repeated prompts.
2//!
3//! Stores LLM responses in a separate SQLite table keyed by a SHA-256 hash of
4//! `(model, system_prompt_hash, user_prompt)`. Entries expire after a
5//! configurable TTL. The cache is optional and disabled by default — users
6//! opt in via `TANDEM_RESPONSE_CACHE_ENABLED=true`.
7//!
8//! Lives alongside `memory.sqlite` as `response_cache.db` so it can be
9//! independently wiped without touching memory chunks.
10
11use chrono::{Duration, Utc};
12use rusqlite::{params, Connection};
13use sha2::{Digest, Sha256};
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use tokio::sync::Mutex;
17
18use crate::types::{MemoryError, MemoryResult};
19
20/// Response cache backed by a dedicated SQLite database.
21pub struct ResponseCache {
22    conn: Arc<Mutex<Connection>>,
23    #[allow(dead_code)]
24    db_path: PathBuf,
25    ttl_minutes: i64,
26    max_entries: usize,
27}
28
29impl ResponseCache {
30    /// Open (or create) the response cache database at `{db_dir}/response_cache.db`.
31    pub async fn new(db_dir: &Path, ttl_minutes: u32, max_entries: usize) -> MemoryResult<Self> {
32        tokio::fs::create_dir_all(db_dir)
33            .await
34            .map_err(MemoryError::Io)?;
35
36        let db_path = db_dir.join("response_cache.db");
37
38        let conn = Connection::open(&db_path)?;
39        conn.execute_batch(
40            "PRAGMA journal_mode = WAL;
41             PRAGMA synchronous  = NORMAL;
42             PRAGMA temp_store   = MEMORY;",
43        )?;
44
45        conn.execute_batch(
46            "CREATE TABLE IF NOT EXISTS response_cache (
47                prompt_hash  TEXT PRIMARY KEY,
48                model        TEXT NOT NULL,
49                response     TEXT NOT NULL,
50                token_count  INTEGER NOT NULL DEFAULT 0,
51                created_at   TEXT NOT NULL,
52                accessed_at  TEXT NOT NULL,
53                hit_count    INTEGER NOT NULL DEFAULT 0
54            );
55            CREATE INDEX IF NOT EXISTS idx_rc_accessed ON response_cache(accessed_at);
56            CREATE INDEX IF NOT EXISTS idx_rc_created  ON response_cache(created_at);",
57        )?;
58
59        Ok(Self {
60            conn: Arc::new(Mutex::new(conn)),
61            db_path,
62            ttl_minutes: i64::from(ttl_minutes),
63            max_entries,
64        })
65    }
66
67    /// Build a deterministic cache key from model + system prompt + user prompt.
68    pub fn cache_key(model: &str, system_prompt: Option<&str>, user_prompt: &str) -> String {
69        let mut hasher = Sha256::new();
70        hasher.update(model.as_bytes());
71        hasher.update(b"|");
72        if let Some(sys) = system_prompt {
73            hasher.update(sys.as_bytes());
74        }
75        hasher.update(b"|");
76        hasher.update(user_prompt.as_bytes());
77        format!("{:064x}", hasher.finalize())
78    }
79
80    /// Look up a cached response. Returns `None` on miss or if the entry has expired.
81    pub async fn get(&self, key: &str) -> MemoryResult<Option<String>> {
82        let conn = self.conn.lock().await;
83        let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
84
85        let result: Option<String> = conn
86            .query_row(
87                "SELECT response FROM response_cache
88                 WHERE prompt_hash = ?1 AND created_at > ?2",
89                params![key, cutoff],
90                |row| row.get(0),
91            )
92            .ok();
93
94        if result.is_some() {
95            let now = Utc::now().to_rfc3339();
96            conn.execute(
97                "UPDATE response_cache
98                 SET accessed_at = ?1, hit_count = hit_count + 1
99                 WHERE prompt_hash = ?2",
100                params![now, key],
101            )?;
102        }
103
104        Ok(result)
105    }
106
107    /// Store a response in the cache, evicting expired or least-recently-used entries.
108    pub async fn put(
109        &self,
110        key: &str,
111        model: &str,
112        response: &str,
113        token_count: u32,
114    ) -> MemoryResult<()> {
115        let conn = self.conn.lock().await;
116        let now = Utc::now().to_rfc3339();
117
118        conn.execute(
119            "INSERT OR REPLACE INTO response_cache
120             (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count)
121             VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0)",
122            params![key, model, response, token_count, now, now],
123        )?;
124
125        // Evict expired entries
126        let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
127        conn.execute(
128            "DELETE FROM response_cache WHERE created_at <= ?1",
129            params![cutoff],
130        )?;
131
132        // LRU eviction if over max_entries
133        #[allow(clippy::cast_possible_wrap)]
134        let max = self.max_entries as i64;
135        conn.execute(
136            "DELETE FROM response_cache WHERE prompt_hash IN (
137                SELECT prompt_hash FROM response_cache
138                ORDER BY accessed_at ASC
139                LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
140            )",
141            params![max],
142        )?;
143
144        Ok(())
145    }
146
147    /// Return cache statistics: `(total_entries, total_hits, estimated_tokens_saved)`.
148    pub async fn stats(&self) -> MemoryResult<(usize, u64, u64)> {
149        let conn = self.conn.lock().await;
150
151        let count: i64 =
152            conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
153
154        let hits: i64 = conn.query_row(
155            "SELECT COALESCE(SUM(hit_count), 0) FROM response_cache",
156            [],
157            |row| row.get(0),
158        )?;
159
160        let tokens_saved: i64 = conn.query_row(
161            "SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache",
162            [],
163            |row| row.get(0),
164        )?;
165
166        #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
167        Ok((count as usize, hits as u64, tokens_saved as u64))
168    }
169
170    /// Clear all cached entries.
171    pub async fn clear(&self) -> MemoryResult<usize> {
172        let conn = self.conn.lock().await;
173        let affected = conn.execute("DELETE FROM response_cache", [])?;
174        Ok(affected)
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use tempfile::TempDir;
182
183    async fn temp_cache(ttl_minutes: u32) -> (TempDir, ResponseCache) {
184        let tmp = TempDir::new().unwrap();
185        let cache = ResponseCache::new(tmp.path(), ttl_minutes, 1000)
186            .await
187            .unwrap();
188        (tmp, cache)
189    }
190
191    #[tokio::test]
192    async fn cache_key_is_deterministic() {
193        let k1 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
194        let k2 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
195        assert_eq!(k1, k2);
196        assert_eq!(k1.len(), 64);
197    }
198
199    #[tokio::test]
200    async fn cache_key_varies_by_model() {
201        let k1 = ResponseCache::cache_key("gpt-4", None, "hello");
202        let k2 = ResponseCache::cache_key("claude-3", None, "hello");
203        assert_ne!(k1, k2);
204    }
205
206    #[tokio::test]
207    async fn put_and_get_roundtrip() {
208        let (_tmp, cache) = temp_cache(60).await;
209        let key = ResponseCache::cache_key("gpt-4", None, "What is Rust?");
210        cache
211            .put(&key, "gpt-4", "Rust is a systems programming language.", 25)
212            .await
213            .unwrap();
214        let result = cache.get(&key).await.unwrap();
215        assert_eq!(
216            result.as_deref(),
217            Some("Rust is a systems programming language.")
218        );
219    }
220
221    #[tokio::test]
222    async fn miss_returns_none() {
223        let (_tmp, cache) = temp_cache(60).await;
224        let result = cache.get("nonexistent").await.unwrap();
225        assert!(result.is_none());
226    }
227
228    #[tokio::test]
229    async fn expired_entry_returns_none() {
230        let (_tmp, cache) = temp_cache(0).await; // 0 TTL → instantly expired
231        let key = ResponseCache::cache_key("gpt-4", None, "test");
232        cache.put(&key, "gpt-4", "response", 10).await.unwrap();
233        let result = cache.get(&key).await.unwrap();
234        assert!(result.is_none());
235    }
236
237    #[tokio::test]
238    async fn stats_tracks_hits_and_tokens() {
239        let (_tmp, cache) = temp_cache(60).await;
240        let key = ResponseCache::cache_key("gpt-4", None, "explain rust");
241        cache.put(&key, "gpt-4", "Rust is...", 100).await.unwrap();
242        for _ in 0..5 {
243            let _ = cache.get(&key).await.unwrap();
244        }
245        let (_, hits, tokens) = cache.stats().await.unwrap();
246        assert_eq!(hits, 5);
247        assert_eq!(tokens, 500);
248    }
249
250    #[tokio::test]
251    async fn lru_eviction_respects_max_entries() {
252        let tmp = TempDir::new().unwrap();
253        let cache = ResponseCache::new(tmp.path(), 60, 3).await.unwrap();
254        for i in 0..5 {
255            let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
256            cache
257                .put(&key, "gpt-4", &format!("response {i}"), 10)
258                .await
259                .unwrap();
260        }
261        let (count, _, _) = cache.stats().await.unwrap();
262        assert!(count <= 3, "cache must not exceed max_entries");
263    }
264}