Skip to main content

trustformers_core/cache/
inference_cache.rs

1use dashmap::DashMap;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use super::{
7    cache_key::CacheKey,
8    eviction::{EvictionPolicy, LRUEviction, SizeBasedEviction, TTLEviction},
9    metrics::CacheMetrics,
10};
11
12/// Configuration for the inference cache
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CacheConfig {
15    /// Maximum number of entries (for LRU)
16    pub max_entries: Option<usize>,
17    /// Maximum memory usage in bytes
18    pub max_memory_bytes: Option<usize>,
19    /// Time-to-live for entries
20    pub ttl: Option<Duration>,
21    /// Whether to enable metrics collection
22    pub enable_metrics: bool,
23    /// Whether to compress cached values
24    pub compress_values: bool,
25    /// Minimum value size to compress (bytes)
26    pub compression_threshold: usize,
27}
28
29impl Default for CacheConfig {
30    fn default() -> Self {
31        Self {
32            max_entries: Some(1000),
33            max_memory_bytes: Some(1024 * 1024 * 1024), // 1GB
34            ttl: Some(Duration::from_secs(3600)),       // 1 hour
35            enable_metrics: true,
36            compress_values: true,
37            compression_threshold: 1024, // 1KB
38        }
39    }
40}
41
42/// A cached entry in the inference cache
43#[derive(Debug, Clone)]
44pub struct CacheEntry {
45    /// The cached value (potentially compressed)
46    pub value: Vec<u8>,
47    /// Size of the uncompressed value
48    pub uncompressed_size: usize,
49    /// Whether the value is compressed
50    pub is_compressed: bool,
51    /// When the entry was created
52    pub created_at: Instant,
53    /// When the entry was last accessed
54    pub last_accessed: Instant,
55    /// Number of times accessed
56    pub access_count: u64,
57}
58
59impl CacheEntry {
60    fn new(value: Vec<u8>, is_compressed: bool, uncompressed_size: usize) -> Self {
61        let now = Instant::now();
62        Self {
63            value,
64            uncompressed_size,
65            is_compressed,
66            created_at: now,
67            last_accessed: now,
68            access_count: 0,
69        }
70    }
71
72    fn access(&mut self) {
73        self.last_accessed = Instant::now();
74        self.access_count += 1;
75    }
76
77    fn memory_size(&self) -> usize {
78        self.value.len() + std::mem::size_of::<Self>()
79    }
80}
81
82/// Thread-safe inference cache with multiple eviction policies
83pub struct InferenceCache {
84    /// The main cache storage
85    cache: Arc<DashMap<CacheKey, CacheEntry>>,
86    /// Eviction policy
87    eviction_policy: Arc<parking_lot::Mutex<Box<dyn EvictionPolicy>>>,
88    /// Cache configuration
89    config: CacheConfig,
90    /// Metrics collector
91    metrics: Option<Arc<CacheMetrics>>,
92}
93
94impl InferenceCache {
95    /// Create a new inference cache with the given configuration
96    pub fn new(config: CacheConfig) -> Self {
97        // Create composite eviction policy based on config
98        let eviction_policy = Self::create_eviction_policy(&config);
99
100        let metrics =
101            if config.enable_metrics { Some(Arc::new(CacheMetrics::new())) } else { None };
102
103        Self {
104            cache: Arc::new(DashMap::new()),
105            eviction_policy: Arc::new(parking_lot::Mutex::new(eviction_policy)),
106            config,
107            metrics,
108        }
109    }
110
111    fn create_eviction_policy(config: &CacheConfig) -> Box<dyn EvictionPolicy> {
112        // Use size-based eviction if memory limit is set
113        if let Some(max_bytes) = config.max_memory_bytes {
114            Box::new(SizeBasedEviction::new(max_bytes))
115        }
116        // Otherwise use LRU if entry limit is set
117        else if let Some(max_entries) = config.max_entries {
118            Box::new(LRUEviction::new(max_entries))
119        }
120        // Otherwise use TTL if set
121        else if let Some(ttl) = config.ttl {
122            Box::new(TTLEviction::new(ttl))
123        }
124        // Default to LRU with 1000 entries
125        else {
126            Box::new(LRUEviction::new(1000))
127        }
128    }
129
130    /// Get a value from the cache
131    pub fn get(&self, key: &CacheKey) -> Option<Vec<u8>> {
132        let start = Instant::now();
133
134        if let Some(mut entry) = self.cache.get_mut(key) {
135            entry.access();
136            let value = entry.value.clone();
137            let is_compressed = entry.is_compressed;
138            drop(entry); // Release lock
139
140            // Update eviction policy
141            self.eviction_policy.lock().on_access(key);
142
143            // Decompress if needed
144            let result = if is_compressed { self.decompress(&value).ok() } else { Some(value) };
145
146            // Record metrics
147            if let Some(metrics) = &self.metrics {
148                let elapsed = start.elapsed();
149                if result.is_some() {
150                    metrics.record_hit(elapsed);
151                } else {
152                    metrics.record_miss(elapsed);
153                }
154            }
155
156            result
157        } else {
158            if let Some(metrics) = &self.metrics {
159                metrics.record_miss(start.elapsed());
160            }
161            None
162        }
163    }
164
165    /// Insert a value into the cache
166    pub fn insert(&self, key: CacheKey, value: Vec<u8>) {
167        let start = Instant::now();
168        let uncompressed_size = value.len();
169
170        // Compress if enabled and above threshold
171        let (stored_value, is_compressed) = if self.config.compress_values
172            && uncompressed_size >= self.config.compression_threshold
173        {
174            match self.compress(&value) {
175                Ok(compressed) if compressed.len() < uncompressed_size => (compressed, true),
176                _ => (value, false),
177            }
178        } else {
179            (value, false)
180        };
181
182        let entry = CacheEntry::new(stored_value, is_compressed, uncompressed_size);
183        let memory_size = entry.memory_size();
184
185        // Insert the entry
186        self.cache.insert(key.clone(), entry);
187
188        // Update eviction policy
189        self.eviction_policy.lock().on_insert(&key, memory_size);
190
191        // Check if we need to evict after insertion
192        self.maybe_evict();
193
194        // Record metrics
195        if let Some(metrics) = &self.metrics {
196            metrics.record_insert(memory_size, start.elapsed());
197        }
198    }
199
200    /// Remove a value from the cache
201    pub fn remove(&self, key: &CacheKey) -> Option<Vec<u8>> {
202        if let Some((_, entry)) = self.cache.remove(key) {
203            let memory_size = entry.memory_size();
204
205            // Update eviction policy
206            self.eviction_policy.lock().on_remove(key);
207
208            // Record metrics
209            if let Some(metrics) = &self.metrics {
210                metrics.record_eviction(memory_size);
211            }
212
213            // Decompress if needed
214            if entry.is_compressed {
215                self.decompress(&entry.value).ok()
216            } else {
217                Some(entry.value)
218            }
219        } else {
220            None
221        }
222    }
223
224    /// Clear all entries from the cache
225    pub fn clear(&self) {
226        self.cache.clear();
227
228        if let Some(metrics) = &self.metrics {
229            metrics.reset();
230        }
231    }
232
233    /// Get the number of entries in the cache
234    pub fn len(&self) -> usize {
235        self.cache.len()
236    }
237
238    /// Check if the cache is empty
239    pub fn is_empty(&self) -> bool {
240        self.cache.is_empty()
241    }
242
243    /// Get cache metrics if enabled
244    pub fn metrics(&self) -> Option<Arc<CacheMetrics>> {
245        self.metrics.clone()
246    }
247
248    /// Helper method to handle eviction for a specific key
249    fn handle_eviction(&self, key: &CacheKey) {
250        if let Some((_, entry)) = self.cache.remove(key) {
251            // Note: Don't call policy.on_remove() here since next_eviction
252            // already removed it from the policy's tracking
253            if let Some(metrics) = &self.metrics {
254                metrics.record_eviction(entry.memory_size());
255            }
256        }
257    }
258
259    /// Perform eviction if necessary
260    fn maybe_evict(&self) {
261        let mut policy = self.eviction_policy.lock();
262
263        while policy.should_evict() {
264            if let Some(key) = policy.next_eviction() {
265                self.handle_eviction(&key);
266            } else {
267                break;
268            }
269        }
270    }
271
272    /// Compress data using zstd
273    fn compress(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
274        use std::io::Write;
275        let mut encoder = oxiarc_zstd::ZstdStreamEncoder::new(Vec::new(), 3);
276        encoder.write_all(data)?;
277        encoder.finish()
278    }
279
280    /// Decompress data using zstd
281    fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
282        oxiarc_zstd::decode_all(data).map_err(|e| std::io::Error::other(e.to_string()))
283    }
284}
285
286/// Builder for creating inference caches with custom configuration
287pub struct InferenceCacheBuilder {
288    config: CacheConfig,
289}
290
291impl InferenceCacheBuilder {
292    pub fn new() -> Self {
293        Self {
294            config: CacheConfig::default(),
295        }
296    }
297
298    pub fn max_entries(mut self, max_entries: usize) -> Self {
299        self.config.max_entries = Some(max_entries);
300        // Clear memory limit to ensure LRU eviction is used
301        self.config.max_memory_bytes = None;
302        self
303    }
304
305    pub fn max_memory_mb(mut self, max_memory_mb: usize) -> Self {
306        self.config.max_memory_bytes = Some(max_memory_mb * 1024 * 1024);
307        // Clear entry limit to ensure size-based eviction is used
308        self.config.max_entries = None;
309        self
310    }
311
312    pub fn ttl(mut self, ttl: Duration) -> Self {
313        self.config.ttl = Some(ttl);
314        self
315    }
316
317    pub fn enable_metrics(mut self, enable: bool) -> Self {
318        self.config.enable_metrics = enable;
319        self
320    }
321
322    pub fn enable_compression(mut self, enable: bool) -> Self {
323        self.config.compress_values = enable;
324        self
325    }
326
327    pub fn compression_threshold(mut self, threshold: usize) -> Self {
328        self.config.compression_threshold = threshold;
329        self
330    }
331
332    pub fn build(self) -> InferenceCache {
333        InferenceCache::new(self.config)
334    }
335}
336
337impl Default for InferenceCacheBuilder {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use crate::cache::cache_key::CacheKeyBuilder;
347
348    #[test]
349    fn test_basic_cache_operations() {
350        let cache = InferenceCacheBuilder::new().max_entries(10).enable_metrics(true).build();
351
352        let key = CacheKeyBuilder::new("test-model", "classification")
353            .with_text("Hello world")
354            .build();
355
356        let value = b"prediction result".to_vec();
357
358        // Test insert and get
359        cache.insert(key.clone(), value.clone());
360        let retrieved = cache.get(&key).expect("expected value not found");
361        assert_eq!(retrieved, value);
362
363        // Test metrics
364        let metrics = cache.metrics().expect("operation failed in test");
365        let snapshot = metrics.snapshot();
366        assert_eq!(snapshot.hits, 1);
367        assert_eq!(snapshot.misses, 0);
368        assert_eq!(snapshot.total_entries, 1);
369    }
370
371    #[test]
372    fn test_compression() {
373        let cache = InferenceCacheBuilder::new()
374            .enable_compression(true)
375            .compression_threshold(10)
376            .build();
377
378        let key = CacheKeyBuilder::new("test-model", "generation")
379            .with_text("Test prompt")
380            .build();
381
382        // Create a large value that should be compressed
383        let value = vec![42u8; 1000];
384
385        cache.insert(key.clone(), value.clone());
386        let retrieved = cache.get(&key).expect("expected value not found");
387        assert_eq!(retrieved, value);
388
389        // Check that the stored value is smaller (compressed)
390        let entry = cache.cache.get(&key).expect("expected value not found");
391        assert!(entry.is_compressed);
392        assert!(entry.value.len() < entry.uncompressed_size);
393    }
394
395    #[test]
396    fn test_eviction() {
397        let cache = InferenceCacheBuilder::new().max_entries(3).enable_metrics(true).build();
398
399        let keys: Vec<_> = (0..5)
400            .map(|i| CacheKeyBuilder::new("model", "task").with_text(&format!("text{}", i)).build())
401            .collect();
402
403        // Insert more entries than the cache can hold
404        for (i, key) in keys.iter().enumerate() {
405            cache.insert(key.clone(), vec![i as u8; 100]);
406        }
407
408        // First two entries should have been evicted
409        assert!(cache.get(&keys[0]).is_none());
410        assert!(cache.get(&keys[1]).is_none());
411
412        // Last three should still be present
413        assert!(cache.get(&keys[2]).is_some());
414        assert!(cache.get(&keys[3]).is_some());
415        assert!(cache.get(&keys[4]).is_some());
416
417        // Check eviction metrics
418        let metrics = cache.metrics().expect("operation failed in test");
419        let snapshot = metrics.snapshot();
420        assert_eq!(snapshot.evictions, 2);
421    }
422}