Skip to main content

tandem_memory/
response_cache.rs

1//! LLM Response Cache — avoid burning tokens on repeated prompts.
2//!
3//! Stores LLM responses in a separate SQLite table keyed by a SHA-256 hash of
4//! `(model, system_prompt_hash, user_prompt)`. Entries expire after a
5//! configurable TTL. The cache is optional and disabled by default — users
6//! opt in via `TANDEM_RESPONSE_CACHE_ENABLED=true`.
7//!
8//! Lives alongside `memory.sqlite` as `response_cache.db` so it can be
9//! independently wiped without touching memory chunks.
10
11use chrono::{Duration, Utc};
12use rusqlite::{params, Connection};
13use sha2::{Digest, Sha256};
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use tokio::sync::Mutex;
17
18use crate::types::{MemoryError, MemoryResult};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct ResponseCacheScope {
22    pub tenant_org_id: String,
23    pub tenant_workspace_id: String,
24    pub tenant_deployment_id: Option<String>,
25    pub source_binding_ids: Vec<String>,
26}
27
28impl ResponseCacheScope {
29    pub fn tenant(
30        tenant_org_id: impl Into<String>,
31        tenant_workspace_id: impl Into<String>,
32        tenant_deployment_id: Option<String>,
33    ) -> Self {
34        Self {
35            tenant_org_id: tenant_org_id.into(),
36            tenant_workspace_id: tenant_workspace_id.into(),
37            tenant_deployment_id,
38            source_binding_ids: Vec::new(),
39        }
40    }
41
42    pub fn with_source_bindings(mut self, source_binding_ids: Vec<String>) -> Self {
43        self.source_binding_ids = normalized_source_binding_ids(source_binding_ids);
44        self
45    }
46
47    fn source_binding_key(&self) -> String {
48        source_binding_key(&self.source_binding_ids)
49    }
50
51    fn fingerprint(&self) -> String {
52        format!(
53            "org={}|workspace={}|deployment={}|source_bindings={}",
54            self.tenant_org_id,
55            self.tenant_workspace_id,
56            self.tenant_deployment_id.as_deref().unwrap_or(""),
57            self.source_binding_key()
58        )
59    }
60}
61
62/// Response cache backed by a dedicated SQLite database.
63pub struct ResponseCache {
64    conn: Arc<Mutex<Connection>>,
65    #[allow(dead_code)]
66    db_path: PathBuf,
67    ttl_minutes: i64,
68    max_entries: usize,
69    crypto: crate::crypto::MemoryCryptoProvider,
70}
71
72impl ResponseCache {
73    /// Open (or create) the response cache database at `{db_dir}/response_cache.db`.
74    pub async fn new(db_dir: &Path, ttl_minutes: u32, max_entries: usize) -> MemoryResult<Self> {
75        tokio::fs::create_dir_all(db_dir)
76            .await
77            .map_err(MemoryError::Io)?;
78
79        let db_path = db_dir.join("response_cache.db");
80
81        let conn = Connection::open(&db_path)?;
82        conn.execute_batch(
83            "PRAGMA journal_mode = WAL;
84             PRAGMA synchronous  = NORMAL;
85             PRAGMA temp_store   = MEMORY;",
86        )?;
87
88        conn.execute_batch(
89            "CREATE TABLE IF NOT EXISTS response_cache (
90                prompt_hash  TEXT PRIMARY KEY,
91                model        TEXT NOT NULL,
92                response     TEXT NOT NULL,
93                token_count  INTEGER NOT NULL DEFAULT 0,
94                created_at   TEXT NOT NULL,
95                accessed_at  TEXT NOT NULL,
96                hit_count    INTEGER NOT NULL DEFAULT 0,
97                tenant_org_id TEXT,
98                tenant_workspace_id TEXT,
99                tenant_deployment_id TEXT,
100                source_binding_key TEXT NOT NULL DEFAULT ''
101            );
102            CREATE INDEX IF NOT EXISTS idx_rc_accessed ON response_cache(accessed_at);
103            CREATE INDEX IF NOT EXISTS idx_rc_created  ON response_cache(created_at);
104            CREATE INDEX IF NOT EXISTS idx_rc_tenant_scope
105                ON response_cache(tenant_org_id, tenant_workspace_id, tenant_deployment_id);
106            CREATE INDEX IF NOT EXISTS idx_rc_source_binding
107                ON response_cache(source_binding_key);",
108        )?;
109        migrate_response_cache_scope_columns(&conn)?;
110
111        Ok(Self {
112            conn: Arc::new(Mutex::new(conn)),
113            db_path,
114            ttl_minutes: i64::from(ttl_minutes),
115            max_entries,
116            crypto: crate::crypto::MemoryCryptoProvider::from_env(),
117        })
118    }
119
120    /// Override the payload crypto provider (cached responses are semantic
121    /// memory and are encrypted at rest in encrypted modes). Used in tests.
122    pub fn with_crypto_provider(mut self, crypto: crate::crypto::MemoryCryptoProvider) -> Self {
123        self.crypto = crypto;
124        self
125    }
126
127    /// Build a deterministic cache key from model + system prompt + user prompt.
128    pub fn cache_key(model: &str, system_prompt: Option<&str>, user_prompt: &str) -> String {
129        let mut hasher = Sha256::new();
130        hasher.update(model.as_bytes());
131        hasher.update(b"|");
132        if let Some(sys) = system_prompt {
133            hasher.update(sys.as_bytes());
134        }
135        hasher.update(b"|");
136        hasher.update(user_prompt.as_bytes());
137        format!("{:064x}", hasher.finalize())
138    }
139
140    /// Build a deterministic cache key that is partitioned by tenant and source bindings.
141    pub fn cache_key_scoped(
142        model: &str,
143        system_prompt: Option<&str>,
144        user_prompt: &str,
145        scope: &ResponseCacheScope,
146    ) -> String {
147        let mut hasher = Sha256::new();
148        hasher.update(model.as_bytes());
149        hasher.update(b"|");
150        if let Some(sys) = system_prompt {
151            hasher.update(sys.as_bytes());
152        }
153        hasher.update(b"|");
154        hasher.update(user_prompt.as_bytes());
155        hasher.update(b"|");
156        hasher.update(scope.fingerprint().as_bytes());
157        format!("{:064x}", hasher.finalize())
158    }
159
160    /// Look up a cached response. Returns `None` on miss or if the entry has expired.
161    pub async fn get(&self, key: &str) -> MemoryResult<Option<String>> {
162        let conn = self.conn.lock().await;
163        let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
164
165        let stored: Option<String> = conn
166            .query_row(
167                "SELECT response FROM response_cache
168                 WHERE prompt_hash = ?1 AND created_at > ?2",
169                params![key, cutoff],
170                |row| row.get(0),
171            )
172            .ok();
173        let result = match stored {
174            Some(value) => Some(self.crypto.decrypt_field(&value)?),
175            None => None,
176        };
177
178        if result.is_some() {
179            let now = Utc::now().to_rfc3339();
180            conn.execute(
181                "UPDATE response_cache
182                 SET accessed_at = ?1, hit_count = hit_count + 1
183                 WHERE prompt_hash = ?2",
184                params![now, key],
185            )?;
186        }
187
188        Ok(result)
189    }
190
191    /// Store a response in the cache, evicting expired or least-recently-used entries.
192    pub async fn put(
193        &self,
194        key: &str,
195        model: &str,
196        response: &str,
197        token_count: u32,
198    ) -> MemoryResult<()> {
199        let response_stored = self.crypto.encrypt_field(response)?;
200        let conn = self.conn.lock().await;
201        let now = Utc::now().to_rfc3339();
202
203        conn.execute(
204            "INSERT OR REPLACE INTO response_cache
205             (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count)
206             VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0)",
207            params![key, model, response_stored, token_count, now, now],
208        )?;
209
210        // Evict expired entries
211        let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
212        conn.execute(
213            "DELETE FROM response_cache WHERE created_at <= ?1",
214            params![cutoff],
215        )?;
216
217        // LRU eviction if over max_entries
218        #[allow(clippy::cast_possible_wrap)]
219        let max = self.max_entries as i64;
220        conn.execute(
221            "DELETE FROM response_cache WHERE prompt_hash IN (
222                SELECT prompt_hash FROM response_cache
223                ORDER BY accessed_at ASC
224                LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
225            )",
226            params![max],
227        )?;
228
229        Ok(())
230    }
231
232    pub async fn put_scoped(
233        &self,
234        key: &str,
235        model: &str,
236        response: &str,
237        token_count: u32,
238        scope: &ResponseCacheScope,
239    ) -> MemoryResult<()> {
240        let response_stored = self.crypto.encrypt_field(response)?;
241        let conn = self.conn.lock().await;
242        let now = Utc::now().to_rfc3339();
243        let source_binding_key = scope.source_binding_key();
244
245        conn.execute(
246            "INSERT OR REPLACE INTO response_cache
247             (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count,
248              tenant_org_id, tenant_workspace_id, tenant_deployment_id, source_binding_key)
249             VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0, ?7, ?8, ?9, ?10)",
250            params![
251                key,
252                model,
253                response_stored,
254                token_count,
255                now,
256                now,
257                scope.tenant_org_id,
258                scope.tenant_workspace_id,
259                scope.tenant_deployment_id,
260                source_binding_key
261            ],
262        )?;
263
264        self.evict_locked(&conn)?;
265        Ok(())
266    }
267
268    pub async fn invalidate_source_binding(
269        &self,
270        tenant_org_id: &str,
271        tenant_workspace_id: &str,
272        tenant_deployment_id: Option<&str>,
273        source_binding_id: &str,
274    ) -> MemoryResult<usize> {
275        let conn = self.conn.lock().await;
276        let needle = format!("%|{}|%", normalize_source_binding_id(source_binding_id));
277        let affected = conn.execute(
278            "DELETE FROM response_cache
279             WHERE tenant_org_id = ?1
280               AND tenant_workspace_id = ?2
281               AND IFNULL(tenant_deployment_id, '') = IFNULL(?3, '')
282               AND source_binding_key LIKE ?4",
283            params![
284                tenant_org_id,
285                tenant_workspace_id,
286                tenant_deployment_id,
287                needle
288            ],
289        )?;
290        Ok(affected)
291    }
292
293    pub async fn invalidate_tenant(
294        &self,
295        tenant_org_id: &str,
296        tenant_workspace_id: &str,
297        tenant_deployment_id: Option<&str>,
298    ) -> MemoryResult<usize> {
299        let conn = self.conn.lock().await;
300        let affected = conn.execute(
301            "DELETE FROM response_cache
302             WHERE tenant_org_id = ?1
303               AND tenant_workspace_id = ?2
304               AND IFNULL(tenant_deployment_id, '') = IFNULL(?3, '')",
305            params![tenant_org_id, tenant_workspace_id, tenant_deployment_id],
306        )?;
307        Ok(affected)
308    }
309
310    /// Return cache statistics: `(total_entries, total_hits, estimated_tokens_saved)`.
311    pub async fn stats(&self) -> MemoryResult<(usize, u64, u64)> {
312        let conn = self.conn.lock().await;
313
314        let count: i64 =
315            conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
316
317        let hits: i64 = conn.query_row(
318            "SELECT COALESCE(SUM(hit_count), 0) FROM response_cache",
319            [],
320            |row| row.get(0),
321        )?;
322
323        let tokens_saved: i64 = conn.query_row(
324            "SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache",
325            [],
326            |row| row.get(0),
327        )?;
328
329        #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
330        Ok((count as usize, hits as u64, tokens_saved as u64))
331    }
332
333    /// Clear all cached entries.
334    pub async fn clear(&self) -> MemoryResult<usize> {
335        let conn = self.conn.lock().await;
336        let affected = conn.execute("DELETE FROM response_cache", [])?;
337        Ok(affected)
338    }
339
340    fn evict_locked(&self, conn: &Connection) -> MemoryResult<()> {
341        let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
342        conn.execute(
343            "DELETE FROM response_cache WHERE created_at <= ?1",
344            params![cutoff],
345        )?;
346
347        #[allow(clippy::cast_possible_wrap)]
348        let max = self.max_entries as i64;
349        conn.execute(
350            "DELETE FROM response_cache WHERE prompt_hash IN (
351                SELECT prompt_hash FROM response_cache
352                ORDER BY accessed_at ASC
353                LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
354            )",
355            params![max],
356        )?;
357
358        Ok(())
359    }
360}
361
362fn migrate_response_cache_scope_columns(conn: &Connection) -> MemoryResult<()> {
363    let columns = conn
364        .prepare("PRAGMA table_info(response_cache)")?
365        .query_map([], |row| row.get::<_, String>(1))?
366        .collect::<Result<std::collections::HashSet<_>, _>>()?;
367    for (name, ddl) in [
368        (
369            "tenant_org_id",
370            "ALTER TABLE response_cache ADD COLUMN tenant_org_id TEXT",
371        ),
372        (
373            "tenant_workspace_id",
374            "ALTER TABLE response_cache ADD COLUMN tenant_workspace_id TEXT",
375        ),
376        (
377            "tenant_deployment_id",
378            "ALTER TABLE response_cache ADD COLUMN tenant_deployment_id TEXT",
379        ),
380        (
381            "source_binding_key",
382            "ALTER TABLE response_cache ADD COLUMN source_binding_key TEXT NOT NULL DEFAULT ''",
383        ),
384    ] {
385        if !columns.contains(name) {
386            conn.execute(ddl, [])?;
387        }
388    }
389    Ok(())
390}
391
392fn normalized_source_binding_ids(source_binding_ids: Vec<String>) -> Vec<String> {
393    let mut ids = source_binding_ids
394        .into_iter()
395        .map(|id| normalize_source_binding_id(&id))
396        .filter(|id| !id.is_empty())
397        .collect::<Vec<_>>();
398    ids.sort();
399    ids.dedup();
400    ids
401}
402
403fn normalize_source_binding_id(source_binding_id: &str) -> String {
404    source_binding_id.trim().replace('|', "")
405}
406
407fn source_binding_key(source_binding_ids: &[String]) -> String {
408    if source_binding_ids.is_empty() {
409        return String::new();
410    }
411    format!("|{}|", source_binding_ids.join("|"))
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use tempfile::TempDir;
418
419    async fn temp_cache(ttl_minutes: u32) -> (TempDir, ResponseCache) {
420        let tmp = TempDir::new().unwrap();
421        let cache = ResponseCache::new(tmp.path(), ttl_minutes, 1000)
422            .await
423            .unwrap();
424        (tmp, cache)
425    }
426
427    #[tokio::test]
428    async fn cache_key_is_deterministic() {
429        let k1 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
430        let k2 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
431        assert_eq!(k1, k2);
432        assert_eq!(k1.len(), 64);
433    }
434
435    #[tokio::test]
436    async fn cache_key_varies_by_model() {
437        let k1 = ResponseCache::cache_key("gpt-4", None, "hello");
438        let k2 = ResponseCache::cache_key("claude-3", None, "hello");
439        assert_ne!(k1, k2);
440    }
441
442    #[tokio::test]
443    async fn scoped_cache_key_varies_by_tenant_and_source_binding() {
444        let scope_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
445            .with_source_bindings(vec!["finance-drive".to_string()]);
446        let scope_b = ResponseCacheScope::tenant("org-a", "workspace-a", None)
447            .with_source_bindings(vec!["hr-drive".to_string()]);
448        let key_a = ResponseCache::cache_key_scoped("gpt-4", Some("sys"), "hello", &scope_a);
449        let key_b = ResponseCache::cache_key_scoped("gpt-4", Some("sys"), "hello", &scope_b);
450        assert_ne!(key_a, key_b);
451    }
452
453    #[tokio::test]
454    async fn put_and_get_roundtrip() {
455        let (_tmp, cache) = temp_cache(60).await;
456        let key = ResponseCache::cache_key("gpt-4", None, "What is Rust?");
457        cache
458            .put(&key, "gpt-4", "Rust is a systems programming language.", 25)
459            .await
460            .unwrap();
461        let result = cache.get(&key).await.unwrap();
462        assert_eq!(
463            result.as_deref(),
464            Some("Rust is a systems programming language.")
465        );
466    }
467
468    #[tokio::test]
469    async fn miss_returns_none() {
470        let (_tmp, cache) = temp_cache(60).await;
471        let result = cache.get("nonexistent").await.unwrap();
472        assert!(result.is_none());
473    }
474
475    #[tokio::test]
476    async fn expired_entry_returns_none() {
477        let (_tmp, cache) = temp_cache(0).await; // 0 TTL → instantly expired
478        let key = ResponseCache::cache_key("gpt-4", None, "test");
479        cache.put(&key, "gpt-4", "response", 10).await.unwrap();
480        let result = cache.get(&key).await.unwrap();
481        assert!(result.is_none());
482    }
483
484    #[tokio::test]
485    async fn stats_tracks_hits_and_tokens() {
486        let (_tmp, cache) = temp_cache(60).await;
487        let key = ResponseCache::cache_key("gpt-4", None, "explain rust");
488        cache.put(&key, "gpt-4", "Rust is...", 100).await.unwrap();
489        for _ in 0..5 {
490            let _ = cache.get(&key).await.unwrap();
491        }
492        let (_, hits, tokens) = cache.stats().await.unwrap();
493        assert_eq!(hits, 5);
494        assert_eq!(tokens, 500);
495    }
496
497    #[tokio::test]
498    async fn lru_eviction_respects_max_entries() {
499        let tmp = TempDir::new().unwrap();
500        let cache = ResponseCache::new(tmp.path(), 60, 3).await.unwrap();
501        for i in 0..5 {
502            let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
503            cache
504                .put(&key, "gpt-4", &format!("response {i}"), 10)
505                .await
506                .unwrap();
507        }
508        let (count, _, _) = cache.stats().await.unwrap();
509        assert!(count <= 3, "cache must not exceed max_entries");
510    }
511
512    #[tokio::test]
513    async fn invalidate_source_binding_removes_only_matching_tenant_entries() {
514        let (_tmp, cache) = temp_cache(60).await;
515        let finance_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
516            .with_source_bindings(vec!["finance-drive".to_string()]);
517        let hr_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
518            .with_source_bindings(vec!["hr-drive".to_string()]);
519        let finance_b = ResponseCacheScope::tenant("org-b", "workspace-b", None)
520            .with_source_bindings(vec!["finance-drive".to_string()]);
521
522        for (idx, scope) in [&finance_a, &hr_a, &finance_b].into_iter().enumerate() {
523            let key =
524                ResponseCache::cache_key_scoped("gpt-4", None, &format!("prompt {idx}"), scope);
525            cache
526                .put_scoped(&key, "gpt-4", &format!("response {idx}"), 10, scope)
527                .await
528                .unwrap();
529        }
530
531        let removed = cache
532            .invalidate_source_binding("org-a", "workspace-a", None, "finance-drive")
533            .await
534            .unwrap();
535        assert_eq!(removed, 1);
536        let (count, _, _) = cache.stats().await.unwrap();
537        assert_eq!(count, 2);
538    }
539}