threatflux_cache/backends/
memory.rs1use async_trait::async_trait;
4use serde::{de::DeserializeOwned, Serialize};
5use std::collections::HashMap;
6use std::hash::Hash;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use crate::{CacheEntry, EntryMetadata, Result, StorageBackend};
11
12#[allow(clippy::type_complexity)]
14pub struct MemoryBackend<K, V, M = ()>
15where
16 K: Hash + Eq + Clone + Send + Sync,
17 V: Clone + Send + Sync,
18 M: Clone + Send + Sync,
19{
20 data: Arc<RwLock<HashMap<K, Vec<CacheEntry<K, V, M>>>>>,
21}
22
23impl<K, V, M> MemoryBackend<K, V, M>
24where
25 K: Hash + Eq + Clone + Send + Sync,
26 V: Clone + Send + Sync,
27 M: Clone + Send + Sync,
28{
29 pub fn new() -> Self {
31 Self {
32 data: Arc::new(RwLock::new(HashMap::new())),
33 }
34 }
35}
36
37impl<K, V, M> Default for MemoryBackend<K, V, M>
38where
39 K: Hash + Eq + Clone + Send + Sync,
40 V: Clone + Send + Sync,
41 M: Clone + Send + Sync,
42{
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl<K, V, M> Clone for MemoryBackend<K, V, M>
49where
50 K: Hash + Eq + Clone + Send + Sync,
51 V: Clone + Send + Sync,
52 M: Clone + Send + Sync,
53{
54 fn clone(&self) -> Self {
55 Self {
56 data: Arc::clone(&self.data),
57 }
58 }
59}
60
61#[async_trait]
62impl<K, V, M> StorageBackend for MemoryBackend<K, V, M>
63where
64 K: Serialize + DeserializeOwned + Hash + Eq + Clone + Send + Sync + 'static,
65 V: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
66 M: Serialize + DeserializeOwned + Clone + Send + Sync + EntryMetadata,
67{
68 type Key = K;
69 type Value = V;
70 type Metadata = M;
71
72 async fn save(&self, entries: &HashMap<K, Vec<CacheEntry<K, V, M>>>) -> Result<()> {
73 let mut data = self.data.write().await;
74 *data = entries.clone();
75 Ok(())
76 }
77
78 async fn load(&self) -> Result<HashMap<K, Vec<CacheEntry<K, V, M>>>> {
79 let data = self.data.read().await;
80 Ok(data.clone())
81 }
82
83 async fn remove(&self, key: &K) -> Result<()> {
84 let mut data = self.data.write().await;
85 data.remove(key);
86 Ok(())
87 }
88
89 async fn clear(&self) -> Result<()> {
90 let mut data = self.data.write().await;
91 data.clear();
92 Ok(())
93 }
94
95 async fn contains(&self, key: &K) -> Result<bool> {
96 let data = self.data.read().await;
97 Ok(data.contains_key(key))
98 }
99
100 async fn size_bytes(&self) -> Result<u64> {
101 let data = self.data.read().await;
102
103 let total_entries: usize = data.values().map(|v| v.len()).sum();
105 let estimated_size = total_entries * std::mem::size_of::<CacheEntry<K, V, M>>();
106
107 Ok(estimated_size as u64)
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[tokio::test]
116 async fn test_memory_backend_operations() {
117 let backend: MemoryBackend<String, String> = MemoryBackend::new();
118
119 let loaded = backend.load().await.unwrap();
121 assert!(loaded.is_empty());
122
123 let mut entries = HashMap::new();
125 let entry = CacheEntry::new("key1".to_string(), "value1".to_string());
126 entries.insert("key1".to_string(), vec![entry]);
127
128 backend.save(&entries).await.unwrap();
129 let loaded = backend.load().await.unwrap();
130 assert_eq!(loaded.len(), 1);
131 assert!(loaded.contains_key("key1"));
132
133 assert!(backend.contains(&"key1".to_string()).await.unwrap());
135 assert!(!backend.contains(&"key2".to_string()).await.unwrap());
136
137 backend.remove(&"key1".to_string()).await.unwrap();
139 assert!(!backend.contains(&"key1".to_string()).await.unwrap());
140
141 backend.save(&entries).await.unwrap();
143 backend.clear().await.unwrap();
144 let loaded = backend.load().await.unwrap();
145 assert!(loaded.is_empty());
146 }
147
148 #[tokio::test]
149 async fn test_memory_backend_clone() {
150 let backend1: MemoryBackend<String, String> = MemoryBackend::new();
151 let backend2 = backend1.clone();
152
153 let mut entries = HashMap::new();
155 let entry = CacheEntry::new("key1".to_string(), "value1".to_string());
156 entries.insert("key1".to_string(), vec![entry]);
157
158 backend1.save(&entries).await.unwrap();
159 assert!(backend2.contains(&"key1".to_string()).await.unwrap());
160 }
161}