Skip to main content

surfpool_core/storage/
overlay.rs

1use std::{
2    collections::{HashMap, HashSet},
3    hash::Hash,
4    sync::{Arc, RwLock},
5};
6
7use serde::{Deserialize, Serialize};
8
9use super::{Storage, StorageError, StorageResult};
10
11/// Represents the state of a key in the overlay
12#[derive(Clone)]
13enum OverlayEntry<V> {
14    /// Value was written to overlay
15    Written(V),
16    /// Value was deleted in overlay (tombstone)
17    Deleted,
18}
19
20/// Thread-safe overlay storage that wraps a base storage.
21/// All writes go to an in-memory HashMap overlay.
22/// Reads check overlay first, then fall through to base.
23/// Deletes are tracked as tombstones in the overlay.
24///
25/// This is useful for transaction profiling where we need to
26/// read from the database but not persist any mutations.
27pub struct OverlayStorage<K, V> {
28    /// The base storage (could be SQLite, Postgres, HashMap, etc.)
29    base: Box<dyn Storage<K, V>>,
30    /// In-memory overlay for writes and deletes
31    /// Using Arc<RwLock<_>> for thread-safety (Send + Sync)
32    overlay: Arc<RwLock<HashMap<K, OverlayEntry<V>>>>,
33    /// Track if base was "cleared" - if true, ignore base for reads
34    base_cleared: Arc<RwLock<bool>>,
35}
36
37impl<K, V> OverlayStorage<K, V>
38where
39    K: Clone + Eq + Hash + Send + Sync + 'static,
40    V: Clone + Send + Sync + 'static,
41{
42    /// Create a new overlay wrapping the given base storage
43    pub fn new(base: Box<dyn Storage<K, V>>) -> Self {
44        Self {
45            base,
46            overlay: Arc::new(RwLock::new(HashMap::new())),
47            base_cleared: Arc::new(RwLock::new(false)),
48        }
49    }
50}
51
52impl<K, V> OverlayStorage<K, V>
53where
54    K: Serialize + for<'de> Deserialize<'de> + Clone + Eq + Hash + Send + Sync + 'static,
55    V: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
56{
57    /// Create a boxed overlay from a base storage.
58    /// This is a convenience method for wrapping storage fields.
59    pub fn wrap(base: Box<dyn Storage<K, V>>) -> Box<dyn Storage<K, V>> {
60        Box::new(OverlayStorage::new(base))
61    }
62}
63
64impl<K, V> Storage<K, V> for OverlayStorage<K, V>
65where
66    K: Serialize + for<'de> Deserialize<'de> + Clone + Eq + Hash + Send + Sync + 'static,
67    V: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
68{
69    fn store(&mut self, key: K, value: V) -> StorageResult<()> {
70        // Write only to overlay, never to base
71        let mut overlay = self.overlay.write().map_err(|_| StorageError::LockError)?;
72        overlay.insert(key, OverlayEntry::Written(value));
73        Ok(())
74    }
75
76    fn get(&self, key: &K) -> StorageResult<Option<V>> {
77        // First check overlay
78        let overlay = self.overlay.read().map_err(|_| StorageError::LockError)?;
79
80        if let Some(entry) = overlay.get(key) {
81            return match entry {
82                OverlayEntry::Written(v) => Ok(Some(v.clone())),
83                OverlayEntry::Deleted => Ok(None), // Tombstone - don't query base
84            };
85        }
86        drop(overlay); // Release read lock before querying base
87
88        // Check if base was cleared
89        let base_cleared = self
90            .base_cleared
91            .read()
92            .map_err(|_| StorageError::LockError)?;
93        if *base_cleared {
94            return Ok(None);
95        }
96        drop(base_cleared);
97
98        // Fall through to base storage
99        self.base.get(key)
100    }
101
102    fn take(&mut self, key: &K) -> StorageResult<Option<V>> {
103        let mut overlay = self.overlay.write().map_err(|_| StorageError::LockError)?;
104
105        // Check if key exists in overlay
106        if let Some(entry) = overlay.get(key) {
107            match entry {
108                OverlayEntry::Written(v) => {
109                    let value = v.clone();
110                    // Replace with tombstone
111                    overlay.insert(key.clone(), OverlayEntry::Deleted);
112                    return Ok(Some(value));
113                }
114                OverlayEntry::Deleted => {
115                    // Already deleted
116                    return Ok(None);
117                }
118            }
119        }
120        drop(overlay);
121
122        // Check if base was cleared
123        let base_cleared = self
124            .base_cleared
125            .read()
126            .map_err(|_| StorageError::LockError)?;
127        if *base_cleared {
128            return Ok(None);
129        }
130        drop(base_cleared);
131
132        // Get from base (but don't modify base)
133        let value = self.base.get(key)?;
134
135        if value.is_some() {
136            // Mark as deleted in overlay
137            let mut overlay = self.overlay.write().map_err(|_| StorageError::LockError)?;
138            overlay.insert(key.clone(), OverlayEntry::Deleted);
139        }
140
141        Ok(value)
142    }
143
144    fn clear(&mut self) -> StorageResult<()> {
145        // Mark base as cleared and clear overlay
146        let mut base_cleared = self
147            .base_cleared
148            .write()
149            .map_err(|_| StorageError::LockError)?;
150        *base_cleared = true;
151
152        let mut overlay = self.overlay.write().map_err(|_| StorageError::LockError)?;
153        overlay.clear();
154
155        Ok(())
156    }
157
158    fn keys(&self) -> StorageResult<Vec<K>> {
159        let overlay = self.overlay.read().map_err(|_| StorageError::LockError)?;
160        let base_cleared = *self
161            .base_cleared
162            .read()
163            .map_err(|_| StorageError::LockError)?;
164
165        let mut result_keys: HashSet<K> = HashSet::new();
166        let mut deleted_keys: HashSet<K> = HashSet::new();
167
168        // Collect overlay keys (written) and deleted keys
169        for (k, entry) in overlay.iter() {
170            match entry {
171                OverlayEntry::Written(_) => {
172                    result_keys.insert(k.clone());
173                }
174                OverlayEntry::Deleted => {
175                    deleted_keys.insert(k.clone());
176                }
177            }
178        }
179
180        // If base not cleared, add base keys (excluding deleted ones)
181        if !base_cleared {
182            drop(overlay);
183
184            for key in self.base.keys()? {
185                if !deleted_keys.contains(&key) && !result_keys.contains(&key) {
186                    result_keys.insert(key);
187                }
188            }
189        }
190
191        Ok(result_keys.into_iter().collect())
192    }
193
194    fn into_iter(&self) -> StorageResult<Box<dyn Iterator<Item = (K, V)> + '_>> {
195        let overlay = self.overlay.read().map_err(|_| StorageError::LockError)?;
196        let base_cleared = *self
197            .base_cleared
198            .read()
199            .map_err(|_| StorageError::LockError)?;
200
201        // Collect deleted keys for filtering
202        let deleted_keys: HashSet<K> = overlay
203            .iter()
204            .filter_map(|(k, entry)| {
205                if matches!(entry, OverlayEntry::Deleted) {
206                    Some(k.clone())
207                } else {
208                    None
209                }
210            })
211            .collect();
212
213        // Collect overlay written entries
214        let overlay_entries: Vec<(K, V)> = overlay
215            .iter()
216            .filter_map(|(k, entry)| {
217                if let OverlayEntry::Written(v) = entry {
218                    Some((k.clone(), v.clone()))
219                } else {
220                    None
221                }
222            })
223            .collect();
224
225        let overlay_keys: HashSet<K> = overlay_entries.iter().map(|(k, _)| k.clone()).collect();
226
227        drop(overlay);
228
229        // Get base entries if not cleared
230        let base_entries: Vec<(K, V)> = if !base_cleared {
231            self.base
232                .into_iter()?
233                .filter(|(k, _)| !deleted_keys.contains(k) && !overlay_keys.contains(k))
234                .collect()
235        } else {
236            Vec::new()
237        };
238
239        // Chain overlay entries with filtered base entries
240        let all_entries = overlay_entries
241            .into_iter()
242            .chain(base_entries)
243            .collect::<Vec<_>>();
244        Ok(Box::new(all_entries.into_iter()))
245    }
246
247    fn count(&self) -> StorageResult<u64> {
248        // Use keys() which handles all edge cases correctly
249        Ok(self.keys()?.len() as u64)
250    }
251
252    fn shutdown(&self) {
253        // No-op - don't propagate to base
254        // The base storage should not be affected by overlay shutdown
255    }
256
257    fn clone_box(&self) -> Box<dyn Storage<K, V>> {
258        // Clone the overlay with its current state
259        let overlay = self.overlay.read().unwrap();
260        let base_cleared = *self.base_cleared.read().unwrap();
261
262        Box::new(OverlayStorage {
263            base: self.base.clone_box(),
264            overlay: Arc::new(RwLock::new(overlay.clone())),
265            base_cleared: Arc::new(RwLock::new(base_cleared)),
266        })
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::storage::StorageHashMap;
274
275    #[test]
276    fn test_overlay_write_does_not_affect_base() {
277        let mut base: Box<dyn Storage<String, String>> =
278            Box::new(StorageHashMap::<String, String>::new());
279        base.store("key1".into(), "base_value".into()).unwrap();
280
281        let mut overlay = OverlayStorage::new(base.clone_box());
282        overlay
283            .store("key1".into(), "overlay_value".into())
284            .unwrap();
285
286        // Overlay should return overlay value
287        assert_eq!(
288            overlay.get(&"key1".into()).unwrap(),
289            Some("overlay_value".into())
290        );
291
292        // Base should still have original value
293        assert_eq!(base.get(&"key1".into()).unwrap(), Some("base_value".into()));
294    }
295
296    #[test]
297    fn test_overlay_read_falls_through_to_base() {
298        let mut base: Box<dyn Storage<String, String>> =
299            Box::new(StorageHashMap::<String, String>::new());
300        base.store("key1".into(), "base_value".into()).unwrap();
301
302        let overlay = OverlayStorage::new(base);
303
304        assert_eq!(
305            overlay.get(&"key1".into()).unwrap(),
306            Some("base_value".into())
307        );
308    }
309
310    #[test]
311    fn test_overlay_delete_creates_tombstone() {
312        let mut base: Box<dyn Storage<String, String>> =
313            Box::new(StorageHashMap::<String, String>::new());
314        base.store("key1".into(), "base_value".into()).unwrap();
315
316        let mut overlay = OverlayStorage::new(base.clone_box());
317        let taken = overlay.take(&"key1".into()).unwrap();
318
319        assert_eq!(taken, Some("base_value".into()));
320        assert_eq!(overlay.get(&"key1".into()).unwrap(), None);
321
322        // Base should still have the value
323        assert_eq!(base.get(&"key1".into()).unwrap(), Some("base_value".into()));
324    }
325
326    #[test]
327    fn test_overlay_keys_merges_correctly() {
328        let mut base: Box<dyn Storage<String, String>> =
329            Box::new(StorageHashMap::<String, String>::new());
330        base.store("base_key".into(), "value".into()).unwrap();
331
332        let mut overlay = OverlayStorage::new(base);
333        overlay.store("overlay_key".into(), "value".into()).unwrap();
334
335        let keys = overlay.keys().unwrap();
336        assert!(keys.contains(&"base_key".into()));
337        assert!(keys.contains(&"overlay_key".into()));
338    }
339
340    #[test]
341    fn test_overlay_clear_ignores_base() {
342        let mut base: Box<dyn Storage<String, String>> =
343            Box::new(StorageHashMap::<String, String>::new());
344        base.store("key1".into(), "base_value".into()).unwrap();
345
346        let mut overlay = OverlayStorage::new(base.clone_box());
347        overlay.clear().unwrap();
348
349        assert_eq!(overlay.get(&"key1".into()).unwrap(), None);
350        assert_eq!(overlay.keys().unwrap().len(), 0);
351
352        // Base should still have the value
353        assert_eq!(base.get(&"key1".into()).unwrap(), Some("base_value".into()));
354    }
355
356    #[test]
357    fn test_overlay_clone_box_creates_independent_copy() {
358        let base: Box<dyn Storage<String, String>> =
359            Box::new(StorageHashMap::<String, String>::new());
360
361        let mut overlay = OverlayStorage::new(base);
362        overlay.store("key1".into(), "value1".into()).unwrap();
363
364        let mut cloned = overlay.clone_box();
365        cloned.store("key2".into(), "value2".into()).unwrap();
366
367        // Original should not have key2
368        assert_eq!(overlay.get(&"key2".into()).unwrap(), None);
369        // Clone should have both
370        assert_eq!(cloned.get(&"key1".into()).unwrap(), Some("value1".into()));
371        assert_eq!(cloned.get(&"key2".into()).unwrap(), Some("value2".into()));
372    }
373
374    #[test]
375    fn test_overlay_count_accounts_for_tombstones() {
376        let mut base: Box<dyn Storage<String, String>> =
377            Box::new(StorageHashMap::<String, String>::new());
378        base.store("key1".into(), "value1".into()).unwrap();
379        base.store("key2".into(), "value2".into()).unwrap();
380
381        let mut overlay = OverlayStorage::new(base);
382        overlay.take(&"key1".into()).unwrap(); // Delete key1
383        overlay.store("key3".into(), "value3".into()).unwrap(); // Add key3
384
385        // Should have key2 (from base) and key3 (from overlay), but not key1 (deleted)
386        assert_eq!(overlay.count().unwrap(), 2);
387    }
388
389    #[test]
390    fn test_overlay_into_iter_merges_correctly() {
391        let mut base: Box<dyn Storage<String, String>> =
392            Box::new(StorageHashMap::<String, String>::new());
393        base.store("base_key".into(), "base_value".into()).unwrap();
394        base.store("deleted_key".into(), "deleted_value".into())
395            .unwrap();
396
397        let mut overlay = OverlayStorage::new(base);
398        overlay
399            .store("overlay_key".into(), "overlay_value".into())
400            .unwrap();
401        overlay.take(&"deleted_key".into()).unwrap();
402
403        let entries: Vec<(String, String)> = overlay.into_iter().unwrap().collect();
404
405        assert_eq!(entries.len(), 2);
406        assert!(entries.contains(&("base_key".into(), "base_value".into())));
407        assert!(entries.contains(&("overlay_key".into(), "overlay_value".into())));
408        assert!(
409            !entries
410                .iter()
411                .any(|(k, _)| k == &String::from("deleted_key"))
412        );
413    }
414}