Skip to main content

simple_agents_cache/
memory.rs

1//! In-memory cache implementation with LRU eviction.
2
3use async_trait::async_trait;
4use simple_agent_type::cache::Cache;
5use simple_agent_type::error::Result;
6use std::collections::{HashMap, VecDeque};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10
11/// Entry in the cache with expiration and access tracking.
12#[derive(Debug, Clone)]
13struct CacheEntry {
14    /// Cached data
15    data: Vec<u8>,
16    /// When this entry expires
17    expires_at: Instant,
18    /// Logical access tick used for LRU ordering.
19    access_tick: u64,
20}
21
22#[derive(Debug, Default)]
23struct CacheState {
24    store: HashMap<String, CacheEntry>,
25    access_order: VecDeque<(String, u64)>,
26    total_size: usize,
27    next_tick: u64,
28}
29
30impl CacheState {
31    fn next_access_tick(&mut self) -> u64 {
32        self.next_tick = self.next_tick.saturating_add(1);
33        self.next_tick
34    }
35
36    fn is_expired(entry: &CacheEntry) -> bool {
37        Instant::now() >= entry.expires_at
38    }
39
40    fn touch_key(&mut self, key: &str) {
41        let tick = self.next_access_tick();
42        if let Some(entry) = self.store.get_mut(key) {
43            entry.access_tick = tick;
44            self.access_order.push_back((key.to_string(), tick));
45        }
46    }
47
48    fn upsert(&mut self, key: &str, value: Vec<u8>, ttl: Duration) {
49        let tick = self.next_access_tick();
50        if let Some(previous) = self.store.get(key) {
51            self.total_size = self.total_size.saturating_sub(previous.data.len());
52        }
53
54        self.total_size = self.total_size.saturating_add(value.len());
55        self.store.insert(
56            key.to_string(),
57            CacheEntry {
58                data: value,
59                expires_at: Instant::now() + ttl,
60                access_tick: tick,
61            },
62        );
63        self.access_order.push_back((key.to_string(), tick));
64    }
65
66    fn remove_key(&mut self, key: &str) {
67        if let Some(entry) = self.store.remove(key) {
68            self.total_size = self.total_size.saturating_sub(entry.data.len());
69        }
70    }
71
72    fn evict_expired(&mut self) {
73        let now = Instant::now();
74        self.store.retain(|_, entry| {
75            if now >= entry.expires_at {
76                self.total_size = self.total_size.saturating_sub(entry.data.len());
77                false
78            } else {
79                true
80            }
81        });
82    }
83
84    fn evict_lru_until_within_limits(&mut self, max_size: usize, max_entries: usize) {
85        let over_size = |total_size: usize| max_size > 0 && total_size > max_size;
86        let over_count = |entry_count: usize| max_entries > 0 && entry_count > max_entries;
87
88        while over_size(self.total_size) || over_count(self.store.len()) {
89            let Some((key, tick)) = self.access_order.pop_front() else {
90                break;
91            };
92
93            let should_remove = self
94                .store
95                .get(key.as_str())
96                .is_some_and(|entry| entry.access_tick == tick);
97
98            if should_remove {
99                self.remove_key(key.as_str());
100            }
101        }
102    }
103}
104
105/// In-memory cache with TTL and LRU eviction.
106///
107/// This cache stores entries in memory and automatically evicts:
108/// - Expired entries (based on TTL)
109/// - Least recently used entries (when max size or max entries exceeded)
110///
111/// # Example
112/// ```no_run
113/// use simple_agents_cache::InMemoryCache;
114/// use simple_agent_type::cache::Cache;
115/// use std::time::Duration;
116///
117/// # async fn example() {
118/// let cache = InMemoryCache::new(1024 * 1024, 100); // 1MB, 100 entries
119///
120/// cache.set("key1", b"value1".to_vec(), Duration::from_secs(60)).await.unwrap();
121/// let value = cache.get("key1").await.unwrap();
122/// assert_eq!(value, Some(b"value1".to_vec()));
123/// # }
124/// ```
125pub struct InMemoryCache {
126    /// Shared cache state.
127    state: Arc<RwLock<CacheState>>,
128    /// Maximum total size in bytes
129    max_size: usize,
130    /// Maximum number of entries
131    max_entries: usize,
132}
133
134impl InMemoryCache {
135    /// Create a new in-memory cache.
136    ///
137    /// # Arguments
138    /// - `max_size`: Maximum total size in bytes (0 = unlimited)
139    /// - `max_entries`: Maximum number of entries (0 = unlimited)
140    pub fn new(max_size: usize, max_entries: usize) -> Self {
141        Self {
142            state: Arc::new(RwLock::new(CacheState::default())),
143            max_size,
144            max_entries,
145        }
146    }
147}
148
149#[async_trait]
150impl Cache for InMemoryCache {
151    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
152        let mut state = self.state.write().await;
153        match state.store.get(key) {
154            Some(entry) if CacheState::is_expired(entry) => {
155                state.remove_key(key);
156                Ok(None)
157            }
158            Some(_) => {
159                let value = state.store.get(key).map(|entry| entry.data.clone());
160                state.touch_key(key);
161                Ok(value)
162            }
163            None => Ok(None),
164        }
165    }
166
167    async fn set(&self, key: &str, value: Vec<u8>, ttl: Duration) -> Result<()> {
168        let mut state = self.state.write().await;
169        state.upsert(key, value, ttl);
170        state.evict_expired();
171        state.evict_lru_until_within_limits(self.max_size, self.max_entries);
172
173        Ok(())
174    }
175
176    async fn delete(&self, key: &str) -> Result<()> {
177        let mut state = self.state.write().await;
178        state.remove_key(key);
179        Ok(())
180    }
181
182    async fn clear(&self) -> Result<()> {
183        let mut state = self.state.write().await;
184        state.store.clear();
185        state.access_order.clear();
186        state.total_size = 0;
187        Ok(())
188    }
189
190    fn name(&self) -> &str {
191        "in-memory"
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use tokio::time::{sleep, Duration};
199
200    #[tokio::test]
201    async fn test_basic_set_get() {
202        let cache = InMemoryCache::new(1024, 10);
203
204        cache
205            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
206            .await
207            .unwrap();
208        let value = cache.get("key1").await.unwrap();
209
210        assert_eq!(value, Some(b"value1".to_vec()));
211    }
212
213    #[tokio::test]
214    async fn test_get_nonexistent() {
215        let cache = InMemoryCache::new(1024, 10);
216        let value = cache.get("nonexistent").await.unwrap();
217        assert_eq!(value, None);
218    }
219
220    #[tokio::test]
221    async fn test_ttl_expiration() {
222        let cache = InMemoryCache::new(1024, 10);
223
224        // Set with very short TTL
225        cache
226            .set("key1", b"value1".to_vec(), Duration::from_millis(100))
227            .await
228            .unwrap();
229
230        // Should exist immediately
231        let value = cache.get("key1").await.unwrap();
232        assert_eq!(value, Some(b"value1".to_vec()));
233
234        // Wait for expiration
235        sleep(Duration::from_millis(150)).await;
236
237        // Should be expired
238        let value = cache.get("key1").await.unwrap();
239        assert_eq!(value, None);
240    }
241
242    #[tokio::test]
243    async fn test_delete() {
244        let cache = InMemoryCache::new(1024, 10);
245
246        cache
247            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
248            .await
249            .unwrap();
250        assert!(cache.get("key1").await.unwrap().is_some());
251
252        cache.delete("key1").await.unwrap();
253        assert!(cache.get("key1").await.unwrap().is_none());
254    }
255
256    #[tokio::test]
257    async fn test_clear() {
258        let cache = InMemoryCache::new(1024, 10);
259
260        cache
261            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
262            .await
263            .unwrap();
264        cache
265            .set("key2", b"value2".to_vec(), Duration::from_secs(60))
266            .await
267            .unwrap();
268
269        cache.clear().await.unwrap();
270
271        assert!(cache.get("key1").await.unwrap().is_none());
272        assert!(cache.get("key2").await.unwrap().is_none());
273    }
274
275    #[tokio::test]
276    async fn test_lru_eviction_by_count() {
277        let cache = InMemoryCache::new(0, 2); // Max 2 entries
278
279        cache
280            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
281            .await
282            .unwrap();
283        cache
284            .set("key2", b"value2".to_vec(), Duration::from_secs(60))
285            .await
286            .unwrap();
287
288        // At this point we have 2 entries (at limit)
289
290        // Add a third entry, should trigger eviction
291        cache
292            .set("key3", b"value3".to_vec(), Duration::from_secs(60))
293            .await
294            .unwrap();
295
296        // After eviction, we should have at most 2 entries
297        let store = cache.state.read().await;
298        assert!(
299            store.store.len() <= 2,
300            "Cache should not exceed max_entries"
301        );
302        // key3 (most recent) should definitely exist
303        assert!(
304            store.store.contains_key("key3"),
305            "Most recently added key should exist"
306        );
307    }
308
309    #[tokio::test]
310    async fn test_lru_eviction_by_size() {
311        let cache = InMemoryCache::new(10, 0); // Max 10 bytes
312
313        cache
314            .set("key1", vec![1, 2, 3, 4, 5], Duration::from_secs(60))
315            .await
316            .unwrap();
317        cache
318            .set("key2", vec![6, 7, 8, 9, 10], Duration::from_secs(60))
319            .await
320            .unwrap();
321
322        // Access key1 to make it more recently used
323        cache.get("key1").await.unwrap();
324
325        // Add a new entry that would exceed size limit
326        cache
327            .set("key3", vec![11, 12], Duration::from_secs(60))
328            .await
329            .unwrap();
330
331        // key1 should still exist, key2 should be evicted
332        assert!(cache.get("key1").await.unwrap().is_some());
333        // key3 should exist
334        assert!(cache.get("key3").await.unwrap().is_some());
335    }
336
337    #[tokio::test]
338    async fn test_concurrent_gets_do_not_serialize_readers() {
339        let cache = Arc::new(InMemoryCache::new(1024, 10));
340        cache
341            .set("shared", b"value".to_vec(), Duration::from_secs(60))
342            .await
343            .unwrap();
344
345        let mut handles = Vec::new();
346        for _ in 0..25 {
347            let cache = cache.clone();
348            handles.push(tokio::spawn(
349                async move { cache.get("shared").await.unwrap() },
350            ));
351        }
352
353        for handle in handles {
354            assert_eq!(handle.await.unwrap(), Some(b"value".to_vec()));
355        }
356    }
357
358    #[tokio::test]
359    async fn test_cache_name() {
360        let cache = InMemoryCache::new(1024, 10);
361        assert_eq!(cache.name(), "in-memory");
362    }
363
364    #[tokio::test]
365    async fn test_update_existing_key_keeps_latest_value() {
366        let cache = InMemoryCache::new(6, 2);
367
368        cache
369            .set("key", b"abc".to_vec(), Duration::from_secs(60))
370            .await
371            .unwrap();
372        cache
373            .set("key", b"abcdef".to_vec(), Duration::from_secs(60))
374            .await
375            .unwrap();
376
377        let value = cache.get("key").await.unwrap();
378        assert_eq!(value, Some(b"abcdef".to_vec()));
379    }
380}