1use zeph_db::DbPool;
10#[allow(unused_imports)]
11use zeph_db::sql;
12
13use crate::error::MemoryError;
14
15pub struct ResponseCache {
35 pool: DbPool,
36 ttl_secs: u64,
37}
38
39impl ResponseCache {
40 #[must_use]
41 pub fn new(pool: DbPool, ttl_secs: u64) -> Self {
42 Self { pool, ttl_secs }
43 }
44
45 pub async fn get(&self, key: &str) -> Result<Option<String>, MemoryError> {
51 let now = unix_now();
52 let row: Option<(String,)> = zeph_db::query_as(sql!(
53 "SELECT response FROM response_cache WHERE cache_key = ? AND expires_at > ?"
54 ))
55 .bind(key)
56 .bind(now)
57 .fetch_optional(&self.pool)
58 .await?;
59 Ok(row.map(|(r,)| r))
60 }
61
62 pub async fn put(&self, key: &str, response: &str, model: &str) -> Result<(), MemoryError> {
68 let now = unix_now();
69 let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
71 zeph_db::query(sql!(
72 "INSERT INTO response_cache (cache_key, response, model, created_at, expires_at) \
73 VALUES (?, ?, ?, ?, ?) \
74 ON CONFLICT(cache_key) DO UPDATE SET \
75 response = excluded.response, model = excluded.model, \
76 created_at = excluded.created_at, expires_at = excluded.expires_at"
77 ))
78 .bind(key)
79 .bind(response)
80 .bind(model)
81 .bind(now)
82 .bind(expires_at)
83 .execute(&self.pool)
84 .await?;
85 Ok(())
86 }
87
88 pub async fn get_semantic(
100 &self,
101 embedding: &[f32],
102 embedding_model: &str,
103 similarity_threshold: f32,
104 max_candidates: u32,
105 ) -> Result<Option<(String, f32)>, MemoryError> {
106 let now = unix_now();
107 let rows: Vec<(String, Vec<u8>)> = zeph_db::query_as(sql!(
108 "SELECT response, embedding FROM response_cache \
109 WHERE embedding_model = ? AND embedding IS NOT NULL AND expires_at > ? \
110 ORDER BY embedding_ts DESC LIMIT ?"
111 ))
112 .bind(embedding_model)
113 .bind(now)
114 .bind(max_candidates)
115 .fetch_all(&self.pool)
116 .await?;
117
118 let mut best_score = -1.0_f32;
119 let mut best_response: Option<String> = None;
120
121 for (response, blob) in &rows {
122 if blob.len() % 4 != 0 {
123 continue;
124 }
125 let stored: Vec<f32> = blob
126 .chunks_exact(4)
127 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
128 .collect();
129 let score = zeph_common::math::cosine_similarity(embedding, &stored);
130 tracing::debug!(
131 score,
132 threshold = similarity_threshold,
133 "semantic cache candidate evaluated",
134 );
135 if score > best_score {
136 best_score = score;
137 best_response = Some(response.clone());
138 }
139 }
140
141 tracing::debug!(
142 examined = rows.len(),
143 best_score,
144 threshold = similarity_threshold,
145 hit = best_score >= similarity_threshold,
146 "semantic cache scan complete",
147 );
148
149 if best_score >= similarity_threshold {
150 Ok(best_response.map(|r| (r, best_score)))
151 } else {
152 Ok(None)
153 }
154 }
155
156 pub async fn put_with_embedding(
164 &self,
165 key: &str,
166 response: &str,
167 model: &str,
168 embedding: &[f32],
169 embedding_model: &str,
170 ) -> Result<(), MemoryError> {
171 let now = unix_now();
172 let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
173 let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
174 zeph_db::query(
175 sql!("INSERT INTO response_cache \
176 (cache_key, response, model, created_at, expires_at, embedding, embedding_model, embedding_ts) \
177 VALUES (?, ?, ?, ?, ?, ?, ?, ?) \
178 ON CONFLICT(cache_key) DO UPDATE SET \
179 response = excluded.response, model = excluded.model, \
180 created_at = excluded.created_at, expires_at = excluded.expires_at, \
181 embedding = excluded.embedding, embedding_model = excluded.embedding_model, \
182 embedding_ts = excluded.embedding_ts"),
183 )
184 .bind(key)
185 .bind(response)
186 .bind(model)
187 .bind(now)
188 .bind(expires_at)
189 .bind(blob)
190 .bind(embedding_model)
191 .bind(now)
192 .execute(&self.pool)
193 .await?;
194 Ok(())
195 }
196
197 pub async fn invalidate_embeddings_for_model(
206 &self,
207 old_model: &str,
208 ) -> Result<u64, MemoryError> {
209 let result = zeph_db::query(sql!(
210 "UPDATE response_cache \
211 SET embedding = NULL, embedding_model = NULL, embedding_ts = NULL \
212 WHERE embedding_model = ?"
213 ))
214 .bind(old_model)
215 .execute(&self.pool)
216 .await?;
217 Ok(result.rows_affected())
218 }
219
220 pub async fn cleanup(&self, current_embedding_model: &str) -> Result<u64, MemoryError> {
232 let now = unix_now();
233 let deleted = zeph_db::query(sql!("DELETE FROM response_cache WHERE expires_at <= ?"))
234 .bind(now)
235 .execute(&self.pool)
236 .await?
237 .rows_affected();
238
239 let updated = zeph_db::query(sql!(
240 "UPDATE response_cache \
241 SET embedding = NULL, embedding_model = NULL, embedding_ts = NULL \
242 WHERE embedding IS NOT NULL AND embedding_model != ?"
243 ))
244 .bind(current_embedding_model)
245 .execute(&self.pool)
246 .await?
247 .rows_affected();
248
249 Ok(deleted + updated)
250 }
251
252 pub async fn cleanup_expired(&self) -> Result<u64, MemoryError> {
258 let now = unix_now();
259 let result = zeph_db::query(sql!("DELETE FROM response_cache WHERE expires_at <= ?"))
260 .bind(now)
261 .execute(&self.pool)
262 .await?;
263 Ok(result.rows_affected())
264 }
265
266 #[must_use]
273 pub fn compute_key(last_user_message: &str, model: &str) -> String {
274 let mut hasher = blake3::Hasher::new();
275 let content = last_user_message.as_bytes();
276 hasher.update(&(content.len() as u64).to_le_bytes());
277 hasher.update(content);
278 let model_bytes = model.as_bytes();
279 hasher.update(&(model_bytes.len() as u64).to_le_bytes());
280 hasher.update(model_bytes);
281 hasher.finalize().to_hex().to_string()
282 }
283}
284
285fn unix_now() -> i64 {
286 std::time::SystemTime::now()
287 .duration_since(std::time::UNIX_EPOCH)
288 .unwrap_or_default()
289 .as_secs()
290 .cast_signed()
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::store::SqliteStore;
297
298 async fn test_cache() -> ResponseCache {
299 let store = SqliteStore::new(":memory:").await.unwrap();
300 ResponseCache::new(store.pool().clone(), 3600)
301 }
302
303 #[tokio::test]
304 async fn cache_miss_returns_none() {
305 let cache = test_cache().await;
306 let result = cache.get("nonexistent").await.unwrap();
307 assert!(result.is_none());
308 }
309
310 #[tokio::test]
311 async fn cache_put_and_get_roundtrip() {
312 let cache = test_cache().await;
313 cache.put("key1", "response text", "gpt-4").await.unwrap();
314 let result = cache.get("key1").await.unwrap();
315 assert_eq!(result.as_deref(), Some("response text"));
316 }
317
318 #[tokio::test]
319 async fn cache_expired_entry_returns_none() {
320 let store = SqliteStore::new(":memory:").await.unwrap();
321 let cache = ResponseCache::new(store.pool().clone(), 0);
322 cache.put("key1", "response", "model").await.unwrap();
324 let result = cache.get("key1").await.unwrap();
326 assert!(result.is_none());
327 }
328
329 #[tokio::test]
330 async fn cleanup_expired_removes_entries() {
331 let store = SqliteStore::new(":memory:").await.unwrap();
332 let cache = ResponseCache::new(store.pool().clone(), 0);
333 cache.put("key1", "response", "model").await.unwrap();
334 let deleted = cache.cleanup_expired().await.unwrap();
335 assert!(deleted > 0);
336 }
337
338 #[tokio::test]
339 async fn cleanup_does_not_remove_valid_entries() {
340 let cache = test_cache().await;
341 cache.put("key1", "response", "model").await.unwrap();
342 let deleted = cache.cleanup_expired().await.unwrap();
343 assert_eq!(deleted, 0);
344 let result = cache.get("key1").await.unwrap();
345 assert!(result.is_some());
346 }
347
348 #[test]
349 fn compute_key_deterministic() {
350 let k1 = ResponseCache::compute_key("hello", "gpt-4");
351 let k2 = ResponseCache::compute_key("hello", "gpt-4");
352 assert_eq!(k1, k2);
353 }
354
355 #[test]
356 fn compute_key_different_for_different_content() {
357 assert_ne!(
358 ResponseCache::compute_key("hello", "gpt-4"),
359 ResponseCache::compute_key("world", "gpt-4")
360 );
361 }
362
363 #[test]
364 fn compute_key_different_for_different_model() {
365 assert_ne!(
366 ResponseCache::compute_key("hello", "gpt-4"),
367 ResponseCache::compute_key("hello", "gpt-3.5")
368 );
369 }
370
371 #[test]
372 fn compute_key_empty_message() {
373 let k = ResponseCache::compute_key("", "model");
374 assert!(!k.is_empty());
375 }
376
377 #[tokio::test]
378 async fn ttl_extreme_value_does_not_overflow() {
379 let store = SqliteStore::new(":memory:").await.unwrap();
380 let cache = ResponseCache::new(store.pool().clone(), u64::MAX - 1);
382 cache.put("key1", "response", "model").await.unwrap();
384 let result = cache.get("key1").await.unwrap();
386 assert_eq!(result.as_deref(), Some("response"));
387 }
388
389 #[tokio::test]
390 async fn insert_or_replace_updates_existing_entry() {
391 let cache = test_cache().await;
392 cache.put("key1", "first response", "gpt-4").await.unwrap();
393 cache.put("key1", "second response", "gpt-4").await.unwrap();
394 let result = cache.get("key1").await.unwrap();
395 assert_eq!(result.as_deref(), Some("second response"));
396 }
397
398 #[tokio::test]
401 async fn test_semantic_get_empty_cache() {
402 let cache = test_cache().await;
403 let result = cache
404 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
405 .await
406 .unwrap();
407 assert!(result.is_none());
408 }
409
410 #[tokio::test]
411 async fn test_semantic_get_identical_embedding() {
412 let cache = test_cache().await;
413 let embedding = vec![1.0_f32, 0.0, 0.0];
414 cache
415 .put_with_embedding("k1", "response-a", "m1", &embedding, "model-a")
416 .await
417 .unwrap();
418 let result = cache
419 .get_semantic(&embedding, "model-a", 0.9, 10)
420 .await
421 .unwrap();
422 assert!(result.is_some());
423 let (resp, score) = result.unwrap();
424 assert_eq!(resp, "response-a");
425 assert!(
426 (score - 1.0).abs() < 1e-5,
427 "expected score ~1.0, got {score}"
428 );
429 }
430
431 #[tokio::test]
432 async fn test_semantic_get_orthogonal_vectors() {
433 let cache = test_cache().await;
434 cache
436 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0, 0.0], "model-a")
437 .await
438 .unwrap();
439 let result = cache
441 .get_semantic(&[0.0, 1.0, 0.0], "model-a", 0.5, 10)
442 .await
443 .unwrap();
444 assert!(result.is_none(), "orthogonal vectors should not hit");
445 }
446
447 #[tokio::test]
448 async fn test_semantic_get_similar_above_threshold() {
449 let cache = test_cache().await;
450 let stored = vec![1.0_f32, 0.1, 0.0];
451 cache
452 .put_with_embedding("k1", "response-a", "m1", &stored, "model-a")
453 .await
454 .unwrap();
455 let query = vec![1.0_f32, 0.05, 0.0];
457 let result = cache
458 .get_semantic(&query, "model-a", 0.9, 10)
459 .await
460 .unwrap();
461 assert!(
462 result.is_some(),
463 "similar vector should hit at threshold 0.9"
464 );
465 }
466
467 #[tokio::test]
468 async fn test_semantic_get_similar_below_threshold() {
469 let cache = test_cache().await;
470 cache
472 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0, 0.0], "model-a")
473 .await
474 .unwrap();
475 let query = vec![0.0_f32, 1.0, 0.0];
477 let result = cache
478 .get_semantic(&query, "model-a", 0.95, 10)
479 .await
480 .unwrap();
481 assert!(
482 result.is_none(),
483 "dissimilar vector should not hit at high threshold"
484 );
485 }
486
487 #[tokio::test]
488 async fn test_semantic_get_max_candidates_limit() {
489 let cache = test_cache().await;
490 for i in 0..5_u8 {
492 cache
493 .put_with_embedding(
494 &format!("k{i}"),
495 &format!("response-{i}"),
496 "m1",
497 &[1.0, 0.0],
498 "model-a",
499 )
500 .await
501 .unwrap();
502 }
503 let result = cache
505 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 2)
506 .await
507 .unwrap();
508 assert!(result.is_some());
509 }
510
511 #[tokio::test]
512 async fn test_semantic_get_ignores_expired() {
513 let store = crate::store::SqliteStore::new(":memory:").await.unwrap();
514 let cache = ResponseCache::new(store.pool().clone(), 0);
516 cache
517 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0], "model-a")
518 .await
519 .unwrap();
520 let result = cache
521 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
522 .await
523 .unwrap();
524 assert!(result.is_none(), "expired entries should not be returned");
525 }
526
527 #[tokio::test]
528 async fn test_semantic_get_filters_by_embedding_model() {
529 let cache = test_cache().await;
530 cache
532 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0], "model-a")
533 .await
534 .unwrap();
535 let result = cache
537 .get_semantic(&[1.0, 0.0], "model-b", 0.9, 10)
538 .await
539 .unwrap();
540 assert!(result.is_none(), "wrong embedding model should not match");
541 }
542
543 #[tokio::test]
544 async fn test_put_with_embedding_roundtrip() {
545 let cache = test_cache().await;
546 let embedding = vec![0.5_f32, 0.5, 0.707];
547 cache
548 .put_with_embedding(
549 "key1",
550 "semantic response",
551 "gpt-4",
552 &embedding,
553 "embed-model",
554 )
555 .await
556 .unwrap();
557 let exact = cache.get("key1").await.unwrap();
559 assert_eq!(exact.as_deref(), Some("semantic response"));
560 let semantic = cache
562 .get_semantic(&embedding, "embed-model", 0.99, 10)
563 .await
564 .unwrap();
565 assert!(semantic.is_some());
566 let (resp, score) = semantic.unwrap();
567 assert_eq!(resp, "semantic response");
568 assert!((score - 1.0).abs() < 1e-5);
569 }
570
571 #[tokio::test]
572 async fn test_invalidate_embeddings_for_model() {
573 let cache = test_cache().await;
574 cache
575 .put_with_embedding("k1", "resp", "m1", &[1.0, 0.0], "model-a")
576 .await
577 .unwrap();
578 let updated = cache
579 .invalidate_embeddings_for_model("model-a")
580 .await
581 .unwrap();
582 assert_eq!(updated, 1);
583 let exact = cache.get("k1").await.unwrap();
585 assert_eq!(exact.as_deref(), Some("resp"));
586 let semantic = cache
588 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
589 .await
590 .unwrap();
591 assert!(semantic.is_none());
592 }
593
594 #[tokio::test]
595 async fn test_cleanup_nulls_stale_embeddings() {
596 let cache = test_cache().await;
597 cache
598 .put_with_embedding("k1", "resp", "m1", &[1.0, 0.0], "model-old")
599 .await
600 .unwrap();
601 let affected = cache.cleanup("model-new").await.unwrap();
602 assert!(affected > 0, "should have updated stale embedding row");
603 let exact = cache.get("k1").await.unwrap();
605 assert_eq!(exact.as_deref(), Some("resp"));
606 let semantic = cache
608 .get_semantic(&[1.0, 0.0], "model-old", 0.9, 10)
609 .await
610 .unwrap();
611 assert!(semantic.is_none());
612 }
613
614 #[tokio::test]
615 async fn test_cleanup_deletes_expired() {
616 let store = crate::store::SqliteStore::new(":memory:").await.unwrap();
617 let cache = ResponseCache::new(store.pool().clone(), 0);
618 cache.put("k1", "resp", "m1").await.unwrap();
619 let affected = cache.cleanup("model-a").await.unwrap();
620 assert!(affected > 0);
621 let result = cache.get("k1").await.unwrap();
622 assert!(result.is_none());
623 }
624
625 #[tokio::test]
626 async fn test_cleanup_preserves_valid() {
627 let cache = test_cache().await;
628 cache
629 .put_with_embedding("k1", "resp", "m1", &[1.0, 0.0], "model-current")
630 .await
631 .unwrap();
632 let affected = cache.cleanup("model-current").await.unwrap();
633 assert_eq!(affected, 0, "valid entries should not be affected");
634 let semantic = cache
635 .get_semantic(&[1.0, 0.0], "model-current", 0.9, 10)
636 .await
637 .unwrap();
638 assert!(semantic.is_some());
639 }
640
641 async fn insert_corrupt_blob(pool: &DbPool, key: &str, blob: &[u8]) {
652 let now = unix_now();
653 let expires_at = now + 3600;
654 zeph_db::query(
655 sql!("INSERT INTO response_cache \
656 (cache_key, response, model, created_at, expires_at, embedding, embedding_model, embedding_ts) \
657 VALUES (?, ?, ?, ?, ?, ?, ?, ?)"),
658 )
659 .bind(key)
660 .bind("corrupt-response")
661 .bind("m1")
662 .bind(now)
663 .bind(expires_at)
664 .bind(blob)
665 .bind("model-a")
666 .bind(now)
667 .execute(pool)
668 .await
669 .unwrap();
670 }
671
672 #[tokio::test]
673 async fn test_semantic_get_corrupted_blob_odd_length() {
674 let store = SqliteStore::new(":memory:").await.unwrap();
677 let pool = store.pool().clone();
678 let cache = ResponseCache::new(pool.clone(), 3600);
679
680 insert_corrupt_blob(&pool, "corrupt-key", &[0xAB, 0xCD, 0xEF, 0x01, 0x02]).await;
681
682 let result = cache
683 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.9, 10)
684 .await
685 .unwrap();
686 assert!(
687 result.is_none(),
688 "corrupt odd-length BLOB must yield Ok(None)"
689 );
690 }
691
692 #[tokio::test]
693 async fn test_semantic_get_corrupted_blob_skips_to_valid() {
694 let store = SqliteStore::new(":memory:").await.unwrap();
698 let pool = store.pool().clone();
699 let cache = ResponseCache::new(pool.clone(), 3600);
700
701 insert_corrupt_blob(&pool, "corrupt-key", &[0x01, 0x02, 0x03]).await;
703
704 let valid_embedding = vec![1.0_f32, 0.0, 0.0];
706 cache
707 .put_with_embedding(
708 "valid-key",
709 "valid-response",
710 "m1",
711 &valid_embedding,
712 "model-a",
713 )
714 .await
715 .unwrap();
716
717 let result = cache
718 .get_semantic(&valid_embedding, "model-a", 0.9, 10)
719 .await
720 .unwrap();
721 assert!(
722 result.is_some(),
723 "valid row must be returned despite corrupt sibling"
724 );
725 let (resp, cache_score) = result.unwrap();
726 assert_eq!(resp, "valid-response");
727 assert!(
728 (cache_score - 1.0).abs() < 1e-5,
729 "identical vectors must yield score ~1.0, got {cache_score}"
730 );
731 }
732
733 #[tokio::test]
734 async fn test_semantic_get_empty_blob() {
735 let store = SqliteStore::new(":memory:").await.unwrap();
739 let pool = store.pool().clone();
740 let cache = ResponseCache::new(pool.clone(), 3600);
741
742 insert_corrupt_blob(&pool, "empty-blob-key", &[]).await;
743
744 let result = cache
745 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
746 .await
747 .unwrap();
748 assert!(
749 result.is_none(),
750 "empty BLOB must yield Ok(None) at threshold 0.9"
751 );
752 }
753
754 #[tokio::test]
755 async fn test_semantic_get_all_blobs_corrupted() {
756 let store = SqliteStore::new(":memory:").await.unwrap();
760 let pool = store.pool().clone();
761 let cache = ResponseCache::new(pool.clone(), 3600);
762
763 let corrupt_blobs: &[&[u8]] = &[
764 &[0x01], &[0x01, 0x02, 0x03], &[0x01, 0x02, 0x03, 0x04, 0x05], &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07], &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], ];
770 for (i, blob) in corrupt_blobs.iter().enumerate() {
771 insert_corrupt_blob(&pool, &format!("corrupt-{i}"), blob).await;
772 }
773
774 let result = cache
775 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.9, 10)
776 .await
777 .unwrap();
778 assert!(result.is_none(), "all corrupt BLOBs must yield Ok(None)");
779 }
780
781 #[tokio::test]
784 async fn test_semantic_get_dimension_mismatch_returns_none() {
785 let cache = test_cache().await;
788 cache
789 .put_with_embedding("k1", "resp-3d", "m1", &[1.0, 0.0, 0.0], "model-a")
790 .await
791 .unwrap();
792 let result = cache
793 .get_semantic(&[1.0, 0.0], "model-a", 0.01, 10)
794 .await
795 .unwrap();
796 assert!(
797 result.is_none(),
798 "dimension mismatch must not produce a hit"
799 );
800 }
801
802 #[tokio::test]
803 async fn test_semantic_get_dimension_mismatch_query_longer() {
804 let cache = test_cache().await;
806 cache
807 .put_with_embedding("k1", "resp-2d", "m1", &[1.0, 0.0], "model-a")
808 .await
809 .unwrap();
810 let result = cache
811 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.01, 10)
812 .await
813 .unwrap();
814 assert!(
815 result.is_none(),
816 "query longer than stored embedding must not produce a hit"
817 );
818 }
819
820 #[tokio::test]
821 async fn test_semantic_get_mixed_dimensions_picks_correct_match() {
822 let cache = test_cache().await;
825 cache
826 .put_with_embedding("k-2d", "resp-2d", "m1", &[1.0, 0.0], "model-a")
827 .await
828 .unwrap();
829 cache
830 .put_with_embedding("k-3d", "resp-3d", "m1", &[1.0, 0.0, 0.0], "model-a")
831 .await
832 .unwrap();
833 let result = cache
834 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.9, 10)
835 .await
836 .unwrap();
837 assert!(result.is_some(), "matching dim=3 entry should be returned");
838 let (response, score) = result.unwrap();
839 assert_eq!(response, "resp-3d", "wrong entry returned");
840 assert!(
841 (score - 1.0).abs() < 1e-5,
842 "expected score ~1.0 for identical vectors, got {score}"
843 );
844 }
845}