Skip to main content

threatflux_cache/backends/
memory.rs

1//! In-memory storage backend
2
3use async_trait::async_trait;
4use serde::{de::DeserializeOwned, Serialize};
5use std::collections::HashMap;
6use std::hash::Hash;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use crate::{CacheEntry, EntryMetadata, Result, StorageBackend};
11
12/// In-memory storage backend
13#[allow(clippy::type_complexity)]
14pub struct MemoryBackend<K, V, M = ()>
15where
16    K: Hash + Eq + Clone + Send + Sync,
17    V: Clone + Send + Sync,
18    M: Clone + Send + Sync,
19{
20    data: Arc<RwLock<HashMap<K, Vec<CacheEntry<K, V, M>>>>>,
21}
22
23impl<K, V, M> MemoryBackend<K, V, M>
24where
25    K: Hash + Eq + Clone + Send + Sync,
26    V: Clone + Send + Sync,
27    M: Clone + Send + Sync,
28{
29    /// Create a new memory backend
30    pub fn new() -> Self {
31        Self {
32            data: Arc::new(RwLock::new(HashMap::new())),
33        }
34    }
35}
36
37impl<K, V, M> Default for MemoryBackend<K, V, M>
38where
39    K: Hash + Eq + Clone + Send + Sync,
40    V: Clone + Send + Sync,
41    M: Clone + Send + Sync,
42{
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl<K, V, M> Clone for MemoryBackend<K, V, M>
49where
50    K: Hash + Eq + Clone + Send + Sync,
51    V: Clone + Send + Sync,
52    M: Clone + Send + Sync,
53{
54    fn clone(&self) -> Self {
55        Self {
56            data: Arc::clone(&self.data),
57        }
58    }
59}
60
61#[async_trait]
62impl<K, V, M> StorageBackend for MemoryBackend<K, V, M>
63where
64    K: Serialize + DeserializeOwned + Hash + Eq + Clone + Send + Sync + 'static,
65    V: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
66    M: Serialize + DeserializeOwned + Clone + Send + Sync + EntryMetadata,
67{
68    type Key = K;
69    type Value = V;
70    type Metadata = M;
71
72    async fn save(&self, entries: &HashMap<K, Vec<CacheEntry<K, V, M>>>) -> Result<()> {
73        let mut data = self.data.write().await;
74        *data = entries.clone();
75        Ok(())
76    }
77
78    async fn load(&self) -> Result<HashMap<K, Vec<CacheEntry<K, V, M>>>> {
79        let data = self.data.read().await;
80        Ok(data.clone())
81    }
82
83    async fn remove(&self, key: &K) -> Result<()> {
84        let mut data = self.data.write().await;
85        data.remove(key);
86        Ok(())
87    }
88
89    async fn clear(&self) -> Result<()> {
90        let mut data = self.data.write().await;
91        data.clear();
92        Ok(())
93    }
94
95    async fn contains(&self, key: &K) -> Result<bool> {
96        let data = self.data.read().await;
97        Ok(data.contains_key(key))
98    }
99
100    async fn size_bytes(&self) -> Result<u64> {
101        let data = self.data.read().await;
102
103        // Estimate size based on number of entries
104        let total_entries: usize = data.values().map(|v| v.len()).sum();
105        let estimated_size = total_entries * std::mem::size_of::<CacheEntry<K, V, M>>();
106
107        Ok(estimated_size as u64)
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[tokio::test]
116    async fn test_memory_backend_operations() {
117        let backend: MemoryBackend<String, String> = MemoryBackend::new();
118
119        // Test empty state
120        let loaded = backend.load().await.unwrap();
121        assert!(loaded.is_empty());
122
123        // Test save and load
124        let mut entries = HashMap::new();
125        let entry = CacheEntry::new("key1".to_string(), "value1".to_string());
126        entries.insert("key1".to_string(), vec![entry]);
127
128        backend.save(&entries).await.unwrap();
129        let loaded = backend.load().await.unwrap();
130        assert_eq!(loaded.len(), 1);
131        assert!(loaded.contains_key("key1"));
132
133        // Test contains
134        assert!(backend.contains(&"key1".to_string()).await.unwrap());
135        assert!(!backend.contains(&"key2".to_string()).await.unwrap());
136
137        // Test remove
138        backend.remove(&"key1".to_string()).await.unwrap();
139        assert!(!backend.contains(&"key1".to_string()).await.unwrap());
140
141        // Test clear
142        backend.save(&entries).await.unwrap();
143        backend.clear().await.unwrap();
144        let loaded = backend.load().await.unwrap();
145        assert!(loaded.is_empty());
146    }
147
148    #[tokio::test]
149    async fn test_memory_backend_clone() {
150        let backend1: MemoryBackend<String, String> = MemoryBackend::new();
151        let backend2 = backend1.clone();
152
153        // Changes in one should be reflected in the other
154        let mut entries = HashMap::new();
155        let entry = CacheEntry::new("key1".to_string(), "value1".to_string());
156        entries.insert("key1".to_string(), vec![entry]);
157
158        backend1.save(&entries).await.unwrap();
159        assert!(backend2.contains(&"key1".to_string()).await.unwrap());
160    }
161}