1use async_trait::async_trait;
4use simple_agent_type::cache::Cache;
5use simple_agent_type::error::Result;
6use std::collections::{HashMap, VecDeque};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10
11#[derive(Debug, Clone)]
13struct CacheEntry {
14 data: Vec<u8>,
16 expires_at: Instant,
18 access_tick: u64,
20}
21
22#[derive(Debug, Default)]
23struct CacheState {
24 store: HashMap<String, CacheEntry>,
25 access_order: VecDeque<(String, u64)>,
26 total_size: usize,
27 next_tick: u64,
28}
29
30impl CacheState {
31 fn next_access_tick(&mut self) -> u64 {
32 self.next_tick = self.next_tick.saturating_add(1);
33 self.next_tick
34 }
35
36 fn is_expired(entry: &CacheEntry) -> bool {
37 Instant::now() >= entry.expires_at
38 }
39
40 fn touch_key(&mut self, key: &str) {
41 let tick = self.next_access_tick();
42 if let Some(entry) = self.store.get_mut(key) {
43 entry.access_tick = tick;
44 self.access_order.push_back((key.to_string(), tick));
45 }
46 }
47
48 fn upsert(&mut self, key: &str, value: Vec<u8>, ttl: Duration) {
49 let tick = self.next_access_tick();
50 if let Some(previous) = self.store.get(key) {
51 self.total_size = self.total_size.saturating_sub(previous.data.len());
52 }
53
54 self.total_size = self.total_size.saturating_add(value.len());
55 self.store.insert(
56 key.to_string(),
57 CacheEntry {
58 data: value,
59 expires_at: Instant::now() + ttl,
60 access_tick: tick,
61 },
62 );
63 self.access_order.push_back((key.to_string(), tick));
64 }
65
66 fn remove_key(&mut self, key: &str) {
67 if let Some(entry) = self.store.remove(key) {
68 self.total_size = self.total_size.saturating_sub(entry.data.len());
69 }
70 }
71
72 fn evict_expired(&mut self) {
73 let now = Instant::now();
74 self.store.retain(|_, entry| {
75 if now >= entry.expires_at {
76 self.total_size = self.total_size.saturating_sub(entry.data.len());
77 false
78 } else {
79 true
80 }
81 });
82 }
83
84 fn evict_lru_until_within_limits(&mut self, max_size: usize, max_entries: usize) {
85 let over_size = |total_size: usize| max_size > 0 && total_size > max_size;
86 let over_count = |entry_count: usize| max_entries > 0 && entry_count > max_entries;
87
88 while over_size(self.total_size) || over_count(self.store.len()) {
89 let Some((key, tick)) = self.access_order.pop_front() else {
90 break;
91 };
92
93 let should_remove = self
94 .store
95 .get(key.as_str())
96 .is_some_and(|entry| entry.access_tick == tick);
97
98 if should_remove {
99 self.remove_key(key.as_str());
100 }
101 }
102 }
103}
104
105pub struct InMemoryCache {
126 state: Arc<RwLock<CacheState>>,
128 max_size: usize,
130 max_entries: usize,
132}
133
134impl InMemoryCache {
135 pub fn new(max_size: usize, max_entries: usize) -> Self {
141 Self {
142 state: Arc::new(RwLock::new(CacheState::default())),
143 max_size,
144 max_entries,
145 }
146 }
147}
148
149#[async_trait]
150impl Cache for InMemoryCache {
151 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
152 let mut state = self.state.write().await;
153 match state.store.get(key) {
154 Some(entry) if CacheState::is_expired(entry) => {
155 state.remove_key(key);
156 Ok(None)
157 }
158 Some(_) => {
159 let value = state.store.get(key).map(|entry| entry.data.clone());
160 state.touch_key(key);
161 Ok(value)
162 }
163 None => Ok(None),
164 }
165 }
166
167 async fn set(&self, key: &str, value: Vec<u8>, ttl: Duration) -> Result<()> {
168 let mut state = self.state.write().await;
169 state.upsert(key, value, ttl);
170 state.evict_expired();
171 state.evict_lru_until_within_limits(self.max_size, self.max_entries);
172
173 Ok(())
174 }
175
176 async fn delete(&self, key: &str) -> Result<()> {
177 let mut state = self.state.write().await;
178 state.remove_key(key);
179 Ok(())
180 }
181
182 async fn clear(&self) -> Result<()> {
183 let mut state = self.state.write().await;
184 state.store.clear();
185 state.access_order.clear();
186 state.total_size = 0;
187 Ok(())
188 }
189
190 fn name(&self) -> &str {
191 "in-memory"
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use tokio::time::{sleep, Duration};
199
200 #[tokio::test]
201 async fn test_basic_set_get() {
202 let cache = InMemoryCache::new(1024, 10);
203
204 cache
205 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
206 .await
207 .unwrap();
208 let value = cache.get("key1").await.unwrap();
209
210 assert_eq!(value, Some(b"value1".to_vec()));
211 }
212
213 #[tokio::test]
214 async fn test_get_nonexistent() {
215 let cache = InMemoryCache::new(1024, 10);
216 let value = cache.get("nonexistent").await.unwrap();
217 assert_eq!(value, None);
218 }
219
220 #[tokio::test]
221 async fn test_ttl_expiration() {
222 let cache = InMemoryCache::new(1024, 10);
223
224 cache
226 .set("key1", b"value1".to_vec(), Duration::from_millis(100))
227 .await
228 .unwrap();
229
230 let value = cache.get("key1").await.unwrap();
232 assert_eq!(value, Some(b"value1".to_vec()));
233
234 sleep(Duration::from_millis(150)).await;
236
237 let value = cache.get("key1").await.unwrap();
239 assert_eq!(value, None);
240 }
241
242 #[tokio::test]
243 async fn test_delete() {
244 let cache = InMemoryCache::new(1024, 10);
245
246 cache
247 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
248 .await
249 .unwrap();
250 assert!(cache.get("key1").await.unwrap().is_some());
251
252 cache.delete("key1").await.unwrap();
253 assert!(cache.get("key1").await.unwrap().is_none());
254 }
255
256 #[tokio::test]
257 async fn test_clear() {
258 let cache = InMemoryCache::new(1024, 10);
259
260 cache
261 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
262 .await
263 .unwrap();
264 cache
265 .set("key2", b"value2".to_vec(), Duration::from_secs(60))
266 .await
267 .unwrap();
268
269 cache.clear().await.unwrap();
270
271 assert!(cache.get("key1").await.unwrap().is_none());
272 assert!(cache.get("key2").await.unwrap().is_none());
273 }
274
275 #[tokio::test]
276 async fn test_lru_eviction_by_count() {
277 let cache = InMemoryCache::new(0, 2); cache
280 .set("key1", b"value1".to_vec(), Duration::from_secs(60))
281 .await
282 .unwrap();
283 cache
284 .set("key2", b"value2".to_vec(), Duration::from_secs(60))
285 .await
286 .unwrap();
287
288 cache
292 .set("key3", b"value3".to_vec(), Duration::from_secs(60))
293 .await
294 .unwrap();
295
296 let store = cache.state.read().await;
298 assert!(
299 store.store.len() <= 2,
300 "Cache should not exceed max_entries"
301 );
302 assert!(
304 store.store.contains_key("key3"),
305 "Most recently added key should exist"
306 );
307 }
308
309 #[tokio::test]
310 async fn test_lru_eviction_by_size() {
311 let cache = InMemoryCache::new(10, 0); cache
314 .set("key1", vec![1, 2, 3, 4, 5], Duration::from_secs(60))
315 .await
316 .unwrap();
317 cache
318 .set("key2", vec![6, 7, 8, 9, 10], Duration::from_secs(60))
319 .await
320 .unwrap();
321
322 cache.get("key1").await.unwrap();
324
325 cache
327 .set("key3", vec![11, 12], Duration::from_secs(60))
328 .await
329 .unwrap();
330
331 assert!(cache.get("key1").await.unwrap().is_some());
333 assert!(cache.get("key3").await.unwrap().is_some());
335 }
336
337 #[tokio::test]
338 async fn test_concurrent_gets_do_not_serialize_readers() {
339 let cache = Arc::new(InMemoryCache::new(1024, 10));
340 cache
341 .set("shared", b"value".to_vec(), Duration::from_secs(60))
342 .await
343 .unwrap();
344
345 let mut handles = Vec::new();
346 for _ in 0..25 {
347 let cache = cache.clone();
348 handles.push(tokio::spawn(
349 async move { cache.get("shared").await.unwrap() },
350 ));
351 }
352
353 for handle in handles {
354 assert_eq!(handle.await.unwrap(), Some(b"value".to_vec()));
355 }
356 }
357
358 #[tokio::test]
359 async fn test_cache_name() {
360 let cache = InMemoryCache::new(1024, 10);
361 assert_eq!(cache.name(), "in-memory");
362 }
363
364 #[tokio::test]
365 async fn test_update_existing_key_keeps_latest_value() {
366 let cache = InMemoryCache::new(6, 2);
367
368 cache
369 .set("key", b"abc".to_vec(), Duration::from_secs(60))
370 .await
371 .unwrap();
372 cache
373 .set("key", b"abcdef".to_vec(), Duration::from_secs(60))
374 .await
375 .unwrap();
376
377 let value = cache.get("key").await.unwrap();
378 assert_eq!(value, Some(b"abcdef".to_vec()));
379 }
380}