1use dashmap::DashMap;
36use serde::{Deserialize, Serialize};
37use std::hash::Hash;
38use std::sync::atomic::{AtomicU64, Ordering};
39use std::sync::Arc;
40use std::time::{Duration, Instant};
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CacheConfig {
45 pub max_entries: usize,
47 pub ttl_seconds: u64,
49 pub enable_stats: bool,
51}
52
53impl Default for CacheConfig {
54 fn default() -> Self {
55 Self {
56 max_entries: 10000,
57 ttl_seconds: 3600, enable_stats: true,
59 }
60 }
61}
62
63struct CacheEntry<V> {
65 value: V,
66 created_at: Instant,
67 last_accessed: Instant,
68 access_count: u64,
69}
70
71impl<V: Clone> CacheEntry<V> {
72 fn new(value: V) -> Self {
73 let now = Instant::now();
74 Self {
75 value,
76 created_at: now,
77 last_accessed: now,
78 access_count: 1,
79 }
80 }
81
82 fn is_expired(&self, ttl: Duration) -> bool {
83 if ttl.is_zero() {
84 return false;
85 }
86 self.created_at.elapsed() > ttl
87 }
88}
89
90#[derive(Debug, Clone, Default, Serialize, Deserialize)]
92pub struct CacheStats {
93 pub hits: u64,
95 pub misses: u64,
97 pub entries: usize,
99 pub evictions: u64,
101 pub hit_rate: f64,
103}
104
105pub struct Cache<K, V> {
107 data: DashMap<K, CacheEntry<V>>,
108 config: CacheConfig,
109 ttl: Duration,
110 hits: AtomicU64,
111 misses: AtomicU64,
112 evictions: AtomicU64,
113}
114
115impl<K, V> Cache<K, V>
116where
117 K: Eq + Hash + Clone,
118 V: Clone,
119{
120 pub fn new(config: CacheConfig) -> Self {
122 let ttl = Duration::from_secs(config.ttl_seconds);
123 Self {
124 data: DashMap::with_capacity(config.max_entries),
125 config,
126 ttl,
127 hits: AtomicU64::new(0),
128 misses: AtomicU64::new(0),
129 evictions: AtomicU64::new(0),
130 }
131 }
132
133 pub fn get(&self, key: &K) -> Option<V> {
135 if let Some(mut entry) = self.data.get_mut(key) {
136 if entry.is_expired(self.ttl) {
138 drop(entry);
139 self.data.remove(key);
140 self.misses.fetch_add(1, Ordering::Relaxed);
141 return None;
142 }
143
144 entry.last_accessed = Instant::now();
146 entry.access_count += 1;
147 self.hits.fetch_add(1, Ordering::Relaxed);
148 Some(entry.value.clone())
149 } else {
150 self.misses.fetch_add(1, Ordering::Relaxed);
151 None
152 }
153 }
154
155 pub fn insert(&self, key: K, value: V) {
157 if self.data.len() >= self.config.max_entries {
159 self.evict_lru();
160 }
161
162 self.data.insert(key, CacheEntry::new(value));
163 }
164
165 pub fn get_or_insert_with<F>(&self, key: K, f: F) -> V
167 where
168 F: FnOnce() -> V,
169 {
170 if let Some(value) = self.get(&key) {
171 return value;
172 }
173
174 let value = f();
175 self.insert(key, value.clone());
176 value
177 }
178
179 pub fn remove(&self, key: &K) -> Option<V> {
181 self.data.remove(key).map(|(_, entry)| entry.value)
182 }
183
184 pub fn clear(&self) {
186 self.data.clear();
187 }
188
189 pub fn stats(&self) -> CacheStats {
191 let hits = self.hits.load(Ordering::Relaxed);
192 let misses = self.misses.load(Ordering::Relaxed);
193 let total = hits + misses;
194
195 CacheStats {
196 hits,
197 misses,
198 entries: self.data.len(),
199 evictions: self.evictions.load(Ordering::Relaxed),
200 hit_rate: if total > 0 {
201 hits as f64 / total as f64
202 } else {
203 0.0
204 },
205 }
206 }
207
208 pub fn contains(&self, key: &K) -> bool {
210 if let Some(entry) = self.data.get(key) {
211 !entry.is_expired(self.ttl)
212 } else {
213 false
214 }
215 }
216
217 pub fn len(&self) -> usize {
219 self.data.len()
220 }
221
222 pub fn is_empty(&self) -> bool {
224 self.data.is_empty()
225 }
226
227 fn evict_lru(&self) {
229 let mut lru_key: Option<K> = None;
231 let mut oldest_access = Instant::now();
232
233 for entry in self.data.iter() {
234 if entry.last_accessed < oldest_access {
235 oldest_access = entry.last_accessed;
236 lru_key = Some(entry.key().clone());
237 }
238 }
239
240 if let Some(key) = lru_key {
242 self.data.remove(&key);
243 self.evictions.fetch_add(1, Ordering::Relaxed);
244 }
245 }
246
247 pub fn cleanup_expired(&self) {
249 if self.ttl.is_zero() {
250 return;
251 }
252
253 let keys_to_remove: Vec<K> = self
254 .data
255 .iter()
256 .filter(|entry| entry.is_expired(self.ttl))
257 .map(|entry| entry.key().clone())
258 .collect();
259
260 for key in keys_to_remove {
261 self.data.remove(&key);
262 self.evictions.fetch_add(1, Ordering::Relaxed);
263 }
264 }
265}
266
267pub type EmbeddingCache = Cache<String, Vec<f32>>;
269
270pub type InferenceCache = Cache<String, String>;
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct QueryResult {
276 pub entities: Vec<String>,
277 pub confidence: f64,
278}
279
280pub type QueryCache = Cache<String, QueryResult>;
281
282pub fn create_embedding_cache(max_size: usize) -> Arc<EmbeddingCache> {
284 Arc::new(EmbeddingCache::new(CacheConfig {
285 max_entries: max_size,
286 ttl_seconds: 86400, enable_stats: true,
288 }))
289}
290
291pub fn create_inference_cache(max_size: usize) -> Arc<InferenceCache> {
293 Arc::new(InferenceCache::new(CacheConfig {
294 max_entries: max_size,
295 ttl_seconds: 300, enable_stats: true,
297 }))
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_cache_basic_operations() {
306 let cache = Cache::<String, i32>::new(CacheConfig::default());
307
308 cache.insert("key1".to_string(), 42);
310 assert_eq!(cache.get(&"key1".to_string()), Some(42));
311 assert_eq!(cache.get(&"key2".to_string()), None);
312
313 let stats = cache.stats();
315 assert_eq!(stats.hits, 1);
316 assert_eq!(stats.misses, 1);
317 assert!(stats.hit_rate > 0.4 && stats.hit_rate < 0.6);
318 }
319
320 #[test]
321 fn test_cache_eviction() {
322 let cache = Cache::<i32, i32>::new(CacheConfig {
323 max_entries: 3,
324 ttl_seconds: 0,
325 enable_stats: true,
326 });
327
328 cache.insert(1, 100);
330 cache.insert(2, 200);
331 cache.insert(3, 300);
332
333 cache.get(&1);
335
336 cache.insert(4, 400);
338
339 assert_eq!(cache.len(), 3);
340 assert!(cache.contains(&1)); assert!(cache.contains(&4)); }
343
344 #[test]
345 fn test_get_or_insert_with() {
346 let cache = Cache::<String, i32>::new(CacheConfig::default());
347 let mut computed = false;
348
349 let v1 = cache.get_or_insert_with("key".to_string(), || {
351 computed = true;
352 42
353 });
354 assert_eq!(v1, 42);
355 assert!(computed);
356
357 computed = false;
359 let v2 = cache.get_or_insert_with("key".to_string(), || {
360 computed = true;
361 99
362 });
363 assert_eq!(v2, 42);
364 assert!(!computed);
365 }
366
367 #[test]
368 fn test_cache_ttl() {
369 let cache = Cache::<String, i32>::new(CacheConfig {
370 max_entries: 100,
371 ttl_seconds: 0, enable_stats: true,
373 });
374
375 cache.insert("key".to_string(), 42);
376 assert_eq!(cache.get(&"key".to_string()), Some(42));
378 }
379}