Skip to main content

synaptic_cache/
semantic.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatResponse, Embeddings, SynapticError};
5use tokio::sync::RwLock;
6
7use crate::LlmCache;
8
9struct SemanticEntry {
10    embedding: Vec<f32>,
11    response: ChatResponse,
12}
13
14/// Cache that uses embedding similarity to match semantically equivalent queries.
15///
16/// When a cache lookup is performed, the key is embedded and compared against all
17/// stored entries using cosine similarity. If any entry exceeds the similarity
18/// threshold, its cached response is returned.
19pub struct SemanticCache {
20    embeddings: Arc<dyn Embeddings>,
21    entries: RwLock<Vec<SemanticEntry>>,
22    similarity_threshold: f32,
23}
24
25impl SemanticCache {
26    /// Create a new semantic cache with the given embeddings provider and similarity threshold.
27    ///
28    /// The threshold should be between 0.0 and 1.0. A typical value is 0.95, meaning
29    /// only very similar queries will match.
30    pub fn new(embeddings: Arc<dyn Embeddings>, similarity_threshold: f32) -> Self {
31        Self {
32            embeddings,
33            entries: RwLock::new(Vec::new()),
34            similarity_threshold,
35        }
36    }
37}
38
39#[async_trait]
40impl LlmCache for SemanticCache {
41    async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError> {
42        let query_embedding =
43            self.embeddings.embed_query(key).await.map_err(|e| {
44                SynapticError::Cache(format!("embedding error during cache get: {e}"))
45            })?;
46
47        let entries = self.entries.read().await;
48        let mut best_score = f32::NEG_INFINITY;
49        let mut best_response = None;
50
51        for entry in entries.iter() {
52            let score = cosine_similarity(&query_embedding, &entry.embedding);
53            if score >= self.similarity_threshold && score > best_score {
54                best_score = score;
55                best_response = Some(entry.response.clone());
56            }
57        }
58
59        Ok(best_response)
60    }
61
62    async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError> {
63        let embedding =
64            self.embeddings.embed_query(key).await.map_err(|e| {
65                SynapticError::Cache(format!("embedding error during cache put: {e}"))
66            })?;
67
68        let mut entries = self.entries.write().await;
69        entries.push(SemanticEntry {
70            embedding,
71            response: response.clone(),
72        });
73
74        Ok(())
75    }
76
77    async fn clear(&self) -> Result<(), SynapticError> {
78        let mut entries = self.entries.write().await;
79        entries.clear();
80        Ok(())
81    }
82}
83
84/// Compute cosine similarity between two vectors.
85fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
86    if a.len() != b.len() || a.is_empty() {
87        return 0.0;
88    }
89
90    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
91    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
92    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
93
94    if mag_a == 0.0 || mag_b == 0.0 {
95        return 0.0;
96    }
97
98    dot / (mag_a * mag_b)
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn cosine_similarity_identical_vectors() {
107        let a = vec![1.0, 0.0, 0.0];
108        let b = vec![1.0, 0.0, 0.0];
109        let sim = cosine_similarity(&a, &b);
110        assert!((sim - 1.0).abs() < 1e-6);
111    }
112
113    #[test]
114    fn cosine_similarity_orthogonal_vectors() {
115        let a = vec![1.0, 0.0];
116        let b = vec![0.0, 1.0];
117        let sim = cosine_similarity(&a, &b);
118        assert!(sim.abs() < 1e-6);
119    }
120
121    #[test]
122    fn cosine_similarity_empty_vectors() {
123        let a: Vec<f32> = vec![];
124        let b: Vec<f32> = vec![];
125        assert_eq!(cosine_similarity(&a, &b), 0.0);
126    }
127}