Skip to main content

tandem_memory/
response_cache.rs

1//! LLM Response Cache — avoid burning tokens on repeated prompts.
2//!
3//! Stores LLM responses in a separate SQLite table keyed by a SHA-256 hash of
4//! `(model, system_prompt_hash, user_prompt)`. Entries expire after a
5//! configurable TTL. The cache is optional and disabled by default — users
6//! opt in via `TANDEM_RESPONSE_CACHE_ENABLED=true`.
7//!
8//! Lives alongside `memory.sqlite` as `response_cache.db` so it can be
9//! independently wiped without touching memory chunks.
10
11use chrono::{Duration, Utc};
12use rusqlite::{params, Connection};
13use sha2::{Digest, Sha256};
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use tokio::sync::Mutex;
17
18use crate::types::{MemoryError, MemoryResult};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct ResponseCacheScope {
22    pub tenant_org_id: String,
23    pub tenant_workspace_id: String,
24    pub tenant_deployment_id: Option<String>,
25    pub source_binding_ids: Vec<String>,
26}
27
28impl ResponseCacheScope {
29    pub fn tenant(
30        tenant_org_id: impl Into<String>,
31        tenant_workspace_id: impl Into<String>,
32        tenant_deployment_id: Option<String>,
33    ) -> Self {
34        Self {
35            tenant_org_id: tenant_org_id.into(),
36            tenant_workspace_id: tenant_workspace_id.into(),
37            tenant_deployment_id,
38            source_binding_ids: Vec::new(),
39        }
40    }
41
42    pub fn with_source_bindings(mut self, source_binding_ids: Vec<String>) -> Self {
43        self.source_binding_ids = normalized_source_binding_ids(source_binding_ids);
44        self
45    }
46
47    fn source_binding_key(&self) -> String {
48        source_binding_key(&self.source_binding_ids)
49    }
50
51    fn fingerprint(&self) -> String {
52        format!(
53            "org={}|workspace={}|deployment={}|source_bindings={}",
54            self.tenant_org_id,
55            self.tenant_workspace_id,
56            self.tenant_deployment_id.as_deref().unwrap_or(""),
57            self.source_binding_key()
58        )
59    }
60}
61
62/// Response cache backed by a dedicated SQLite database.
63pub struct ResponseCache {
64    conn: Arc<Mutex<Connection>>,
65    #[allow(dead_code)]
66    db_path: PathBuf,
67    ttl_minutes: i64,
68    max_entries: usize,
69}
70
71impl ResponseCache {
72    /// Open (or create) the response cache database at `{db_dir}/response_cache.db`.
73    pub async fn new(db_dir: &Path, ttl_minutes: u32, max_entries: usize) -> MemoryResult<Self> {
74        tokio::fs::create_dir_all(db_dir)
75            .await
76            .map_err(MemoryError::Io)?;
77
78        let db_path = db_dir.join("response_cache.db");
79
80        let conn = Connection::open(&db_path)?;
81        conn.execute_batch(
82            "PRAGMA journal_mode = WAL;
83             PRAGMA synchronous  = NORMAL;
84             PRAGMA temp_store   = MEMORY;",
85        )?;
86
87        conn.execute_batch(
88            "CREATE TABLE IF NOT EXISTS response_cache (
89                prompt_hash  TEXT PRIMARY KEY,
90                model        TEXT NOT NULL,
91                response     TEXT NOT NULL,
92                token_count  INTEGER NOT NULL DEFAULT 0,
93                created_at   TEXT NOT NULL,
94                accessed_at  TEXT NOT NULL,
95                hit_count    INTEGER NOT NULL DEFAULT 0,
96                tenant_org_id TEXT,
97                tenant_workspace_id TEXT,
98                tenant_deployment_id TEXT,
99                source_binding_key TEXT NOT NULL DEFAULT ''
100            );
101            CREATE INDEX IF NOT EXISTS idx_rc_accessed ON response_cache(accessed_at);
102            CREATE INDEX IF NOT EXISTS idx_rc_created  ON response_cache(created_at);
103            CREATE INDEX IF NOT EXISTS idx_rc_tenant_scope
104                ON response_cache(tenant_org_id, tenant_workspace_id, tenant_deployment_id);
105            CREATE INDEX IF NOT EXISTS idx_rc_source_binding
106                ON response_cache(source_binding_key);",
107        )?;
108        migrate_response_cache_scope_columns(&conn)?;
109
110        Ok(Self {
111            conn: Arc::new(Mutex::new(conn)),
112            db_path,
113            ttl_minutes: i64::from(ttl_minutes),
114            max_entries,
115        })
116    }
117
118    /// Build a deterministic cache key from model + system prompt + user prompt.
119    pub fn cache_key(model: &str, system_prompt: Option<&str>, user_prompt: &str) -> String {
120        let mut hasher = Sha256::new();
121        hasher.update(model.as_bytes());
122        hasher.update(b"|");
123        if let Some(sys) = system_prompt {
124            hasher.update(sys.as_bytes());
125        }
126        hasher.update(b"|");
127        hasher.update(user_prompt.as_bytes());
128        format!("{:064x}", hasher.finalize())
129    }
130
131    /// Build a deterministic cache key that is partitioned by tenant and source bindings.
132    pub fn cache_key_scoped(
133        model: &str,
134        system_prompt: Option<&str>,
135        user_prompt: &str,
136        scope: &ResponseCacheScope,
137    ) -> String {
138        let mut hasher = Sha256::new();
139        hasher.update(model.as_bytes());
140        hasher.update(b"|");
141        if let Some(sys) = system_prompt {
142            hasher.update(sys.as_bytes());
143        }
144        hasher.update(b"|");
145        hasher.update(user_prompt.as_bytes());
146        hasher.update(b"|");
147        hasher.update(scope.fingerprint().as_bytes());
148        format!("{:064x}", hasher.finalize())
149    }
150
151    /// Look up a cached response. Returns `None` on miss or if the entry has expired.
152    pub async fn get(&self, key: &str) -> MemoryResult<Option<String>> {
153        let conn = self.conn.lock().await;
154        let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
155
156        let result: Option<String> = conn
157            .query_row(
158                "SELECT response FROM response_cache
159                 WHERE prompt_hash = ?1 AND created_at > ?2",
160                params![key, cutoff],
161                |row| row.get(0),
162            )
163            .ok();
164
165        if result.is_some() {
166            let now = Utc::now().to_rfc3339();
167            conn.execute(
168                "UPDATE response_cache
169                 SET accessed_at = ?1, hit_count = hit_count + 1
170                 WHERE prompt_hash = ?2",
171                params![now, key],
172            )?;
173        }
174
175        Ok(result)
176    }
177
178    /// Store a response in the cache, evicting expired or least-recently-used entries.
179    pub async fn put(
180        &self,
181        key: &str,
182        model: &str,
183        response: &str,
184        token_count: u32,
185    ) -> MemoryResult<()> {
186        let conn = self.conn.lock().await;
187        let now = Utc::now().to_rfc3339();
188
189        conn.execute(
190            "INSERT OR REPLACE INTO response_cache
191             (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count)
192             VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0)",
193            params![key, model, response, token_count, now, now],
194        )?;
195
196        // Evict expired entries
197        let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
198        conn.execute(
199            "DELETE FROM response_cache WHERE created_at <= ?1",
200            params![cutoff],
201        )?;
202
203        // LRU eviction if over max_entries
204        #[allow(clippy::cast_possible_wrap)]
205        let max = self.max_entries as i64;
206        conn.execute(
207            "DELETE FROM response_cache WHERE prompt_hash IN (
208                SELECT prompt_hash FROM response_cache
209                ORDER BY accessed_at ASC
210                LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
211            )",
212            params![max],
213        )?;
214
215        Ok(())
216    }
217
218    pub async fn put_scoped(
219        &self,
220        key: &str,
221        model: &str,
222        response: &str,
223        token_count: u32,
224        scope: &ResponseCacheScope,
225    ) -> MemoryResult<()> {
226        let conn = self.conn.lock().await;
227        let now = Utc::now().to_rfc3339();
228        let source_binding_key = scope.source_binding_key();
229
230        conn.execute(
231            "INSERT OR REPLACE INTO response_cache
232             (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count,
233              tenant_org_id, tenant_workspace_id, tenant_deployment_id, source_binding_key)
234             VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0, ?7, ?8, ?9, ?10)",
235            params![
236                key,
237                model,
238                response,
239                token_count,
240                now,
241                now,
242                scope.tenant_org_id,
243                scope.tenant_workspace_id,
244                scope.tenant_deployment_id,
245                source_binding_key
246            ],
247        )?;
248
249        self.evict_locked(&conn)?;
250        Ok(())
251    }
252
253    pub async fn invalidate_source_binding(
254        &self,
255        tenant_org_id: &str,
256        tenant_workspace_id: &str,
257        tenant_deployment_id: Option<&str>,
258        source_binding_id: &str,
259    ) -> MemoryResult<usize> {
260        let conn = self.conn.lock().await;
261        let needle = format!("%|{}|%", normalize_source_binding_id(source_binding_id));
262        let affected = conn.execute(
263            "DELETE FROM response_cache
264             WHERE tenant_org_id = ?1
265               AND tenant_workspace_id = ?2
266               AND IFNULL(tenant_deployment_id, '') = IFNULL(?3, '')
267               AND source_binding_key LIKE ?4",
268            params![
269                tenant_org_id,
270                tenant_workspace_id,
271                tenant_deployment_id,
272                needle
273            ],
274        )?;
275        Ok(affected)
276    }
277
278    pub async fn invalidate_tenant(
279        &self,
280        tenant_org_id: &str,
281        tenant_workspace_id: &str,
282        tenant_deployment_id: Option<&str>,
283    ) -> MemoryResult<usize> {
284        let conn = self.conn.lock().await;
285        let affected = conn.execute(
286            "DELETE FROM response_cache
287             WHERE tenant_org_id = ?1
288               AND tenant_workspace_id = ?2
289               AND IFNULL(tenant_deployment_id, '') = IFNULL(?3, '')",
290            params![tenant_org_id, tenant_workspace_id, tenant_deployment_id],
291        )?;
292        Ok(affected)
293    }
294
295    /// Return cache statistics: `(total_entries, total_hits, estimated_tokens_saved)`.
296    pub async fn stats(&self) -> MemoryResult<(usize, u64, u64)> {
297        let conn = self.conn.lock().await;
298
299        let count: i64 =
300            conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
301
302        let hits: i64 = conn.query_row(
303            "SELECT COALESCE(SUM(hit_count), 0) FROM response_cache",
304            [],
305            |row| row.get(0),
306        )?;
307
308        let tokens_saved: i64 = conn.query_row(
309            "SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache",
310            [],
311            |row| row.get(0),
312        )?;
313
314        #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
315        Ok((count as usize, hits as u64, tokens_saved as u64))
316    }
317
318    /// Clear all cached entries.
319    pub async fn clear(&self) -> MemoryResult<usize> {
320        let conn = self.conn.lock().await;
321        let affected = conn.execute("DELETE FROM response_cache", [])?;
322        Ok(affected)
323    }
324
325    fn evict_locked(&self, conn: &Connection) -> MemoryResult<()> {
326        let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
327        conn.execute(
328            "DELETE FROM response_cache WHERE created_at <= ?1",
329            params![cutoff],
330        )?;
331
332        #[allow(clippy::cast_possible_wrap)]
333        let max = self.max_entries as i64;
334        conn.execute(
335            "DELETE FROM response_cache WHERE prompt_hash IN (
336                SELECT prompt_hash FROM response_cache
337                ORDER BY accessed_at ASC
338                LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
339            )",
340            params![max],
341        )?;
342
343        Ok(())
344    }
345}
346
347fn migrate_response_cache_scope_columns(conn: &Connection) -> MemoryResult<()> {
348    let columns = conn
349        .prepare("PRAGMA table_info(response_cache)")?
350        .query_map([], |row| row.get::<_, String>(1))?
351        .collect::<Result<std::collections::HashSet<_>, _>>()?;
352    for (name, ddl) in [
353        (
354            "tenant_org_id",
355            "ALTER TABLE response_cache ADD COLUMN tenant_org_id TEXT",
356        ),
357        (
358            "tenant_workspace_id",
359            "ALTER TABLE response_cache ADD COLUMN tenant_workspace_id TEXT",
360        ),
361        (
362            "tenant_deployment_id",
363            "ALTER TABLE response_cache ADD COLUMN tenant_deployment_id TEXT",
364        ),
365        (
366            "source_binding_key",
367            "ALTER TABLE response_cache ADD COLUMN source_binding_key TEXT NOT NULL DEFAULT ''",
368        ),
369    ] {
370        if !columns.contains(name) {
371            conn.execute(ddl, [])?;
372        }
373    }
374    Ok(())
375}
376
377fn normalized_source_binding_ids(source_binding_ids: Vec<String>) -> Vec<String> {
378    let mut ids = source_binding_ids
379        .into_iter()
380        .map(|id| normalize_source_binding_id(&id))
381        .filter(|id| !id.is_empty())
382        .collect::<Vec<_>>();
383    ids.sort();
384    ids.dedup();
385    ids
386}
387
388fn normalize_source_binding_id(source_binding_id: &str) -> String {
389    source_binding_id.trim().replace('|', "")
390}
391
392fn source_binding_key(source_binding_ids: &[String]) -> String {
393    if source_binding_ids.is_empty() {
394        return String::new();
395    }
396    format!("|{}|", source_binding_ids.join("|"))
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use tempfile::TempDir;
403
404    async fn temp_cache(ttl_minutes: u32) -> (TempDir, ResponseCache) {
405        let tmp = TempDir::new().unwrap();
406        let cache = ResponseCache::new(tmp.path(), ttl_minutes, 1000)
407            .await
408            .unwrap();
409        (tmp, cache)
410    }
411
412    #[tokio::test]
413    async fn cache_key_is_deterministic() {
414        let k1 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
415        let k2 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
416        assert_eq!(k1, k2);
417        assert_eq!(k1.len(), 64);
418    }
419
420    #[tokio::test]
421    async fn cache_key_varies_by_model() {
422        let k1 = ResponseCache::cache_key("gpt-4", None, "hello");
423        let k2 = ResponseCache::cache_key("claude-3", None, "hello");
424        assert_ne!(k1, k2);
425    }
426
427    #[tokio::test]
428    async fn scoped_cache_key_varies_by_tenant_and_source_binding() {
429        let scope_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
430            .with_source_bindings(vec!["finance-drive".to_string()]);
431        let scope_b = ResponseCacheScope::tenant("org-a", "workspace-a", None)
432            .with_source_bindings(vec!["hr-drive".to_string()]);
433        let key_a = ResponseCache::cache_key_scoped("gpt-4", Some("sys"), "hello", &scope_a);
434        let key_b = ResponseCache::cache_key_scoped("gpt-4", Some("sys"), "hello", &scope_b);
435        assert_ne!(key_a, key_b);
436    }
437
438    #[tokio::test]
439    async fn put_and_get_roundtrip() {
440        let (_tmp, cache) = temp_cache(60).await;
441        let key = ResponseCache::cache_key("gpt-4", None, "What is Rust?");
442        cache
443            .put(&key, "gpt-4", "Rust is a systems programming language.", 25)
444            .await
445            .unwrap();
446        let result = cache.get(&key).await.unwrap();
447        assert_eq!(
448            result.as_deref(),
449            Some("Rust is a systems programming language.")
450        );
451    }
452
453    #[tokio::test]
454    async fn miss_returns_none() {
455        let (_tmp, cache) = temp_cache(60).await;
456        let result = cache.get("nonexistent").await.unwrap();
457        assert!(result.is_none());
458    }
459
460    #[tokio::test]
461    async fn expired_entry_returns_none() {
462        let (_tmp, cache) = temp_cache(0).await; // 0 TTL → instantly expired
463        let key = ResponseCache::cache_key("gpt-4", None, "test");
464        cache.put(&key, "gpt-4", "response", 10).await.unwrap();
465        let result = cache.get(&key).await.unwrap();
466        assert!(result.is_none());
467    }
468
469    #[tokio::test]
470    async fn stats_tracks_hits_and_tokens() {
471        let (_tmp, cache) = temp_cache(60).await;
472        let key = ResponseCache::cache_key("gpt-4", None, "explain rust");
473        cache.put(&key, "gpt-4", "Rust is...", 100).await.unwrap();
474        for _ in 0..5 {
475            let _ = cache.get(&key).await.unwrap();
476        }
477        let (_, hits, tokens) = cache.stats().await.unwrap();
478        assert_eq!(hits, 5);
479        assert_eq!(tokens, 500);
480    }
481
482    #[tokio::test]
483    async fn lru_eviction_respects_max_entries() {
484        let tmp = TempDir::new().unwrap();
485        let cache = ResponseCache::new(tmp.path(), 60, 3).await.unwrap();
486        for i in 0..5 {
487            let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
488            cache
489                .put(&key, "gpt-4", &format!("response {i}"), 10)
490                .await
491                .unwrap();
492        }
493        let (count, _, _) = cache.stats().await.unwrap();
494        assert!(count <= 3, "cache must not exceed max_entries");
495    }
496
497    #[tokio::test]
498    async fn invalidate_source_binding_removes_only_matching_tenant_entries() {
499        let (_tmp, cache) = temp_cache(60).await;
500        let finance_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
501            .with_source_bindings(vec!["finance-drive".to_string()]);
502        let hr_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
503            .with_source_bindings(vec!["hr-drive".to_string()]);
504        let finance_b = ResponseCacheScope::tenant("org-b", "workspace-b", None)
505            .with_source_bindings(vec!["finance-drive".to_string()]);
506
507        for (idx, scope) in [&finance_a, &hr_a, &finance_b].into_iter().enumerate() {
508            let key =
509                ResponseCache::cache_key_scoped("gpt-4", None, &format!("prompt {idx}"), scope);
510            cache
511                .put_scoped(&key, "gpt-4", &format!("response {idx}"), 10, scope)
512                .await
513                .unwrap();
514        }
515
516        let removed = cache
517            .invalidate_source_binding("org-a", "workspace-a", None, "finance-drive")
518            .await
519            .unwrap();
520        assert_eq!(removed, 1);
521        let (count, _, _) = cache.stats().await.unwrap();
522        assert_eq!(count, 2);
523    }
524}