Skip to main content

sqry_nl/cache/
mod.rs

1//! Translation cache for repeated queries.
2//!
3//! Provides LRU caching with context-aware keys and TTL expiration.
4
5mod key;
6
7pub use key::CacheKey;
8
9use crate::types::Intent;
10use lru::LruCache;
11use parking_lot::Mutex;
12use std::num::NonZeroUsize;
13use std::time::{Duration, Instant};
14
15/// Cache configuration.
16#[derive(Debug, Clone)]
17pub struct CacheConfig {
18    /// Time-to-live for cache entries
19    pub ttl: Duration,
20    /// Maximum number of cached entries
21    pub capacity: usize,
22}
23
24impl Default for CacheConfig {
25    fn default() -> Self {
26        Self {
27            ttl: Duration::from_secs(3600), // 1 hour
28            capacity: 1000,
29        }
30    }
31}
32
33/// Cached translation result.
34#[derive(Debug, Clone)]
35pub struct CachedResult {
36    /// The translated command
37    pub command: String,
38    /// The classified intent
39    pub intent: Intent,
40    /// The classifier confidence
41    pub confidence: f32,
42    /// When this entry was created
43    pub created_at: Instant,
44}
45
46/// Cache statistics.
47#[derive(Debug, Clone, Copy)]
48pub struct CacheStats {
49    /// Current number of entries
50    pub size: usize,
51    /// Maximum capacity
52    pub capacity: usize,
53    /// Number of cache hits
54    pub hits: u64,
55    /// Number of cache misses
56    pub misses: u64,
57}
58
59/// Thread-safe LRU translation cache.
60pub struct TranslationCache {
61    cache: Mutex<LruCache<CacheKey, CachedResult>>,
62    config: CacheConfig,
63    hits: Mutex<u64>,
64    misses: Mutex<u64>,
65}
66
67impl TranslationCache {
68    /// Create a new cache with the given capacity.
69    #[must_use]
70    pub fn new(capacity: usize) -> Self {
71        Self::with_config(CacheConfig {
72            capacity,
73            ..Default::default()
74        })
75    }
76
77    /// Create a new cache with full configuration.
78    #[must_use]
79    pub fn with_config(config: CacheConfig) -> Self {
80        let capacity = NonZeroUsize::new(config.capacity).unwrap_or(NonZeroUsize::MIN);
81        Self {
82            cache: Mutex::new(LruCache::new(capacity)),
83            config,
84            hits: Mutex::new(0),
85            misses: Mutex::new(0),
86        }
87    }
88
89    /// Get a cached result if valid.
90    ///
91    /// Returns `None` if not found or expired.
92    #[must_use]
93    pub fn get(&self, key: &CacheKey) -> Option<CachedResult> {
94        let mut cache = self.cache.lock();
95        if let Some(result) = cache.get(key) {
96            // Check TTL
97            if result.created_at.elapsed() < self.config.ttl {
98                *self.hits.lock() += 1;
99                return Some(result.clone());
100            }
101            // Expired - remove it
102            cache.pop(key);
103        }
104        *self.misses.lock() += 1;
105        None
106    }
107
108    /// Store a result in the cache.
109    pub fn put(&self, key: CacheKey, result: CachedResult) {
110        let mut cache = self.cache.lock();
111        cache.put(key, result);
112    }
113
114    /// Clear all cached entries.
115    pub fn clear(&self) {
116        let mut cache = self.cache.lock();
117        cache.clear();
118        *self.hits.lock() = 0;
119        *self.misses.lock() = 0;
120    }
121
122    /// Get cache statistics.
123    #[must_use]
124    pub fn stats(&self) -> CacheStats {
125        let cache = self.cache.lock();
126        CacheStats {
127            size: cache.len(),
128            capacity: cache.cap().get(),
129            hits: *self.hits.lock(),
130            misses: *self.misses.lock(),
131        }
132    }
133
134    /// Get the hit rate (0.0-1.0).
135    #[must_use]
136    pub fn hit_rate(&self) -> f64 {
137        let hits = *self.hits.lock();
138        let misses = *self.misses.lock();
139        let total = hits + misses;
140        if total == 0 {
141            0.0
142        } else {
143            let hits_f64 = f64::from(u32::try_from(hits).unwrap_or(u32::MAX));
144            let total_f64 = f64::from(u32::try_from(total).unwrap_or(u32::MAX));
145            hits_f64 / total_f64
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_cache_put_get() {
156        let cache = TranslationCache::new(10);
157        let key = CacheKey::new("find foo", &[], None, 100);
158        let result = CachedResult {
159            command: "sqry query \"foo\"".to_string(),
160            intent: Intent::SymbolQuery,
161            confidence: 0.95,
162            created_at: Instant::now(),
163        };
164
165        cache.put(key.clone(), result.clone());
166
167        let cached = cache.get(&key);
168        assert!(cached.is_some());
169        assert_eq!(cached.unwrap().command, result.command);
170    }
171
172    #[test]
173    fn test_cache_miss() {
174        let cache = TranslationCache::new(10);
175        let key = CacheKey::new("find foo", &[], None, 100);
176
177        let cached = cache.get(&key);
178        assert!(cached.is_none());
179    }
180
181    #[test]
182    fn test_cache_expiration() {
183        let config = CacheConfig {
184            ttl: Duration::from_millis(1),
185            capacity: 10,
186        };
187        let cache = TranslationCache::with_config(config);
188        let key = CacheKey::new("find foo", &[], None, 100);
189        let result = CachedResult {
190            command: "sqry query \"foo\"".to_string(),
191            intent: Intent::SymbolQuery,
192            confidence: 0.95,
193            created_at: Instant::now(),
194        };
195
196        cache.put(key.clone(), result);
197
198        // Wait for expiration
199        std::thread::sleep(Duration::from_millis(10));
200
201        let cached = cache.get(&key);
202        assert!(cached.is_none());
203    }
204
205    #[test]
206    fn test_cache_stats() {
207        let cache = TranslationCache::new(10);
208        let key = CacheKey::new("find foo", &[], None, 100);
209        let result = CachedResult {
210            command: "sqry query \"foo\"".to_string(),
211            intent: Intent::SymbolQuery,
212            confidence: 0.95,
213            created_at: Instant::now(),
214        };
215
216        // Miss
217        let _ = cache.get(&key);
218
219        // Put
220        cache.put(key.clone(), result);
221
222        // Hit
223        let _ = cache.get(&key);
224
225        let stats = cache.stats();
226        assert_eq!(stats.size, 1);
227        assert_eq!(stats.hits, 1);
228        assert_eq!(stats.misses, 1);
229    }
230
231    #[test]
232    fn test_lru_eviction() {
233        let cache = TranslationCache::new(2);
234
235        // Add 3 entries to a cache of size 2
236        for i in 0..3 {
237            let key = CacheKey::new(&format!("query {i}"), &[], None, 100);
238            let result = CachedResult {
239                command: format!("sqry query \"{i}\""),
240                intent: Intent::SymbolQuery,
241                confidence: 0.95,
242                created_at: Instant::now(),
243            };
244            cache.put(key, result);
245        }
246
247        let stats = cache.stats();
248        assert_eq!(stats.size, 2);
249
250        // First entry should be evicted
251        let key0 = CacheKey::new("query 0", &[], None, 100);
252        assert!(cache.get(&key0).is_none());
253
254        // Last two should still be there
255        let key1 = CacheKey::new("query 1", &[], None, 100);
256        let key2 = CacheKey::new("query 2", &[], None, 100);
257        assert!(cache.get(&key1).is_some());
258        assert!(cache.get(&key2).is_some());
259    }
260}