synaptic_cache/
semantic.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatResponse, SynapseError};
5use synaptic_embeddings::Embeddings;
6use tokio::sync::RwLock;
7
8use crate::LlmCache;
9
10struct SemanticEntry {
11 embedding: Vec<f32>,
12 response: ChatResponse,
13}
14
15pub struct SemanticCache {
21 embeddings: Arc<dyn Embeddings>,
22 entries: RwLock<Vec<SemanticEntry>>,
23 similarity_threshold: f32,
24}
25
26impl SemanticCache {
27 pub fn new(embeddings: Arc<dyn Embeddings>, similarity_threshold: f32) -> Self {
32 Self {
33 embeddings,
34 entries: RwLock::new(Vec::new()),
35 similarity_threshold,
36 }
37 }
38}
39
40#[async_trait]
41impl LlmCache for SemanticCache {
42 async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapseError> {
43 let query_embedding =
44 self.embeddings.embed_query(key).await.map_err(|e| {
45 SynapseError::Cache(format!("embedding error during cache get: {e}"))
46 })?;
47
48 let entries = self.entries.read().await;
49 let mut best_score = f32::NEG_INFINITY;
50 let mut best_response = None;
51
52 for entry in entries.iter() {
53 let score = cosine_similarity(&query_embedding, &entry.embedding);
54 if score >= self.similarity_threshold && score > best_score {
55 best_score = score;
56 best_response = Some(entry.response.clone());
57 }
58 }
59
60 Ok(best_response)
61 }
62
63 async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapseError> {
64 let embedding =
65 self.embeddings.embed_query(key).await.map_err(|e| {
66 SynapseError::Cache(format!("embedding error during cache put: {e}"))
67 })?;
68
69 let mut entries = self.entries.write().await;
70 entries.push(SemanticEntry {
71 embedding,
72 response: response.clone(),
73 });
74
75 Ok(())
76 }
77
78 async fn clear(&self) -> Result<(), SynapseError> {
79 let mut entries = self.entries.write().await;
80 entries.clear();
81 Ok(())
82 }
83}
84
85fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
87 if a.len() != b.len() || a.is_empty() {
88 return 0.0;
89 }
90
91 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
92 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
93 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
94
95 if mag_a == 0.0 || mag_b == 0.0 {
96 return 0.0;
97 }
98
99 dot / (mag_a * mag_b)
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn cosine_similarity_identical_vectors() {
108 let a = vec![1.0, 0.0, 0.0];
109 let b = vec![1.0, 0.0, 0.0];
110 let sim = cosine_similarity(&a, &b);
111 assert!((sim - 1.0).abs() < 1e-6);
112 }
113
114 #[test]
115 fn cosine_similarity_orthogonal_vectors() {
116 let a = vec![1.0, 0.0];
117 let b = vec![0.0, 1.0];
118 let sim = cosine_similarity(&a, &b);
119 assert!(sim.abs() < 1e-6);
120 }
121
122 #[test]
123 fn cosine_similarity_empty_vectors() {
124 let a: Vec<f32> = vec![];
125 let b: Vec<f32> = vec![];
126 assert_eq!(cosine_similarity(&a, &b), 0.0);
127 }
128}