simple_agents_cache/
memory.rs1use async_trait::async_trait;
4use simple_agent_type::cache::Cache;
5use simple_agent_type::error::Result;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10
11#[derive(Debug, Clone)]
13struct CacheEntry {
14 data: Vec<u8>,
16 expires_at: Instant,
18 last_accessed: Instant,
20}
21
22impl CacheEntry {
23 fn is_expired(&self) -> bool {
25 Instant::now() >= self.expires_at
26 }
27
28 fn touch(&mut self) {
30 self.last_accessed = Instant::now();
31 }
32}
33
34pub struct InMemoryCache {
55 store: Arc<RwLock<HashMap<String, CacheEntry>>>,
57 max_size: usize,
59 max_entries: usize,
61}
62
63impl InMemoryCache {
64 pub fn new(max_size: usize, max_entries: usize) -> Self {
70 Self {
71 store: Arc::new(RwLock::new(HashMap::new())),
72 max_size,
73 max_entries,
74 }
75 }
76
77 async fn evict_expired(&self) {
79 let mut store = self.store.write().await;
80 store.retain(|_, entry| !entry.is_expired());
81 }
82
83 async fn evict_lru(&self) {
85 let mut store = self.store.write().await;
86
87 let current_size: usize = store.values().map(|e| e.data.len()).sum();
89
90 let needs_eviction = (self.max_size > 0 && current_size > self.max_size)
92 || (self.max_entries > 0 && store.len() > self.max_entries);
93
94 if !needs_eviction {
95 return;
96 }
97
98 let mut entries: Vec<_> = store
100 .iter()
101 .map(|(k, v)| (k.clone(), v.last_accessed, v.data.len()))
102 .collect();
103 entries.sort_by_key(|(_, accessed, _)| *accessed);
104
105 let mut remaining_size = current_size;
107 let mut remaining_count = store.len();
108 let mut entries_to_remove = Vec::new();
109
110 for (key, _, size) in entries {
111 let under_size_limit = self.max_size == 0 || remaining_size <= self.max_size;
113 let under_count_limit = self.max_entries == 0 || remaining_count <= self.max_entries;
114
115 if under_size_limit && under_count_limit {
116 break;
117 }
118
119 remaining_size = remaining_size.saturating_sub(size);
121 remaining_count = remaining_count.saturating_sub(1);
122 entries_to_remove.push(key);
123 }
124
125 for key in entries_to_remove {
127 store.remove(&key);
128 }
129 }
130}
131
132#[async_trait]
133impl Cache for InMemoryCache {
134 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
135 self.evict_expired().await;
137
138 let mut store = self.store.write().await;
139
140 if let Some(entry) = store.get_mut(key) {
141 if entry.is_expired() {
142 store.remove(key);
143 return Ok(None);
144 }
145
146 entry.touch();
147 Ok(Some(entry.data.clone()))
148 } else {
149 Ok(None)
150 }
151 }
152
153 async fn set(&self, key: &str, value: Vec<u8>, ttl: Duration) -> Result<()> {
154 let entry = CacheEntry {
155 data: value,
156 expires_at: Instant::now() + ttl,
157 last_accessed: Instant::now(),
158 };
159
160 {
161 let mut store = self.store.write().await;
162 store.insert(key.to_string(), entry);
163 }
164
165 self.evict_lru().await;
167
168 Ok(())
169 }
170
171 async fn delete(&self, key: &str) -> Result<()> {
172 let mut store = self.store.write().await;
173 store.remove(key);
174 Ok(())
175 }
176
177 async fn clear(&self) -> Result<()> {
178 let mut store = self.store.write().await;
179 store.clear();
180 Ok(())
181 }
182
183 fn name(&self) -> &str {
184 "in-memory"
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use tokio::time::{sleep, Duration};
192
193 #[tokio::test]
194 async fn test_basic_set_get() {
195 let cache = InMemoryCache::new(1024, 10);
196
197 cache
198 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
199 .await
200 .unwrap();
201 let value = cache.get("key1").await.unwrap();
202
203 assert_eq!(value, Some(b"value1".to_vec()));
204 }
205
206 #[tokio::test]
207 async fn test_get_nonexistent() {
208 let cache = InMemoryCache::new(1024, 10);
209 let value = cache.get("nonexistent").await.unwrap();
210 assert_eq!(value, None);
211 }
212
213 #[tokio::test]
214 async fn test_ttl_expiration() {
215 let cache = InMemoryCache::new(1024, 10);
216
217 cache
219 .set("key1", b"value1".to_vec(), Duration::from_millis(100))
220 .await
221 .unwrap();
222
223 let value = cache.get("key1").await.unwrap();
225 assert_eq!(value, Some(b"value1".to_vec()));
226
227 sleep(Duration::from_millis(150)).await;
229
230 let value = cache.get("key1").await.unwrap();
232 assert_eq!(value, None);
233 }
234
235 #[tokio::test]
236 async fn test_delete() {
237 let cache = InMemoryCache::new(1024, 10);
238
239 cache
240 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
241 .await
242 .unwrap();
243 assert!(cache.get("key1").await.unwrap().is_some());
244
245 cache.delete("key1").await.unwrap();
246 assert!(cache.get("key1").await.unwrap().is_none());
247 }
248
249 #[tokio::test]
250 async fn test_clear() {
251 let cache = InMemoryCache::new(1024, 10);
252
253 cache
254 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
255 .await
256 .unwrap();
257 cache
258 .set("key2", b"value2".to_vec(), Duration::from_secs(60))
259 .await
260 .unwrap();
261
262 cache.clear().await.unwrap();
263
264 assert!(cache.get("key1").await.unwrap().is_none());
265 assert!(cache.get("key2").await.unwrap().is_none());
266 }
267
268 #[tokio::test]
269 async fn test_lru_eviction_by_count() {
270 let cache = InMemoryCache::new(0, 2); cache
273 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
274 .await
275 .unwrap();
276 cache
277 .set("key2", b"value2".to_vec(), Duration::from_secs(60))
278 .await
279 .unwrap();
280
281 cache
285 .set("key3", b"value3".to_vec(), Duration::from_secs(60))
286 .await
287 .unwrap();
288
289 let store = cache.store.read().await;
291 assert!(store.len() <= 2, "Cache should not exceed max_entries");
292 assert!(
294 store.contains_key("key3"),
295 "Most recently added key should exist"
296 );
297 }
298
299 #[tokio::test]
300 async fn test_lru_eviction_by_size() {
301 let cache = InMemoryCache::new(10, 0); cache
304 .set("key1", vec![1, 2, 3, 4, 5], Duration::from_secs(60))
305 .await
306 .unwrap();
307 cache
308 .set("key2", vec![6, 7, 8, 9, 10], Duration::from_secs(60))
309 .await
310 .unwrap();
311
312 cache.get("key1").await.unwrap();
314
315 cache
317 .set("key3", vec![11, 12], Duration::from_secs(60))
318 .await
319 .unwrap();
320
321 assert!(cache.get("key1").await.unwrap().is_some());
323 assert!(cache.get("key3").await.unwrap().is_some());
325 }
326
327 #[tokio::test]
328 async fn test_cache_name() {
329 let cache = InMemoryCache::new(1024, 10);
330 assert_eq!(cache.name(), "in-memory");
331 }
332}