Skip to main content

simple_agents_cache/
memory.rs

1//! In-memory cache implementation with LRU eviction.
2
3use async_trait::async_trait;
4use simple_agent_type::cache::Cache;
5use simple_agent_type::error::Result;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10
11/// Entry in the cache with expiration and access tracking.
12#[derive(Debug, Clone)]
13struct CacheEntry {
14    /// Cached data
15    data: Vec<u8>,
16    /// When this entry expires
17    expires_at: Instant,
18    /// Last time this entry was accessed (for LRU)
19    last_accessed: Instant,
20}
21
22impl CacheEntry {
23    /// Check if this entry has expired.
24    fn is_expired(&self) -> bool {
25        Instant::now() >= self.expires_at
26    }
27
28    /// Update the last accessed time.
29    fn touch(&mut self) {
30        self.last_accessed = Instant::now();
31    }
32}
33
34/// In-memory cache with TTL and LRU eviction.
35///
36/// This cache stores entries in memory and automatically evicts:
37/// - Expired entries (based on TTL)
38/// - Least recently used entries (when max size or max entries exceeded)
39///
40/// # Example
41/// ```no_run
42/// use simple_agents_cache::InMemoryCache;
43/// use simple_agent_type::cache::Cache;
44/// use std::time::Duration;
45///
46/// # async fn example() {
47/// let cache = InMemoryCache::new(1024 * 1024, 100); // 1MB, 100 entries
48///
49/// cache.set("key1", b"value1".to_vec(), Duration::from_secs(60)).await.unwrap();
50/// let value = cache.get("key1").await.unwrap();
51/// assert_eq!(value, Some(b"value1".to_vec()));
52/// # }
53/// ```
54pub struct InMemoryCache {
55    /// The cache store
56    store: Arc<RwLock<HashMap<String, CacheEntry>>>,
57    /// Maximum total size in bytes
58    max_size: usize,
59    /// Maximum number of entries
60    max_entries: usize,
61}
62
63impl InMemoryCache {
64    /// Create a new in-memory cache.
65    ///
66    /// # Arguments
67    /// - `max_size`: Maximum total size in bytes (0 = unlimited)
68    /// - `max_entries`: Maximum number of entries (0 = unlimited)
69    pub fn new(max_size: usize, max_entries: usize) -> Self {
70        Self {
71            store: Arc::new(RwLock::new(HashMap::new())),
72            max_size,
73            max_entries,
74        }
75    }
76
77    /// Evict expired entries from the cache.
78    async fn evict_expired(&self) {
79        let mut store = self.store.write().await;
80        store.retain(|_, entry| !entry.is_expired());
81    }
82
83    /// Evict least recently used entries to enforce size/count limits.
84    async fn evict_lru(&self) {
85        let mut store = self.store.write().await;
86
87        // Calculate current size
88        let current_size: usize = store.values().map(|e| e.data.len()).sum();
89
90        // Check if we need to evict based on size or entry count
91        let needs_eviction = (self.max_size > 0 && current_size > self.max_size)
92            || (self.max_entries > 0 && store.len() > self.max_entries);
93
94        if !needs_eviction {
95            return;
96        }
97
98        // Sort entries by last accessed time (oldest first)
99        let mut entries: Vec<_> = store
100            .iter()
101            .map(|(k, v)| (k.clone(), v.last_accessed, v.data.len()))
102            .collect();
103        entries.sort_by_key(|(_, accessed, _)| *accessed);
104
105        // Remove oldest entries until we're under the limit
106        let mut remaining_size = current_size;
107        let mut remaining_count = store.len();
108        let mut entries_to_remove = Vec::new();
109
110        for (key, _, size) in entries {
111            // Check if we're now under both limits
112            let under_size_limit = self.max_size == 0 || remaining_size <= self.max_size;
113            let under_count_limit = self.max_entries == 0 || remaining_count <= self.max_entries;
114
115            if under_size_limit && under_count_limit {
116                break;
117            }
118
119            // Mark this entry for removal
120            remaining_size = remaining_size.saturating_sub(size);
121            remaining_count = remaining_count.saturating_sub(1);
122            entries_to_remove.push(key);
123        }
124
125        // Remove the marked entries
126        for key in entries_to_remove {
127            store.remove(&key);
128        }
129    }
130}
131
132#[async_trait]
133impl Cache for InMemoryCache {
134    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
135        {
136            let store = self.store.read().await;
137            match store.get(key) {
138                Some(entry) if !entry.is_expired() => {
139                    drop(store);
140
141                    let mut store = self.store.write().await;
142                    if let Some(entry) = store.get_mut(key) {
143                        if entry.is_expired() {
144                            store.remove(key);
145                            return Ok(None);
146                        }
147
148                        entry.touch();
149                        return Ok(Some(entry.data.clone()));
150                    }
151
152                    return Ok(None);
153                }
154                Some(_) => {} // expired; clean up below
155                None => return Ok(None),
156            }
157        }
158
159        let mut store = self.store.write().await;
160        if let Some(entry) = store.get_mut(key) {
161            if entry.is_expired() {
162                store.remove(key);
163                return Ok(None);
164            }
165
166            entry.touch();
167            Ok(Some(entry.data.clone()))
168        } else {
169            Ok(None)
170        }
171    }
172
173    async fn set(&self, key: &str, value: Vec<u8>, ttl: Duration) -> Result<()> {
174        let entry = CacheEntry {
175            data: value,
176            expires_at: Instant::now() + ttl,
177            last_accessed: Instant::now(),
178        };
179
180        {
181            let mut store = self.store.write().await;
182            store.insert(key.to_string(), entry);
183        }
184
185        // Periodically clear expired entries before enforcing limits
186        self.evict_expired().await;
187
188        // Evict if needed
189        self.evict_lru().await;
190
191        Ok(())
192    }
193
194    async fn delete(&self, key: &str) -> Result<()> {
195        let mut store = self.store.write().await;
196        store.remove(key);
197        Ok(())
198    }
199
200    async fn clear(&self) -> Result<()> {
201        let mut store = self.store.write().await;
202        store.clear();
203        Ok(())
204    }
205
206    fn name(&self) -> &str {
207        "in-memory"
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use tokio::time::{sleep, Duration};
215
216    #[tokio::test]
217    async fn test_basic_set_get() {
218        let cache = InMemoryCache::new(1024, 10);
219
220        cache
221            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
222            .await
223            .unwrap();
224        let value = cache.get("key1").await.unwrap();
225
226        assert_eq!(value, Some(b"value1".to_vec()));
227    }
228
229    #[tokio::test]
230    async fn test_get_nonexistent() {
231        let cache = InMemoryCache::new(1024, 10);
232        let value = cache.get("nonexistent").await.unwrap();
233        assert_eq!(value, None);
234    }
235
236    #[tokio::test]
237    async fn test_ttl_expiration() {
238        let cache = InMemoryCache::new(1024, 10);
239
240        // Set with very short TTL
241        cache
242            .set("key1", b"value1".to_vec(), Duration::from_millis(100))
243            .await
244            .unwrap();
245
246        // Should exist immediately
247        let value = cache.get("key1").await.unwrap();
248        assert_eq!(value, Some(b"value1".to_vec()));
249
250        // Wait for expiration
251        sleep(Duration::from_millis(150)).await;
252
253        // Should be expired
254        let value = cache.get("key1").await.unwrap();
255        assert_eq!(value, None);
256    }
257
258    #[tokio::test]
259    async fn test_delete() {
260        let cache = InMemoryCache::new(1024, 10);
261
262        cache
263            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
264            .await
265            .unwrap();
266        assert!(cache.get("key1").await.unwrap().is_some());
267
268        cache.delete("key1").await.unwrap();
269        assert!(cache.get("key1").await.unwrap().is_none());
270    }
271
272    #[tokio::test]
273    async fn test_clear() {
274        let cache = InMemoryCache::new(1024, 10);
275
276        cache
277            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
278            .await
279            .unwrap();
280        cache
281            .set("key2", b"value2".to_vec(), Duration::from_secs(60))
282            .await
283            .unwrap();
284
285        cache.clear().await.unwrap();
286
287        assert!(cache.get("key1").await.unwrap().is_none());
288        assert!(cache.get("key2").await.unwrap().is_none());
289    }
290
291    #[tokio::test]
292    async fn test_lru_eviction_by_count() {
293        let cache = InMemoryCache::new(0, 2); // Max 2 entries
294
295        cache
296            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
297            .await
298            .unwrap();
299        cache
300            .set("key2", b"value2".to_vec(), Duration::from_secs(60))
301            .await
302            .unwrap();
303
304        // At this point we have 2 entries (at limit)
305
306        // Add a third entry, should trigger eviction
307        cache
308            .set("key3", b"value3".to_vec(), Duration::from_secs(60))
309            .await
310            .unwrap();
311
312        // After eviction, we should have at most 2 entries
313        let store = cache.store.read().await;
314        assert!(store.len() <= 2, "Cache should not exceed max_entries");
315        // key3 (most recent) should definitely exist
316        assert!(
317            store.contains_key("key3"),
318            "Most recently added key should exist"
319        );
320    }
321
322    #[tokio::test]
323    async fn test_lru_eviction_by_size() {
324        let cache = InMemoryCache::new(10, 0); // Max 10 bytes
325
326        cache
327            .set("key1", vec![1, 2, 3, 4, 5], Duration::from_secs(60))
328            .await
329            .unwrap();
330        cache
331            .set("key2", vec![6, 7, 8, 9, 10], Duration::from_secs(60))
332            .await
333            .unwrap();
334
335        // Access key1 to make it more recently used
336        cache.get("key1").await.unwrap();
337
338        // Add a new entry that would exceed size limit
339        cache
340            .set("key3", vec![11, 12], Duration::from_secs(60))
341            .await
342            .unwrap();
343
344        // key1 should still exist, key2 should be evicted
345        assert!(cache.get("key1").await.unwrap().is_some());
346        // key3 should exist
347        assert!(cache.get("key3").await.unwrap().is_some());
348    }
349
350    #[tokio::test]
351    async fn test_concurrent_gets_do_not_serialize_readers() {
352        let cache = Arc::new(InMemoryCache::new(1024, 10));
353        cache
354            .set("shared", b"value".to_vec(), Duration::from_secs(60))
355            .await
356            .unwrap();
357
358        let mut handles = Vec::new();
359        for _ in 0..25 {
360            let cache = cache.clone();
361            handles.push(tokio::spawn(
362                async move { cache.get("shared").await.unwrap() },
363            ));
364        }
365
366        for handle in handles {
367            assert_eq!(handle.await.unwrap(), Some(b"value".to_vec()));
368        }
369    }
370
371    #[tokio::test]
372    async fn test_cache_name() {
373        let cache = InMemoryCache::new(1024, 10);
374        assert_eq!(cache.name(), "in-memory");
375    }
376}