1use zeph_db::DbPool;
5#[allow(unused_imports)]
6use zeph_db::sql;
7
8use crate::error::MemoryError;
9
10pub struct ResponseCache {
11 pool: DbPool,
12 ttl_secs: u64,
13}
14
15impl ResponseCache {
16 #[must_use]
17 pub fn new(pool: DbPool, ttl_secs: u64) -> Self {
18 Self { pool, ttl_secs }
19 }
20
21 pub async fn get(&self, key: &str) -> Result<Option<String>, MemoryError> {
27 let now = unix_now();
28 let row: Option<(String,)> = zeph_db::query_as(sql!(
29 "SELECT response FROM response_cache WHERE cache_key = ? AND expires_at > ?"
30 ))
31 .bind(key)
32 .bind(now)
33 .fetch_optional(&self.pool)
34 .await?;
35 Ok(row.map(|(r,)| r))
36 }
37
38 pub async fn put(&self, key: &str, response: &str, model: &str) -> Result<(), MemoryError> {
44 let now = unix_now();
45 let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
47 zeph_db::query(sql!(
48 "INSERT INTO response_cache (cache_key, response, model, created_at, expires_at) \
49 VALUES (?, ?, ?, ?, ?) \
50 ON CONFLICT(cache_key) DO UPDATE SET \
51 response = excluded.response, model = excluded.model, \
52 created_at = excluded.created_at, expires_at = excluded.expires_at"
53 ))
54 .bind(key)
55 .bind(response)
56 .bind(model)
57 .bind(now)
58 .bind(expires_at)
59 .execute(&self.pool)
60 .await?;
61 Ok(())
62 }
63
64 pub async fn get_semantic(
76 &self,
77 embedding: &[f32],
78 embedding_model: &str,
79 similarity_threshold: f32,
80 max_candidates: u32,
81 ) -> Result<Option<(String, f32)>, MemoryError> {
82 let now = unix_now();
83 let rows: Vec<(String, Vec<u8>)> = zeph_db::query_as(sql!(
84 "SELECT response, embedding FROM response_cache \
85 WHERE embedding_model = ? AND embedding IS NOT NULL AND expires_at > ? \
86 ORDER BY embedding_ts DESC LIMIT ?"
87 ))
88 .bind(embedding_model)
89 .bind(now)
90 .bind(max_candidates)
91 .fetch_all(&self.pool)
92 .await?;
93
94 let mut best_score = -1.0_f32;
95 let mut best_response: Option<String> = None;
96
97 for (response, blob) in &rows {
98 if blob.len() % 4 != 0 {
99 continue;
100 }
101 let stored: Vec<f32> = blob
102 .chunks_exact(4)
103 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
104 .collect();
105 let score = zeph_common::math::cosine_similarity(embedding, &stored);
106 tracing::debug!(
107 score,
108 threshold = similarity_threshold,
109 "semantic cache candidate evaluated",
110 );
111 if score > best_score {
112 best_score = score;
113 best_response = Some(response.clone());
114 }
115 }
116
117 tracing::debug!(
118 examined = rows.len(),
119 best_score,
120 threshold = similarity_threshold,
121 hit = best_score >= similarity_threshold,
122 "semantic cache scan complete",
123 );
124
125 if best_score >= similarity_threshold {
126 Ok(best_response.map(|r| (r, best_score)))
127 } else {
128 Ok(None)
129 }
130 }
131
132 pub async fn put_with_embedding(
140 &self,
141 key: &str,
142 response: &str,
143 model: &str,
144 embedding: &[f32],
145 embedding_model: &str,
146 ) -> Result<(), MemoryError> {
147 let now = unix_now();
148 let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
149 let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
150 zeph_db::query(
151 sql!("INSERT INTO response_cache \
152 (cache_key, response, model, created_at, expires_at, embedding, embedding_model, embedding_ts) \
153 VALUES (?, ?, ?, ?, ?, ?, ?, ?) \
154 ON CONFLICT(cache_key) DO UPDATE SET \
155 response = excluded.response, model = excluded.model, \
156 created_at = excluded.created_at, expires_at = excluded.expires_at, \
157 embedding = excluded.embedding, embedding_model = excluded.embedding_model, \
158 embedding_ts = excluded.embedding_ts"),
159 )
160 .bind(key)
161 .bind(response)
162 .bind(model)
163 .bind(now)
164 .bind(expires_at)
165 .bind(blob)
166 .bind(embedding_model)
167 .bind(now)
168 .execute(&self.pool)
169 .await?;
170 Ok(())
171 }
172
173 pub async fn invalidate_embeddings_for_model(
182 &self,
183 old_model: &str,
184 ) -> Result<u64, MemoryError> {
185 let result = zeph_db::query(sql!(
186 "UPDATE response_cache \
187 SET embedding = NULL, embedding_model = NULL, embedding_ts = NULL \
188 WHERE embedding_model = ?"
189 ))
190 .bind(old_model)
191 .execute(&self.pool)
192 .await?;
193 Ok(result.rows_affected())
194 }
195
196 pub async fn cleanup(&self, current_embedding_model: &str) -> Result<u64, MemoryError> {
208 let now = unix_now();
209 let deleted = zeph_db::query(sql!("DELETE FROM response_cache WHERE expires_at <= ?"))
210 .bind(now)
211 .execute(&self.pool)
212 .await?
213 .rows_affected();
214
215 let updated = zeph_db::query(sql!(
216 "UPDATE response_cache \
217 SET embedding = NULL, embedding_model = NULL, embedding_ts = NULL \
218 WHERE embedding IS NOT NULL AND embedding_model != ?"
219 ))
220 .bind(current_embedding_model)
221 .execute(&self.pool)
222 .await?
223 .rows_affected();
224
225 Ok(deleted + updated)
226 }
227
228 pub async fn cleanup_expired(&self) -> Result<u64, MemoryError> {
234 let now = unix_now();
235 let result = zeph_db::query(sql!("DELETE FROM response_cache WHERE expires_at <= ?"))
236 .bind(now)
237 .execute(&self.pool)
238 .await?;
239 Ok(result.rows_affected())
240 }
241
242 #[must_use]
249 pub fn compute_key(last_user_message: &str, model: &str) -> String {
250 let mut hasher = blake3::Hasher::new();
251 let content = last_user_message.as_bytes();
252 hasher.update(&(content.len() as u64).to_le_bytes());
253 hasher.update(content);
254 let model_bytes = model.as_bytes();
255 hasher.update(&(model_bytes.len() as u64).to_le_bytes());
256 hasher.update(model_bytes);
257 hasher.finalize().to_hex().to_string()
258 }
259}
260
261fn unix_now() -> i64 {
262 std::time::SystemTime::now()
263 .duration_since(std::time::UNIX_EPOCH)
264 .unwrap_or_default()
265 .as_secs()
266 .cast_signed()
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::store::SqliteStore;
273
274 async fn test_cache() -> ResponseCache {
275 let store = SqliteStore::new(":memory:").await.unwrap();
276 ResponseCache::new(store.pool().clone(), 3600)
277 }
278
279 #[tokio::test]
280 async fn cache_miss_returns_none() {
281 let cache = test_cache().await;
282 let result = cache.get("nonexistent").await.unwrap();
283 assert!(result.is_none());
284 }
285
286 #[tokio::test]
287 async fn cache_put_and_get_roundtrip() {
288 let cache = test_cache().await;
289 cache.put("key1", "response text", "gpt-4").await.unwrap();
290 let result = cache.get("key1").await.unwrap();
291 assert_eq!(result.as_deref(), Some("response text"));
292 }
293
294 #[tokio::test]
295 async fn cache_expired_entry_returns_none() {
296 let store = SqliteStore::new(":memory:").await.unwrap();
297 let cache = ResponseCache::new(store.pool().clone(), 0);
298 cache.put("key1", "response", "model").await.unwrap();
300 let result = cache.get("key1").await.unwrap();
302 assert!(result.is_none());
303 }
304
305 #[tokio::test]
306 async fn cleanup_expired_removes_entries() {
307 let store = SqliteStore::new(":memory:").await.unwrap();
308 let cache = ResponseCache::new(store.pool().clone(), 0);
309 cache.put("key1", "response", "model").await.unwrap();
310 let deleted = cache.cleanup_expired().await.unwrap();
311 assert!(deleted > 0);
312 }
313
314 #[tokio::test]
315 async fn cleanup_does_not_remove_valid_entries() {
316 let cache = test_cache().await;
317 cache.put("key1", "response", "model").await.unwrap();
318 let deleted = cache.cleanup_expired().await.unwrap();
319 assert_eq!(deleted, 0);
320 let result = cache.get("key1").await.unwrap();
321 assert!(result.is_some());
322 }
323
324 #[test]
325 fn compute_key_deterministic() {
326 let k1 = ResponseCache::compute_key("hello", "gpt-4");
327 let k2 = ResponseCache::compute_key("hello", "gpt-4");
328 assert_eq!(k1, k2);
329 }
330
331 #[test]
332 fn compute_key_different_for_different_content() {
333 assert_ne!(
334 ResponseCache::compute_key("hello", "gpt-4"),
335 ResponseCache::compute_key("world", "gpt-4")
336 );
337 }
338
339 #[test]
340 fn compute_key_different_for_different_model() {
341 assert_ne!(
342 ResponseCache::compute_key("hello", "gpt-4"),
343 ResponseCache::compute_key("hello", "gpt-3.5")
344 );
345 }
346
347 #[test]
348 fn compute_key_empty_message() {
349 let k = ResponseCache::compute_key("", "model");
350 assert!(!k.is_empty());
351 }
352
353 #[tokio::test]
354 async fn ttl_extreme_value_does_not_overflow() {
355 let store = SqliteStore::new(":memory:").await.unwrap();
356 let cache = ResponseCache::new(store.pool().clone(), u64::MAX - 1);
358 cache.put("key1", "response", "model").await.unwrap();
360 let result = cache.get("key1").await.unwrap();
362 assert_eq!(result.as_deref(), Some("response"));
363 }
364
365 #[tokio::test]
366 async fn insert_or_replace_updates_existing_entry() {
367 let cache = test_cache().await;
368 cache.put("key1", "first response", "gpt-4").await.unwrap();
369 cache.put("key1", "second response", "gpt-4").await.unwrap();
370 let result = cache.get("key1").await.unwrap();
371 assert_eq!(result.as_deref(), Some("second response"));
372 }
373
374 #[tokio::test]
377 async fn test_semantic_get_empty_cache() {
378 let cache = test_cache().await;
379 let result = cache
380 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
381 .await
382 .unwrap();
383 assert!(result.is_none());
384 }
385
386 #[tokio::test]
387 async fn test_semantic_get_identical_embedding() {
388 let cache = test_cache().await;
389 let embedding = vec![1.0_f32, 0.0, 0.0];
390 cache
391 .put_with_embedding("k1", "response-a", "m1", &embedding, "model-a")
392 .await
393 .unwrap();
394 let result = cache
395 .get_semantic(&embedding, "model-a", 0.9, 10)
396 .await
397 .unwrap();
398 assert!(result.is_some());
399 let (resp, score) = result.unwrap();
400 assert_eq!(resp, "response-a");
401 assert!(
402 (score - 1.0).abs() < 1e-5,
403 "expected score ~1.0, got {score}"
404 );
405 }
406
407 #[tokio::test]
408 async fn test_semantic_get_orthogonal_vectors() {
409 let cache = test_cache().await;
410 cache
412 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0, 0.0], "model-a")
413 .await
414 .unwrap();
415 let result = cache
417 .get_semantic(&[0.0, 1.0, 0.0], "model-a", 0.5, 10)
418 .await
419 .unwrap();
420 assert!(result.is_none(), "orthogonal vectors should not hit");
421 }
422
423 #[tokio::test]
424 async fn test_semantic_get_similar_above_threshold() {
425 let cache = test_cache().await;
426 let stored = vec![1.0_f32, 0.1, 0.0];
427 cache
428 .put_with_embedding("k1", "response-a", "m1", &stored, "model-a")
429 .await
430 .unwrap();
431 let query = vec![1.0_f32, 0.05, 0.0];
433 let result = cache
434 .get_semantic(&query, "model-a", 0.9, 10)
435 .await
436 .unwrap();
437 assert!(
438 result.is_some(),
439 "similar vector should hit at threshold 0.9"
440 );
441 }
442
443 #[tokio::test]
444 async fn test_semantic_get_similar_below_threshold() {
445 let cache = test_cache().await;
446 cache
448 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0, 0.0], "model-a")
449 .await
450 .unwrap();
451 let query = vec![0.0_f32, 1.0, 0.0];
453 let result = cache
454 .get_semantic(&query, "model-a", 0.95, 10)
455 .await
456 .unwrap();
457 assert!(
458 result.is_none(),
459 "dissimilar vector should not hit at high threshold"
460 );
461 }
462
463 #[tokio::test]
464 async fn test_semantic_get_max_candidates_limit() {
465 let cache = test_cache().await;
466 for i in 0..5_u8 {
468 cache
469 .put_with_embedding(
470 &format!("k{i}"),
471 &format!("response-{i}"),
472 "m1",
473 &[1.0, 0.0],
474 "model-a",
475 )
476 .await
477 .unwrap();
478 }
479 let result = cache
481 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 2)
482 .await
483 .unwrap();
484 assert!(result.is_some());
485 }
486
487 #[tokio::test]
488 async fn test_semantic_get_ignores_expired() {
489 let store = crate::store::SqliteStore::new(":memory:").await.unwrap();
490 let cache = ResponseCache::new(store.pool().clone(), 0);
492 cache
493 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0], "model-a")
494 .await
495 .unwrap();
496 let result = cache
497 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
498 .await
499 .unwrap();
500 assert!(result.is_none(), "expired entries should not be returned");
501 }
502
503 #[tokio::test]
504 async fn test_semantic_get_filters_by_embedding_model() {
505 let cache = test_cache().await;
506 cache
508 .put_with_embedding("k1", "response-a", "m1", &[1.0, 0.0], "model-a")
509 .await
510 .unwrap();
511 let result = cache
513 .get_semantic(&[1.0, 0.0], "model-b", 0.9, 10)
514 .await
515 .unwrap();
516 assert!(result.is_none(), "wrong embedding model should not match");
517 }
518
519 #[tokio::test]
520 async fn test_put_with_embedding_roundtrip() {
521 let cache = test_cache().await;
522 let embedding = vec![0.5_f32, 0.5, 0.707];
523 cache
524 .put_with_embedding(
525 "key1",
526 "semantic response",
527 "gpt-4",
528 &embedding,
529 "embed-model",
530 )
531 .await
532 .unwrap();
533 let exact = cache.get("key1").await.unwrap();
535 assert_eq!(exact.as_deref(), Some("semantic response"));
536 let semantic = cache
538 .get_semantic(&embedding, "embed-model", 0.99, 10)
539 .await
540 .unwrap();
541 assert!(semantic.is_some());
542 let (resp, score) = semantic.unwrap();
543 assert_eq!(resp, "semantic response");
544 assert!((score - 1.0).abs() < 1e-5);
545 }
546
547 #[tokio::test]
548 async fn test_invalidate_embeddings_for_model() {
549 let cache = test_cache().await;
550 cache
551 .put_with_embedding("k1", "resp", "m1", &[1.0, 0.0], "model-a")
552 .await
553 .unwrap();
554 let updated = cache
555 .invalidate_embeddings_for_model("model-a")
556 .await
557 .unwrap();
558 assert_eq!(updated, 1);
559 let exact = cache.get("k1").await.unwrap();
561 assert_eq!(exact.as_deref(), Some("resp"));
562 let semantic = cache
564 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
565 .await
566 .unwrap();
567 assert!(semantic.is_none());
568 }
569
570 #[tokio::test]
571 async fn test_cleanup_nulls_stale_embeddings() {
572 let cache = test_cache().await;
573 cache
574 .put_with_embedding("k1", "resp", "m1", &[1.0, 0.0], "model-old")
575 .await
576 .unwrap();
577 let affected = cache.cleanup("model-new").await.unwrap();
578 assert!(affected > 0, "should have updated stale embedding row");
579 let exact = cache.get("k1").await.unwrap();
581 assert_eq!(exact.as_deref(), Some("resp"));
582 let semantic = cache
584 .get_semantic(&[1.0, 0.0], "model-old", 0.9, 10)
585 .await
586 .unwrap();
587 assert!(semantic.is_none());
588 }
589
590 #[tokio::test]
591 async fn test_cleanup_deletes_expired() {
592 let store = crate::store::SqliteStore::new(":memory:").await.unwrap();
593 let cache = ResponseCache::new(store.pool().clone(), 0);
594 cache.put("k1", "resp", "m1").await.unwrap();
595 let affected = cache.cleanup("model-a").await.unwrap();
596 assert!(affected > 0);
597 let result = cache.get("k1").await.unwrap();
598 assert!(result.is_none());
599 }
600
601 #[tokio::test]
602 async fn test_cleanup_preserves_valid() {
603 let cache = test_cache().await;
604 cache
605 .put_with_embedding("k1", "resp", "m1", &[1.0, 0.0], "model-current")
606 .await
607 .unwrap();
608 let affected = cache.cleanup("model-current").await.unwrap();
609 assert_eq!(affected, 0, "valid entries should not be affected");
610 let semantic = cache
611 .get_semantic(&[1.0, 0.0], "model-current", 0.9, 10)
612 .await
613 .unwrap();
614 assert!(semantic.is_some());
615 }
616
617 async fn insert_corrupt_blob(pool: &DbPool, key: &str, blob: &[u8]) {
628 let now = unix_now();
629 let expires_at = now + 3600;
630 zeph_db::query(
631 sql!("INSERT INTO response_cache \
632 (cache_key, response, model, created_at, expires_at, embedding, embedding_model, embedding_ts) \
633 VALUES (?, ?, ?, ?, ?, ?, ?, ?)"),
634 )
635 .bind(key)
636 .bind("corrupt-response")
637 .bind("m1")
638 .bind(now)
639 .bind(expires_at)
640 .bind(blob)
641 .bind("model-a")
642 .bind(now)
643 .execute(pool)
644 .await
645 .unwrap();
646 }
647
648 #[tokio::test]
649 async fn test_semantic_get_corrupted_blob_odd_length() {
650 let store = SqliteStore::new(":memory:").await.unwrap();
653 let pool = store.pool().clone();
654 let cache = ResponseCache::new(pool.clone(), 3600);
655
656 insert_corrupt_blob(&pool, "corrupt-key", &[0xAB, 0xCD, 0xEF, 0x01, 0x02]).await;
657
658 let result = cache
659 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.9, 10)
660 .await
661 .unwrap();
662 assert!(
663 result.is_none(),
664 "corrupt odd-length BLOB must yield Ok(None)"
665 );
666 }
667
668 #[tokio::test]
669 async fn test_semantic_get_corrupted_blob_skips_to_valid() {
670 let store = SqliteStore::new(":memory:").await.unwrap();
674 let pool = store.pool().clone();
675 let cache = ResponseCache::new(pool.clone(), 3600);
676
677 insert_corrupt_blob(&pool, "corrupt-key", &[0x01, 0x02, 0x03]).await;
679
680 let valid_embedding = vec![1.0_f32, 0.0, 0.0];
682 cache
683 .put_with_embedding(
684 "valid-key",
685 "valid-response",
686 "m1",
687 &valid_embedding,
688 "model-a",
689 )
690 .await
691 .unwrap();
692
693 let result = cache
694 .get_semantic(&valid_embedding, "model-a", 0.9, 10)
695 .await
696 .unwrap();
697 assert!(
698 result.is_some(),
699 "valid row must be returned despite corrupt sibling"
700 );
701 let (resp, cache_score) = result.unwrap();
702 assert_eq!(resp, "valid-response");
703 assert!(
704 (cache_score - 1.0).abs() < 1e-5,
705 "identical vectors must yield score ~1.0, got {cache_score}"
706 );
707 }
708
709 #[tokio::test]
710 async fn test_semantic_get_empty_blob() {
711 let store = SqliteStore::new(":memory:").await.unwrap();
715 let pool = store.pool().clone();
716 let cache = ResponseCache::new(pool.clone(), 3600);
717
718 insert_corrupt_blob(&pool, "empty-blob-key", &[]).await;
719
720 let result = cache
721 .get_semantic(&[1.0, 0.0], "model-a", 0.9, 10)
722 .await
723 .unwrap();
724 assert!(
725 result.is_none(),
726 "empty BLOB must yield Ok(None) at threshold 0.9"
727 );
728 }
729
730 #[tokio::test]
731 async fn test_semantic_get_all_blobs_corrupted() {
732 let store = SqliteStore::new(":memory:").await.unwrap();
736 let pool = store.pool().clone();
737 let cache = ResponseCache::new(pool.clone(), 3600);
738
739 let corrupt_blobs: &[&[u8]] = &[
740 &[0x01], &[0x01, 0x02, 0x03], &[0x01, 0x02, 0x03, 0x04, 0x05], &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07], &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], ];
746 for (i, blob) in corrupt_blobs.iter().enumerate() {
747 insert_corrupt_blob(&pool, &format!("corrupt-{i}"), blob).await;
748 }
749
750 let result = cache
751 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.9, 10)
752 .await
753 .unwrap();
754 assert!(result.is_none(), "all corrupt BLOBs must yield Ok(None)");
755 }
756
757 #[tokio::test]
760 async fn test_semantic_get_dimension_mismatch_returns_none() {
761 let cache = test_cache().await;
764 cache
765 .put_with_embedding("k1", "resp-3d", "m1", &[1.0, 0.0, 0.0], "model-a")
766 .await
767 .unwrap();
768 let result = cache
769 .get_semantic(&[1.0, 0.0], "model-a", 0.01, 10)
770 .await
771 .unwrap();
772 assert!(
773 result.is_none(),
774 "dimension mismatch must not produce a hit"
775 );
776 }
777
778 #[tokio::test]
779 async fn test_semantic_get_dimension_mismatch_query_longer() {
780 let cache = test_cache().await;
782 cache
783 .put_with_embedding("k1", "resp-2d", "m1", &[1.0, 0.0], "model-a")
784 .await
785 .unwrap();
786 let result = cache
787 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.01, 10)
788 .await
789 .unwrap();
790 assert!(
791 result.is_none(),
792 "query longer than stored embedding must not produce a hit"
793 );
794 }
795
796 #[tokio::test]
797 async fn test_semantic_get_mixed_dimensions_picks_correct_match() {
798 let cache = test_cache().await;
801 cache
802 .put_with_embedding("k-2d", "resp-2d", "m1", &[1.0, 0.0], "model-a")
803 .await
804 .unwrap();
805 cache
806 .put_with_embedding("k-3d", "resp-3d", "m1", &[1.0, 0.0, 0.0], "model-a")
807 .await
808 .unwrap();
809 let result = cache
810 .get_semantic(&[1.0, 0.0, 0.0], "model-a", 0.9, 10)
811 .await
812 .unwrap();
813 assert!(result.is_some(), "matching dim=3 entry should be returned");
814 let (response, score) = result.unwrap();
815 assert_eq!(response, "resp-3d", "wrong entry returned");
816 assert!(
817 (score - 1.0).abs() < 1e-5,
818 "expected score ~1.0 for identical vectors, got {score}"
819 );
820 }
821}