1use chrono::{Duration, Utc};
12use rusqlite::{params, Connection};
13use sha2::{Digest, Sha256};
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use tokio::sync::Mutex;
17
18use crate::types::{MemoryError, MemoryResult};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct ResponseCacheScope {
22 pub tenant_org_id: String,
23 pub tenant_workspace_id: String,
24 pub tenant_deployment_id: Option<String>,
25 pub source_binding_ids: Vec<String>,
26}
27
28impl ResponseCacheScope {
29 pub fn tenant(
30 tenant_org_id: impl Into<String>,
31 tenant_workspace_id: impl Into<String>,
32 tenant_deployment_id: Option<String>,
33 ) -> Self {
34 Self {
35 tenant_org_id: tenant_org_id.into(),
36 tenant_workspace_id: tenant_workspace_id.into(),
37 tenant_deployment_id,
38 source_binding_ids: Vec::new(),
39 }
40 }
41
42 pub fn with_source_bindings(mut self, source_binding_ids: Vec<String>) -> Self {
43 self.source_binding_ids = normalized_source_binding_ids(source_binding_ids);
44 self
45 }
46
47 fn source_binding_key(&self) -> String {
48 source_binding_key(&self.source_binding_ids)
49 }
50
51 fn fingerprint(&self) -> String {
52 format!(
53 "org={}|workspace={}|deployment={}|source_bindings={}",
54 self.tenant_org_id,
55 self.tenant_workspace_id,
56 self.tenant_deployment_id.as_deref().unwrap_or(""),
57 self.source_binding_key()
58 )
59 }
60}
61
62pub struct ResponseCache {
64 conn: Arc<Mutex<Connection>>,
65 #[allow(dead_code)]
66 db_path: PathBuf,
67 ttl_minutes: i64,
68 max_entries: usize,
69}
70
71impl ResponseCache {
72 pub async fn new(db_dir: &Path, ttl_minutes: u32, max_entries: usize) -> MemoryResult<Self> {
74 tokio::fs::create_dir_all(db_dir)
75 .await
76 .map_err(MemoryError::Io)?;
77
78 let db_path = db_dir.join("response_cache.db");
79
80 let conn = Connection::open(&db_path)?;
81 conn.execute_batch(
82 "PRAGMA journal_mode = WAL;
83 PRAGMA synchronous = NORMAL;
84 PRAGMA temp_store = MEMORY;",
85 )?;
86
87 conn.execute_batch(
88 "CREATE TABLE IF NOT EXISTS response_cache (
89 prompt_hash TEXT PRIMARY KEY,
90 model TEXT NOT NULL,
91 response TEXT NOT NULL,
92 token_count INTEGER NOT NULL DEFAULT 0,
93 created_at TEXT NOT NULL,
94 accessed_at TEXT NOT NULL,
95 hit_count INTEGER NOT NULL DEFAULT 0,
96 tenant_org_id TEXT,
97 tenant_workspace_id TEXT,
98 tenant_deployment_id TEXT,
99 source_binding_key TEXT NOT NULL DEFAULT ''
100 );
101 CREATE INDEX IF NOT EXISTS idx_rc_accessed ON response_cache(accessed_at);
102 CREATE INDEX IF NOT EXISTS idx_rc_created ON response_cache(created_at);
103 CREATE INDEX IF NOT EXISTS idx_rc_tenant_scope
104 ON response_cache(tenant_org_id, tenant_workspace_id, tenant_deployment_id);
105 CREATE INDEX IF NOT EXISTS idx_rc_source_binding
106 ON response_cache(source_binding_key);",
107 )?;
108 migrate_response_cache_scope_columns(&conn)?;
109
110 Ok(Self {
111 conn: Arc::new(Mutex::new(conn)),
112 db_path,
113 ttl_minutes: i64::from(ttl_minutes),
114 max_entries,
115 })
116 }
117
118 pub fn cache_key(model: &str, system_prompt: Option<&str>, user_prompt: &str) -> String {
120 let mut hasher = Sha256::new();
121 hasher.update(model.as_bytes());
122 hasher.update(b"|");
123 if let Some(sys) = system_prompt {
124 hasher.update(sys.as_bytes());
125 }
126 hasher.update(b"|");
127 hasher.update(user_prompt.as_bytes());
128 format!("{:064x}", hasher.finalize())
129 }
130
131 pub fn cache_key_scoped(
133 model: &str,
134 system_prompt: Option<&str>,
135 user_prompt: &str,
136 scope: &ResponseCacheScope,
137 ) -> String {
138 let mut hasher = Sha256::new();
139 hasher.update(model.as_bytes());
140 hasher.update(b"|");
141 if let Some(sys) = system_prompt {
142 hasher.update(sys.as_bytes());
143 }
144 hasher.update(b"|");
145 hasher.update(user_prompt.as_bytes());
146 hasher.update(b"|");
147 hasher.update(scope.fingerprint().as_bytes());
148 format!("{:064x}", hasher.finalize())
149 }
150
151 pub async fn get(&self, key: &str) -> MemoryResult<Option<String>> {
153 let conn = self.conn.lock().await;
154 let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
155
156 let result: Option<String> = conn
157 .query_row(
158 "SELECT response FROM response_cache
159 WHERE prompt_hash = ?1 AND created_at > ?2",
160 params![key, cutoff],
161 |row| row.get(0),
162 )
163 .ok();
164
165 if result.is_some() {
166 let now = Utc::now().to_rfc3339();
167 conn.execute(
168 "UPDATE response_cache
169 SET accessed_at = ?1, hit_count = hit_count + 1
170 WHERE prompt_hash = ?2",
171 params![now, key],
172 )?;
173 }
174
175 Ok(result)
176 }
177
178 pub async fn put(
180 &self,
181 key: &str,
182 model: &str,
183 response: &str,
184 token_count: u32,
185 ) -> MemoryResult<()> {
186 let conn = self.conn.lock().await;
187 let now = Utc::now().to_rfc3339();
188
189 conn.execute(
190 "INSERT OR REPLACE INTO response_cache
191 (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count)
192 VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0)",
193 params![key, model, response, token_count, now, now],
194 )?;
195
196 let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
198 conn.execute(
199 "DELETE FROM response_cache WHERE created_at <= ?1",
200 params![cutoff],
201 )?;
202
203 #[allow(clippy::cast_possible_wrap)]
205 let max = self.max_entries as i64;
206 conn.execute(
207 "DELETE FROM response_cache WHERE prompt_hash IN (
208 SELECT prompt_hash FROM response_cache
209 ORDER BY accessed_at ASC
210 LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
211 )",
212 params![max],
213 )?;
214
215 Ok(())
216 }
217
218 pub async fn put_scoped(
219 &self,
220 key: &str,
221 model: &str,
222 response: &str,
223 token_count: u32,
224 scope: &ResponseCacheScope,
225 ) -> MemoryResult<()> {
226 let conn = self.conn.lock().await;
227 let now = Utc::now().to_rfc3339();
228 let source_binding_key = scope.source_binding_key();
229
230 conn.execute(
231 "INSERT OR REPLACE INTO response_cache
232 (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count,
233 tenant_org_id, tenant_workspace_id, tenant_deployment_id, source_binding_key)
234 VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0, ?7, ?8, ?9, ?10)",
235 params![
236 key,
237 model,
238 response,
239 token_count,
240 now,
241 now,
242 scope.tenant_org_id,
243 scope.tenant_workspace_id,
244 scope.tenant_deployment_id,
245 source_binding_key
246 ],
247 )?;
248
249 self.evict_locked(&conn)?;
250 Ok(())
251 }
252
253 pub async fn invalidate_source_binding(
254 &self,
255 tenant_org_id: &str,
256 tenant_workspace_id: &str,
257 tenant_deployment_id: Option<&str>,
258 source_binding_id: &str,
259 ) -> MemoryResult<usize> {
260 let conn = self.conn.lock().await;
261 let needle = format!("%|{}|%", normalize_source_binding_id(source_binding_id));
262 let affected = conn.execute(
263 "DELETE FROM response_cache
264 WHERE tenant_org_id = ?1
265 AND tenant_workspace_id = ?2
266 AND IFNULL(tenant_deployment_id, '') = IFNULL(?3, '')
267 AND source_binding_key LIKE ?4",
268 params![
269 tenant_org_id,
270 tenant_workspace_id,
271 tenant_deployment_id,
272 needle
273 ],
274 )?;
275 Ok(affected)
276 }
277
278 pub async fn invalidate_tenant(
279 &self,
280 tenant_org_id: &str,
281 tenant_workspace_id: &str,
282 tenant_deployment_id: Option<&str>,
283 ) -> MemoryResult<usize> {
284 let conn = self.conn.lock().await;
285 let affected = conn.execute(
286 "DELETE FROM response_cache
287 WHERE tenant_org_id = ?1
288 AND tenant_workspace_id = ?2
289 AND IFNULL(tenant_deployment_id, '') = IFNULL(?3, '')",
290 params![tenant_org_id, tenant_workspace_id, tenant_deployment_id],
291 )?;
292 Ok(affected)
293 }
294
295 pub async fn stats(&self) -> MemoryResult<(usize, u64, u64)> {
297 let conn = self.conn.lock().await;
298
299 let count: i64 =
300 conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
301
302 let hits: i64 = conn.query_row(
303 "SELECT COALESCE(SUM(hit_count), 0) FROM response_cache",
304 [],
305 |row| row.get(0),
306 )?;
307
308 let tokens_saved: i64 = conn.query_row(
309 "SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache",
310 [],
311 |row| row.get(0),
312 )?;
313
314 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
315 Ok((count as usize, hits as u64, tokens_saved as u64))
316 }
317
318 pub async fn clear(&self) -> MemoryResult<usize> {
320 let conn = self.conn.lock().await;
321 let affected = conn.execute("DELETE FROM response_cache", [])?;
322 Ok(affected)
323 }
324
325 fn evict_locked(&self, conn: &Connection) -> MemoryResult<()> {
326 let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
327 conn.execute(
328 "DELETE FROM response_cache WHERE created_at <= ?1",
329 params![cutoff],
330 )?;
331
332 #[allow(clippy::cast_possible_wrap)]
333 let max = self.max_entries as i64;
334 conn.execute(
335 "DELETE FROM response_cache WHERE prompt_hash IN (
336 SELECT prompt_hash FROM response_cache
337 ORDER BY accessed_at ASC
338 LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
339 )",
340 params![max],
341 )?;
342
343 Ok(())
344 }
345}
346
347fn migrate_response_cache_scope_columns(conn: &Connection) -> MemoryResult<()> {
348 let columns = conn
349 .prepare("PRAGMA table_info(response_cache)")?
350 .query_map([], |row| row.get::<_, String>(1))?
351 .collect::<Result<std::collections::HashSet<_>, _>>()?;
352 for (name, ddl) in [
353 (
354 "tenant_org_id",
355 "ALTER TABLE response_cache ADD COLUMN tenant_org_id TEXT",
356 ),
357 (
358 "tenant_workspace_id",
359 "ALTER TABLE response_cache ADD COLUMN tenant_workspace_id TEXT",
360 ),
361 (
362 "tenant_deployment_id",
363 "ALTER TABLE response_cache ADD COLUMN tenant_deployment_id TEXT",
364 ),
365 (
366 "source_binding_key",
367 "ALTER TABLE response_cache ADD COLUMN source_binding_key TEXT NOT NULL DEFAULT ''",
368 ),
369 ] {
370 if !columns.contains(name) {
371 conn.execute(ddl, [])?;
372 }
373 }
374 Ok(())
375}
376
377fn normalized_source_binding_ids(source_binding_ids: Vec<String>) -> Vec<String> {
378 let mut ids = source_binding_ids
379 .into_iter()
380 .map(|id| normalize_source_binding_id(&id))
381 .filter(|id| !id.is_empty())
382 .collect::<Vec<_>>();
383 ids.sort();
384 ids.dedup();
385 ids
386}
387
388fn normalize_source_binding_id(source_binding_id: &str) -> String {
389 source_binding_id.trim().replace('|', "")
390}
391
392fn source_binding_key(source_binding_ids: &[String]) -> String {
393 if source_binding_ids.is_empty() {
394 return String::new();
395 }
396 format!("|{}|", source_binding_ids.join("|"))
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use tempfile::TempDir;
403
404 async fn temp_cache(ttl_minutes: u32) -> (TempDir, ResponseCache) {
405 let tmp = TempDir::new().unwrap();
406 let cache = ResponseCache::new(tmp.path(), ttl_minutes, 1000)
407 .await
408 .unwrap();
409 (tmp, cache)
410 }
411
412 #[tokio::test]
413 async fn cache_key_is_deterministic() {
414 let k1 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
415 let k2 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
416 assert_eq!(k1, k2);
417 assert_eq!(k1.len(), 64);
418 }
419
420 #[tokio::test]
421 async fn cache_key_varies_by_model() {
422 let k1 = ResponseCache::cache_key("gpt-4", None, "hello");
423 let k2 = ResponseCache::cache_key("claude-3", None, "hello");
424 assert_ne!(k1, k2);
425 }
426
427 #[tokio::test]
428 async fn scoped_cache_key_varies_by_tenant_and_source_binding() {
429 let scope_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
430 .with_source_bindings(vec!["finance-drive".to_string()]);
431 let scope_b = ResponseCacheScope::tenant("org-a", "workspace-a", None)
432 .with_source_bindings(vec!["hr-drive".to_string()]);
433 let key_a = ResponseCache::cache_key_scoped("gpt-4", Some("sys"), "hello", &scope_a);
434 let key_b = ResponseCache::cache_key_scoped("gpt-4", Some("sys"), "hello", &scope_b);
435 assert_ne!(key_a, key_b);
436 }
437
438 #[tokio::test]
439 async fn put_and_get_roundtrip() {
440 let (_tmp, cache) = temp_cache(60).await;
441 let key = ResponseCache::cache_key("gpt-4", None, "What is Rust?");
442 cache
443 .put(&key, "gpt-4", "Rust is a systems programming language.", 25)
444 .await
445 .unwrap();
446 let result = cache.get(&key).await.unwrap();
447 assert_eq!(
448 result.as_deref(),
449 Some("Rust is a systems programming language.")
450 );
451 }
452
453 #[tokio::test]
454 async fn miss_returns_none() {
455 let (_tmp, cache) = temp_cache(60).await;
456 let result = cache.get("nonexistent").await.unwrap();
457 assert!(result.is_none());
458 }
459
460 #[tokio::test]
461 async fn expired_entry_returns_none() {
462 let (_tmp, cache) = temp_cache(0).await; let key = ResponseCache::cache_key("gpt-4", None, "test");
464 cache.put(&key, "gpt-4", "response", 10).await.unwrap();
465 let result = cache.get(&key).await.unwrap();
466 assert!(result.is_none());
467 }
468
469 #[tokio::test]
470 async fn stats_tracks_hits_and_tokens() {
471 let (_tmp, cache) = temp_cache(60).await;
472 let key = ResponseCache::cache_key("gpt-4", None, "explain rust");
473 cache.put(&key, "gpt-4", "Rust is...", 100).await.unwrap();
474 for _ in 0..5 {
475 let _ = cache.get(&key).await.unwrap();
476 }
477 let (_, hits, tokens) = cache.stats().await.unwrap();
478 assert_eq!(hits, 5);
479 assert_eq!(tokens, 500);
480 }
481
482 #[tokio::test]
483 async fn lru_eviction_respects_max_entries() {
484 let tmp = TempDir::new().unwrap();
485 let cache = ResponseCache::new(tmp.path(), 60, 3).await.unwrap();
486 for i in 0..5 {
487 let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
488 cache
489 .put(&key, "gpt-4", &format!("response {i}"), 10)
490 .await
491 .unwrap();
492 }
493 let (count, _, _) = cache.stats().await.unwrap();
494 assert!(count <= 3, "cache must not exceed max_entries");
495 }
496
497 #[tokio::test]
498 async fn invalidate_source_binding_removes_only_matching_tenant_entries() {
499 let (_tmp, cache) = temp_cache(60).await;
500 let finance_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
501 .with_source_bindings(vec!["finance-drive".to_string()]);
502 let hr_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
503 .with_source_bindings(vec!["hr-drive".to_string()]);
504 let finance_b = ResponseCacheScope::tenant("org-b", "workspace-b", None)
505 .with_source_bindings(vec!["finance-drive".to_string()]);
506
507 for (idx, scope) in [&finance_a, &hr_a, &finance_b].into_iter().enumerate() {
508 let key =
509 ResponseCache::cache_key_scoped("gpt-4", None, &format!("prompt {idx}"), scope);
510 cache
511 .put_scoped(&key, "gpt-4", &format!("response {idx}"), 10, scope)
512 .await
513 .unwrap();
514 }
515
516 let removed = cache
517 .invalidate_source_binding("org-a", "workspace-a", None, "finance-drive")
518 .await
519 .unwrap();
520 assert_eq!(removed, 1);
521 let (count, _, _) = cache.stats().await.unwrap();
522 assert_eq!(count, 2);
523 }
524}