Skip to main content

stateset_nsr/cache/
local.rs

1//! # Caching Layer for Performance Optimization
2//!
3//! Provides thread-safe caching for expensive computations:
4//! - Embedding computations
5//! - Inference results
6//! - Knowledge base queries
7//!
8//! ## Features
9//!
10//! - **LRU Eviction**: Least-recently-used eviction policy
11//! - **TTL Support**: Time-to-live for cache entries
12//! - **Thread-Safe**: Lock-free reads with DashMap
13//! - **Metrics**: Cache hit/miss statistics
14//!
15//! ## Example
16//!
17//! ```rust
18//! use stateset_nsr::cache::{Cache, CacheConfig};
19//!
20//! let cache = Cache::<String, Vec<f32>>::new(CacheConfig {
21//!     max_entries: 10000,
22//!     ttl_seconds: 3600,
23//!     ..Default::default()
24//! });
25//!
26//! // Store embedding
27//! cache.insert("hello".to_string(), vec![0.1, 0.2, 0.3]);
28//!
29//! // Retrieve (cache hit)
30//! if let Some(embedding) = cache.get(&"hello".to_string()) {
31//!     println!("Got cached embedding: {:?}", embedding);
32//! }
33//! ```
34
35use dashmap::DashMap;
36use serde::{Deserialize, Serialize};
37use std::hash::Hash;
38use std::sync::atomic::{AtomicU64, Ordering};
39use std::sync::Arc;
40use std::time::{Duration, Instant};
41
42/// Cache configuration
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CacheConfig {
45    /// Maximum number of entries in the cache
46    pub max_entries: usize,
47    /// Time-to-live in seconds (0 = no expiration)
48    pub ttl_seconds: u64,
49    /// Whether to enable cache statistics
50    pub enable_stats: bool,
51}
52
53impl Default for CacheConfig {
54    fn default() -> Self {
55        Self {
56            max_entries: 10000,
57            ttl_seconds: 3600, // 1 hour
58            enable_stats: true,
59        }
60    }
61}
62
63/// Cache entry with metadata
64struct CacheEntry<V> {
65    value: V,
66    created_at: Instant,
67    last_accessed: Instant,
68    access_count: u64,
69}
70
71impl<V: Clone> CacheEntry<V> {
72    fn new(value: V) -> Self {
73        let now = Instant::now();
74        Self {
75            value,
76            created_at: now,
77            last_accessed: now,
78            access_count: 1,
79        }
80    }
81
82    fn is_expired(&self, ttl: Duration) -> bool {
83        if ttl.is_zero() {
84            return false;
85        }
86        self.created_at.elapsed() > ttl
87    }
88}
89
90/// Cache statistics
91#[derive(Debug, Clone, Default, Serialize, Deserialize)]
92pub struct CacheStats {
93    /// Number of cache hits
94    pub hits: u64,
95    /// Number of cache misses
96    pub misses: u64,
97    /// Number of entries currently in cache
98    pub entries: usize,
99    /// Number of evictions
100    pub evictions: u64,
101    /// Cache hit rate (0.0-1.0)
102    pub hit_rate: f64,
103}
104
105/// Thread-safe LRU cache with TTL support
106pub struct Cache<K, V> {
107    data: DashMap<K, CacheEntry<V>>,
108    config: CacheConfig,
109    ttl: Duration,
110    hits: AtomicU64,
111    misses: AtomicU64,
112    evictions: AtomicU64,
113}
114
115impl<K, V> Cache<K, V>
116where
117    K: Eq + Hash + Clone,
118    V: Clone,
119{
120    /// Create a new cache with the given configuration
121    pub fn new(config: CacheConfig) -> Self {
122        let ttl = Duration::from_secs(config.ttl_seconds);
123        Self {
124            data: DashMap::with_capacity(config.max_entries),
125            config,
126            ttl,
127            hits: AtomicU64::new(0),
128            misses: AtomicU64::new(0),
129            evictions: AtomicU64::new(0),
130        }
131    }
132
133    /// Get a value from the cache
134    pub fn get(&self, key: &K) -> Option<V> {
135        if let Some(mut entry) = self.data.get_mut(key) {
136            // Check TTL
137            if entry.is_expired(self.ttl) {
138                drop(entry);
139                self.data.remove(key);
140                self.misses.fetch_add(1, Ordering::Relaxed);
141                return None;
142            }
143
144            // Update access metadata
145            entry.last_accessed = Instant::now();
146            entry.access_count += 1;
147            self.hits.fetch_add(1, Ordering::Relaxed);
148            Some(entry.value.clone())
149        } else {
150            self.misses.fetch_add(1, Ordering::Relaxed);
151            None
152        }
153    }
154
155    /// Insert a value into the cache
156    pub fn insert(&self, key: K, value: V) {
157        // Check if we need to evict
158        if self.data.len() >= self.config.max_entries {
159            self.evict_lru();
160        }
161
162        self.data.insert(key, CacheEntry::new(value));
163    }
164
165    /// Get or insert a value using a closure
166    pub fn get_or_insert_with<F>(&self, key: K, f: F) -> V
167    where
168        F: FnOnce() -> V,
169    {
170        if let Some(value) = self.get(&key) {
171            return value;
172        }
173
174        let value = f();
175        self.insert(key, value.clone());
176        value
177    }
178
179    /// Remove a value from the cache
180    pub fn remove(&self, key: &K) -> Option<V> {
181        self.data.remove(key).map(|(_, entry)| entry.value)
182    }
183
184    /// Clear all entries from the cache
185    pub fn clear(&self) {
186        self.data.clear();
187    }
188
189    /// Get cache statistics
190    pub fn stats(&self) -> CacheStats {
191        let hits = self.hits.load(Ordering::Relaxed);
192        let misses = self.misses.load(Ordering::Relaxed);
193        let total = hits + misses;
194
195        CacheStats {
196            hits,
197            misses,
198            entries: self.data.len(),
199            evictions: self.evictions.load(Ordering::Relaxed),
200            hit_rate: if total > 0 {
201                hits as f64 / total as f64
202            } else {
203                0.0
204            },
205        }
206    }
207
208    /// Check if cache contains a key
209    pub fn contains(&self, key: &K) -> bool {
210        if let Some(entry) = self.data.get(key) {
211            !entry.is_expired(self.ttl)
212        } else {
213            false
214        }
215    }
216
217    /// Get the number of entries in the cache
218    pub fn len(&self) -> usize {
219        self.data.len()
220    }
221
222    /// Check if cache is empty
223    pub fn is_empty(&self) -> bool {
224        self.data.is_empty()
225    }
226
227    /// Evict the least recently used entry
228    fn evict_lru(&self) {
229        // Find the LRU entry
230        let mut lru_key: Option<K> = None;
231        let mut oldest_access = Instant::now();
232
233        for entry in self.data.iter() {
234            if entry.last_accessed < oldest_access {
235                oldest_access = entry.last_accessed;
236                lru_key = Some(entry.key().clone());
237            }
238        }
239
240        // Remove LRU entry
241        if let Some(key) = lru_key {
242            self.data.remove(&key);
243            self.evictions.fetch_add(1, Ordering::Relaxed);
244        }
245    }
246
247    /// Remove all expired entries
248    pub fn cleanup_expired(&self) {
249        if self.ttl.is_zero() {
250            return;
251        }
252
253        let keys_to_remove: Vec<K> = self
254            .data
255            .iter()
256            .filter(|entry| entry.is_expired(self.ttl))
257            .map(|entry| entry.key().clone())
258            .collect();
259
260        for key in keys_to_remove {
261            self.data.remove(&key);
262            self.evictions.fetch_add(1, Ordering::Relaxed);
263        }
264    }
265}
266
267/// Embedding cache specialized for text-to-embedding lookups
268pub type EmbeddingCache = Cache<String, Vec<f32>>;
269
270/// Inference result cache
271pub type InferenceCache = Cache<String, String>;
272
273/// Query result cache for knowledge base queries
274#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct QueryResult {
276    pub entities: Vec<String>,
277    pub confidence: f64,
278}
279
280pub type QueryCache = Cache<String, QueryResult>;
281
282/// Create a shared embedding cache
283pub fn create_embedding_cache(max_size: usize) -> Arc<EmbeddingCache> {
284    Arc::new(EmbeddingCache::new(CacheConfig {
285        max_entries: max_size,
286        ttl_seconds: 86400, // 24 hours for embeddings
287        enable_stats: true,
288    }))
289}
290
291/// Create a shared inference cache
292pub fn create_inference_cache(max_size: usize) -> Arc<InferenceCache> {
293    Arc::new(InferenceCache::new(CacheConfig {
294        max_entries: max_size,
295        ttl_seconds: 300, // 5 minutes for inference results
296        enable_stats: true,
297    }))
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_cache_basic_operations() {
306        let cache = Cache::<String, i32>::new(CacheConfig::default());
307
308        // Insert and get
309        cache.insert("key1".to_string(), 42);
310        assert_eq!(cache.get(&"key1".to_string()), Some(42));
311        assert_eq!(cache.get(&"key2".to_string()), None);
312
313        // Stats
314        let stats = cache.stats();
315        assert_eq!(stats.hits, 1);
316        assert_eq!(stats.misses, 1);
317        assert!(stats.hit_rate > 0.4 && stats.hit_rate < 0.6);
318    }
319
320    #[test]
321    fn test_cache_eviction() {
322        let cache = Cache::<i32, i32>::new(CacheConfig {
323            max_entries: 3,
324            ttl_seconds: 0,
325            enable_stats: true,
326        });
327
328        // Fill cache
329        cache.insert(1, 100);
330        cache.insert(2, 200);
331        cache.insert(3, 300);
332
333        // Access key 1 to make it recently used
334        cache.get(&1);
335
336        // Insert new entry, should evict LRU (key 2 or 3)
337        cache.insert(4, 400);
338
339        assert_eq!(cache.len(), 3);
340        assert!(cache.contains(&1)); // Recently accessed
341        assert!(cache.contains(&4)); // Just inserted
342    }
343
344    #[test]
345    fn test_get_or_insert_with() {
346        let cache = Cache::<String, i32>::new(CacheConfig::default());
347        let mut computed = false;
348
349        // First call computes
350        let v1 = cache.get_or_insert_with("key".to_string(), || {
351            computed = true;
352            42
353        });
354        assert_eq!(v1, 42);
355        assert!(computed);
356
357        // Second call uses cache
358        computed = false;
359        let v2 = cache.get_or_insert_with("key".to_string(), || {
360            computed = true;
361            99
362        });
363        assert_eq!(v2, 42);
364        assert!(!computed);
365    }
366
367    #[test]
368    fn test_cache_ttl() {
369        let cache = Cache::<String, i32>::new(CacheConfig {
370            max_entries: 100,
371            ttl_seconds: 0, // Immediate expiration for testing would need time manipulation
372            enable_stats: true,
373        });
374
375        cache.insert("key".to_string(), 42);
376        // With ttl_seconds = 0, entries don't expire
377        assert_eq!(cache.get(&"key".to_string()), Some(42));
378    }
379}