1use super::{Connection, Transport, TransportError};
8use sha2::{Digest, Sha256};
9use std::collections::HashMap;
10use std::sync::Mutex;
11
12#[derive(Debug, Clone)]
14pub struct MemoConfig {
15 pub max_entries: usize,
17 pub enabled: bool,
19}
20
21impl Default for MemoConfig {
22 fn default() -> Self {
23 Self {
24 max_entries: 1024,
25 enabled: true,
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
32struct CacheEntry {
33 result: Vec<u8>,
34 hits: u64,
35}
36
37#[derive(Debug)]
40struct CacheState {
41 cache: HashMap<[u8; 32], CacheEntry>,
42 stats: MemoStats,
43 insertion_order: Vec<[u8; 32]>,
45}
46
47pub struct MemoizedTransport<T: Transport> {
53 inner: T,
54 config: MemoConfig,
55 state: Mutex<CacheState>,
56}
57
58#[derive(Debug, Default, Clone)]
60pub struct MemoStats {
61 pub cache_hits: u64,
62 pub cache_misses: u64,
63 pub evictions: u64,
64 pub total_requests: u64,
65}
66
67impl<T: Transport> MemoizedTransport<T> {
68 pub fn new(inner: T, config: MemoConfig) -> Self {
70 let state = CacheState {
71 cache: HashMap::with_capacity(config.max_entries),
72 stats: MemoStats::default(),
73 insertion_order: Vec::new(),
74 };
75 Self {
76 inner,
77 config,
78 state: Mutex::new(state),
79 }
80 }
81
82 pub fn compute_cache_key(destination: &str, payload: &[u8]) -> [u8; 32] {
84 let mut hasher = Sha256::new();
85 hasher.update(destination.as_bytes());
86 hasher.update(payload);
87 hasher.finalize().into()
88 }
89
90 pub fn invalidate(&self, key: &[u8; 32]) {
92 let mut state = self.state.lock().unwrap();
93 if state.cache.remove(key).is_some() {
94 state.insertion_order.retain(|k| k != key);
95 }
96 }
97
98 pub fn invalidate_all(&self) {
100 let mut state = self.state.lock().unwrap();
101 state.cache.clear();
102 state.insertion_order.clear();
103 }
104
105 pub fn stats(&self) -> MemoStats {
107 self.state.lock().unwrap().stats.clone()
108 }
109
110 pub fn cache_len(&self) -> usize {
112 self.state.lock().unwrap().cache.len()
113 }
114}
115
116impl CacheState {
117 fn evict_oldest(&mut self) {
119 if let Some(oldest_key) = self.insertion_order.first().copied() {
120 self.cache.remove(&oldest_key);
121 self.insertion_order.remove(0);
122 self.stats.evictions += 1;
123 }
124 }
125}
126
127impl<T: Transport> Transport for MemoizedTransport<T> {
128 fn send(&self, destination: &str, payload: &[u8]) -> Result<Vec<u8>, TransportError> {
129 let key = MemoizedTransport::<T>::compute_cache_key(destination, payload);
130
131 {
132 let mut state = self.state.lock().unwrap();
133 state.stats.total_requests += 1;
134
135 if !self.config.enabled {
136 drop(state);
138 return self.inner.send(destination, payload);
139 }
140
141 if let Some(entry) = state.cache.get_mut(&key) {
143 let result = entry.result.clone();
144 entry.hits += 1;
145 state.stats.cache_hits += 1;
146 return Ok(result);
147 }
148
149 state.stats.cache_misses += 1;
150 }
152
153 let result = self.inner.send(destination, payload)?;
155
156 {
158 let mut state = self.state.lock().unwrap();
159
160 if state.cache.len() >= self.config.max_entries {
162 state.evict_oldest();
163 }
164
165 state.insertion_order.push(key);
166 state.cache.insert(
167 key,
168 CacheEntry {
169 result: result.clone(),
170 hits: 0,
171 },
172 );
173 }
174
175 Ok(result)
176 }
177
178 fn connect(&self, destination: &str) -> Result<Box<dyn Connection>, TransportError> {
179 self.inner.connect(destination)
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use std::sync::Arc;
188 use std::sync::atomic::{AtomicU64, Ordering};
189
190 struct EchoTransport {
192 call_count: Arc<AtomicU64>,
193 }
194
195 impl EchoTransport {
196 fn new() -> (Self, Arc<AtomicU64>) {
197 let counter = Arc::new(AtomicU64::new(0));
198 (
199 Self {
200 call_count: counter.clone(),
201 },
202 counter,
203 )
204 }
205 }
206
207 impl Transport for EchoTransport {
208 fn send(&self, _destination: &str, payload: &[u8]) -> Result<Vec<u8>, TransportError> {
209 self.call_count.fetch_add(1, Ordering::SeqCst);
210 Ok(payload.to_vec())
211 }
212
213 fn connect(&self, _destination: &str) -> Result<Box<dyn Connection>, TransportError> {
214 Err(TransportError::ConnectionFailed(
215 "not supported".to_string(),
216 ))
217 }
218 }
219
220 #[test]
221 fn test_cache_hit() {
222 let (echo, counter) = EchoTransport::new();
223 let memo = MemoizedTransport::new(echo, MemoConfig::default());
224
225 let r1 = Transport::send(&memo, "host:1234", b"hello").unwrap();
226 let r2 = Transport::send(&memo, "host:1234", b"hello").unwrap();
227
228 assert_eq!(r1, r2);
229 assert_eq!(counter.load(Ordering::SeqCst), 1);
230 let stats = memo.stats();
231 assert_eq!(stats.cache_hits, 1);
232 assert_eq!(stats.cache_misses, 1);
233 assert_eq!(stats.total_requests, 2);
234 }
235
236 #[test]
237 fn test_cache_miss_different_payload() {
238 let (echo, counter) = EchoTransport::new();
239 let memo = MemoizedTransport::new(echo, MemoConfig::default());
240
241 Transport::send(&memo, "host:1234", b"aaa").unwrap();
242 Transport::send(&memo, "host:1234", b"bbb").unwrap();
243
244 assert_eq!(counter.load(Ordering::SeqCst), 2);
245 }
246
247 #[test]
248 fn test_cache_miss_different_destination() {
249 let (echo, counter) = EchoTransport::new();
250 let memo = MemoizedTransport::new(echo, MemoConfig::default());
251
252 Transport::send(&memo, "host-a:1234", b"same").unwrap();
253 Transport::send(&memo, "host-b:1234", b"same").unwrap();
254
255 assert_eq!(counter.load(Ordering::SeqCst), 2);
256 }
257
258 #[test]
259 fn test_lru_eviction() {
260 let (echo, _counter) = EchoTransport::new();
261 let memo = MemoizedTransport::new(
262 echo,
263 MemoConfig {
264 max_entries: 2,
265 enabled: true,
266 },
267 );
268
269 Transport::send(&memo, "a", b"1").unwrap();
270 Transport::send(&memo, "b", b"2").unwrap();
271 Transport::send(&memo, "c", b"3").unwrap();
273
274 assert_eq!(memo.stats().evictions, 1);
275 assert_eq!(memo.cache_len(), 2);
276
277 let key_a = MemoizedTransport::<EchoTransport>::compute_cache_key("a", b"1");
279 assert!(!memo.state.lock().unwrap().cache.contains_key(&key_a));
280 }
281
282 #[test]
283 fn test_disabled_passthrough() {
284 let (echo, counter) = EchoTransport::new();
285 let memo = MemoizedTransport::new(
286 echo,
287 MemoConfig {
288 max_entries: 1024,
289 enabled: false,
290 },
291 );
292
293 Transport::send(&memo, "host", b"x").unwrap();
294 Transport::send(&memo, "host", b"x").unwrap();
295
296 assert_eq!(counter.load(Ordering::SeqCst), 2);
298 }
299
300 #[test]
301 fn test_invalidate() {
302 let (echo, counter) = EchoTransport::new();
303 let memo = MemoizedTransport::new(echo, MemoConfig::default());
304
305 Transport::send(&memo, "host", b"data").unwrap();
306 assert_eq!(counter.load(Ordering::SeqCst), 1);
307
308 let key = MemoizedTransport::<EchoTransport>::compute_cache_key("host", b"data");
309 memo.invalidate(&key);
310
311 Transport::send(&memo, "host", b"data").unwrap();
313 assert_eq!(counter.load(Ordering::SeqCst), 2);
314 }
315
316 #[test]
317 fn test_invalidate_all() {
318 let (echo, counter) = EchoTransport::new();
319 let memo = MemoizedTransport::new(echo, MemoConfig::default());
320
321 Transport::send(&memo, "a", b"1").unwrap();
322 Transport::send(&memo, "b", b"2").unwrap();
323 assert_eq!(counter.load(Ordering::SeqCst), 2);
324
325 memo.invalidate_all();
326 assert_eq!(memo.cache_len(), 0);
327
328 Transport::send(&memo, "a", b"1").unwrap();
329 Transport::send(&memo, "b", b"2").unwrap();
330 assert_eq!(counter.load(Ordering::SeqCst), 4);
331 }
332
333 #[test]
334 fn test_compute_cache_key_deterministic() {
335 let k1 = MemoizedTransport::<EchoTransport>::compute_cache_key("host", b"payload");
336 let k2 = MemoizedTransport::<EchoTransport>::compute_cache_key("host", b"payload");
337 assert_eq!(k1, k2);
338 }
339
340 #[test]
341 fn test_compute_cache_key_distinct() {
342 let k1 = MemoizedTransport::<EchoTransport>::compute_cache_key("host", b"aaa");
343 let k2 = MemoizedTransport::<EchoTransport>::compute_cache_key("host", b"bbb");
344 assert_ne!(k1, k2);
345 }
346}