1use chrono::{Duration, Utc};
12use rusqlite::{params, Connection};
13use sha2::{Digest, Sha256};
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use tokio::sync::Mutex;
17
18use crate::types::{MemoryError, MemoryResult};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct ResponseCacheScope {
22 pub tenant_org_id: String,
23 pub tenant_workspace_id: String,
24 pub tenant_deployment_id: Option<String>,
25 pub source_binding_ids: Vec<String>,
26}
27
28impl ResponseCacheScope {
29 pub fn tenant(
30 tenant_org_id: impl Into<String>,
31 tenant_workspace_id: impl Into<String>,
32 tenant_deployment_id: Option<String>,
33 ) -> Self {
34 Self {
35 tenant_org_id: tenant_org_id.into(),
36 tenant_workspace_id: tenant_workspace_id.into(),
37 tenant_deployment_id,
38 source_binding_ids: Vec::new(),
39 }
40 }
41
42 pub fn with_source_bindings(mut self, source_binding_ids: Vec<String>) -> Self {
43 self.source_binding_ids = normalized_source_binding_ids(source_binding_ids);
44 self
45 }
46
47 fn source_binding_key(&self) -> String {
48 source_binding_key(&self.source_binding_ids)
49 }
50
51 fn fingerprint(&self) -> String {
52 format!(
53 "org={}|workspace={}|deployment={}|source_bindings={}",
54 self.tenant_org_id,
55 self.tenant_workspace_id,
56 self.tenant_deployment_id.as_deref().unwrap_or(""),
57 self.source_binding_key()
58 )
59 }
60}
61
62pub struct ResponseCache {
64 conn: Arc<Mutex<Connection>>,
65 #[allow(dead_code)]
66 db_path: PathBuf,
67 ttl_minutes: i64,
68 max_entries: usize,
69 crypto: crate::crypto::MemoryCryptoProvider,
70}
71
72impl ResponseCache {
73 pub async fn new(db_dir: &Path, ttl_minutes: u32, max_entries: usize) -> MemoryResult<Self> {
75 tokio::fs::create_dir_all(db_dir)
76 .await
77 .map_err(MemoryError::Io)?;
78
79 let db_path = db_dir.join("response_cache.db");
80
81 let conn = Connection::open(&db_path)?;
82 conn.execute_batch(
83 "PRAGMA journal_mode = WAL;
84 PRAGMA synchronous = NORMAL;
85 PRAGMA temp_store = MEMORY;",
86 )?;
87
88 conn.execute_batch(
89 "CREATE TABLE IF NOT EXISTS response_cache (
90 prompt_hash TEXT PRIMARY KEY,
91 model TEXT NOT NULL,
92 response TEXT NOT NULL,
93 token_count INTEGER NOT NULL DEFAULT 0,
94 created_at TEXT NOT NULL,
95 accessed_at TEXT NOT NULL,
96 hit_count INTEGER NOT NULL DEFAULT 0,
97 tenant_org_id TEXT,
98 tenant_workspace_id TEXT,
99 tenant_deployment_id TEXT,
100 source_binding_key TEXT NOT NULL DEFAULT ''
101 );
102 CREATE INDEX IF NOT EXISTS idx_rc_accessed ON response_cache(accessed_at);
103 CREATE INDEX IF NOT EXISTS idx_rc_created ON response_cache(created_at);
104 CREATE INDEX IF NOT EXISTS idx_rc_tenant_scope
105 ON response_cache(tenant_org_id, tenant_workspace_id, tenant_deployment_id);
106 CREATE INDEX IF NOT EXISTS idx_rc_source_binding
107 ON response_cache(source_binding_key);",
108 )?;
109 migrate_response_cache_scope_columns(&conn)?;
110
111 Ok(Self {
112 conn: Arc::new(Mutex::new(conn)),
113 db_path,
114 ttl_minutes: i64::from(ttl_minutes),
115 max_entries,
116 crypto: crate::crypto::MemoryCryptoProvider::from_env(),
117 })
118 }
119
120 pub fn with_crypto_provider(mut self, crypto: crate::crypto::MemoryCryptoProvider) -> Self {
123 self.crypto = crypto;
124 self
125 }
126
127 pub fn cache_key(model: &str, system_prompt: Option<&str>, user_prompt: &str) -> String {
129 let mut hasher = Sha256::new();
130 hasher.update(model.as_bytes());
131 hasher.update(b"|");
132 if let Some(sys) = system_prompt {
133 hasher.update(sys.as_bytes());
134 }
135 hasher.update(b"|");
136 hasher.update(user_prompt.as_bytes());
137 format!("{:064x}", hasher.finalize())
138 }
139
140 pub fn cache_key_scoped(
142 model: &str,
143 system_prompt: Option<&str>,
144 user_prompt: &str,
145 scope: &ResponseCacheScope,
146 ) -> String {
147 let mut hasher = Sha256::new();
148 hasher.update(model.as_bytes());
149 hasher.update(b"|");
150 if let Some(sys) = system_prompt {
151 hasher.update(sys.as_bytes());
152 }
153 hasher.update(b"|");
154 hasher.update(user_prompt.as_bytes());
155 hasher.update(b"|");
156 hasher.update(scope.fingerprint().as_bytes());
157 format!("{:064x}", hasher.finalize())
158 }
159
160 pub async fn get(&self, key: &str) -> MemoryResult<Option<String>> {
162 let conn = self.conn.lock().await;
163 let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
164
165 let stored: Option<String> = conn
166 .query_row(
167 "SELECT response FROM response_cache
168 WHERE prompt_hash = ?1 AND created_at > ?2",
169 params![key, cutoff],
170 |row| row.get(0),
171 )
172 .ok();
173 let result = match stored {
174 Some(value) => Some(self.crypto.decrypt_field(&value)?),
175 None => None,
176 };
177
178 if result.is_some() {
179 let now = Utc::now().to_rfc3339();
180 conn.execute(
181 "UPDATE response_cache
182 SET accessed_at = ?1, hit_count = hit_count + 1
183 WHERE prompt_hash = ?2",
184 params![now, key],
185 )?;
186 }
187
188 Ok(result)
189 }
190
191 pub async fn put(
193 &self,
194 key: &str,
195 model: &str,
196 response: &str,
197 token_count: u32,
198 ) -> MemoryResult<()> {
199 let response_stored = self.crypto.encrypt_field(response)?;
200 let conn = self.conn.lock().await;
201 let now = Utc::now().to_rfc3339();
202
203 conn.execute(
204 "INSERT OR REPLACE INTO response_cache
205 (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count)
206 VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0)",
207 params![key, model, response_stored, token_count, now, now],
208 )?;
209
210 let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
212 conn.execute(
213 "DELETE FROM response_cache WHERE created_at <= ?1",
214 params![cutoff],
215 )?;
216
217 #[allow(clippy::cast_possible_wrap)]
219 let max = self.max_entries as i64;
220 conn.execute(
221 "DELETE FROM response_cache WHERE prompt_hash IN (
222 SELECT prompt_hash FROM response_cache
223 ORDER BY accessed_at ASC
224 LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
225 )",
226 params![max],
227 )?;
228
229 Ok(())
230 }
231
232 pub async fn put_scoped(
233 &self,
234 key: &str,
235 model: &str,
236 response: &str,
237 token_count: u32,
238 scope: &ResponseCacheScope,
239 ) -> MemoryResult<()> {
240 let response_stored = self.crypto.encrypt_field(response)?;
241 let conn = self.conn.lock().await;
242 let now = Utc::now().to_rfc3339();
243 let source_binding_key = scope.source_binding_key();
244
245 conn.execute(
246 "INSERT OR REPLACE INTO response_cache
247 (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count,
248 tenant_org_id, tenant_workspace_id, tenant_deployment_id, source_binding_key)
249 VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0, ?7, ?8, ?9, ?10)",
250 params![
251 key,
252 model,
253 response_stored,
254 token_count,
255 now,
256 now,
257 scope.tenant_org_id,
258 scope.tenant_workspace_id,
259 scope.tenant_deployment_id,
260 source_binding_key
261 ],
262 )?;
263
264 self.evict_locked(&conn)?;
265 Ok(())
266 }
267
268 pub async fn invalidate_source_binding(
269 &self,
270 tenant_org_id: &str,
271 tenant_workspace_id: &str,
272 tenant_deployment_id: Option<&str>,
273 source_binding_id: &str,
274 ) -> MemoryResult<usize> {
275 let conn = self.conn.lock().await;
276 let needle = format!("%|{}|%", normalize_source_binding_id(source_binding_id));
277 let affected = conn.execute(
278 "DELETE FROM response_cache
279 WHERE tenant_org_id = ?1
280 AND tenant_workspace_id = ?2
281 AND IFNULL(tenant_deployment_id, '') = IFNULL(?3, '')
282 AND source_binding_key LIKE ?4",
283 params![
284 tenant_org_id,
285 tenant_workspace_id,
286 tenant_deployment_id,
287 needle
288 ],
289 )?;
290 Ok(affected)
291 }
292
293 pub async fn invalidate_tenant(
294 &self,
295 tenant_org_id: &str,
296 tenant_workspace_id: &str,
297 tenant_deployment_id: Option<&str>,
298 ) -> MemoryResult<usize> {
299 let conn = self.conn.lock().await;
300 let affected = conn.execute(
301 "DELETE FROM response_cache
302 WHERE tenant_org_id = ?1
303 AND tenant_workspace_id = ?2
304 AND IFNULL(tenant_deployment_id, '') = IFNULL(?3, '')",
305 params![tenant_org_id, tenant_workspace_id, tenant_deployment_id],
306 )?;
307 Ok(affected)
308 }
309
310 pub async fn stats(&self) -> MemoryResult<(usize, u64, u64)> {
312 let conn = self.conn.lock().await;
313
314 let count: i64 =
315 conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
316
317 let hits: i64 = conn.query_row(
318 "SELECT COALESCE(SUM(hit_count), 0) FROM response_cache",
319 [],
320 |row| row.get(0),
321 )?;
322
323 let tokens_saved: i64 = conn.query_row(
324 "SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache",
325 [],
326 |row| row.get(0),
327 )?;
328
329 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
330 Ok((count as usize, hits as u64, tokens_saved as u64))
331 }
332
333 pub async fn clear(&self) -> MemoryResult<usize> {
335 let conn = self.conn.lock().await;
336 let affected = conn.execute("DELETE FROM response_cache", [])?;
337 Ok(affected)
338 }
339
340 fn evict_locked(&self, conn: &Connection) -> MemoryResult<()> {
341 let cutoff = (Utc::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339();
342 conn.execute(
343 "DELETE FROM response_cache WHERE created_at <= ?1",
344 params![cutoff],
345 )?;
346
347 #[allow(clippy::cast_possible_wrap)]
348 let max = self.max_entries as i64;
349 conn.execute(
350 "DELETE FROM response_cache WHERE prompt_hash IN (
351 SELECT prompt_hash FROM response_cache
352 ORDER BY accessed_at ASC
353 LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1)
354 )",
355 params![max],
356 )?;
357
358 Ok(())
359 }
360}
361
362fn migrate_response_cache_scope_columns(conn: &Connection) -> MemoryResult<()> {
363 let columns = conn
364 .prepare("PRAGMA table_info(response_cache)")?
365 .query_map([], |row| row.get::<_, String>(1))?
366 .collect::<Result<std::collections::HashSet<_>, _>>()?;
367 for (name, ddl) in [
368 (
369 "tenant_org_id",
370 "ALTER TABLE response_cache ADD COLUMN tenant_org_id TEXT",
371 ),
372 (
373 "tenant_workspace_id",
374 "ALTER TABLE response_cache ADD COLUMN tenant_workspace_id TEXT",
375 ),
376 (
377 "tenant_deployment_id",
378 "ALTER TABLE response_cache ADD COLUMN tenant_deployment_id TEXT",
379 ),
380 (
381 "source_binding_key",
382 "ALTER TABLE response_cache ADD COLUMN source_binding_key TEXT NOT NULL DEFAULT ''",
383 ),
384 ] {
385 if !columns.contains(name) {
386 conn.execute(ddl, [])?;
387 }
388 }
389 Ok(())
390}
391
392fn normalized_source_binding_ids(source_binding_ids: Vec<String>) -> Vec<String> {
393 let mut ids = source_binding_ids
394 .into_iter()
395 .map(|id| normalize_source_binding_id(&id))
396 .filter(|id| !id.is_empty())
397 .collect::<Vec<_>>();
398 ids.sort();
399 ids.dedup();
400 ids
401}
402
403fn normalize_source_binding_id(source_binding_id: &str) -> String {
404 source_binding_id.trim().replace('|', "")
405}
406
407fn source_binding_key(source_binding_ids: &[String]) -> String {
408 if source_binding_ids.is_empty() {
409 return String::new();
410 }
411 format!("|{}|", source_binding_ids.join("|"))
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use tempfile::TempDir;
418
419 async fn temp_cache(ttl_minutes: u32) -> (TempDir, ResponseCache) {
420 let tmp = TempDir::new().unwrap();
421 let cache = ResponseCache::new(tmp.path(), ttl_minutes, 1000)
422 .await
423 .unwrap();
424 (tmp, cache)
425 }
426
427 #[tokio::test]
428 async fn cache_key_is_deterministic() {
429 let k1 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
430 let k2 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello");
431 assert_eq!(k1, k2);
432 assert_eq!(k1.len(), 64);
433 }
434
435 #[tokio::test]
436 async fn cache_key_varies_by_model() {
437 let k1 = ResponseCache::cache_key("gpt-4", None, "hello");
438 let k2 = ResponseCache::cache_key("claude-3", None, "hello");
439 assert_ne!(k1, k2);
440 }
441
442 #[tokio::test]
443 async fn scoped_cache_key_varies_by_tenant_and_source_binding() {
444 let scope_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
445 .with_source_bindings(vec!["finance-drive".to_string()]);
446 let scope_b = ResponseCacheScope::tenant("org-a", "workspace-a", None)
447 .with_source_bindings(vec!["hr-drive".to_string()]);
448 let key_a = ResponseCache::cache_key_scoped("gpt-4", Some("sys"), "hello", &scope_a);
449 let key_b = ResponseCache::cache_key_scoped("gpt-4", Some("sys"), "hello", &scope_b);
450 assert_ne!(key_a, key_b);
451 }
452
453 #[tokio::test]
454 async fn put_and_get_roundtrip() {
455 let (_tmp, cache) = temp_cache(60).await;
456 let key = ResponseCache::cache_key("gpt-4", None, "What is Rust?");
457 cache
458 .put(&key, "gpt-4", "Rust is a systems programming language.", 25)
459 .await
460 .unwrap();
461 let result = cache.get(&key).await.unwrap();
462 assert_eq!(
463 result.as_deref(),
464 Some("Rust is a systems programming language.")
465 );
466 }
467
468 #[tokio::test]
469 async fn miss_returns_none() {
470 let (_tmp, cache) = temp_cache(60).await;
471 let result = cache.get("nonexistent").await.unwrap();
472 assert!(result.is_none());
473 }
474
475 #[tokio::test]
476 async fn expired_entry_returns_none() {
477 let (_tmp, cache) = temp_cache(0).await; let key = ResponseCache::cache_key("gpt-4", None, "test");
479 cache.put(&key, "gpt-4", "response", 10).await.unwrap();
480 let result = cache.get(&key).await.unwrap();
481 assert!(result.is_none());
482 }
483
484 #[tokio::test]
485 async fn stats_tracks_hits_and_tokens() {
486 let (_tmp, cache) = temp_cache(60).await;
487 let key = ResponseCache::cache_key("gpt-4", None, "explain rust");
488 cache.put(&key, "gpt-4", "Rust is...", 100).await.unwrap();
489 for _ in 0..5 {
490 let _ = cache.get(&key).await.unwrap();
491 }
492 let (_, hits, tokens) = cache.stats().await.unwrap();
493 assert_eq!(hits, 5);
494 assert_eq!(tokens, 500);
495 }
496
497 #[tokio::test]
498 async fn lru_eviction_respects_max_entries() {
499 let tmp = TempDir::new().unwrap();
500 let cache = ResponseCache::new(tmp.path(), 60, 3).await.unwrap();
501 for i in 0..5 {
502 let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}"));
503 cache
504 .put(&key, "gpt-4", &format!("response {i}"), 10)
505 .await
506 .unwrap();
507 }
508 let (count, _, _) = cache.stats().await.unwrap();
509 assert!(count <= 3, "cache must not exceed max_entries");
510 }
511
512 #[tokio::test]
513 async fn invalidate_source_binding_removes_only_matching_tenant_entries() {
514 let (_tmp, cache) = temp_cache(60).await;
515 let finance_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
516 .with_source_bindings(vec!["finance-drive".to_string()]);
517 let hr_a = ResponseCacheScope::tenant("org-a", "workspace-a", None)
518 .with_source_bindings(vec!["hr-drive".to_string()]);
519 let finance_b = ResponseCacheScope::tenant("org-b", "workspace-b", None)
520 .with_source_bindings(vec!["finance-drive".to_string()]);
521
522 for (idx, scope) in [&finance_a, &hr_a, &finance_b].into_iter().enumerate() {
523 let key =
524 ResponseCache::cache_key_scoped("gpt-4", None, &format!("prompt {idx}"), scope);
525 cache
526 .put_scoped(&key, "gpt-4", &format!("response {idx}"), 10, scope)
527 .await
528 .unwrap();
529 }
530
531 let removed = cache
532 .invalidate_source_binding("org-a", "workspace-a", None, "finance-drive")
533 .await
534 .unwrap();
535 assert_eq!(removed, 1);
536 let (count, _, _) = cache.stats().await.unwrap();
537 assert_eq!(count, 2);
538 }
539}