1use std::sync::{Arc, RwLock};
6use std::collections::HashMap;
7use std::time::{SystemTime, UNIX_EPOCH};
8use serde::{Deserialize, Serialize};
9use crate::error::Result;
10use crate::config::CacheConfig;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CachedResult {
15 pub latex: String,
17
18 pub alternatives: HashMap<String, String>,
20
21 pub confidence: f32,
23
24 pub timestamp: u64,
26
27 pub access_count: usize,
29
30 pub image_hash: String,
32}
33
34#[derive(Debug, Clone)]
36struct CacheEntry {
37 embedding: Vec<f32>,
39
40 result: CachedResult,
42
43 last_access: u64,
45}
46
47pub struct CacheManager {
49 config: CacheConfig,
51
52 entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
54
55 lru_order: Arc<RwLock<Vec<String>>>,
57
58 stats: Arc<RwLock<CacheStats>>,
60}
61
62#[derive(Debug, Clone, Default, Serialize, Deserialize)]
64pub struct CacheStats {
65 pub hits: u64,
67
68 pub misses: u64,
70
71 pub entries: usize,
73
74 pub evictions: u64,
76
77 pub avg_similarity: f32,
79}
80
81impl CacheStats {
82 pub fn hit_rate(&self) -> f32 {
84 if self.hits + self.misses == 0 {
85 return 0.0;
86 }
87 self.hits as f32 / (self.hits + self.misses) as f32
88 }
89}
90
91impl CacheManager {
92 pub fn new(config: CacheConfig) -> Self {
116 Self {
117 config,
118 entries: Arc::new(RwLock::new(HashMap::new())),
119 lru_order: Arc::new(RwLock::new(Vec::new())),
120 stats: Arc::new(RwLock::new(CacheStats::default())),
121 }
122 }
123
124 fn generate_embedding(&self, image_data: &[u8]) -> Result<Vec<f32>> {
128 let hash = self.hash_image(image_data);
131 let mut embedding = vec![0.0; self.config.vector_dimension];
132
133 for (i, byte) in hash.as_bytes().iter().enumerate() {
134 if i < embedding.len() {
135 embedding[i] = *byte as f32 / 255.0;
136 }
137 }
138
139 Ok(embedding)
140 }
141
142 fn hash_image(&self, image_data: &[u8]) -> String {
144 use std::collections::hash_map::DefaultHasher;
145 use std::hash::{Hash, Hasher};
146
147 let mut hasher = DefaultHasher::new();
148 image_data.hash(&mut hasher);
149 format!("{:x}", hasher.finish())
150 }
151
152 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
154 if a.len() != b.len() {
155 return 0.0;
156 }
157
158 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
159 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
160 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
161
162 if norm_a == 0.0 || norm_b == 0.0 {
163 return 0.0;
164 }
165
166 dot_product / (norm_a * norm_b)
167 }
168
169 pub fn lookup(&self, image_data: &[u8]) -> Result<Option<CachedResult>> {
179 if !self.config.enabled {
180 return Ok(None);
181 }
182
183 let embedding = self.generate_embedding(image_data)?;
184 let hash = self.hash_image(image_data);
185
186 let entries = self.entries.read().unwrap();
187
188 if let Some(entry) = entries.get(&hash) {
190 if !self.is_expired(&entry) {
191 self.record_hit();
192 self.update_lru(&hash);
193 return Ok(Some(entry.result.clone()));
194 }
195 }
196
197 let mut best_match: Option<(String, f32, CachedResult)> = None;
199
200 for (key, entry) in entries.iter() {
201 if self.is_expired(entry) {
202 continue;
203 }
204
205 let similarity = self.cosine_similarity(&embedding, &entry.embedding);
206
207 if similarity >= self.config.similarity_threshold {
208 if best_match.is_none() || similarity > best_match.as_ref().unwrap().1 {
209 best_match = Some((key.clone(), similarity, entry.result.clone()));
210 }
211 }
212 }
213
214 if let Some((key, similarity, result)) = best_match {
215 self.record_hit_with_similarity(similarity);
216 self.update_lru(&key);
217 Ok(Some(result))
218 } else {
219 self.record_miss();
220 Ok(None)
221 }
222 }
223
224 pub fn store(&self, image_data: &[u8], result: CachedResult) -> Result<()> {
231 if !self.config.enabled {
232 return Ok(());
233 }
234
235 let embedding = self.generate_embedding(image_data)?;
236 let hash = self.hash_image(image_data);
237
238 let entry = CacheEntry {
239 embedding,
240 result,
241 last_access: self.current_timestamp(),
242 };
243
244 let mut entries = self.entries.write().unwrap();
245
246 if entries.len() >= self.config.capacity && !entries.contains_key(&hash) {
248 self.evict_lru(&mut entries);
249 }
250
251 entries.insert(hash.clone(), entry);
252 self.update_lru(&hash);
253 self.update_stats_entries(entries.len());
254
255 Ok(())
256 }
257
258 fn is_expired(&self, entry: &CacheEntry) -> bool {
260 let current = self.current_timestamp();
261 current - entry.last_access > self.config.ttl
262 }
263
264 fn current_timestamp(&self) -> u64 {
266 SystemTime::now()
267 .duration_since(UNIX_EPOCH)
268 .unwrap()
269 .as_secs()
270 }
271
272 fn evict_lru(&self, entries: &mut HashMap<String, CacheEntry>) {
274 let mut lru = self.lru_order.write().unwrap();
275
276 if let Some(key) = lru.first() {
277 entries.remove(key);
278 lru.remove(0);
279 self.record_eviction();
280 }
281 }
282
283 fn update_lru(&self, key: &str) {
285 let mut lru = self.lru_order.write().unwrap();
286 lru.retain(|k| k != key);
287 lru.push(key.to_string());
288 }
289
290 fn record_hit(&self) {
292 let mut stats = self.stats.write().unwrap();
293 stats.hits += 1;
294 }
295
296 fn record_hit_with_similarity(&self, similarity: f32) {
298 let mut stats = self.stats.write().unwrap();
299 stats.hits += 1;
300
301 let total = stats.hits as f32;
303 stats.avg_similarity = (stats.avg_similarity * (total - 1.0) + similarity) / total;
304 }
305
306 fn record_miss(&self) {
308 let mut stats = self.stats.write().unwrap();
309 stats.misses += 1;
310 }
311
312 fn record_eviction(&self) {
314 let mut stats = self.stats.write().unwrap();
315 stats.evictions += 1;
316 }
317
318 fn update_stats_entries(&self, count: usize) {
320 let mut stats = self.stats.write().unwrap();
321 stats.entries = count;
322 }
323
324 pub fn stats(&self) -> CacheStats {
326 self.stats.read().unwrap().clone()
327 }
328
329 pub fn clear(&self) {
331 let mut entries = self.entries.write().unwrap();
332 let mut lru = self.lru_order.write().unwrap();
333
334 entries.clear();
335 lru.clear();
336
337 self.update_stats_entries(0);
338 }
339
340 pub fn cleanup(&self) {
342 let mut entries = self.entries.write().unwrap();
343 let mut lru = self.lru_order.write().unwrap();
344
345 let expired: Vec<String> = entries
346 .iter()
347 .filter(|(_, entry)| self.is_expired(entry))
348 .map(|(key, _)| key.clone())
349 .collect();
350
351 for key in &expired {
352 entries.remove(key);
353 lru.retain(|k| k != key);
354 }
355
356 self.update_stats_entries(entries.len());
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 fn test_config() -> CacheConfig {
365 CacheConfig {
366 enabled: true,
367 capacity: 100,
368 similarity_threshold: 0.95,
369 ttl: 3600,
370 vector_dimension: 128,
371 persistent: false,
372 cache_dir: ".cache/test".to_string(),
373 }
374 }
375
376 fn test_result() -> CachedResult {
377 CachedResult {
378 latex: r"\frac{x^2}{2}".to_string(),
379 alternatives: HashMap::new(),
380 confidence: 0.95,
381 timestamp: 0,
382 access_count: 0,
383 image_hash: "test".to_string(),
384 }
385 }
386
387 #[test]
388 fn test_cache_creation() {
389 let config = test_config();
390 let cache = CacheManager::new(config);
391 assert_eq!(cache.stats().hits, 0);
392 assert_eq!(cache.stats().misses, 0);
393 }
394
395 #[test]
396 fn test_store_and_lookup() {
397 let config = test_config();
398 let cache = CacheManager::new(config);
399
400 let image_data = b"test image data";
401 let result = test_result();
402
403 cache.store(image_data, result.clone()).unwrap();
404
405 let lookup_result = cache.lookup(image_data).unwrap();
406 assert!(lookup_result.is_some());
407 assert_eq!(lookup_result.unwrap().latex, result.latex);
408 }
409
410 #[test]
411 fn test_cache_miss() {
412 let config = test_config();
413 let cache = CacheManager::new(config);
414
415 let image_data = b"nonexistent image";
416 let lookup_result = cache.lookup(image_data).unwrap();
417
418 assert!(lookup_result.is_none());
419 assert_eq!(cache.stats().misses, 1);
420 }
421
422 #[test]
423 fn test_cache_hit_rate() {
424 let config = test_config();
425 let cache = CacheManager::new(config);
426
427 let image_data = b"test image";
428 let result = test_result();
429
430 cache.store(image_data, result).unwrap();
432 cache.lookup(image_data).unwrap();
433
434 cache.lookup(image_data).unwrap();
436
437 cache.lookup(b"different image").unwrap();
439
440 let stats = cache.stats();
441 assert_eq!(stats.hits, 2);
442 assert_eq!(stats.misses, 1);
443 assert!((stats.hit_rate() - 0.666).abs() < 0.01);
444 }
445
446 #[test]
447 fn test_cosine_similarity() {
448 let config = test_config();
449 let cache = CacheManager::new(config);
450
451 let vec_a = vec![1.0, 0.0, 0.0];
452 let vec_b = vec![1.0, 0.0, 0.0];
453 let vec_c = vec![0.0, 1.0, 0.0];
454
455 assert!((cache.cosine_similarity(&vec_a, &vec_b) - 1.0).abs() < 0.01);
456 assert!((cache.cosine_similarity(&vec_a, &vec_c) - 0.0).abs() < 0.01);
457 }
458
459 #[test]
460 fn test_cache_clear() {
461 let config = test_config();
462 let cache = CacheManager::new(config);
463
464 let image_data = b"test image";
465 let result = test_result();
466
467 cache.store(image_data, result).unwrap();
468 assert_eq!(cache.stats().entries, 1);
469
470 cache.clear();
471 assert_eq!(cache.stats().entries, 0);
472 }
473
474 #[test]
475 fn test_disabled_cache() {
476 let mut config = test_config();
477 config.enabled = false;
478 let cache = CacheManager::new(config);
479
480 let image_data = b"test image";
481 let result = test_result();
482
483 cache.store(image_data, result).unwrap();
484 let lookup_result = cache.lookup(image_data).unwrap();
485
486 assert!(lookup_result.is_none());
487 }
488}