simple_agents_cache/
memory.rs1use async_trait::async_trait;
4use simple_agent_type::cache::Cache;
5use simple_agent_type::error::Result;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10
11#[derive(Debug, Clone)]
13struct CacheEntry {
14 data: Vec<u8>,
16 expires_at: Instant,
18 last_accessed: Instant,
20}
21
22impl CacheEntry {
23 fn is_expired(&self) -> bool {
25 Instant::now() >= self.expires_at
26 }
27
28 fn touch(&mut self) {
30 self.last_accessed = Instant::now();
31 }
32}
33
34pub struct InMemoryCache {
55 store: Arc<RwLock<HashMap<String, CacheEntry>>>,
57 max_size: usize,
59 max_entries: usize,
61}
62
63impl InMemoryCache {
64 pub fn new(max_size: usize, max_entries: usize) -> Self {
70 Self {
71 store: Arc::new(RwLock::new(HashMap::new())),
72 max_size,
73 max_entries,
74 }
75 }
76
77 async fn evict_expired(&self) {
79 let mut store = self.store.write().await;
80 store.retain(|_, entry| !entry.is_expired());
81 }
82
83 async fn evict_lru(&self) {
85 let mut store = self.store.write().await;
86
87 let current_size: usize = store.values().map(|e| e.data.len()).sum();
89
90 let needs_eviction = (self.max_size > 0 && current_size > self.max_size)
92 || (self.max_entries > 0 && store.len() > self.max_entries);
93
94 if !needs_eviction {
95 return;
96 }
97
98 let mut entries: Vec<_> = store
100 .iter()
101 .map(|(k, v)| (k.clone(), v.last_accessed, v.data.len()))
102 .collect();
103 entries.sort_by_key(|(_, accessed, _)| *accessed);
104
105 let mut remaining_size = current_size;
107 let mut remaining_count = store.len();
108 let mut entries_to_remove = Vec::new();
109
110 for (key, _, size) in entries {
111 let under_size_limit = self.max_size == 0 || remaining_size <= self.max_size;
113 let under_count_limit = self.max_entries == 0 || remaining_count <= self.max_entries;
114
115 if under_size_limit && under_count_limit {
116 break;
117 }
118
119 remaining_size = remaining_size.saturating_sub(size);
121 remaining_count = remaining_count.saturating_sub(1);
122 entries_to_remove.push(key);
123 }
124
125 for key in entries_to_remove {
127 store.remove(&key);
128 }
129 }
130}
131
132#[async_trait]
133impl Cache for InMemoryCache {
134 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
135 {
136 let store = self.store.read().await;
137 match store.get(key) {
138 Some(entry) if !entry.is_expired() => {
139 drop(store);
140
141 let mut store = self.store.write().await;
142 if let Some(entry) = store.get_mut(key) {
143 if entry.is_expired() {
144 store.remove(key);
145 return Ok(None);
146 }
147
148 entry.touch();
149 return Ok(Some(entry.data.clone()));
150 }
151
152 return Ok(None);
153 }
154 Some(_) => {} None => return Ok(None),
156 }
157 }
158
159 let mut store = self.store.write().await;
160 if let Some(entry) = store.get_mut(key) {
161 if entry.is_expired() {
162 store.remove(key);
163 return Ok(None);
164 }
165
166 entry.touch();
167 Ok(Some(entry.data.clone()))
168 } else {
169 Ok(None)
170 }
171 }
172
173 async fn set(&self, key: &str, value: Vec<u8>, ttl: Duration) -> Result<()> {
174 let entry = CacheEntry {
175 data: value,
176 expires_at: Instant::now() + ttl,
177 last_accessed: Instant::now(),
178 };
179
180 {
181 let mut store = self.store.write().await;
182 store.insert(key.to_string(), entry);
183 }
184
185 self.evict_expired().await;
187
188 self.evict_lru().await;
190
191 Ok(())
192 }
193
194 async fn delete(&self, key: &str) -> Result<()> {
195 let mut store = self.store.write().await;
196 store.remove(key);
197 Ok(())
198 }
199
200 async fn clear(&self) -> Result<()> {
201 let mut store = self.store.write().await;
202 store.clear();
203 Ok(())
204 }
205
206 fn name(&self) -> &str {
207 "in-memory"
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use tokio::time::{sleep, Duration};
215
216 #[tokio::test]
217 async fn test_basic_set_get() {
218 let cache = InMemoryCache::new(1024, 10);
219
220 cache
221 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
222 .await
223 .unwrap();
224 let value = cache.get("key1").await.unwrap();
225
226 assert_eq!(value, Some(b"value1".to_vec()));
227 }
228
229 #[tokio::test]
230 async fn test_get_nonexistent() {
231 let cache = InMemoryCache::new(1024, 10);
232 let value = cache.get("nonexistent").await.unwrap();
233 assert_eq!(value, None);
234 }
235
236 #[tokio::test]
237 async fn test_ttl_expiration() {
238 let cache = InMemoryCache::new(1024, 10);
239
240 cache
242 .set("key1", b"value1".to_vec(), Duration::from_millis(100))
243 .await
244 .unwrap();
245
246 let value = cache.get("key1").await.unwrap();
248 assert_eq!(value, Some(b"value1".to_vec()));
249
250 sleep(Duration::from_millis(150)).await;
252
253 let value = cache.get("key1").await.unwrap();
255 assert_eq!(value, None);
256 }
257
258 #[tokio::test]
259 async fn test_delete() {
260 let cache = InMemoryCache::new(1024, 10);
261
262 cache
263 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
264 .await
265 .unwrap();
266 assert!(cache.get("key1").await.unwrap().is_some());
267
268 cache.delete("key1").await.unwrap();
269 assert!(cache.get("key1").await.unwrap().is_none());
270 }
271
272 #[tokio::test]
273 async fn test_clear() {
274 let cache = InMemoryCache::new(1024, 10);
275
276 cache
277 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
278 .await
279 .unwrap();
280 cache
281 .set("key2", b"value2".to_vec(), Duration::from_secs(60))
282 .await
283 .unwrap();
284
285 cache.clear().await.unwrap();
286
287 assert!(cache.get("key1").await.unwrap().is_none());
288 assert!(cache.get("key2").await.unwrap().is_none());
289 }
290
291 #[tokio::test]
292 async fn test_lru_eviction_by_count() {
293 let cache = InMemoryCache::new(0, 2); cache
296 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
297 .await
298 .unwrap();
299 cache
300 .set("key2", b"value2".to_vec(), Duration::from_secs(60))
301 .await
302 .unwrap();
303
304 cache
308 .set("key3", b"value3".to_vec(), Duration::from_secs(60))
309 .await
310 .unwrap();
311
312 let store = cache.store.read().await;
314 assert!(store.len() <= 2, "Cache should not exceed max_entries");
315 assert!(
317 store.contains_key("key3"),
318 "Most recently added key should exist"
319 );
320 }
321
322 #[tokio::test]
323 async fn test_lru_eviction_by_size() {
324 let cache = InMemoryCache::new(10, 0); cache
327 .set("key1", vec![1, 2, 3, 4, 5], Duration::from_secs(60))
328 .await
329 .unwrap();
330 cache
331 .set("key2", vec![6, 7, 8, 9, 10], Duration::from_secs(60))
332 .await
333 .unwrap();
334
335 cache.get("key1").await.unwrap();
337
338 cache
340 .set("key3", vec![11, 12], Duration::from_secs(60))
341 .await
342 .unwrap();
343
344 assert!(cache.get("key1").await.unwrap().is_some());
346 assert!(cache.get("key3").await.unwrap().is_some());
348 }
349
350 #[tokio::test]
351 async fn test_concurrent_gets_do_not_serialize_readers() {
352 let cache = Arc::new(InMemoryCache::new(1024, 10));
353 cache
354 .set("shared", b"value".to_vec(), Duration::from_secs(60))
355 .await
356 .unwrap();
357
358 let mut handles = Vec::new();
359 for _ in 0..25 {
360 let cache = cache.clone();
361 handles.push(tokio::spawn(
362 async move { cache.get("shared").await.unwrap() },
363 ));
364 }
365
366 for handle in handles {
367 assert_eq!(handle.await.unwrap(), Some(b"value".to_vec()));
368 }
369 }
370
371 #[tokio::test]
372 async fn test_cache_name() {
373 let cache = InMemoryCache::new(1024, 10);
374 assert_eq!(cache.name(), "in-memory");
375 }
376}