Skip to main content

somatize_runtime/cache/
memory.rs

1use chrono::Utc;
2use somatize_core::cache::{CacheKey, CacheStore, EntryMeta, Origin};
3use somatize_core::error::Result;
4use somatize_core::value::Value;
5use std::collections::{HashMap, VecDeque};
6use std::sync::Mutex;
7
8/// In-memory LRU cache store.
9///
10/// Enforces a maximum byte limit. When the limit is exceeded,
11/// the least recently accessed entries are evicted.
12/// Thread-safe via Mutex.
13pub struct MemoryCache {
14    store: Mutex<LruStore>,
15}
16
17struct LruStore {
18    entries: HashMap<CacheKey, CacheEntry>,
19    /// Access order: most recent at back, least recent at front.
20    access_order: VecDeque<CacheKey>,
21    current_bytes: usize,
22    max_bytes: usize,
23}
24
25struct CacheEntry {
26    value: Value,
27    meta: EntryMeta,
28    size: usize,
29}
30
31impl LruStore {
32    fn new(max_bytes: usize) -> Self {
33        Self {
34            entries: HashMap::new(),
35            access_order: VecDeque::new(),
36            current_bytes: 0,
37            max_bytes,
38        }
39    }
40
41    fn touch(&mut self, key: &CacheKey) {
42        self.access_order.retain(|k| k != key);
43        self.access_order.push_back(key.clone());
44    }
45
46    fn evict_until_fits(&mut self, needed: usize) {
47        while self.current_bytes + needed > self.max_bytes && !self.access_order.is_empty() {
48            if let Some(oldest_key) = self.access_order.pop_front()
49                && let Some(entry) = self.entries.remove(&oldest_key)
50            {
51                self.current_bytes = self.current_bytes.saturating_sub(entry.size);
52            }
53        }
54    }
55
56    fn insert(&mut self, key: CacheKey, entry: CacheEntry) {
57        let size = entry.size;
58
59        // Remove old entry if exists
60        if let Some(old) = self.entries.remove(&key) {
61            self.current_bytes = self.current_bytes.saturating_sub(old.size);
62            self.access_order.retain(|k| k != &key);
63        }
64
65        // Evict if needed
66        self.evict_until_fits(size);
67
68        self.current_bytes += size;
69        self.access_order.push_back(key.clone());
70        self.entries.insert(key, entry);
71    }
72
73    fn remove(&mut self, key: &CacheKey) {
74        if let Some(entry) = self.entries.remove(key) {
75            self.current_bytes = self.current_bytes.saturating_sub(entry.size);
76            self.access_order.retain(|k| k != key);
77        }
78    }
79}
80
81impl MemoryCache {
82    /// Create a new memory cache with a maximum byte limit.
83    pub fn new(max_bytes: usize) -> Self {
84        Self {
85            store: Mutex::new(LruStore::new(max_bytes)),
86        }
87    }
88
89    /// Number of entries currently in the cache.
90    pub fn len(&self) -> usize {
91        self.store
92            .lock()
93            .unwrap_or_else(|e| e.into_inner())
94            .entries
95            .len()
96    }
97
98    /// Whether the cache is empty.
99    pub fn is_empty(&self) -> bool {
100        self.len() == 0
101    }
102
103    /// Current memory usage in bytes.
104    pub fn current_bytes(&self) -> usize {
105        self.store
106            .lock()
107            .unwrap_or_else(|e| e.into_inner())
108            .current_bytes
109    }
110
111    /// Clear all entries.
112    pub fn clear(&self) {
113        let mut store = self.store.lock().unwrap_or_else(|e| e.into_inner());
114        store.entries.clear();
115        store.access_order.clear();
116        store.current_bytes = 0;
117    }
118}
119
120impl Default for MemoryCache {
121    fn default() -> Self {
122        Self::new(1024 * 1024 * 1024) // 1GB
123    }
124}
125
126impl CacheStore for MemoryCache {
127    fn get(&self, key: &CacheKey) -> Result<Option<Value>> {
128        let mut store = self.store.lock().unwrap_or_else(|e| e.into_inner());
129        if store.entries.contains_key(key) {
130            store.touch(key);
131            if let Some(entry) = store.entries.get_mut(key) {
132                entry.meta.last_accessed = Utc::now();
133                return Ok(Some(entry.value.clone()));
134            }
135        }
136        Ok(None)
137    }
138
139    fn put(&self, key: &CacheKey, value: &Value) -> Result<()> {
140        let size = estimate_size(value);
141        let now = Utc::now();
142
143        let mut store = self.store.lock().unwrap_or_else(|e| e.into_inner());
144        store.insert(
145            key.clone(),
146            CacheEntry {
147                value: value.clone(),
148                meta: EntryMeta {
149                    key: key.clone(),
150                    size_bytes: size as u64,
151                    created_at: now,
152                    last_accessed: now,
153                    ttl: None,
154                    origin: Origin::Computed {
155                        node_id: String::new(),
156                        run_id: String::new(),
157                    },
158                },
159                size,
160            },
161        );
162        Ok(())
163    }
164
165    fn exists(&self, key: &CacheKey) -> Result<bool> {
166        Ok(self
167            .store
168            .lock()
169            .unwrap_or_else(|e| e.into_inner())
170            .entries
171            .contains_key(key))
172    }
173
174    fn remove(&self, key: &CacheKey) -> Result<()> {
175        self.store
176            .lock()
177            .unwrap_or_else(|e| e.into_inner())
178            .remove(key);
179        Ok(())
180    }
181
182    fn metadata(&self, key: &CacheKey) -> Result<Option<EntryMeta>> {
183        Ok(self
184            .store
185            .lock()
186            .unwrap_or_else(|e| e.into_inner())
187            .entries
188            .get(key)
189            .map(|e| e.meta.clone()))
190    }
191}
192
193fn estimate_size(value: &Value) -> usize {
194    match value {
195        Value::Tensor { values, shape } => {
196            values.len() * std::mem::size_of::<f64>() + shape.len() * std::mem::size_of::<usize>()
197        }
198        Value::Json(v) => v.to_string().len(),
199        Value::Bytes(b) => b.len(),
200        Value::Empty => 0,
201        _ => 0,
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use serde_json::json;
209
210    #[test]
211    fn put_and_get() {
212        let cache = MemoryCache::default();
213        let key = CacheKey::hash_data(b"test");
214        let value = Value::tensor(vec![1.0, 2.0, 3.0], vec![3]);
215
216        cache.put(&key, &value).unwrap();
217        let retrieved = cache.get(&key).unwrap().unwrap();
218        assert_eq!(retrieved, value);
219    }
220
221    #[test]
222    fn get_missing_returns_none() {
223        let cache = MemoryCache::default();
224        let key = CacheKey::hash_data(b"nonexistent");
225        assert!(cache.get(&key).unwrap().is_none());
226    }
227
228    #[test]
229    fn exists_check() {
230        let cache = MemoryCache::default();
231        let key = CacheKey::hash_data(b"test");
232        assert!(!cache.exists(&key).unwrap());
233
234        cache.put(&key, &Value::Empty).unwrap();
235        assert!(cache.exists(&key).unwrap());
236    }
237
238    #[test]
239    fn remove_entry() {
240        let cache = MemoryCache::default();
241        let key = CacheKey::hash_data(b"test");
242        cache.put(&key, &Value::Empty).unwrap();
243        assert_eq!(cache.len(), 1);
244
245        cache.remove(&key).unwrap();
246        assert_eq!(cache.len(), 0);
247        assert!(!cache.exists(&key).unwrap());
248    }
249
250    #[test]
251    fn metadata_available() {
252        let cache = MemoryCache::default();
253        let key = CacheKey::hash_data(b"test");
254        let value = Value::tensor(vec![1.0; 100], vec![10, 10]);
255
256        cache.put(&key, &value).unwrap();
257        let meta = cache.metadata(&key).unwrap().unwrap();
258        // 100 f64 values * 8 bytes + 2 shape elements * 8 bytes = 816
259        assert_eq!(meta.size_bytes, 816);
260    }
261
262    #[test]
263    fn clear_empties_cache() {
264        let cache = MemoryCache::default();
265        cache
266            .put(&CacheKey::hash_data(b"a"), &Value::Empty)
267            .unwrap();
268        cache
269            .put(&CacheKey::hash_data(b"b"), &Value::Empty)
270            .unwrap();
271        assert_eq!(cache.len(), 2);
272
273        cache.clear();
274        assert!(cache.is_empty());
275        assert_eq!(cache.current_bytes(), 0);
276    }
277
278    #[test]
279    fn overwrite_existing_key() {
280        let cache = MemoryCache::default();
281        let key = CacheKey::hash_data(b"test");
282
283        cache.put(&key, &Value::json(json!(1))).unwrap();
284        cache.put(&key, &Value::json(json!(2))).unwrap();
285
286        let val = cache.get(&key).unwrap().unwrap();
287        assert_eq!(val, Value::json(json!(2)));
288        assert_eq!(cache.len(), 1);
289    }
290
291    #[test]
292    fn multiple_keys() {
293        let cache = MemoryCache::default();
294        for i in 0..10 {
295            let key = CacheKey::hash_data(format!("key_{i}").as_bytes());
296            let val = Value::tensor(vec![i as f64], vec![1]);
297            cache.put(&key, &val).unwrap();
298        }
299        assert_eq!(cache.len(), 10);
300
301        let key5 = CacheKey::hash_data(b"key_5");
302        let val = cache.get(&key5).unwrap().unwrap();
303        let (data, _) = val.as_tensor().unwrap();
304        assert_eq!(data, &[5.0]);
305    }
306
307    // ── LRU eviction tests ──
308
309    #[test]
310    fn lru_evicts_oldest_when_full() {
311        // Cache with 100 bytes max
312        let cache = MemoryCache::new(100);
313
314        // Each tensor of 5 f64s = 5*8 + 1*8 = 48 bytes
315        let k1 = CacheKey::hash_data(b"first");
316        let k2 = CacheKey::hash_data(b"second");
317        let k3 = CacheKey::hash_data(b"third");
318
319        cache
320            .put(&k1, &Value::tensor(vec![0.0; 5], vec![5]))
321            .unwrap();
322        cache
323            .put(&k2, &Value::tensor(vec![0.0; 5], vec![5]))
324            .unwrap();
325        assert_eq!(cache.len(), 2);
326
327        // Adding third should evict first (48+48+48=144 > 100)
328        cache
329            .put(&k3, &Value::tensor(vec![0.0; 5], vec![5]))
330            .unwrap();
331
332        assert!(!cache.exists(&k1).unwrap(), "k1 should be evicted");
333        assert!(cache.exists(&k2).unwrap(), "k2 should remain");
334        assert!(cache.exists(&k3).unwrap(), "k3 should remain");
335    }
336
337    #[test]
338    fn lru_access_prevents_eviction() {
339        let cache = MemoryCache::new(100);
340
341        let k1 = CacheKey::hash_data(b"first");
342        let k2 = CacheKey::hash_data(b"second");
343        let k3 = CacheKey::hash_data(b"third");
344
345        cache
346            .put(&k1, &Value::tensor(vec![0.0; 5], vec![5]))
347            .unwrap();
348        cache
349            .put(&k2, &Value::tensor(vec![0.0; 5], vec![5]))
350            .unwrap();
351
352        // Access k1, making k2 the least recently used
353        cache.get(&k1).unwrap();
354
355        // Adding k3 should evict k2 (LRU), not k1
356        cache
357            .put(&k3, &Value::tensor(vec![0.0; 5], vec![5]))
358            .unwrap();
359
360        assert!(cache.exists(&k1).unwrap(), "k1 was accessed, should remain");
361        assert!(!cache.exists(&k2).unwrap(), "k2 was LRU, should be evicted");
362        assert!(cache.exists(&k3).unwrap(), "k3 is new, should remain");
363    }
364
365    #[test]
366    fn lru_tracks_byte_usage() {
367        let cache = MemoryCache::new(1024);
368
369        assert_eq!(cache.current_bytes(), 0);
370
371        // 10 f64s = 80 bytes data + 8 bytes shape = 88
372        cache
373            .put(
374                &CacheKey::hash_data(b"a"),
375                &Value::tensor(vec![0.0; 10], vec![10]),
376            )
377            .unwrap();
378        assert_eq!(cache.current_bytes(), 88);
379
380        cache.remove(&CacheKey::hash_data(b"a")).unwrap();
381        assert_eq!(cache.current_bytes(), 0);
382    }
383
384    #[test]
385    fn lru_overwrite_updates_size() {
386        let cache = MemoryCache::new(1024);
387
388        let key = CacheKey::hash_data(b"key");
389        cache
390            .put(&key, &Value::tensor(vec![0.0; 10], vec![10]))
391            .unwrap();
392        let size1 = cache.current_bytes();
393
394        // Replace with larger value
395        cache
396            .put(&key, &Value::tensor(vec![0.0; 20], vec![20]))
397            .unwrap();
398        let size2 = cache.current_bytes();
399
400        assert!(size2 > size1, "larger value should use more bytes");
401        assert_eq!(cache.len(), 1, "should still be one entry");
402    }
403}