1use async_trait::async_trait;
48use parking_lot::RwLock;
49use std::collections::HashMap;
50use std::sync::Arc;
51use std::time::{Duration, SystemTime};
52
53use crate::cache::Cache;
54
55type CacheKey = u64;
60
61#[derive(Clone)]
66struct MemoryCacheEntry {
67 last_accessed_at: SystemTime,
69
70 expires_at: SystemTime,
72
73 value: Vec<u8>,
75}
76
77struct MemoryCacheInner {
82 capacity: usize,
84
85 max_item_size: usize,
87
88 size: usize,
90
91 keys: Vec<CacheKey>,
93
94 items: HashMap<CacheKey, MemoryCacheEntry>,
96}
97
98#[derive(Clone)]
122pub struct MemoryCache {
123 inner: Arc<RwLock<MemoryCacheInner>>,
124}
125
126impl MemoryCache {
127 pub fn new(capacity: usize, max_item_size: usize) -> Self {
143 MemoryCache {
144 inner: Arc::new(RwLock::new(MemoryCacheInner {
145 capacity,
146 max_item_size,
147 size: 0,
148 keys: Vec::new(),
149 items: HashMap::new(),
150 })),
151 }
152 }
153
154 #[inline]
158 fn hash_key(key: &str) -> CacheKey {
159 use std::collections::hash_map::DefaultHasher;
160 use std::hash::{Hash, Hasher};
161
162 let mut hasher = DefaultHasher::new();
163 key.hash(&mut hasher);
164 hasher.finish()
165 }
166
167 #[inline]
171 fn current_time() -> SystemTime {
172 SystemTime::now()
173 }
174
175 fn get_internal(&self, key: CacheKey) -> Option<Vec<u8>> {
180 let mut inner = self.inner.write();
181 let now = Self::current_time();
182
183 let entry = inner.items.get_mut(&key)?;
185
186 if entry.expires_at < now {
188 tracing::debug!("Cache: entry expired");
189 return None;
190 }
191
192 entry.last_accessed_at = now;
194
195 Some(entry.value.clone())
197 }
198
199 fn set_internal(&self, key: CacheKey, value: Vec<u8>, expires_at: SystemTime) {
203 let mut inner = self.inner.write();
204
205 let item_size = value.len();
206
207 if item_size > inner.max_item_size || item_size > inner.capacity {
209 tracing::debug!(
210 "Cache: item is too large to store, len={}, max_item_size={}, capacity={}",
211 item_size,
212 inner.max_item_size,
213 inner.capacity
214 );
215 return;
216 }
217
218 let limit = inner.capacity - item_size;
220 while inner.size > limit {
221 tracing::debug!(
222 "Cache: evicting item to make space, current_size={}, need_size={}",
223 inner.size,
224 limit
225 );
226 Self::evict_oldest_item(&mut inner);
227 }
228
229 if let Some(existing) = inner.items.get(&key) {
231 inner.size -= existing.value.len();
232 } else {
233 inner.keys.push(key);
235 }
236
237 inner.items.insert(
239 key,
240 MemoryCacheEntry {
241 last_accessed_at: Self::current_time(),
242 expires_at,
243 value,
244 },
245 );
246
247 inner.size += item_size;
249
250 tracing::debug!(
251 "Cache: added item, key={}, size={}, expires_at={:?}",
252 key,
253 item_size,
254 expires_at
255 );
256 }
257
258 fn evict_oldest_item(inner: &mut MemoryCacheInner) {
269 use rand::Rng;
270
271 if inner.keys.is_empty() {
272 return;
273 }
274
275 let mut oldest_key: CacheKey = 0;
276 let mut oldest_index: usize = 0;
277 let mut oldest_time = SystemTime::UNIX_EPOCH;
278
279 let now = Self::current_time();
280 let mut rng = rand::rng();
281
282 let sample_size = std::cmp::min(5, inner.keys.len());
284
285 for _ in 0..sample_size {
286 let index = rng.random_range(0..inner.keys.len());
287 let key = inner.keys[index];
288
289 if let Some(entry) = inner.items.get(&key) {
290 if entry.expires_at < now {
292 oldest_key = key;
293 oldest_index = index;
294 break;
295 }
296
297 if oldest_time == SystemTime::UNIX_EPOCH || entry.last_accessed_at < oldest_time {
299 oldest_time = entry.last_accessed_at;
300 oldest_key = key;
301 oldest_index = index;
302 }
303 }
304 }
305
306 let keys_len = inner.keys.len();
308 if oldest_index < keys_len {
309 inner.keys.swap(oldest_index, keys_len - 1);
310 inner.keys.pop();
311 }
312
313 if let Some(entry) = inner.items.remove(&oldest_key) {
315 inner.size -= entry.value.len();
316 tracing::debug!(
317 "Cache: evicted item, key={}, size={}",
318 oldest_key,
319 entry.value.len()
320 );
321 }
322 }
323
324 pub fn stats(&self) -> (usize, usize, usize) {
337 let inner = self.inner.read();
338 (inner.size, inner.capacity, inner.items.len())
339 }
340
341 pub fn clear(&self) {
354 let mut inner = self.inner.write();
355 inner.keys.clear();
356 inner.items.clear();
357 inner.size = 0;
358 tracing::debug!("Cache: cleared all entries");
359 }
360}
361
362#[async_trait]
363impl Cache for MemoryCache {
364 async fn get(&self, key: &str) -> anyhow::Result<Option<Vec<u8>>> {
365 let hash = Self::hash_key(key);
366 Ok(self.get_internal(hash))
367 }
368
369 async fn set(&self, key: &str, value: &[u8], ttl_seconds: u64) -> anyhow::Result<()> {
370 let hash = Self::hash_key(key);
371 let expires_at = if ttl_seconds > 0 {
372 Self::current_time() + Duration::from_secs(ttl_seconds)
373 } else {
374 Self::current_time() + Duration::from_secs(100 * 365 * 24 * 3600)
376 };
377
378 self.set_internal(hash, value.to_vec(), expires_at);
379 Ok(())
380 }
381
382 async fn delete(&self, key: &str) -> anyhow::Result<()> {
383 let hash = Self::hash_key(key);
384 let mut inner = self.inner.write();
385
386 if let Some(entry) = inner.items.remove(&hash) {
388 inner.size -= entry.value.len();
389
390 if let Some(pos) = inner.keys.iter().position(|&k| k == hash) {
392 let keys_len = inner.keys.len();
393 inner.keys.swap(pos, keys_len - 1);
394 inner.keys.pop();
395 }
396
397 tracing::debug!("Cache: deleted item, key={}", hash);
398 }
399
400 Ok(())
401 }
402
403 async fn exists(&self, key: &str) -> anyhow::Result<bool> {
404 let hash = Self::hash_key(key);
405 let inner = self.inner.read();
406 let now = Self::current_time();
407
408 if let Some(entry) = inner.items.get(&hash) {
409 Ok(entry.expires_at >= now)
411 } else {
412 Ok(false)
413 }
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 const KB: usize = 1024;
422 const MB: usize = 1024 * KB;
423
424 #[tokio::test]
425 async fn test_store_and_retrieve() {
426 let cache = MemoryCache::new(32 * MB, 1 * MB);
427 cache
428 .set("test_key", b"hello world", 30)
429 .await
430 .expect("set failed");
431
432 let value = cache.get("test_key").await.expect("get failed");
433 assert_eq!(value, Some(b"hello world".to_vec()));
434 }
435
436 #[tokio::test]
437 async fn test_storing_updates_existing_value() {
438 let cache = MemoryCache::new(32 * MB, 1 * MB);
439 cache.set("key", b"first", 30).await.expect("set failed");
440 cache.set("key", b"second", 30).await.expect("set failed");
441
442 let value = cache.get("key").await.expect("get failed");
443 assert_eq!(value, Some(b"second".to_vec()));
444 }
445
446 #[tokio::test]
447 async fn test_storing_existing_value_keeps_size_correct() {
448 let cache = MemoryCache::new(32 * MB, 1 * MB);
449 cache.set("key", b"first", 30).await.expect("set failed");
450 cache.set("key", b"second", 30).await.expect("set failed");
451
452 let (size, _, count) = cache.stats();
453 assert_eq!(count, 1);
454 assert_eq!(size, 6); }
456
457 #[tokio::test]
458 async fn test_expiry() {
459 let cache = MemoryCache::new(32 * MB, 1 * MB);
460
461 cache
463 .set("key", b"hello world", 1)
464 .await
465 .expect("set failed");
466
467 let value = cache.get("key").await.expect("get failed");
469 assert_eq!(value, Some(b"hello world".to_vec()));
470
471 tokio::time::sleep(Duration::from_secs(2)).await;
473
474 let value = cache.get("key").await.expect("get failed");
476 assert_eq!(value, None);
477 }
478
479 #[tokio::test]
480 async fn test_does_not_store_items_over_cache_limit() {
481 let cache = MemoryCache::new(3 * KB, 50 * KB);
482
483 let payload = vec![0u8; 10 * KB];
484 cache.set("key", &payload, 3600).await.expect("set failed");
485
486 let value = cache.get("key").await.expect("get failed");
487 assert_eq!(value, None);
488 }
489
490 #[tokio::test]
491 async fn test_cache_of_size_zero_does_not_store_items() {
492 let cache = MemoryCache::new(0, 1 * KB);
493
494 cache
495 .set("key", b"no storage", 3600)
496 .await
497 .expect("set failed");
498
499 let value = cache.get("key").await.expect("get failed");
500 assert_eq!(value, None);
501 }
502
503 #[tokio::test]
504 async fn test_items_are_evicted_to_make_space() {
505 let max_cache_size = 10 * KB;
506 let cache = MemoryCache::new(max_cache_size, 1 * KB);
507
508 for i in 0u32..20 {
510 let key = format!("key_{}", i);
511 let payload = vec![i as u8; 1 * KB];
512 cache.set(&key, &payload, 3600).await.expect("set failed");
513
514 let value = cache.get(&key).await.expect("get failed");
516 assert_eq!(value, Some(payload));
517 }
518
519 let (size, capacity, _) = cache.stats();
521 assert_eq!(size, capacity);
522 assert_eq!(size, max_cache_size);
523 }
524
525 #[tokio::test]
526 async fn test_does_not_store_items_over_item_limit() {
527 let cache = MemoryCache::new(50 * KB, 3 * KB);
528
529 let payload = vec![0u8; 10 * KB];
530 cache.set("key", &payload, 3600).await.expect("set failed");
531
532 let value = cache.get("key").await.expect("get failed");
533 assert_eq!(value, None);
534 }
535
536 #[tokio::test]
537 async fn test_delete() {
538 let cache = MemoryCache::new(32 * MB, 1 * MB);
539 cache.set("key", b"value", 30).await.expect("set failed");
540
541 let exists = cache.exists("key").await.expect("exists failed");
543 assert!(exists);
544
545 cache.delete("key").await.expect("delete failed");
547
548 let exists = cache.exists("key").await.expect("exists failed");
550 assert!(!exists);
551
552 let value = cache.get("key").await.expect("get failed");
553 assert_eq!(value, None);
554 }
555
556 #[tokio::test]
557 async fn test_exists() {
558 let cache = MemoryCache::new(32 * MB, 1 * MB);
559
560 let exists = cache.exists("missing").await.expect("exists failed");
562 assert!(!exists);
563
564 cache.set("key", b"value", 30).await.expect("set failed");
566
567 let exists = cache.exists("key").await.expect("exists failed");
569 assert!(exists);
570 }
571
572 #[tokio::test]
573 async fn test_exists_returns_false_for_expired() {
574 let cache = MemoryCache::new(32 * MB, 1 * MB);
575
576 cache.set("key", b"value", 1).await.expect("set failed");
578
579 let exists = cache.exists("key").await.expect("exists failed");
581 assert!(exists);
582
583 tokio::time::sleep(Duration::from_secs(2)).await;
585
586 let exists = cache.exists("key").await.expect("exists failed");
588 assert!(!exists);
589 }
590
591 #[tokio::test]
592 async fn test_clear() {
593 let cache = MemoryCache::new(32 * MB, 1 * MB);
594
595 for i in 0..10 {
597 cache
598 .set(&format!("key_{}", i), b"value", 30)
599 .await
600 .expect("set failed");
601 }
602
603 let (_, _, count_before) = cache.stats();
604 assert_eq!(count_before, 10);
605
606 cache.clear();
608
609 let (size, _, count_after) = cache.stats();
610 assert_eq!(count_after, 0);
611 assert_eq!(size, 0);
612 }
613
614 #[tokio::test]
615 async fn test_concurrent_access() {
616 use std::sync::Arc;
617 use tokio::task;
618
619 let cache = Arc::new(MemoryCache::new(32 * MB, 1 * MB));
620
621 let mut handles = vec![];
623
624 for i in 0..10 {
625 let cache = Arc::clone(&cache);
626 let handle = task::spawn(async move {
627 for j in 0..100 {
628 let key = format!("key_{}_{}", i, j);
629 let value = format!("value_{}_{}", i, j);
630
631 cache.set(&key, value.as_bytes(), 30).await.unwrap();
632
633 let retrieved = cache.get(&key).await.unwrap();
634 assert_eq!(retrieved, Some(value.as_bytes().to_vec()));
635 }
636 });
637 handles.push(handle);
638 }
639
640 for handle in handles {
642 handle.await.expect("task failed");
643 }
644
645 let (_, _, count) = cache.stats();
647 assert!(count > 0);
648 }
649
650 #[test]
651 fn test_hash_key_consistency() {
652 let key1 = "test_key";
653 let key2 = "test_key";
654 let key3 = "different_key";
655
656 assert_eq!(MemoryCache::hash_key(key1), MemoryCache::hash_key(key2));
657 assert_ne!(MemoryCache::hash_key(key1), MemoryCache::hash_key(key3));
658 }
659}