1use std::{
2 collections::{HashMap, HashSet},
3 hash::Hash,
4 sync::{Arc, RwLock},
5};
6
7use serde::{Deserialize, Serialize};
8
9use super::{Storage, StorageError, StorageResult};
10
11#[derive(Clone)]
13enum OverlayEntry<V> {
14 Written(V),
16 Deleted,
18}
19
20pub struct OverlayStorage<K, V> {
28 base: Box<dyn Storage<K, V>>,
30 overlay: Arc<RwLock<HashMap<K, OverlayEntry<V>>>>,
33 base_cleared: Arc<RwLock<bool>>,
35}
36
37impl<K, V> OverlayStorage<K, V>
38where
39 K: Clone + Eq + Hash + Send + Sync + 'static,
40 V: Clone + Send + Sync + 'static,
41{
42 pub fn new(base: Box<dyn Storage<K, V>>) -> Self {
44 Self {
45 base,
46 overlay: Arc::new(RwLock::new(HashMap::new())),
47 base_cleared: Arc::new(RwLock::new(false)),
48 }
49 }
50}
51
52impl<K, V> OverlayStorage<K, V>
53where
54 K: Serialize + for<'de> Deserialize<'de> + Clone + Eq + Hash + Send + Sync + 'static,
55 V: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
56{
57 pub fn wrap(base: Box<dyn Storage<K, V>>) -> Box<dyn Storage<K, V>> {
60 Box::new(OverlayStorage::new(base))
61 }
62}
63
64impl<K, V> Storage<K, V> for OverlayStorage<K, V>
65where
66 K: Serialize + for<'de> Deserialize<'de> + Clone + Eq + Hash + Send + Sync + 'static,
67 V: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
68{
69 fn store(&mut self, key: K, value: V) -> StorageResult<()> {
70 let mut overlay = self.overlay.write().map_err(|_| StorageError::LockError)?;
72 overlay.insert(key, OverlayEntry::Written(value));
73 Ok(())
74 }
75
76 fn get(&self, key: &K) -> StorageResult<Option<V>> {
77 let overlay = self.overlay.read().map_err(|_| StorageError::LockError)?;
79
80 if let Some(entry) = overlay.get(key) {
81 return match entry {
82 OverlayEntry::Written(v) => Ok(Some(v.clone())),
83 OverlayEntry::Deleted => Ok(None), };
85 }
86 drop(overlay); let base_cleared = self
90 .base_cleared
91 .read()
92 .map_err(|_| StorageError::LockError)?;
93 if *base_cleared {
94 return Ok(None);
95 }
96 drop(base_cleared);
97
98 self.base.get(key)
100 }
101
102 fn take(&mut self, key: &K) -> StorageResult<Option<V>> {
103 let mut overlay = self.overlay.write().map_err(|_| StorageError::LockError)?;
104
105 if let Some(entry) = overlay.get(key) {
107 match entry {
108 OverlayEntry::Written(v) => {
109 let value = v.clone();
110 overlay.insert(key.clone(), OverlayEntry::Deleted);
112 return Ok(Some(value));
113 }
114 OverlayEntry::Deleted => {
115 return Ok(None);
117 }
118 }
119 }
120 drop(overlay);
121
122 let base_cleared = self
124 .base_cleared
125 .read()
126 .map_err(|_| StorageError::LockError)?;
127 if *base_cleared {
128 return Ok(None);
129 }
130 drop(base_cleared);
131
132 let value = self.base.get(key)?;
134
135 if value.is_some() {
136 let mut overlay = self.overlay.write().map_err(|_| StorageError::LockError)?;
138 overlay.insert(key.clone(), OverlayEntry::Deleted);
139 }
140
141 Ok(value)
142 }
143
144 fn clear(&mut self) -> StorageResult<()> {
145 let mut base_cleared = self
147 .base_cleared
148 .write()
149 .map_err(|_| StorageError::LockError)?;
150 *base_cleared = true;
151
152 let mut overlay = self.overlay.write().map_err(|_| StorageError::LockError)?;
153 overlay.clear();
154
155 Ok(())
156 }
157
158 fn keys(&self) -> StorageResult<Vec<K>> {
159 let overlay = self.overlay.read().map_err(|_| StorageError::LockError)?;
160 let base_cleared = *self
161 .base_cleared
162 .read()
163 .map_err(|_| StorageError::LockError)?;
164
165 let mut result_keys: HashSet<K> = HashSet::new();
166 let mut deleted_keys: HashSet<K> = HashSet::new();
167
168 for (k, entry) in overlay.iter() {
170 match entry {
171 OverlayEntry::Written(_) => {
172 result_keys.insert(k.clone());
173 }
174 OverlayEntry::Deleted => {
175 deleted_keys.insert(k.clone());
176 }
177 }
178 }
179
180 if !base_cleared {
182 drop(overlay);
183
184 for key in self.base.keys()? {
185 if !deleted_keys.contains(&key) && !result_keys.contains(&key) {
186 result_keys.insert(key);
187 }
188 }
189 }
190
191 Ok(result_keys.into_iter().collect())
192 }
193
194 fn into_iter(&self) -> StorageResult<Box<dyn Iterator<Item = (K, V)> + '_>> {
195 let overlay = self.overlay.read().map_err(|_| StorageError::LockError)?;
196 let base_cleared = *self
197 .base_cleared
198 .read()
199 .map_err(|_| StorageError::LockError)?;
200
201 let deleted_keys: HashSet<K> = overlay
203 .iter()
204 .filter_map(|(k, entry)| {
205 if matches!(entry, OverlayEntry::Deleted) {
206 Some(k.clone())
207 } else {
208 None
209 }
210 })
211 .collect();
212
213 let overlay_entries: Vec<(K, V)> = overlay
215 .iter()
216 .filter_map(|(k, entry)| {
217 if let OverlayEntry::Written(v) = entry {
218 Some((k.clone(), v.clone()))
219 } else {
220 None
221 }
222 })
223 .collect();
224
225 let overlay_keys: HashSet<K> = overlay_entries.iter().map(|(k, _)| k.clone()).collect();
226
227 drop(overlay);
228
229 let base_entries: Vec<(K, V)> = if !base_cleared {
231 self.base
232 .into_iter()?
233 .filter(|(k, _)| !deleted_keys.contains(k) && !overlay_keys.contains(k))
234 .collect()
235 } else {
236 Vec::new()
237 };
238
239 let all_entries = overlay_entries
241 .into_iter()
242 .chain(base_entries)
243 .collect::<Vec<_>>();
244 Ok(Box::new(all_entries.into_iter()))
245 }
246
247 fn count(&self) -> StorageResult<u64> {
248 Ok(self.keys()?.len() as u64)
250 }
251
252 fn shutdown(&self) {
253 }
256
257 fn clone_box(&self) -> Box<dyn Storage<K, V>> {
258 let overlay = self.overlay.read().unwrap();
260 let base_cleared = *self.base_cleared.read().unwrap();
261
262 Box::new(OverlayStorage {
263 base: self.base.clone_box(),
264 overlay: Arc::new(RwLock::new(overlay.clone())),
265 base_cleared: Arc::new(RwLock::new(base_cleared)),
266 })
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::storage::StorageHashMap;
274
275 #[test]
276 fn test_overlay_write_does_not_affect_base() {
277 let mut base: Box<dyn Storage<String, String>> =
278 Box::new(StorageHashMap::<String, String>::new());
279 base.store("key1".into(), "base_value".into()).unwrap();
280
281 let mut overlay = OverlayStorage::new(base.clone_box());
282 overlay
283 .store("key1".into(), "overlay_value".into())
284 .unwrap();
285
286 assert_eq!(
288 overlay.get(&"key1".into()).unwrap(),
289 Some("overlay_value".into())
290 );
291
292 assert_eq!(base.get(&"key1".into()).unwrap(), Some("base_value".into()));
294 }
295
296 #[test]
297 fn test_overlay_read_falls_through_to_base() {
298 let mut base: Box<dyn Storage<String, String>> =
299 Box::new(StorageHashMap::<String, String>::new());
300 base.store("key1".into(), "base_value".into()).unwrap();
301
302 let overlay = OverlayStorage::new(base);
303
304 assert_eq!(
305 overlay.get(&"key1".into()).unwrap(),
306 Some("base_value".into())
307 );
308 }
309
310 #[test]
311 fn test_overlay_delete_creates_tombstone() {
312 let mut base: Box<dyn Storage<String, String>> =
313 Box::new(StorageHashMap::<String, String>::new());
314 base.store("key1".into(), "base_value".into()).unwrap();
315
316 let mut overlay = OverlayStorage::new(base.clone_box());
317 let taken = overlay.take(&"key1".into()).unwrap();
318
319 assert_eq!(taken, Some("base_value".into()));
320 assert_eq!(overlay.get(&"key1".into()).unwrap(), None);
321
322 assert_eq!(base.get(&"key1".into()).unwrap(), Some("base_value".into()));
324 }
325
326 #[test]
327 fn test_overlay_keys_merges_correctly() {
328 let mut base: Box<dyn Storage<String, String>> =
329 Box::new(StorageHashMap::<String, String>::new());
330 base.store("base_key".into(), "value".into()).unwrap();
331
332 let mut overlay = OverlayStorage::new(base);
333 overlay.store("overlay_key".into(), "value".into()).unwrap();
334
335 let keys = overlay.keys().unwrap();
336 assert!(keys.contains(&"base_key".into()));
337 assert!(keys.contains(&"overlay_key".into()));
338 }
339
340 #[test]
341 fn test_overlay_clear_ignores_base() {
342 let mut base: Box<dyn Storage<String, String>> =
343 Box::new(StorageHashMap::<String, String>::new());
344 base.store("key1".into(), "base_value".into()).unwrap();
345
346 let mut overlay = OverlayStorage::new(base.clone_box());
347 overlay.clear().unwrap();
348
349 assert_eq!(overlay.get(&"key1".into()).unwrap(), None);
350 assert_eq!(overlay.keys().unwrap().len(), 0);
351
352 assert_eq!(base.get(&"key1".into()).unwrap(), Some("base_value".into()));
354 }
355
356 #[test]
357 fn test_overlay_clone_box_creates_independent_copy() {
358 let base: Box<dyn Storage<String, String>> =
359 Box::new(StorageHashMap::<String, String>::new());
360
361 let mut overlay = OverlayStorage::new(base);
362 overlay.store("key1".into(), "value1".into()).unwrap();
363
364 let mut cloned = overlay.clone_box();
365 cloned.store("key2".into(), "value2".into()).unwrap();
366
367 assert_eq!(overlay.get(&"key2".into()).unwrap(), None);
369 assert_eq!(cloned.get(&"key1".into()).unwrap(), Some("value1".into()));
371 assert_eq!(cloned.get(&"key2".into()).unwrap(), Some("value2".into()));
372 }
373
374 #[test]
375 fn test_overlay_count_accounts_for_tombstones() {
376 let mut base: Box<dyn Storage<String, String>> =
377 Box::new(StorageHashMap::<String, String>::new());
378 base.store("key1".into(), "value1".into()).unwrap();
379 base.store("key2".into(), "value2".into()).unwrap();
380
381 let mut overlay = OverlayStorage::new(base);
382 overlay.take(&"key1".into()).unwrap(); overlay.store("key3".into(), "value3".into()).unwrap(); assert_eq!(overlay.count().unwrap(), 2);
387 }
388
389 #[test]
390 fn test_overlay_into_iter_merges_correctly() {
391 let mut base: Box<dyn Storage<String, String>> =
392 Box::new(StorageHashMap::<String, String>::new());
393 base.store("base_key".into(), "base_value".into()).unwrap();
394 base.store("deleted_key".into(), "deleted_value".into())
395 .unwrap();
396
397 let mut overlay = OverlayStorage::new(base);
398 overlay
399 .store("overlay_key".into(), "overlay_value".into())
400 .unwrap();
401 overlay.take(&"deleted_key".into()).unwrap();
402
403 let entries: Vec<(String, String)> = overlay.into_iter().unwrap().collect();
404
405 assert_eq!(entries.len(), 2);
406 assert!(entries.contains(&("base_key".into(), "base_value".into())));
407 assert!(entries.contains(&("overlay_key".into(), "overlay_value".into())));
408 assert!(
409 !entries
410 .iter()
411 .any(|(k, _)| k == &String::from("deleted_key"))
412 );
413 }
414}