synaptic_cache/
semantic.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatResponse, Embeddings, SynapticError};
5use tokio::sync::RwLock;
6
7use crate::LlmCache;
8
9struct SemanticEntry {
10 embedding: Vec<f32>,
11 response: ChatResponse,
12}
13
14pub struct SemanticCache {
20 embeddings: Arc<dyn Embeddings>,
21 entries: RwLock<Vec<SemanticEntry>>,
22 similarity_threshold: f32,
23}
24
25impl SemanticCache {
26 pub fn new(embeddings: Arc<dyn Embeddings>, similarity_threshold: f32) -> Self {
31 Self {
32 embeddings,
33 entries: RwLock::new(Vec::new()),
34 similarity_threshold,
35 }
36 }
37}
38
39#[async_trait]
40impl LlmCache for SemanticCache {
41 async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError> {
42 let query_embedding =
43 self.embeddings.embed_query(key).await.map_err(|e| {
44 SynapticError::Cache(format!("embedding error during cache get: {e}"))
45 })?;
46
47 let entries = self.entries.read().await;
48 let mut best_score = f32::NEG_INFINITY;
49 let mut best_response = None;
50
51 for entry in entries.iter() {
52 let score = cosine_similarity(&query_embedding, &entry.embedding);
53 if score >= self.similarity_threshold && score > best_score {
54 best_score = score;
55 best_response = Some(entry.response.clone());
56 }
57 }
58
59 Ok(best_response)
60 }
61
62 async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError> {
63 let embedding =
64 self.embeddings.embed_query(key).await.map_err(|e| {
65 SynapticError::Cache(format!("embedding error during cache put: {e}"))
66 })?;
67
68 let mut entries = self.entries.write().await;
69 entries.push(SemanticEntry {
70 embedding,
71 response: response.clone(),
72 });
73
74 Ok(())
75 }
76
77 async fn clear(&self) -> Result<(), SynapticError> {
78 let mut entries = self.entries.write().await;
79 entries.clear();
80 Ok(())
81 }
82}
83
84fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
86 if a.len() != b.len() || a.is_empty() {
87 return 0.0;
88 }
89
90 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
91 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
92 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
93
94 if mag_a == 0.0 || mag_b == 0.0 {
95 return 0.0;
96 }
97
98 dot / (mag_a * mag_b)
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn cosine_similarity_identical_vectors() {
107 let a = vec![1.0, 0.0, 0.0];
108 let b = vec![1.0, 0.0, 0.0];
109 let sim = cosine_similarity(&a, &b);
110 assert!((sim - 1.0).abs() < 1e-6);
111 }
112
113 #[test]
114 fn cosine_similarity_orthogonal_vectors() {
115 let a = vec![1.0, 0.0];
116 let b = vec![0.0, 1.0];
117 let sim = cosine_similarity(&a, &b);
118 assert!(sim.abs() < 1e-6);
119 }
120
121 #[test]
122 fn cosine_similarity_empty_vectors() {
123 let a: Vec<f32> = vec![];
124 let b: Vec<f32> = vec![];
125 assert_eq!(cosine_similarity(&a, &b), 0.0);
126 }
127}