1use sqlx::SqlitePool;
5
6use crate::error::MemoryError;
7
8pub struct ResponseCache {
9 pool: SqlitePool,
10 ttl_secs: u64,
11}
12
13impl ResponseCache {
14 #[must_use]
15 pub fn new(pool: SqlitePool, ttl_secs: u64) -> Self {
16 Self { pool, ttl_secs }
17 }
18
19 pub async fn get(&self, key: &str) -> Result<Option<String>, MemoryError> {
25 let now = unix_now();
26 let row: Option<(String,)> = sqlx::query_as(
27 "SELECT response FROM response_cache WHERE cache_key = ? AND expires_at > ?",
28 )
29 .bind(key)
30 .bind(now)
31 .fetch_optional(&self.pool)
32 .await?;
33 Ok(row.map(|(r,)| r))
34 }
35
36 pub async fn put(&self, key: &str, response: &str, model: &str) -> Result<(), MemoryError> {
42 let now = unix_now();
43 let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
45 sqlx::query(
46 "INSERT OR REPLACE INTO response_cache (cache_key, response, model, created_at, expires_at) \
47 VALUES (?, ?, ?, ?, ?)",
48 )
49 .bind(key)
50 .bind(response)
51 .bind(model)
52 .bind(now)
53 .bind(expires_at)
54 .execute(&self.pool)
55 .await?;
56 Ok(())
57 }
58
59 pub async fn get_semantic(
71 &self,
72 embedding: &[f32],
73 embedding_model: &str,
74 similarity_threshold: f32,
75 max_candidates: u32,
76 ) -> Result<Option<(String, f32)>, MemoryError> {
77 let now = unix_now();
78 let rows: Vec<(String, Vec<u8>)> = sqlx::query_as(
79 "SELECT response, embedding FROM response_cache \
80 WHERE embedding_model = ? AND embedding IS NOT NULL AND expires_at > ? \
81 ORDER BY embedding_ts DESC LIMIT ?",
82 )
83 .bind(embedding_model)
84 .bind(now)
85 .bind(max_candidates)
86 .fetch_all(&self.pool)
87 .await?;
88
89 let mut best_score = -1.0_f32;
90 let mut best_response: Option<String> = None;
91
92 for (response, blob) in &rows {
93 match bytemuck::try_cast_slice::<u8, f32>(blob) {
95 Ok(stored) => {
96 let score = crate::math::cosine_similarity(embedding, stored);
97 tracing::debug!(
98 score,
99 threshold = similarity_threshold,
100 "semantic cache candidate evaluated",
101 );
102 if score > best_score {
103 best_score = score;
104 best_response = Some(response.clone());
105 }
106 }
107 Err(e) => {
108 tracing::warn!("semantic cache: failed to deserialize embedding blob: {e}");
109 }
110 }
111 }
112
113 tracing::debug!(
114 examined = rows.len(),
115 best_score,
116 threshold = similarity_threshold,
117 hit = best_score >= similarity_threshold,
118 "semantic cache scan complete",
119 );
120
121 if best_score >= similarity_threshold {
122 Ok(best_response.map(|r| (r, best_score)))
123 } else {
124 Ok(None)
125 }
126 }
127
128 pub async fn put_with_embedding(
136 &self,
137 key: &str,
138 response: &str,
139 model: &str,
140 embedding: &[f32],
141 embedding_model: &str,
142 ) -> Result<(), MemoryError> {
143 let now = unix_now();
144 let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
145 let blob: &[u8] = bytemuck::cast_slice(embedding);
147 sqlx::query(
148 "INSERT OR REPLACE INTO response_cache \
149 (cache_key, response, model, created_at, expires_at, embedding, embedding_model, embedding_ts) \
150 VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
151 )
152 .bind(key)
153 .bind(response)
154 .bind(model)
155 .bind(now)
156 .bind(expires_at)
157 .bind(blob)
158 .bind(embedding_model)
159 .bind(now)
160 .execute(&self.pool)
161 .await?;
162 Ok(())
163 }
164
165 pub async fn invalidate_embeddings_for_model(
174 &self,
175 old_model: &str,
176 ) -> Result<u64, MemoryError> {
177 let result = sqlx::query(
178 "UPDATE response_cache \
179 SET embedding = NULL, embedding_model = NULL, embedding_ts = NULL \
180 WHERE embedding_model = ?",
181 )
182 .bind(old_model)
183 .execute(&self.pool)
184 .await?;
185 Ok(result.rows_affected())
186 }
187
188 pub async fn cleanup(&self, current_embedding_model: &str) -> Result<u64, MemoryError> {
200 let now = unix_now();
201 let deleted = sqlx::query("DELETE FROM response_cache WHERE expires_at <= ?")
202 .bind(now)
203 .execute(&self.pool)
204 .await?
205 .rows_affected();
206
207 let updated = sqlx::query(
208 "UPDATE response_cache \
209 SET embedding = NULL, embedding_model = NULL, embedding_ts = NULL \
210 WHERE embedding IS NOT NULL AND embedding_model != ?",
211 )
212 .bind(current_embedding_model)
213 .execute(&self.pool)
214 .await?
215 .rows_affected();
216
217 Ok(deleted + updated)
218 }
219
220 pub async fn cleanup_expired(&self) -> Result<u64, MemoryError> {
226 let now = unix_now();
227 let result = sqlx::query("DELETE FROM response_cache WHERE expires_at <= ?")
228 .bind(now)
229 .execute(&self.pool)
230 .await?;
231 Ok(result.rows_affected())
232 }
233
234 #[must_use]
241 pub fn compute_key(last_user_message: &str, model: &str) -> String {
242 let mut hasher = blake3::Hasher::new();
243 let content = last_user_message.as_bytes();
244 hasher.update(&(content.len() as u64).to_le_bytes());
245 hasher.update(content);
246 let model_bytes = model.as_bytes();
247 hasher.update(&(model_bytes.len() as u64).to_le_bytes());
248 hasher.update(model_bytes);
249 hasher.finalize().to_hex().to_string()
250 }
251}
252
253fn unix_now() -> i64 {
254 std::time::SystemTime::now()
255 .duration_since(std::time::UNIX_EPOCH)
256 .unwrap_or_default()
257 .as_secs()
258 .cast_signed()
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::sqlite::SqliteStore;
265
266 async fn test_cache() -> ResponseCache {
267 let store = SqliteStore::new(":memory:").await.unwrap();
268 ResponseCache::new(store.pool().clone(), 3600)
269 }
270
271 #[tokio::test]
272 async fn cache_miss_returns_none() {
273 let cache = test_cache().await;
274 let result = cache.get("nonexistent").await.unwrap();
275 assert!(result.is_none());
276 }
277
278 #[tokio::test]
279 async fn cache_put_and_get_roundtrip() {
280 let cache = test_cache().await;
281 cache.put("key1", "response text", "gpt-4").await.unwrap();
282 let result = cache.get("key1").await.unwrap();
283 assert_eq!(result.as_deref(), Some("response text"));
284 }
285
286 #[tokio::test]
287 async fn cache_expired_entry_returns_none() {
288 let store = SqliteStore::new(":memory:").await.unwrap();
289 let cache = ResponseCache::new(store.pool().clone(), 0);
290 cache.put("key1", "response", "model").await.unwrap();
292 let result = cache.get("key1").await.unwrap();
294 assert!(result.is_none());
295 }
296
297 #[tokio::test]
298 async fn cleanup_expired_removes_entries() {
299 let store = SqliteStore::new(":memory:").await.unwrap();
300 let cache = ResponseCache::new(store.pool().clone(), 0);
301 cache.put("key1", "response", "model").await.unwrap();
302 let deleted = cache.cleanup_expired().await.unwrap();
303 assert!(deleted > 0);
304 }
305
306 #[tokio::test]
307 async fn cleanup_does_not_remove_valid_entries() {
308 let cache = test_cache().await;
309 cache.put("key1", "response", "model").await.unwrap();
310 let deleted = cache.cleanup_expired().await.unwrap();
311 assert_eq!(deleted, 0);
312 let result = cache.get("key1").await.unwrap();
313 assert!(result.is_some());
314 }
315
316 #[test]
317 fn compute_key_deterministic() {
318 let k1 = ResponseCache::compute_key("hello", "gpt-4");
319 let k2 = ResponseCache::compute_key("hello", "gpt-4");
320 assert_eq!(k1, k2);
321 }
322
323 #[test]
324 fn compute_key_different_for_different_content() {
325 assert_ne!(
326 ResponseCache::compute_key("hello", "gpt-4"),
327 ResponseCache::compute_key("world", "gpt-4")
328 );
329 }
330
331 #[test]
332 fn compute_key_different_for_different_model() {
333 assert_ne!(
334 ResponseCache::compute_key("hello", "gpt-4"),
335 ResponseCache::compute_key("hello", "gpt-3.5")
336 );
337 }
338
339 #[test]
340 fn compute_key_empty_message() {
341 let k = ResponseCache::compute_key("", "model");
342 assert!(!k.is_empty());
343 }
344
345 #[tokio::test]
346 async fn ttl_extreme_value_does_not_overflow() {
347 let store = SqliteStore::new(":memory:").await.unwrap();
348 let cache = ResponseCache::new(store.pool().clone(), u64::MAX - 1);
350 cache.put("key1", "response", "model").await.unwrap();
352 let result = cache.get("key1").await.unwrap();
354 assert_eq!(result.as_deref(), Some("response"));
355 }
356
357 #[tokio::test]
358 async fn insert_or_replace_updates_existing_entry() {
359 let cache = test_cache().await;
360 cache.put("key1", "first response", "gpt-4").await.unwrap();
361 cache.put("key1", "second response", "gpt-4").await.unwrap();
362 let result = cache.get("key1").await.unwrap();
363 assert_eq!(result.as_deref(), Some("second response"));
364 }
365
366 #[tokio::test]
369 async fn test_semantic_get_empty_cache() {
370 let cache = test_cache().await;
371 let result = cache
372 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
373 .await
374 .unwrap();
375 assert!(result.is_none());
376 }
377
378 #[tokio::test]
379 async fn test_semantic_get_identical_embedding() {
380 let cache = test_cache().await;
381 let embedding = vec![1.0_f32, 0.0, 0.0];
382 cache
383 .put_with_embedding("k1", "response-a", "m1", &embedding, "model-a")
384 .await
385 .unwrap();
386 let result = cache
387 .get_semantic(&embedding, "model-a", 0.9, 10)
388 .await
389 .unwrap();
390 assert!(result.is_some());
391 let (resp, score) = result.unwrap();
392 assert_eq!(resp, "response-a");
393 assert!(
394 (score - 1.0).abs() < 1e-5,
395 "expected score ~1.0, got {score}"
396 );
397 }
398
399 #[tokio::test]
400 async fn test_semantic_get_orthogonal_vectors() {
401 let cache = test_cache().await;
402 cache
404 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0, 0.0], "model-a")
405 .await
406 .unwrap();
407 let result = cache
409 .get_semantic(&[0.0, 1.0, 0.0], "model-a", 0.5, 10)
410 .await
411 .unwrap();
412 assert!(result.is_none(), "orthogonal vectors should not hit");
413 }
414
415 #[tokio::test]
416 async fn test_semantic_get_similar_above_threshold() {
417 let cache = test_cache().await;
418 let stored = vec![1.0_f32, 0.1, 0.0];
419 cache
420 .put_with_embedding("k1", "response-a", "m1", &stored, "model-a")
421 .await
422 .unwrap();
423 let query = vec![1.0_f32, 0.05, 0.0];
425 let result = cache
426 .get_semantic(&query, "model-a", 0.9, 10)
427 .await
428 .unwrap();
429 assert!(
430 result.is_some(),
431 "similar vector should hit at threshold 0.9"
432 );
433 }
434
435 #[tokio::test]
436 async fn test_semantic_get_similar_below_threshold() {
437 let cache = test_cache().await;
438 cache
440 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0, 0.0], "model-a")
441 .await
442 .unwrap();
443 let query = vec![0.0_f32, 1.0, 0.0];
445 let result = cache
446 .get_semantic(&query, "model-a", 0.95, 10)
447 .await
448 .unwrap();
449 assert!(
450 result.is_none(),
451 "dissimilar vector should not hit at high threshold"
452 );
453 }
454
455 #[tokio::test]
456 async fn test_semantic_get_max_candidates_limit() {
457 let cache = test_cache().await;
458 for i in 0..5_u8 {
460 cache
461 .put_with_embedding(
462 &format!("k{i}"),
463 &format!("response-{i}"),
464 "m1",
465 &[1.0, 0.0],
466 "model-a",
467 )
468 .await
469 .unwrap();
470 }
471 let result = cache
473 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 2)
474 .await
475 .unwrap();
476 assert!(result.is_some());
477 }
478
479 #[tokio::test]
480 async fn test_semantic_get_ignores_expired() {
481 let store = crate::sqlite::SqliteStore::new(":memory:").await.unwrap();
482 let cache = ResponseCache::new(store.pool().clone(), 0);
484 cache
485 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0], "model-a")
486 .await
487 .unwrap();
488 let result = cache
489 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
490 .await
491 .unwrap();
492 assert!(result.is_none(), "expired entries should not be returned");
493 }
494
495 #[tokio::test]
496 async fn test_semantic_get_filters_by_embedding_model() {
497 let cache = test_cache().await;
498 cache
500 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0], "model-a")
501 .await
502 .unwrap();
503 let result = cache
505 .get_semantic(&[1.0, 0.0], "model-b", 0.9, 10)
506 .await
507 .unwrap();
508 assert!(result.is_none(), "wrong embedding model should not match");
509 }
510
511 #[tokio::test]
512 async fn test_put_with_embedding_roundtrip() {
513 let cache = test_cache().await;
514 let embedding = vec![0.5_f32, 0.5, 0.707];
515 cache
516 .put_with_embedding(
517 "key1",
518 "semantic response",
519 "gpt-4",
520 &embedding,
521 "embed-model",
522 )
523 .await
524 .unwrap();
525 let exact = cache.get("key1").await.unwrap();
527 assert_eq!(exact.as_deref(), Some("semantic response"));
528 let semantic = cache
530 .get_semantic(&embedding, "embed-model", 0.99, 10)
531 .await
532 .unwrap();
533 assert!(semantic.is_some());
534 let (resp, score) = semantic.unwrap();
535 assert_eq!(resp, "semantic response");
536 assert!((score - 1.0).abs() < 1e-5);
537 }
538
539 #[tokio::test]
540 async fn test_invalidate_embeddings_for_model() {
541 let cache = test_cache().await;
542 cache
543 .put_with_embedding("k1", "resp", "m1", &[1.0, 0.0], "model-a")
544 .await
545 .unwrap();
546 let updated = cache
547 .invalidate_embeddings_for_model("model-a")
548 .await
549 .unwrap();
550 assert_eq!(updated, 1);
551 let exact = cache.get("k1").await.unwrap();
553 assert_eq!(exact.as_deref(), Some("resp"));
554 let semantic = cache
556 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
557 .await
558 .unwrap();
559 assert!(semantic.is_none());
560 }
561
562 #[tokio::test]
563 async fn test_cleanup_nulls_stale_embeddings() {
564 let cache = test_cache().await;
565 cache
566 .put_with_embedding("k1", "resp", "m1", &[1.0, 0.0], "model-old")
567 .await
568 .unwrap();
569 let affected = cache.cleanup("model-new").await.unwrap();
570 assert!(affected > 0, "should have updated stale embedding row");
571 let exact = cache.get("k1").await.unwrap();
573 assert_eq!(exact.as_deref(), Some("resp"));
574 let semantic = cache
576 .get_semantic(&[1.0, 0.0], "model-old", 0.9, 10)
577 .await
578 .unwrap();
579 assert!(semantic.is_none());
580 }
581
582 #[tokio::test]
583 async fn test_cleanup_deletes_expired() {
584 let store = crate::sqlite::SqliteStore::new(":memory:").await.unwrap();
585 let cache = ResponseCache::new(store.pool().clone(), 0);
586 cache.put("k1", "resp", "m1").await.unwrap();
587 let affected = cache.cleanup("model-a").await.unwrap();
588 assert!(affected > 0);
589 let result = cache.get("k1").await.unwrap();
590 assert!(result.is_none());
591 }
592
593 #[tokio::test]
594 async fn test_cleanup_preserves_valid() {
595 let cache = test_cache().await;
596 cache
597 .put_with_embedding("k1", "resp", "m1", &[1.0, 0.0], "model-current")
598 .await
599 .unwrap();
600 let affected = cache.cleanup("model-current").await.unwrap();
601 assert_eq!(affected, 0, "valid entries should not be affected");
602 let semantic = cache
603 .get_semantic(&[1.0, 0.0], "model-current", 0.9, 10)
604 .await
605 .unwrap();
606 assert!(semantic.is_some());
607 }
608
609 async fn insert_corrupt_blob(pool: &SqlitePool, key: &str, blob: &[u8]) {
620 let now = unix_now();
621 let expires_at = now + 3600;
622 sqlx::query(
623 "INSERT INTO response_cache \
624 (cache_key, response, model, created_at, expires_at, embedding, embedding_model, embedding_ts) \
625 VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
626 )
627 .bind(key)
628 .bind("corrupt-response")
629 .bind("m1")
630 .bind(now)
631 .bind(expires_at)
632 .bind(blob)
633 .bind("model-a")
634 .bind(now)
635 .execute(pool)
636 .await
637 .unwrap();
638 }
639
640 #[tokio::test]
641 async fn test_semantic_get_corrupted_blob_odd_length() {
642 let store = SqliteStore::new(":memory:").await.unwrap();
645 let pool = store.pool().clone();
646 let cache = ResponseCache::new(pool.clone(), 3600);
647
648 insert_corrupt_blob(&pool, "corrupt-key", &[0xAB, 0xCD, 0xEF, 0x01, 0x02]).await;
649
650 let result = cache
651 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.9, 10)
652 .await
653 .unwrap();
654 assert!(
655 result.is_none(),
656 "corrupt odd-length BLOB must yield Ok(None)"
657 );
658 }
659
660 #[tokio::test]
661 async fn test_semantic_get_corrupted_blob_skips_to_valid() {
662 let store = SqliteStore::new(":memory:").await.unwrap();
666 let pool = store.pool().clone();
667 let cache = ResponseCache::new(pool.clone(), 3600);
668
669 insert_corrupt_blob(&pool, "corrupt-key", &[0x01, 0x02, 0x03]).await;
671
672 let valid_embedding = vec![1.0_f32, 0.0, 0.0];
674 cache
675 .put_with_embedding(
676 "valid-key",
677 "valid-response",
678 "m1",
679 &valid_embedding,
680 "model-a",
681 )
682 .await
683 .unwrap();
684
685 let result = cache
686 .get_semantic(&valid_embedding, "model-a", 0.9, 10)
687 .await
688 .unwrap();
689 assert!(
690 result.is_some(),
691 "valid row must be returned despite corrupt sibling"
692 );
693 let (resp, score) = result.unwrap();
694 assert_eq!(resp, "valid-response");
695 assert!(
696 (score - 1.0).abs() < 1e-5,
697 "identical vectors must yield score ~1.0, got {score}"
698 );
699 }
700
701 #[tokio::test]
702 async fn test_semantic_get_empty_blob() {
703 let store = SqliteStore::new(":memory:").await.unwrap();
707 let pool = store.pool().clone();
708 let cache = ResponseCache::new(pool.clone(), 3600);
709
710 insert_corrupt_blob(&pool, "empty-blob-key", &[]).await;
711
712 let result = cache
713 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
714 .await
715 .unwrap();
716 assert!(
717 result.is_none(),
718 "empty BLOB must yield Ok(None) at threshold 0.9"
719 );
720 }
721
722 #[tokio::test]
723 async fn test_semantic_get_all_blobs_corrupted() {
724 let store = SqliteStore::new(":memory:").await.unwrap();
728 let pool = store.pool().clone();
729 let cache = ResponseCache::new(pool.clone(), 3600);
730
731 let corrupt_blobs: &[&[u8]] = &[
732 &[0x01], &[0x01, 0x02, 0x03], &[0x01, 0x02, 0x03, 0x04, 0x05], &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07], &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], ];
738 for (i, blob) in corrupt_blobs.iter().enumerate() {
739 insert_corrupt_blob(&pool, &format!("corrupt-{i}"), blob).await;
740 }
741
742 let result = cache
743 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.9, 10)
744 .await
745 .unwrap();
746 assert!(result.is_none(), "all corrupt BLOBs must yield Ok(None)");
747 }
748
749 #[tokio::test]
752 async fn test_semantic_get_dimension_mismatch_returns_none() {
753 let cache = test_cache().await;
756 cache
757 .put_with_embedding("k1", "resp-3d", "m1", &[1.0, 0.0, 0.0], "model-a")
758 .await
759 .unwrap();
760 let result = cache
761 .get_semantic(&[1.0, 0.0], "model-a", 0.01, 10)
762 .await
763 .unwrap();
764 assert!(
765 result.is_none(),
766 "dimension mismatch must not produce a hit"
767 );
768 }
769
770 #[tokio::test]
771 async fn test_semantic_get_dimension_mismatch_query_longer() {
772 let cache = test_cache().await;
774 cache
775 .put_with_embedding("k1", "resp-2d", "m1", &[1.0, 0.0], "model-a")
776 .await
777 .unwrap();
778 let result = cache
779 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.01, 10)
780 .await
781 .unwrap();
782 assert!(
783 result.is_none(),
784 "query longer than stored embedding must not produce a hit"
785 );
786 }
787
788 #[tokio::test]
789 async fn test_semantic_get_mixed_dimensions_picks_correct_match() {
790 let cache = test_cache().await;
793 cache
794 .put_with_embedding("k-2d", "resp-2d", "m1", &[1.0, 0.0], "model-a")
795 .await
796 .unwrap();
797 cache
798 .put_with_embedding("k-3d", "resp-3d", "m1", &[1.0, 0.0, 0.0], "model-a")
799 .await
800 .unwrap();
801 let result = cache
802 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.9, 10)
803 .await
804 .unwrap();
805 assert!(result.is_some(), "matching dim=3 entry should be returned");
806 let (response, score) = result.unwrap();
807 assert_eq!(response, "resp-3d", "wrong entry returned");
808 assert!(
809 (score - 1.0).abs() < 1e-5,
810 "expected score ~1.0 for identical vectors, got {score}"
811 );
812 }
813}