Skip to main content

shape_wire/transport/
memoized.rs

1//! Memoized transport wrapper that caches results of remote function calls.
2//!
3//! Caches results by `SHA-256(destination || payload)` using an LRU cache.
4//! Intercepts [`Transport::send`] before forwarding to the inner transport,
5//! returning the cached result when available.
6
7use super::{Connection, Transport, TransportError};
8use sha2::{Digest, Sha256};
9use std::collections::HashMap;
10use std::sync::Mutex;
11
12/// Configuration for the memoized transport.
13#[derive(Debug, Clone)]
14pub struct MemoConfig {
15    /// Maximum number of cached entries before LRU eviction kicks in.
16    pub max_entries: usize,
17    /// Whether caching is enabled. When `false`, all calls pass through.
18    pub enabled: bool,
19}
20
21impl Default for MemoConfig {
22    fn default() -> Self {
23        Self {
24            max_entries: 1024,
25            enabled: true,
26        }
27    }
28}
29
30/// A single entry in the memo cache.
31#[derive(Debug, Clone)]
32struct CacheEntry {
33    result: Vec<u8>,
34    hits: u64,
35}
36
37/// Interior mutable cache state, protected by a `Mutex` so we satisfy
38/// the `Send + Sync` bound required by [`Transport`].
39#[derive(Debug)]
40struct CacheState {
41    cache: HashMap<[u8; 32], CacheEntry>,
42    stats: MemoStats,
43    /// Insertion order for LRU eviction (oldest first).
44    insertion_order: Vec<[u8; 32]>,
45}
46
47/// Memoized transport wrapper with LRU eviction.
48///
49/// Wraps any [`Transport`] implementation and caches the results of
50/// one-shot `send` calls. Persistent connections (`connect`) are
51/// forwarded directly to the inner transport without caching.
52pub struct MemoizedTransport<T: Transport> {
53    inner: T,
54    config: MemoConfig,
55    state: Mutex<CacheState>,
56}
57
58/// Cache hit/miss statistics.
59#[derive(Debug, Default, Clone)]
60pub struct MemoStats {
61    pub cache_hits: u64,
62    pub cache_misses: u64,
63    pub evictions: u64,
64    pub total_requests: u64,
65}
66
67impl<T: Transport> MemoizedTransport<T> {
68    /// Create a new memoized transport wrapping `inner` with the given config.
69    pub fn new(inner: T, config: MemoConfig) -> Self {
70        let state = CacheState {
71            cache: HashMap::with_capacity(config.max_entries),
72            stats: MemoStats::default(),
73            insertion_order: Vec::new(),
74        };
75        Self {
76            inner,
77            config,
78            state: Mutex::new(state),
79        }
80    }
81
82    /// Compute the cache key as `SHA-256(destination || payload)`.
83    pub fn compute_cache_key(destination: &str, payload: &[u8]) -> [u8; 32] {
84        let mut hasher = Sha256::new();
85        hasher.update(destination.as_bytes());
86        hasher.update(payload);
87        hasher.finalize().into()
88    }
89
90    /// Invalidate a specific cache entry by key.
91    pub fn invalidate(&self, key: &[u8; 32]) {
92        let mut state = self.state.lock().unwrap();
93        if state.cache.remove(key).is_some() {
94            state.insertion_order.retain(|k| k != key);
95        }
96    }
97
98    /// Invalidate all cached entries.
99    pub fn invalidate_all(&self) {
100        let mut state = self.state.lock().unwrap();
101        state.cache.clear();
102        state.insertion_order.clear();
103    }
104
105    /// Return a snapshot of the current cache statistics.
106    pub fn stats(&self) -> MemoStats {
107        self.state.lock().unwrap().stats.clone()
108    }
109
110    /// Return the current number of cached entries.
111    pub fn cache_len(&self) -> usize {
112        self.state.lock().unwrap().cache.len()
113    }
114}
115
116impl CacheState {
117    /// Evict the oldest entry to make room for a new one.
118    fn evict_oldest(&mut self) {
119        if let Some(oldest_key) = self.insertion_order.first().copied() {
120            self.cache.remove(&oldest_key);
121            self.insertion_order.remove(0);
122            self.stats.evictions += 1;
123        }
124    }
125}
126
127impl<T: Transport> Transport for MemoizedTransport<T> {
128    fn send(&self, destination: &str, payload: &[u8]) -> Result<Vec<u8>, TransportError> {
129        let key = MemoizedTransport::<T>::compute_cache_key(destination, payload);
130
131        {
132            let mut state = self.state.lock().unwrap();
133            state.stats.total_requests += 1;
134
135            if !self.config.enabled {
136                // Drop the lock before calling inner.
137                drop(state);
138                return self.inner.send(destination, payload);
139            }
140
141            // Check cache.
142            if let Some(entry) = state.cache.get_mut(&key) {
143                let result = entry.result.clone();
144                entry.hits += 1;
145                state.stats.cache_hits += 1;
146                return Ok(result);
147            }
148
149            state.stats.cache_misses += 1;
150            // Drop lock before the potentially blocking inner send.
151        }
152
153        // Cache miss -- delegate to inner transport (lock not held).
154        let result = self.inner.send(destination, payload)?;
155
156        // Re-acquire lock to insert the result.
157        {
158            let mut state = self.state.lock().unwrap();
159
160            // Evict if at capacity.
161            if state.cache.len() >= self.config.max_entries {
162                state.evict_oldest();
163            }
164
165            state.insertion_order.push(key);
166            state.cache.insert(
167                key,
168                CacheEntry {
169                    result: result.clone(),
170                    hits: 0,
171                },
172            );
173        }
174
175        Ok(result)
176    }
177
178    fn connect(&self, destination: &str) -> Result<Box<dyn Connection>, TransportError> {
179        // Persistent connections are not cacheable; delegate directly.
180        self.inner.connect(destination)
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use std::sync::Arc;
188    use std::sync::atomic::{AtomicU64, Ordering};
189
190    /// A mock transport that echoes the payload back with a call counter.
191    struct EchoTransport {
192        call_count: Arc<AtomicU64>,
193    }
194
195    impl EchoTransport {
196        fn new() -> (Self, Arc<AtomicU64>) {
197            let counter = Arc::new(AtomicU64::new(0));
198            (
199                Self {
200                    call_count: counter.clone(),
201                },
202                counter,
203            )
204        }
205    }
206
207    impl Transport for EchoTransport {
208        fn send(&self, _destination: &str, payload: &[u8]) -> Result<Vec<u8>, TransportError> {
209            self.call_count.fetch_add(1, Ordering::SeqCst);
210            Ok(payload.to_vec())
211        }
212
213        fn connect(&self, _destination: &str) -> Result<Box<dyn Connection>, TransportError> {
214            Err(TransportError::ConnectionFailed(
215                "not supported".to_string(),
216            ))
217        }
218    }
219
220    #[test]
221    fn test_cache_hit() {
222        let (echo, counter) = EchoTransport::new();
223        let memo = MemoizedTransport::new(echo, MemoConfig::default());
224
225        let r1 = Transport::send(&memo, "host:1234", b"hello").unwrap();
226        let r2 = Transport::send(&memo, "host:1234", b"hello").unwrap();
227
228        assert_eq!(r1, r2);
229        assert_eq!(counter.load(Ordering::SeqCst), 1);
230        let stats = memo.stats();
231        assert_eq!(stats.cache_hits, 1);
232        assert_eq!(stats.cache_misses, 1);
233        assert_eq!(stats.total_requests, 2);
234    }
235
236    #[test]
237    fn test_cache_miss_different_payload() {
238        let (echo, counter) = EchoTransport::new();
239        let memo = MemoizedTransport::new(echo, MemoConfig::default());
240
241        Transport::send(&memo, "host:1234", b"aaa").unwrap();
242        Transport::send(&memo, "host:1234", b"bbb").unwrap();
243
244        assert_eq!(counter.load(Ordering::SeqCst), 2);
245    }
246
247    #[test]
248    fn test_cache_miss_different_destination() {
249        let (echo, counter) = EchoTransport::new();
250        let memo = MemoizedTransport::new(echo, MemoConfig::default());
251
252        Transport::send(&memo, "host-a:1234", b"same").unwrap();
253        Transport::send(&memo, "host-b:1234", b"same").unwrap();
254
255        assert_eq!(counter.load(Ordering::SeqCst), 2);
256    }
257
258    #[test]
259    fn test_lru_eviction() {
260        let (echo, _counter) = EchoTransport::new();
261        let memo = MemoizedTransport::new(
262            echo,
263            MemoConfig {
264                max_entries: 2,
265                enabled: true,
266            },
267        );
268
269        Transport::send(&memo, "a", b"1").unwrap();
270        Transport::send(&memo, "b", b"2").unwrap();
271        // This should evict the entry for ("a", "1").
272        Transport::send(&memo, "c", b"3").unwrap();
273
274        assert_eq!(memo.stats().evictions, 1);
275        assert_eq!(memo.cache_len(), 2);
276
277        // "a"/"1" should have been evicted.
278        let key_a = MemoizedTransport::<EchoTransport>::compute_cache_key("a", b"1");
279        assert!(!memo.state.lock().unwrap().cache.contains_key(&key_a));
280    }
281
282    #[test]
283    fn test_disabled_passthrough() {
284        let (echo, counter) = EchoTransport::new();
285        let memo = MemoizedTransport::new(
286            echo,
287            MemoConfig {
288                max_entries: 1024,
289                enabled: false,
290            },
291        );
292
293        Transport::send(&memo, "host", b"x").unwrap();
294        Transport::send(&memo, "host", b"x").unwrap();
295
296        // Both should go through to the inner transport.
297        assert_eq!(counter.load(Ordering::SeqCst), 2);
298    }
299
300    #[test]
301    fn test_invalidate() {
302        let (echo, counter) = EchoTransport::new();
303        let memo = MemoizedTransport::new(echo, MemoConfig::default());
304
305        Transport::send(&memo, "host", b"data").unwrap();
306        assert_eq!(counter.load(Ordering::SeqCst), 1);
307
308        let key = MemoizedTransport::<EchoTransport>::compute_cache_key("host", b"data");
309        memo.invalidate(&key);
310
311        // After invalidation the next call should miss.
312        Transport::send(&memo, "host", b"data").unwrap();
313        assert_eq!(counter.load(Ordering::SeqCst), 2);
314    }
315
316    #[test]
317    fn test_invalidate_all() {
318        let (echo, counter) = EchoTransport::new();
319        let memo = MemoizedTransport::new(echo, MemoConfig::default());
320
321        Transport::send(&memo, "a", b"1").unwrap();
322        Transport::send(&memo, "b", b"2").unwrap();
323        assert_eq!(counter.load(Ordering::SeqCst), 2);
324
325        memo.invalidate_all();
326        assert_eq!(memo.cache_len(), 0);
327
328        Transport::send(&memo, "a", b"1").unwrap();
329        Transport::send(&memo, "b", b"2").unwrap();
330        assert_eq!(counter.load(Ordering::SeqCst), 4);
331    }
332
333    #[test]
334    fn test_compute_cache_key_deterministic() {
335        let k1 = MemoizedTransport::<EchoTransport>::compute_cache_key("host", b"payload");
336        let k2 = MemoizedTransport::<EchoTransport>::compute_cache_key("host", b"payload");
337        assert_eq!(k1, k2);
338    }
339
340    #[test]
341    fn test_compute_cache_key_distinct() {
342        let k1 = MemoizedTransport::<EchoTransport>::compute_cache_key("host", b"aaa");
343        let k2 = MemoizedTransport::<EchoTransport>::compute_cache_key("host", b"bbb");
344        assert_ne!(k1, k2);
345    }
346}