1use chrono::Utc;
2use somatize_core::cache::{CacheKey, CacheStore, EntryMeta, Origin};
3use somatize_core::error::Result;
4use somatize_core::value::Value;
5use std::collections::{HashMap, VecDeque};
6use std::sync::Mutex;
7
8pub struct MemoryCache {
14 store: Mutex<LruStore>,
15}
16
17struct LruStore {
18 entries: HashMap<CacheKey, CacheEntry>,
19 access_order: VecDeque<CacheKey>,
21 current_bytes: usize,
22 max_bytes: usize,
23}
24
25struct CacheEntry {
26 value: Value,
27 meta: EntryMeta,
28 size: usize,
29}
30
31impl LruStore {
32 fn new(max_bytes: usize) -> Self {
33 Self {
34 entries: HashMap::new(),
35 access_order: VecDeque::new(),
36 current_bytes: 0,
37 max_bytes,
38 }
39 }
40
41 fn touch(&mut self, key: &CacheKey) {
42 self.access_order.retain(|k| k != key);
43 self.access_order.push_back(key.clone());
44 }
45
46 fn evict_until_fits(&mut self, needed: usize) {
47 while self.current_bytes + needed > self.max_bytes && !self.access_order.is_empty() {
48 if let Some(oldest_key) = self.access_order.pop_front()
49 && let Some(entry) = self.entries.remove(&oldest_key)
50 {
51 self.current_bytes = self.current_bytes.saturating_sub(entry.size);
52 }
53 }
54 }
55
56 fn insert(&mut self, key: CacheKey, entry: CacheEntry) {
57 let size = entry.size;
58
59 if let Some(old) = self.entries.remove(&key) {
61 self.current_bytes = self.current_bytes.saturating_sub(old.size);
62 self.access_order.retain(|k| k != &key);
63 }
64
65 self.evict_until_fits(size);
67
68 self.current_bytes += size;
69 self.access_order.push_back(key.clone());
70 self.entries.insert(key, entry);
71 }
72
73 fn remove(&mut self, key: &CacheKey) {
74 if let Some(entry) = self.entries.remove(key) {
75 self.current_bytes = self.current_bytes.saturating_sub(entry.size);
76 self.access_order.retain(|k| k != key);
77 }
78 }
79}
80
81impl MemoryCache {
82 pub fn new(max_bytes: usize) -> Self {
84 Self {
85 store: Mutex::new(LruStore::new(max_bytes)),
86 }
87 }
88
89 pub fn len(&self) -> usize {
91 self.store
92 .lock()
93 .unwrap_or_else(|e| e.into_inner())
94 .entries
95 .len()
96 }
97
98 pub fn is_empty(&self) -> bool {
100 self.len() == 0
101 }
102
103 pub fn current_bytes(&self) -> usize {
105 self.store
106 .lock()
107 .unwrap_or_else(|e| e.into_inner())
108 .current_bytes
109 }
110
111 pub fn clear(&self) {
113 let mut store = self.store.lock().unwrap_or_else(|e| e.into_inner());
114 store.entries.clear();
115 store.access_order.clear();
116 store.current_bytes = 0;
117 }
118}
119
120impl Default for MemoryCache {
121 fn default() -> Self {
122 Self::new(1024 * 1024 * 1024) }
124}
125
126impl CacheStore for MemoryCache {
127 fn get(&self, key: &CacheKey) -> Result<Option<Value>> {
128 let mut store = self.store.lock().unwrap_or_else(|e| e.into_inner());
129 if store.entries.contains_key(key) {
130 store.touch(key);
131 if let Some(entry) = store.entries.get_mut(key) {
132 entry.meta.last_accessed = Utc::now();
133 return Ok(Some(entry.value.clone()));
134 }
135 }
136 Ok(None)
137 }
138
139 fn put(&self, key: &CacheKey, value: &Value) -> Result<()> {
140 let size = estimate_size(value);
141 let now = Utc::now();
142
143 let mut store = self.store.lock().unwrap_or_else(|e| e.into_inner());
144 store.insert(
145 key.clone(),
146 CacheEntry {
147 value: value.clone(),
148 meta: EntryMeta {
149 key: key.clone(),
150 size_bytes: size as u64,
151 created_at: now,
152 last_accessed: now,
153 ttl: None,
154 origin: Origin::Computed {
155 node_id: String::new(),
156 run_id: String::new(),
157 },
158 },
159 size,
160 },
161 );
162 Ok(())
163 }
164
165 fn exists(&self, key: &CacheKey) -> Result<bool> {
166 Ok(self
167 .store
168 .lock()
169 .unwrap_or_else(|e| e.into_inner())
170 .entries
171 .contains_key(key))
172 }
173
174 fn remove(&self, key: &CacheKey) -> Result<()> {
175 self.store
176 .lock()
177 .unwrap_or_else(|e| e.into_inner())
178 .remove(key);
179 Ok(())
180 }
181
182 fn metadata(&self, key: &CacheKey) -> Result<Option<EntryMeta>> {
183 Ok(self
184 .store
185 .lock()
186 .unwrap_or_else(|e| e.into_inner())
187 .entries
188 .get(key)
189 .map(|e| e.meta.clone()))
190 }
191}
192
193fn estimate_size(value: &Value) -> usize {
194 match value {
195 Value::Tensor { values, shape } => {
196 values.len() * std::mem::size_of::<f64>() + shape.len() * std::mem::size_of::<usize>()
197 }
198 Value::Json(v) => v.to_string().len(),
199 Value::Bytes(b) => b.len(),
200 Value::Empty => 0,
201 _ => 0,
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use serde_json::json;
209
210 #[test]
211 fn put_and_get() {
212 let cache = MemoryCache::default();
213 let key = CacheKey::hash_data(b"test");
214 let value = Value::tensor(vec![1.0, 2.0, 3.0], vec![3]);
215
216 cache.put(&key, &value).unwrap();
217 let retrieved = cache.get(&key).unwrap().unwrap();
218 assert_eq!(retrieved, value);
219 }
220
221 #[test]
222 fn get_missing_returns_none() {
223 let cache = MemoryCache::default();
224 let key = CacheKey::hash_data(b"nonexistent");
225 assert!(cache.get(&key).unwrap().is_none());
226 }
227
228 #[test]
229 fn exists_check() {
230 let cache = MemoryCache::default();
231 let key = CacheKey::hash_data(b"test");
232 assert!(!cache.exists(&key).unwrap());
233
234 cache.put(&key, &Value::Empty).unwrap();
235 assert!(cache.exists(&key).unwrap());
236 }
237
238 #[test]
239 fn remove_entry() {
240 let cache = MemoryCache::default();
241 let key = CacheKey::hash_data(b"test");
242 cache.put(&key, &Value::Empty).unwrap();
243 assert_eq!(cache.len(), 1);
244
245 cache.remove(&key).unwrap();
246 assert_eq!(cache.len(), 0);
247 assert!(!cache.exists(&key).unwrap());
248 }
249
250 #[test]
251 fn metadata_available() {
252 let cache = MemoryCache::default();
253 let key = CacheKey::hash_data(b"test");
254 let value = Value::tensor(vec![1.0; 100], vec![10, 10]);
255
256 cache.put(&key, &value).unwrap();
257 let meta = cache.metadata(&key).unwrap().unwrap();
258 assert_eq!(meta.size_bytes, 816);
260 }
261
262 #[test]
263 fn clear_empties_cache() {
264 let cache = MemoryCache::default();
265 cache
266 .put(&CacheKey::hash_data(b"a"), &Value::Empty)
267 .unwrap();
268 cache
269 .put(&CacheKey::hash_data(b"b"), &Value::Empty)
270 .unwrap();
271 assert_eq!(cache.len(), 2);
272
273 cache.clear();
274 assert!(cache.is_empty());
275 assert_eq!(cache.current_bytes(), 0);
276 }
277
278 #[test]
279 fn overwrite_existing_key() {
280 let cache = MemoryCache::default();
281 let key = CacheKey::hash_data(b"test");
282
283 cache.put(&key, &Value::json(json!(1))).unwrap();
284 cache.put(&key, &Value::json(json!(2))).unwrap();
285
286 let val = cache.get(&key).unwrap().unwrap();
287 assert_eq!(val, Value::json(json!(2)));
288 assert_eq!(cache.len(), 1);
289 }
290
291 #[test]
292 fn multiple_keys() {
293 let cache = MemoryCache::default();
294 for i in 0..10 {
295 let key = CacheKey::hash_data(format!("key_{i}").as_bytes());
296 let val = Value::tensor(vec![i as f64], vec![1]);
297 cache.put(&key, &val).unwrap();
298 }
299 assert_eq!(cache.len(), 10);
300
301 let key5 = CacheKey::hash_data(b"key_5");
302 let val = cache.get(&key5).unwrap().unwrap();
303 let (data, _) = val.as_tensor().unwrap();
304 assert_eq!(data, &[5.0]);
305 }
306
307 #[test]
310 fn lru_evicts_oldest_when_full() {
311 let cache = MemoryCache::new(100);
313
314 let k1 = CacheKey::hash_data(b"first");
316 let k2 = CacheKey::hash_data(b"second");
317 let k3 = CacheKey::hash_data(b"third");
318
319 cache
320 .put(&k1, &Value::tensor(vec![0.0; 5], vec![5]))
321 .unwrap();
322 cache
323 .put(&k2, &Value::tensor(vec![0.0; 5], vec![5]))
324 .unwrap();
325 assert_eq!(cache.len(), 2);
326
327 cache
329 .put(&k3, &Value::tensor(vec![0.0; 5], vec![5]))
330 .unwrap();
331
332 assert!(!cache.exists(&k1).unwrap(), "k1 should be evicted");
333 assert!(cache.exists(&k2).unwrap(), "k2 should remain");
334 assert!(cache.exists(&k3).unwrap(), "k3 should remain");
335 }
336
337 #[test]
338 fn lru_access_prevents_eviction() {
339 let cache = MemoryCache::new(100);
340
341 let k1 = CacheKey::hash_data(b"first");
342 let k2 = CacheKey::hash_data(b"second");
343 let k3 = CacheKey::hash_data(b"third");
344
345 cache
346 .put(&k1, &Value::tensor(vec![0.0; 5], vec![5]))
347 .unwrap();
348 cache
349 .put(&k2, &Value::tensor(vec![0.0; 5], vec![5]))
350 .unwrap();
351
352 cache.get(&k1).unwrap();
354
355 cache
357 .put(&k3, &Value::tensor(vec![0.0; 5], vec![5]))
358 .unwrap();
359
360 assert!(cache.exists(&k1).unwrap(), "k1 was accessed, should remain");
361 assert!(!cache.exists(&k2).unwrap(), "k2 was LRU, should be evicted");
362 assert!(cache.exists(&k3).unwrap(), "k3 is new, should remain");
363 }
364
365 #[test]
366 fn lru_tracks_byte_usage() {
367 let cache = MemoryCache::new(1024);
368
369 assert_eq!(cache.current_bytes(), 0);
370
371 cache
373 .put(
374 &CacheKey::hash_data(b"a"),
375 &Value::tensor(vec![0.0; 10], vec![10]),
376 )
377 .unwrap();
378 assert_eq!(cache.current_bytes(), 88);
379
380 cache.remove(&CacheKey::hash_data(b"a")).unwrap();
381 assert_eq!(cache.current_bytes(), 0);
382 }
383
384 #[test]
385 fn lru_overwrite_updates_size() {
386 let cache = MemoryCache::new(1024);
387
388 let key = CacheKey::hash_data(b"key");
389 cache
390 .put(&key, &Value::tensor(vec![0.0; 10], vec![10]))
391 .unwrap();
392 let size1 = cache.current_bytes();
393
394 cache
396 .put(&key, &Value::tensor(vec![0.0; 20], vec![20]))
397 .unwrap();
398 let size2 = cache.current_bytes();
399
400 assert!(size2 > size1, "larger value should use more bytes");
401 assert_eq!(cache.len(), 1, "should still be one entry");
402 }
403}