1use dashmap::DashMap;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use super::{
7 cache_key::CacheKey,
8 eviction::{EvictionPolicy, LRUEviction, SizeBasedEviction, TTLEviction},
9 metrics::CacheMetrics,
10};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CacheConfig {
15 pub max_entries: Option<usize>,
17 pub max_memory_bytes: Option<usize>,
19 pub ttl: Option<Duration>,
21 pub enable_metrics: bool,
23 pub compress_values: bool,
25 pub compression_threshold: usize,
27}
28
29impl Default for CacheConfig {
30 fn default() -> Self {
31 Self {
32 max_entries: Some(1000),
33 max_memory_bytes: Some(1024 * 1024 * 1024), ttl: Some(Duration::from_secs(3600)), enable_metrics: true,
36 compress_values: true,
37 compression_threshold: 1024, }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct CacheEntry {
45 pub value: Vec<u8>,
47 pub uncompressed_size: usize,
49 pub is_compressed: bool,
51 pub created_at: Instant,
53 pub last_accessed: Instant,
55 pub access_count: u64,
57}
58
59impl CacheEntry {
60 fn new(value: Vec<u8>, is_compressed: bool, uncompressed_size: usize) -> Self {
61 let now = Instant::now();
62 Self {
63 value,
64 uncompressed_size,
65 is_compressed,
66 created_at: now,
67 last_accessed: now,
68 access_count: 0,
69 }
70 }
71
72 fn access(&mut self) {
73 self.last_accessed = Instant::now();
74 self.access_count += 1;
75 }
76
77 fn memory_size(&self) -> usize {
78 self.value.len() + std::mem::size_of::<Self>()
79 }
80}
81
82pub struct InferenceCache {
84 cache: Arc<DashMap<CacheKey, CacheEntry>>,
86 eviction_policy: Arc<parking_lot::Mutex<Box<dyn EvictionPolicy>>>,
88 config: CacheConfig,
90 metrics: Option<Arc<CacheMetrics>>,
92}
93
94impl InferenceCache {
95 pub fn new(config: CacheConfig) -> Self {
97 let eviction_policy = Self::create_eviction_policy(&config);
99
100 let metrics =
101 if config.enable_metrics { Some(Arc::new(CacheMetrics::new())) } else { None };
102
103 Self {
104 cache: Arc::new(DashMap::new()),
105 eviction_policy: Arc::new(parking_lot::Mutex::new(eviction_policy)),
106 config,
107 metrics,
108 }
109 }
110
111 fn create_eviction_policy(config: &CacheConfig) -> Box<dyn EvictionPolicy> {
112 if let Some(max_bytes) = config.max_memory_bytes {
114 Box::new(SizeBasedEviction::new(max_bytes))
115 }
116 else if let Some(max_entries) = config.max_entries {
118 Box::new(LRUEviction::new(max_entries))
119 }
120 else if let Some(ttl) = config.ttl {
122 Box::new(TTLEviction::new(ttl))
123 }
124 else {
126 Box::new(LRUEviction::new(1000))
127 }
128 }
129
130 pub fn get(&self, key: &CacheKey) -> Option<Vec<u8>> {
132 let start = Instant::now();
133
134 if let Some(mut entry) = self.cache.get_mut(key) {
135 entry.access();
136 let value = entry.value.clone();
137 let is_compressed = entry.is_compressed;
138 drop(entry); self.eviction_policy.lock().on_access(key);
142
143 let result = if is_compressed { self.decompress(&value).ok() } else { Some(value) };
145
146 if let Some(metrics) = &self.metrics {
148 let elapsed = start.elapsed();
149 if result.is_some() {
150 metrics.record_hit(elapsed);
151 } else {
152 metrics.record_miss(elapsed);
153 }
154 }
155
156 result
157 } else {
158 if let Some(metrics) = &self.metrics {
159 metrics.record_miss(start.elapsed());
160 }
161 None
162 }
163 }
164
165 pub fn insert(&self, key: CacheKey, value: Vec<u8>) {
167 let start = Instant::now();
168 let uncompressed_size = value.len();
169
170 let (stored_value, is_compressed) = if self.config.compress_values
172 && uncompressed_size >= self.config.compression_threshold
173 {
174 match self.compress(&value) {
175 Ok(compressed) if compressed.len() < uncompressed_size => (compressed, true),
176 _ => (value, false),
177 }
178 } else {
179 (value, false)
180 };
181
182 let entry = CacheEntry::new(stored_value, is_compressed, uncompressed_size);
183 let memory_size = entry.memory_size();
184
185 self.cache.insert(key.clone(), entry);
187
188 self.eviction_policy.lock().on_insert(&key, memory_size);
190
191 self.maybe_evict();
193
194 if let Some(metrics) = &self.metrics {
196 metrics.record_insert(memory_size, start.elapsed());
197 }
198 }
199
200 pub fn remove(&self, key: &CacheKey) -> Option<Vec<u8>> {
202 if let Some((_, entry)) = self.cache.remove(key) {
203 let memory_size = entry.memory_size();
204
205 self.eviction_policy.lock().on_remove(key);
207
208 if let Some(metrics) = &self.metrics {
210 metrics.record_eviction(memory_size);
211 }
212
213 if entry.is_compressed {
215 self.decompress(&entry.value).ok()
216 } else {
217 Some(entry.value)
218 }
219 } else {
220 None
221 }
222 }
223
224 pub fn clear(&self) {
226 self.cache.clear();
227
228 if let Some(metrics) = &self.metrics {
229 metrics.reset();
230 }
231 }
232
233 pub fn len(&self) -> usize {
235 self.cache.len()
236 }
237
238 pub fn is_empty(&self) -> bool {
240 self.cache.is_empty()
241 }
242
243 pub fn metrics(&self) -> Option<Arc<CacheMetrics>> {
245 self.metrics.clone()
246 }
247
248 fn handle_eviction(&self, key: &CacheKey) {
250 if let Some((_, entry)) = self.cache.remove(key) {
251 if let Some(metrics) = &self.metrics {
254 metrics.record_eviction(entry.memory_size());
255 }
256 }
257 }
258
259 fn maybe_evict(&self) {
261 let mut policy = self.eviction_policy.lock();
262
263 while policy.should_evict() {
264 if let Some(key) = policy.next_eviction() {
265 self.handle_eviction(&key);
266 } else {
267 break;
268 }
269 }
270 }
271
272 fn compress(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
274 use std::io::Write;
275 let mut encoder = oxiarc_zstd::ZstdStreamEncoder::new(Vec::new(), 3);
276 encoder.write_all(data)?;
277 encoder.finish()
278 }
279
280 fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
282 oxiarc_zstd::decode_all(data).map_err(|e| std::io::Error::other(e.to_string()))
283 }
284}
285
286pub struct InferenceCacheBuilder {
288 config: CacheConfig,
289}
290
291impl InferenceCacheBuilder {
292 pub fn new() -> Self {
293 Self {
294 config: CacheConfig::default(),
295 }
296 }
297
298 pub fn max_entries(mut self, max_entries: usize) -> Self {
299 self.config.max_entries = Some(max_entries);
300 self.config.max_memory_bytes = None;
302 self
303 }
304
305 pub fn max_memory_mb(mut self, max_memory_mb: usize) -> Self {
306 self.config.max_memory_bytes = Some(max_memory_mb * 1024 * 1024);
307 self.config.max_entries = None;
309 self
310 }
311
312 pub fn ttl(mut self, ttl: Duration) -> Self {
313 self.config.ttl = Some(ttl);
314 self
315 }
316
317 pub fn enable_metrics(mut self, enable: bool) -> Self {
318 self.config.enable_metrics = enable;
319 self
320 }
321
322 pub fn enable_compression(mut self, enable: bool) -> Self {
323 self.config.compress_values = enable;
324 self
325 }
326
327 pub fn compression_threshold(mut self, threshold: usize) -> Self {
328 self.config.compression_threshold = threshold;
329 self
330 }
331
332 pub fn build(self) -> InferenceCache {
333 InferenceCache::new(self.config)
334 }
335}
336
337impl Default for InferenceCacheBuilder {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use crate::cache::cache_key::CacheKeyBuilder;
347
348 #[test]
349 fn test_basic_cache_operations() {
350 let cache = InferenceCacheBuilder::new().max_entries(10).enable_metrics(true).build();
351
352 let key = CacheKeyBuilder::new("test-model", "classification")
353 .with_text("Hello world")
354 .build();
355
356 let value = b"prediction result".to_vec();
357
358 cache.insert(key.clone(), value.clone());
360 let retrieved = cache.get(&key).expect("expected value not found");
361 assert_eq!(retrieved, value);
362
363 let metrics = cache.metrics().expect("operation failed in test");
365 let snapshot = metrics.snapshot();
366 assert_eq!(snapshot.hits, 1);
367 assert_eq!(snapshot.misses, 0);
368 assert_eq!(snapshot.total_entries, 1);
369 }
370
371 #[test]
372 fn test_compression() {
373 let cache = InferenceCacheBuilder::new()
374 .enable_compression(true)
375 .compression_threshold(10)
376 .build();
377
378 let key = CacheKeyBuilder::new("test-model", "generation")
379 .with_text("Test prompt")
380 .build();
381
382 let value = vec![42u8; 1000];
384
385 cache.insert(key.clone(), value.clone());
386 let retrieved = cache.get(&key).expect("expected value not found");
387 assert_eq!(retrieved, value);
388
389 let entry = cache.cache.get(&key).expect("expected value not found");
391 assert!(entry.is_compressed);
392 assert!(entry.value.len() < entry.uncompressed_size);
393 }
394
395 #[test]
396 fn test_eviction() {
397 let cache = InferenceCacheBuilder::new().max_entries(3).enable_metrics(true).build();
398
399 let keys: Vec<_> = (0..5)
400 .map(|i| CacheKeyBuilder::new("model", "task").with_text(&format!("text{}", i)).build())
401 .collect();
402
403 for (i, key) in keys.iter().enumerate() {
405 cache.insert(key.clone(), vec![i as u8; 100]);
406 }
407
408 assert!(cache.get(&keys[0]).is_none());
410 assert!(cache.get(&keys[1]).is_none());
411
412 assert!(cache.get(&keys[2]).is_some());
414 assert!(cache.get(&keys[3]).is_some());
415 assert!(cache.get(&keys[4]).is_some());
416
417 let metrics = cache.metrics().expect("operation failed in test");
419 let snapshot = metrics.snapshot();
420 assert_eq!(snapshot.evictions, 2);
421 }
422}