1use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CachedResponse {
15 pub response: String,
16 pub similarity: f32,
17 pub cached_at: i64,
18 pub token_count: u32,
19}
20
21#[derive(Debug, Clone)]
22pub struct CacheEntry {
23 pub response: CachedResponse,
24 pub embedding: Vec<f32>,
25}
26
27pub struct StringSimilarityCache {
28 entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
29 similarity_threshold: f32,
30 max_cache_size: usize,
31 ttl_seconds: i64,
32}
33
34pub type SemanticCache = StringSimilarityCache;
36
37impl StringSimilarityCache {
38 pub fn new(similarity_threshold: f32, max_cache_size: usize, ttl_seconds: i64) -> Self {
39 Self {
40 entries: Arc::new(RwLock::new(HashMap::new())),
41 similarity_threshold,
42 max_cache_size,
43 ttl_seconds,
44 }
45 }
46
47 pub fn get(&self, query: &str) -> Option<CachedResponse> {
48 let query_embedding = self.compute_embedding(query);
49 let entries = self.entries.read();
50
51 let mut best_match: Option<(f32, &CacheEntry)> = None;
52
53 for (_key, entry) in entries.iter() {
54 let similarity = cosine_similarity(&query_embedding, &entry.embedding);
55
56 if similarity >= self.similarity_threshold
57 && (best_match.is_none() || similarity > best_match.as_ref().unwrap().0)
58 {
59 best_match = Some((similarity, entry));
60 }
61 }
62
63 if let Some((similarity, entry)) = best_match {
64 let now = chrono::Utc::now().timestamp();
65 if now - entry.response.cached_at < self.ttl_seconds {
66 let mut response = entry.response.clone();
67 response.similarity = similarity;
68 return Some(response);
69 }
70 }
71
72 None
73 }
74
75 pub fn store(&self, query: &str, response: String, token_count: u32) {
76 let key = self.compute_key(query);
77 let embedding = self.compute_embedding(query);
78
79 let mut entries = self.entries.write();
80
81 if entries.len() >= self.max_cache_size {
82 if let Some(oldest_key) = entries
83 .iter()
84 .min_by_key(|(_, e)| e.response.cached_at)
85 .map(|(k, _)| k.clone())
86 {
87 entries.remove(&oldest_key);
88 }
89 }
90
91 entries.insert(
92 key,
93 CacheEntry {
94 response: CachedResponse {
95 response,
96 similarity: 1.0,
97 cached_at: chrono::Utc::now().timestamp(),
98 token_count,
99 },
100 embedding,
101 },
102 );
103 }
104
105 fn compute_key(&self, query: &str) -> String {
106 let mut hasher = Sha256::new();
107 hasher.update(query.as_bytes());
108 hex::encode(hasher.finalize())
109 }
110
111 fn compute_embedding(&self, query: &str) -> Vec<f32> {
112 hash_based_embedding(query)
113 }
114
115 pub fn stats(&self) -> CacheStats {
116 let entries = self.entries.read();
117 let now = chrono::Utc::now().timestamp();
118
119 let valid_entries = entries
120 .values()
121 .filter(|e| now - e.response.cached_at < self.ttl_seconds)
122 .count();
123
124 CacheStats {
125 total_entries: entries.len(),
126 valid_entries,
127 cache_size_bytes: entries
128 .values()
129 .map(|e| e.response.response.len() + e.embedding.len() * 4)
130 .sum(),
131 }
132 }
133
134 pub fn clear(&self) {
135 let mut entries = self.entries.write();
136 entries.clear();
137 }
138}
139
140fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
141 if a.len() != b.len() || a.is_empty() {
142 return 0.0;
143 }
144
145 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
146 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
147 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
148
149 if norm_a == 0.0 || norm_b == 0.0 {
150 return 0.0;
151 }
152
153 dot / (norm_a * norm_b)
154}
155
156fn hash_based_embedding(text: &str) -> Vec<f32> {
159 let text_lower = text.to_lowercase();
160 let words: Vec<&str> = text_lower.split_whitespace().collect();
161
162 let mut embedding = vec![0.0f32; 64];
163
164 for (i, word) in words.iter().take(64).enumerate() {
165 let hash = djb2_hash(word);
166 embedding[i % 64] += (hash as f32) / (words.len() as f32).sqrt();
167 }
168
169 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
170 if norm > 0.0 {
171 for x in &mut embedding {
172 *x /= norm;
173 }
174 }
175
176 embedding
177}
178
179fn djb2_hash(s: &str) -> u32 {
181 let mut hash: u32 = 5381;
182 for c in s.bytes() {
183 hash = hash.wrapping_mul(33).wrapping_add(c as u32);
184 }
185 hash
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct CacheStats {
190 pub total_entries: usize,
191 pub valid_entries: usize,
192 pub cache_size_bytes: usize,
193}
194
195impl Default for StringSimilarityCache {
196 fn default() -> Self {
197 Self::new(0.85, 10000, 86400)
198 }
199}