Skip to main content

simple_agents_cache/
memory.rs

1//! In-memory cache implementation with LRU eviction.
2
3use async_trait::async_trait;
4use simple_agent_type::cache::Cache;
5use simple_agent_type::error::Result;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10
11/// Entry in the cache with expiration and access tracking.
12#[derive(Debug, Clone)]
13struct CacheEntry {
14    /// Cached data
15    data: Vec<u8>,
16    /// When this entry expires
17    expires_at: Instant,
18    /// Last time this entry was accessed (for LRU)
19    last_accessed: Instant,
20}
21
22impl CacheEntry {
23    /// Check if this entry has expired.
24    fn is_expired(&self) -> bool {
25        Instant::now() >= self.expires_at
26    }
27
28    /// Update the last accessed time.
29    fn touch(&mut self) {
30        self.last_accessed = Instant::now();
31    }
32}
33
34/// In-memory cache with TTL and LRU eviction.
35///
36/// This cache stores entries in memory and automatically evicts:
37/// - Expired entries (based on TTL)
38/// - Least recently used entries (when max size or max entries exceeded)
39///
40/// # Example
41/// ```no_run
42/// use simple_agents_cache::InMemoryCache;
43/// use simple_agent_type::cache::Cache;
44/// use std::time::Duration;
45///
46/// # async fn example() {
47/// let cache = InMemoryCache::new(1024 * 1024, 100); // 1MB, 100 entries
48///
49/// cache.set("key1", b"value1".to_vec(), Duration::from_secs(60)).await.unwrap();
50/// let value = cache.get("key1").await.unwrap();
51/// assert_eq!(value, Some(b"value1".to_vec()));
52/// # }
53/// ```
54pub struct InMemoryCache {
55    /// The cache store
56    store: Arc<RwLock<HashMap<String, CacheEntry>>>,
57    /// Maximum total size in bytes
58    max_size: usize,
59    /// Maximum number of entries
60    max_entries: usize,
61}
62
63impl InMemoryCache {
64    /// Create a new in-memory cache.
65    ///
66    /// # Arguments
67    /// - `max_size`: Maximum total size in bytes (0 = unlimited)
68    /// - `max_entries`: Maximum number of entries (0 = unlimited)
69    pub fn new(max_size: usize, max_entries: usize) -> Self {
70        Self {
71            store: Arc::new(RwLock::new(HashMap::new())),
72            max_size,
73            max_entries,
74        }
75    }
76
77    /// Evict expired entries from the cache.
78    async fn evict_expired(&self) {
79        let mut store = self.store.write().await;
80        store.retain(|_, entry| !entry.is_expired());
81    }
82
83    /// Evict least recently used entries to enforce size/count limits.
84    async fn evict_lru(&self) {
85        let mut store = self.store.write().await;
86
87        // Calculate current size
88        let current_size: usize = store.values().map(|e| e.data.len()).sum();
89
90        // Check if we need to evict based on size or entry count
91        let needs_eviction = (self.max_size > 0 && current_size > self.max_size)
92            || (self.max_entries > 0 && store.len() > self.max_entries);
93
94        if !needs_eviction {
95            return;
96        }
97
98        // Sort entries by last accessed time (oldest first)
99        let mut entries: Vec<_> = store
100            .iter()
101            .map(|(k, v)| (k.clone(), v.last_accessed, v.data.len()))
102            .collect();
103        entries.sort_by_key(|(_, accessed, _)| *accessed);
104
105        // Remove oldest entries until we're under the limit
106        let mut remaining_size = current_size;
107        let mut remaining_count = store.len();
108        let mut entries_to_remove = Vec::new();
109
110        for (key, _, size) in entries {
111            // Check if we're now under both limits
112            let under_size_limit = self.max_size == 0 || remaining_size <= self.max_size;
113            let under_count_limit = self.max_entries == 0 || remaining_count <= self.max_entries;
114
115            if under_size_limit && under_count_limit {
116                break;
117            }
118
119            // Mark this entry for removal
120            remaining_size = remaining_size.saturating_sub(size);
121            remaining_count = remaining_count.saturating_sub(1);
122            entries_to_remove.push(key);
123        }
124
125        // Remove the marked entries
126        for key in entries_to_remove {
127            store.remove(&key);
128        }
129    }
130}
131
132#[async_trait]
133impl Cache for InMemoryCache {
134    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
135        // First evict expired entries
136        self.evict_expired().await;
137
138        let mut store = self.store.write().await;
139
140        if let Some(entry) = store.get_mut(key) {
141            if entry.is_expired() {
142                store.remove(key);
143                return Ok(None);
144            }
145
146            entry.touch();
147            Ok(Some(entry.data.clone()))
148        } else {
149            Ok(None)
150        }
151    }
152
153    async fn set(&self, key: &str, value: Vec<u8>, ttl: Duration) -> Result<()> {
154        let entry = CacheEntry {
155            data: value,
156            expires_at: Instant::now() + ttl,
157            last_accessed: Instant::now(),
158        };
159
160        {
161            let mut store = self.store.write().await;
162            store.insert(key.to_string(), entry);
163        }
164
165        // Evict if needed
166        self.evict_lru().await;
167
168        Ok(())
169    }
170
171    async fn delete(&self, key: &str) -> Result<()> {
172        let mut store = self.store.write().await;
173        store.remove(key);
174        Ok(())
175    }
176
177    async fn clear(&self) -> Result<()> {
178        let mut store = self.store.write().await;
179        store.clear();
180        Ok(())
181    }
182
183    fn name(&self) -> &str {
184        "in-memory"
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use tokio::time::{sleep, Duration};
192
193    #[tokio::test]
194    async fn test_basic_set_get() {
195        let cache = InMemoryCache::new(1024, 10);
196
197        cache
198            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
199            .await
200            .unwrap();
201        let value = cache.get("key1").await.unwrap();
202
203        assert_eq!(value, Some(b"value1".to_vec()));
204    }
205
206    #[tokio::test]
207    async fn test_get_nonexistent() {
208        let cache = InMemoryCache::new(1024, 10);
209        let value = cache.get("nonexistent").await.unwrap();
210        assert_eq!(value, None);
211    }
212
213    #[tokio::test]
214    async fn test_ttl_expiration() {
215        let cache = InMemoryCache::new(1024, 10);
216
217        // Set with very short TTL
218        cache
219            .set("key1", b"value1".to_vec(), Duration::from_millis(100))
220            .await
221            .unwrap();
222
223        // Should exist immediately
224        let value = cache.get("key1").await.unwrap();
225        assert_eq!(value, Some(b"value1".to_vec()));
226
227        // Wait for expiration
228        sleep(Duration::from_millis(150)).await;
229
230        // Should be expired
231        let value = cache.get("key1").await.unwrap();
232        assert_eq!(value, None);
233    }
234
235    #[tokio::test]
236    async fn test_delete() {
237        let cache = InMemoryCache::new(1024, 10);
238
239        cache
240            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
241            .await
242            .unwrap();
243        assert!(cache.get("key1").await.unwrap().is_some());
244
245        cache.delete("key1").await.unwrap();
246        assert!(cache.get("key1").await.unwrap().is_none());
247    }
248
249    #[tokio::test]
250    async fn test_clear() {
251        let cache = InMemoryCache::new(1024, 10);
252
253        cache
254            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
255            .await
256            .unwrap();
257        cache
258            .set("key2", b"value2".to_vec(), Duration::from_secs(60))
259            .await
260            .unwrap();
261
262        cache.clear().await.unwrap();
263
264        assert!(cache.get("key1").await.unwrap().is_none());
265        assert!(cache.get("key2").await.unwrap().is_none());
266    }
267
268    #[tokio::test]
269    async fn test_lru_eviction_by_count() {
270        let cache = InMemoryCache::new(0, 2); // Max 2 entries
271
272        cache
273            .set("key1", b"value1".to_vec(), Duration::from_secs(60))
274            .await
275            .unwrap();
276        cache
277            .set("key2", b"value2".to_vec(), Duration::from_secs(60))
278            .await
279            .unwrap();
280
281        // At this point we have 2 entries (at limit)
282
283        // Add a third entry, should trigger eviction
284        cache
285            .set("key3", b"value3".to_vec(), Duration::from_secs(60))
286            .await
287            .unwrap();
288
289        // After eviction, we should have at most 2 entries
290        let store = cache.store.read().await;
291        assert!(store.len() <= 2, "Cache should not exceed max_entries");
292        // key3 (most recent) should definitely exist
293        assert!(
294            store.contains_key("key3"),
295            "Most recently added key should exist"
296        );
297    }
298
299    #[tokio::test]
300    async fn test_lru_eviction_by_size() {
301        let cache = InMemoryCache::new(10, 0); // Max 10 bytes
302
303        cache
304            .set("key1", vec![1, 2, 3, 4, 5], Duration::from_secs(60))
305            .await
306            .unwrap();
307        cache
308            .set("key2", vec![6, 7, 8, 9, 10], Duration::from_secs(60))
309            .await
310            .unwrap();
311
312        // Access key1 to make it more recently used
313        cache.get("key1").await.unwrap();
314
315        // Add a new entry that would exceed size limit
316        cache
317            .set("key3", vec![11, 12], Duration::from_secs(60))
318            .await
319            .unwrap();
320
321        // key1 should still exist, key2 should be evicted
322        assert!(cache.get("key1").await.unwrap().is_some());
323        // key3 should exist
324        assert!(cache.get("key3").await.unwrap().is_some());
325    }
326
327    #[tokio::test]
328    async fn test_cache_name() {
329        let cache = InMemoryCache::new(1024, 10);
330        assert_eq!(cache.name(), "in-memory");
331    }
332}