Skip to main content

vex_router/cache/
mod.rs

1//! String Similarity Caching - Cache responses using character-level hash similarity
2//!
3//! **Note:** Despite the historical naming, this cache uses DJB2-based character hashing
4//! (not neural embeddings) to compute similarity. For true semantic similarity,
5//! integrate an `EmbeddingProvider` via the optional field.
6
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CachedResponse {
15    pub response: String,
16    pub similarity: f32,
17    pub cached_at: i64,
18    pub token_count: u32,
19}
20
21#[derive(Debug, Clone)]
22pub struct CacheEntry {
23    pub response: CachedResponse,
24    pub embedding: Vec<f32>,
25}
26
27pub struct StringSimilarityCache {
28    entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
29    similarity_threshold: f32,
30    max_cache_size: usize,
31    ttl_seconds: i64,
32}
33
34/// Backward-compatible alias
35pub type SemanticCache = StringSimilarityCache;
36
37impl StringSimilarityCache {
38    pub fn new(similarity_threshold: f32, max_cache_size: usize, ttl_seconds: i64) -> Self {
39        Self {
40            entries: Arc::new(RwLock::new(HashMap::new())),
41            similarity_threshold,
42            max_cache_size,
43            ttl_seconds,
44        }
45    }
46
47    pub fn get(&self, query: &str) -> Option<CachedResponse> {
48        let query_embedding = self.compute_embedding(query);
49        let entries = self.entries.read();
50
51        let mut best_match: Option<(f32, &CacheEntry)> = None;
52
53        for (_key, entry) in entries.iter() {
54            let similarity = cosine_similarity(&query_embedding, &entry.embedding);
55
56            if similarity >= self.similarity_threshold
57                && (best_match.is_none() || similarity > best_match.as_ref().unwrap().0)
58            {
59                best_match = Some((similarity, entry));
60            }
61        }
62
63        if let Some((similarity, entry)) = best_match {
64            let now = chrono::Utc::now().timestamp();
65            if now - entry.response.cached_at < self.ttl_seconds {
66                let mut response = entry.response.clone();
67                response.similarity = similarity;
68                return Some(response);
69            }
70        }
71
72        None
73    }
74
75    pub fn store(&self, query: &str, response: String, token_count: u32) {
76        let key = self.compute_key(query);
77        let embedding = self.compute_embedding(query);
78
79        let mut entries = self.entries.write();
80
81        if entries.len() >= self.max_cache_size {
82            if let Some(oldest_key) = entries
83                .iter()
84                .min_by_key(|(_, e)| e.response.cached_at)
85                .map(|(k, _)| k.clone())
86            {
87                entries.remove(&oldest_key);
88            }
89        }
90
91        entries.insert(
92            key,
93            CacheEntry {
94                response: CachedResponse {
95                    response,
96                    similarity: 1.0,
97                    cached_at: chrono::Utc::now().timestamp(),
98                    token_count,
99                },
100                embedding,
101            },
102        );
103    }
104
105    fn compute_key(&self, query: &str) -> String {
106        let mut hasher = Sha256::new();
107        hasher.update(query.as_bytes());
108        hex::encode(hasher.finalize())
109    }
110
111    fn compute_embedding(&self, query: &str) -> Vec<f32> {
112        hash_based_embedding(query)
113    }
114
115    pub fn stats(&self) -> CacheStats {
116        let entries = self.entries.read();
117        let now = chrono::Utc::now().timestamp();
118
119        let valid_entries = entries
120            .values()
121            .filter(|e| now - e.response.cached_at < self.ttl_seconds)
122            .count();
123
124        CacheStats {
125            total_entries: entries.len(),
126            valid_entries,
127            cache_size_bytes: entries
128                .values()
129                .map(|e| e.response.response.len() + e.embedding.len() * 4)
130                .sum(),
131        }
132    }
133
134    pub fn clear(&self) {
135        let mut entries = self.entries.write();
136        entries.clear();
137    }
138}
139
140fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
141    if a.len() != b.len() || a.is_empty() {
142        return 0.0;
143    }
144
145    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
146    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
147    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
148
149    if norm_a == 0.0 || norm_b == 0.0 {
150        return 0.0;
151    }
152
153    dot / (norm_a * norm_b)
154}
155
156/// Compute a hash-based pseudo-embedding vector from text.
157/// Uses DJB2 character hashing — not a neural embedding.
158fn hash_based_embedding(text: &str) -> Vec<f32> {
159    let text_lower = text.to_lowercase();
160    let words: Vec<&str> = text_lower.split_whitespace().collect();
161
162    let mut embedding = vec![0.0f32; 64];
163
164    for (i, word) in words.iter().take(64).enumerate() {
165        let hash = djb2_hash(word);
166        embedding[i % 64] += (hash as f32) / (words.len() as f32).sqrt();
167    }
168
169    let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
170    if norm > 0.0 {
171        for x in &mut embedding {
172            *x /= norm;
173        }
174    }
175
176    embedding
177}
178
179/// DJB2 hash function for string hashing
180fn djb2_hash(s: &str) -> u32 {
181    let mut hash: u32 = 5381;
182    for c in s.bytes() {
183        hash = hash.wrapping_mul(33).wrapping_add(c as u32);
184    }
185    hash
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct CacheStats {
190    pub total_entries: usize,
191    pub valid_entries: usize,
192    pub cache_size_bytes: usize,
193}
194
195impl Default for StringSimilarityCache {
196    fn default() -> Self {
197        Self::new(0.85, 10000, 86400)
198    }
199}