1use std::collections::{BTreeMap, BTreeSet};
20
21use sled::{
22 transaction::{ConflictableTransactionError, TransactionError},
23 IVec, Transactional,
24};
25
26use crate::{SledTreeOverlay, SledTreeOverlayIter, SledTreeOverlayStateDiff};
27
28#[derive(Debug, Clone)]
30pub struct SledDbOverlayState {
31 pub initial_tree_names: Vec<IVec>,
33 pub new_tree_names: Vec<IVec>,
35 pub caches: BTreeMap<IVec, SledTreeOverlay>,
37 pub dropped_trees: BTreeMap<IVec, SledTreeOverlayStateDiff>,
39 pub protected_tree_names: Vec<IVec>,
42}
43
44impl SledDbOverlayState {
45 pub fn new(initial_tree_names: Vec<IVec>, protected_tree_names: Vec<IVec>) -> Self {
47 Self {
48 initial_tree_names,
49 new_tree_names: vec![],
50 caches: BTreeMap::new(),
51 dropped_trees: BTreeMap::new(),
52 protected_tree_names,
53 }
54 }
55
56 fn aggregate(&self) -> Result<(Vec<sled::Tree>, Vec<sled::Batch>), sled::Error> {
60 let mut trees = vec![];
61 let mut batches = vec![];
62
63 for (key, cache) in self.caches.iter() {
64 if self.dropped_trees.contains_key(key) {
65 return Err(sled::Error::CollectionNotFound(key.into()));
66 }
67
68 if let Some(batch) = cache.aggregate() {
69 trees.push(cache.tree.clone());
70 batches.push(batch);
71 }
72 }
73
74 Ok((trees, batches))
75 }
76
77 pub fn add_diff(
79 &mut self,
80 db: &sled::Db,
81 diff: &SledDbOverlayStateDiff,
82 ) -> Result<(), sled::Error> {
83 self.initial_tree_names
84 .retain(|x| diff.initial_tree_names.contains(x));
85
86 for (k, (cache, drop)) in diff.caches.iter() {
87 if *drop {
88 assert!(!self.protected_tree_names.contains(k));
89 self.new_tree_names.retain(|x| x != k);
90 self.caches.remove(k);
91 self.dropped_trees.insert(k.clone(), cache.clone());
92 continue;
93 }
94
95 let Some(tree_overlay) = self.caches.get_mut(k) else {
96 if !self.initial_tree_names.contains(k) && !self.new_tree_names.contains(k) {
97 self.new_tree_names.push(k.clone());
98 }
99 let mut overlay = SledTreeOverlay::new(&db.open_tree(k)?);
100 overlay.add_diff(cache);
101 self.caches.insert(k.clone(), overlay);
102 continue;
103 };
104
105 tree_overlay.add_diff(cache);
107 }
108
109 for (k, (cache, restored)) in &diff.dropped_trees {
110 if !restored {
112 if self.dropped_trees.contains_key(k) {
113 continue;
114 }
115 self.new_tree_names.retain(|x| x != k);
116 self.caches.remove(k);
117 self.dropped_trees.insert(k.clone(), cache.clone());
118 continue;
119 }
120 assert!(!self.protected_tree_names.contains(k));
121
122 self.initial_tree_names.retain(|x| x != k);
124 if !self.new_tree_names.contains(k) {
125 self.new_tree_names.push(k.clone());
126 }
127
128 let mut overlay = SledTreeOverlay::new(&db.open_tree(k)?);
129 overlay.add_diff(cache);
130 self.caches.insert(k.clone(), overlay);
131 }
132
133 Ok(())
134 }
135
136 pub fn remove_diff(&mut self, diff: &SledDbOverlayStateDiff) {
138 for (k, (cache, drop)) in diff.caches.iter() {
142 assert!(
144 self.initial_tree_names.contains(k)
145 || self.new_tree_names.contains(k)
146 || self.dropped_trees.contains_key(k)
147 );
148 if !self.initial_tree_names.contains(k) {
149 self.initial_tree_names.push(k.clone());
150 }
151 self.new_tree_names.retain(|x| x != k);
152
153 if *drop {
155 assert!(!self.protected_tree_names.contains(k));
156 self.initial_tree_names.retain(|x| x != k);
157 self.new_tree_names.retain(|x| x != k);
158 self.caches.remove(k);
159 self.dropped_trees.remove(k);
160 continue;
161 }
162
163 let Some(tree_overlay) = self.caches.get_mut(k) else {
166 let Some(tree_overlay) = self.dropped_trees.get_mut(k) else {
167 continue;
168 };
169 tree_overlay.update_values(cache);
170 continue;
171 };
172
173 if tree_overlay.state == cache.into() {
175 if self.protected_tree_names.contains(k) {
177 tree_overlay.state.cache = BTreeMap::new();
178 tree_overlay.state.removed = BTreeSet::new();
179 tree_overlay.checkpoint();
180 continue;
181 }
182
183 self.caches.remove(k);
185 continue;
186 }
187
188 tree_overlay.remove_diff(cache);
190 }
191
192 for (k, (cache, restored)) in diff.dropped_trees.iter() {
194 assert!(
196 self.initial_tree_names.contains(k)
197 || self.new_tree_names.contains(k)
198 || self.dropped_trees.contains_key(k)
199 );
200
201 if !restored {
203 assert!(!self.protected_tree_names.contains(k));
204 self.initial_tree_names.retain(|x| x != k);
205 self.new_tree_names.retain(|x| x != k);
206 self.caches.remove(k);
207 self.dropped_trees.remove(k);
208 continue;
209 }
210
211 self.initial_tree_names.retain(|x| x != k);
213 if !self.new_tree_names.contains(k) {
214 self.new_tree_names.push(k.clone());
215 }
216
217 let Some(tree_overlay) = self.caches.get_mut(k) else {
219 continue;
220 };
221
222 if tree_overlay.state == cache.into() {
224 if self.protected_tree_names.contains(k) {
226 tree_overlay.state.cache = BTreeMap::new();
227 tree_overlay.state.removed = BTreeSet::new();
228 tree_overlay.checkpoint();
229 continue;
230 }
231
232 self.caches.remove(k);
234 continue;
235 }
236
237 tree_overlay.remove_diff(cache);
239 }
240 }
241}
242
243impl Default for SledDbOverlayState {
244 fn default() -> Self {
245 Self::new(vec![], vec![])
246 }
247}
248
249#[derive(Debug, Default, Clone, PartialEq)]
251pub struct SledDbOverlayStateDiff {
252 pub initial_tree_names: Vec<IVec>,
254 pub caches: BTreeMap<IVec, (SledTreeOverlayStateDiff, bool)>,
260 pub dropped_trees: BTreeMap<IVec, (SledTreeOverlayStateDiff, bool)>,
265}
266
267impl SledDbOverlayStateDiff {
268 pub fn new(state: &SledDbOverlayState) -> Result<Self, sled::Error> {
272 let mut caches = BTreeMap::new();
273 let mut dropped_trees = BTreeMap::new();
274
275 for (key, cache) in state.caches.iter() {
276 let mut diff = cache.diff(&[])?;
277
278 if diff.cache.is_empty()
280 && diff.removed.is_empty()
281 && !state.new_tree_names.contains(key)
282 {
283 continue;
284 }
285
286 if state.new_tree_names.contains(key) {
288 diff.removed = BTreeMap::new();
289 }
290
291 caches.insert(key.clone(), (diff, false));
292 }
293
294 for (key, cache) in state.dropped_trees.iter() {
295 dropped_trees.insert(key.clone(), (cache.clone(), false));
296 }
297
298 Ok(Self {
299 initial_tree_names: state.initial_tree_names.clone(),
300 caches,
301 dropped_trees,
302 })
303 }
304
305 fn aggregate(
310 &self,
311 state_trees: &BTreeMap<IVec, sled::Tree>,
312 ) -> Result<(Vec<sled::Tree>, Vec<sled::Batch>), sled::Error> {
313 let mut trees = vec![];
314 let mut batches = vec![];
315
316 for (key, (cache, drop)) in self.caches.iter() {
317 if *drop {
318 continue;
319 }
320
321 let Some(tree) = state_trees.get(key) else {
322 return Err(sled::Error::CollectionNotFound(key.into()));
323 };
324
325 if let Some(batch) = cache.aggregate() {
326 trees.push(tree.clone());
327 batches.push(batch);
328 }
329 }
330
331 for (key, (cache, restored)) in self.dropped_trees.iter() {
332 if !restored {
333 continue;
334 }
335
336 let Some(tree) = state_trees.get(key) else {
337 return Err(sled::Error::CollectionNotFound(key.into()));
338 };
339
340 if let Some(batch) = cache.aggregate() {
341 trees.push(tree.clone());
342 batches.push(batch);
343 }
344 }
345
346 Ok((trees, batches))
347 }
348
349 pub fn inverse(&self) -> Self {
352 let mut diff = Self {
353 initial_tree_names: self.initial_tree_names.clone(),
354 ..Default::default()
355 };
356
357 for (key, (cache, drop)) in self.caches.iter() {
358 let inverse = cache.inverse();
359 let drop = if inverse.cache.is_empty()
362 && inverse.removed.is_empty()
363 && !self.initial_tree_names.contains(key)
364 {
365 !drop
366 } else {
367 inverse.cache.is_empty() && !self.initial_tree_names.contains(key)
368 };
369 diff.caches.insert(key.clone(), (inverse, drop));
371 }
372
373 for (key, (cache, restored)) in self.dropped_trees.iter() {
374 if !self.initial_tree_names.contains(key) {
375 continue;
376 }
377 diff.dropped_trees
378 .insert(key.clone(), (cache.clone(), !restored));
379 }
380
381 diff
382 }
383
384 pub fn remove_diff(&mut self, other: &Self) {
386 for initial_tree_name in &other.initial_tree_names {
390 assert!(self.initial_tree_names.contains(initial_tree_name));
391 }
392
393 for (key, cache_pair) in other.caches.iter() {
395 if !self.initial_tree_names.contains(key) {
396 self.initial_tree_names.push(key.clone());
397 }
398
399 let Some(tree_overlay) = self.caches.get_mut(key) else {
402 let Some((tree_overlay, _)) = self.dropped_trees.get_mut(key) else {
403 continue;
404 };
405 tree_overlay.update_values(&cache_pair.0);
406 continue;
407 };
408
409 if tree_overlay == cache_pair {
411 self.caches.remove(key);
413 continue;
414 }
415
416 tree_overlay.0.remove_diff(&cache_pair.0);
418 }
419
420 for (key, (cache, restored)) in other.dropped_trees.iter() {
423 assert!(!self.caches.contains_key(key));
424 assert!(self.dropped_trees.contains_key(key));
425
426 if *restored {
428 self.caches.insert(key.clone(), (cache.clone(), false));
429 }
430
431 self.initial_tree_names.retain(|x| x != key);
433 self.dropped_trees.remove(key);
434 }
435 }
436}
437
438#[derive(Clone)]
440pub struct SledDbOverlay {
441 db: sled::Db,
443 pub state: SledDbOverlayState,
445 checkpoint: SledDbOverlayState,
447}
448
449impl SledDbOverlay {
450 pub fn new(db: &sled::Db, protected_tree_names: Vec<&[u8]>) -> Self {
454 let initial_tree_names = db.tree_names();
455 let protected_tree_names: Vec<IVec> = protected_tree_names
456 .into_iter()
457 .map(|tree_name| tree_name.into())
458 .collect();
459 Self {
460 db: db.clone(),
461 state: SledDbOverlayState::new(
462 initial_tree_names.clone(),
463 protected_tree_names.clone(),
464 ),
465 checkpoint: SledDbOverlayState::new(initial_tree_names, protected_tree_names),
466 }
467 }
468
469 pub fn open_tree(&mut self, tree_name: &[u8], protected: bool) -> Result<(), sled::Error> {
476 let tree_key: IVec = tree_name.into();
477
478 if self.state.dropped_trees.contains_key(&tree_key) {
480 return Err(sled::Error::CollectionNotFound(tree_key));
481 }
482
483 if self.state.caches.contains_key(&tree_key) {
484 return Ok(());
486 }
487
488 let tree = self.db.open_tree(&tree_key)?;
491 let cache = SledTreeOverlay::new(&tree);
492
493 if !self.state.initial_tree_names.contains(&tree_key) {
494 self.state.new_tree_names.push(tree_key.clone());
495 }
496
497 self.state.caches.insert(tree_key.clone(), cache);
498
499 if protected && !self.state.protected_tree_names.contains(&tree_key) {
501 self.state.protected_tree_names.push(tree_key);
502 }
503
504 Ok(())
505 }
506
507 pub fn drop_tree(&mut self, tree_name: &[u8]) -> Result<(), sled::Error> {
509 let tree_key: IVec = tree_name.into();
510
511 if self.state.protected_tree_names.contains(&tree_key) {
513 return Err(sled::Error::Unsupported(
514 "Protected tree can't be dropped".to_string(),
515 ));
516 }
517
518 if self.state.dropped_trees.contains_key(&tree_key) {
520 return Err(sled::Error::CollectionNotFound(tree_key));
521 }
522
523 if self.state.new_tree_names.contains(&tree_key) {
525 self.state.new_tree_names.retain(|x| *x != tree_key);
526 let tree = match self.get_cache(&tree_key) {
527 Ok(cache) => &cache.tree,
528 _ => &self.db.open_tree(&tree_key)?,
529 };
530 let diff = SledTreeOverlayStateDiff::new_dropped(tree);
531 self.state.caches.remove(&tree_key);
532 self.state.dropped_trees.insert(tree_key, diff);
533
534 return Ok(());
535 }
536
537 if !self.state.initial_tree_names.contains(&tree_key) {
539 return Err(sled::Error::CollectionNotFound(tree_key));
540 }
541
542 let tree = match self.get_cache(&tree_key) {
543 Ok(cache) => &cache.tree,
544 _ => &self.db.open_tree(&tree_key)?,
545 };
546 let diff = SledTreeOverlayStateDiff::new_dropped(tree);
547 self.state.caches.remove(&tree_key);
548 self.state.dropped_trees.insert(tree_key, diff);
549
550 Ok(())
551 }
552
553 pub fn purge_new_trees(&self) -> Result<(), sled::Error> {
557 for i in &self.state.new_tree_names {
558 self.db.drop_tree(i)?;
559 }
560
561 Ok(())
562 }
563
564 fn get_cache(&self, tree_key: &IVec) -> Result<&SledTreeOverlay, sled::Error> {
566 if self.state.dropped_trees.contains_key(tree_key) {
567 return Err(sled::Error::CollectionNotFound(tree_key.into()));
568 }
569
570 if let Some(v) = self.state.caches.get(tree_key) {
571 return Ok(v);
572 }
573
574 Err(sled::Error::CollectionNotFound(tree_key.into()))
575 }
576
577 fn get_cache_mut(&mut self, tree_key: &IVec) -> Result<&mut SledTreeOverlay, sled::Error> {
579 if self.state.dropped_trees.contains_key(tree_key) {
580 return Err(sled::Error::CollectionNotFound(tree_key.into()));
581 }
582
583 if let Some(v) = self.state.caches.get_mut(tree_key) {
584 return Ok(v);
585 }
586 Err(sled::Error::CollectionNotFound(tree_key.clone()))
587 }
588
589 pub fn get_state_trees(&self) -> BTreeMap<IVec, sled::Tree> {
591 let mut state_trees = BTreeMap::new();
593 for (key, cache) in self.state.caches.iter() {
594 state_trees.insert(key.clone(), cache.tree.clone());
595 }
596
597 state_trees
598 }
599
600 pub fn contains_key(&self, tree_key: &[u8], key: &[u8]) -> Result<bool, sled::Error> {
603 let cache = self.get_cache(&tree_key.into())?;
604 cache.contains_key(key)
605 }
606
607 pub fn get(&self, tree_key: &[u8], key: &[u8]) -> Result<Option<IVec>, sled::Error> {
609 let cache = self.get_cache(&tree_key.into())?;
610 cache.get(key)
611 }
612
613 pub fn is_empty(&self, tree_key: &[u8]) -> Result<bool, sled::Error> {
615 let cache = self.get_cache(&tree_key.into())?;
616 Ok(cache.is_empty())
617 }
618
619 pub fn last(&self, tree_key: &[u8]) -> Result<Option<(IVec, IVec)>, sled::Error> {
621 let cache = self.get_cache(&tree_key.into())?;
622 cache.last()
623 }
624
625 pub fn insert(
628 &mut self,
629 tree_key: &[u8],
630 key: &[u8],
631 value: &[u8],
632 ) -> Result<Option<IVec>, sled::Error> {
633 let cache = self.get_cache_mut(&tree_key.into())?;
634 cache.insert(key, value)
635 }
636
637 pub fn remove(&mut self, tree_key: &[u8], key: &[u8]) -> Result<Option<IVec>, sled::Error> {
639 let cache = self.get_cache_mut(&tree_key.into())?;
640 cache.remove(key)
641 }
642
643 pub fn clear(&mut self, tree_key: &[u8]) -> Result<(), sled::Error> {
646 let cache = self.get_cache_mut(&tree_key.into())?;
647 cache.clear()
648 }
649
650 fn aggregate(&self) -> Result<(Vec<sled::Tree>, Vec<sled::Batch>), sled::Error> {
654 self.state.aggregate()
655 }
656
657 pub fn apply(&mut self) -> Result<(), TransactionError<sled::Error>> {
664 let new_tree_names = self.state.new_tree_names.clone();
666 for tree_key in &new_tree_names {
667 let tree = self.db.open_tree(tree_key)?;
668 let cache = self.get_cache_mut(tree_key)?;
670 cache.tree = tree;
671 }
672
673 for tree in self.state.dropped_trees.keys() {
675 self.db.drop_tree(tree)?;
676 }
677
678 let (trees, batches) = self.aggregate()?;
680 if trees.is_empty() {
681 return Ok(());
682 }
683
684 trees.transaction(|trees| {
687 for (index, tree) in trees.iter().enumerate() {
688 tree.apply_batch(&batches[index])?;
689 }
690
691 Ok::<(), ConflictableTransactionError<sled::Error>>(())
692 })?;
693
694 Ok(())
695 }
696
697 pub fn checkpoint(&mut self) {
699 self.checkpoint = self.state.clone();
700 }
701
702 pub fn revert_to_checkpoint(&mut self) -> Result<(), sled::Error> {
704 let new_trees: Vec<_> = self
706 .state
707 .new_tree_names
708 .iter()
709 .filter(|tree| !self.checkpoint.new_tree_names.contains(tree))
710 .collect();
711 for tree in &new_trees {
712 self.db.drop_tree(tree)?;
713 }
714
715 self.state = self.checkpoint.clone();
716
717 Ok(())
718 }
719
720 pub fn diff(
726 &self,
727 sequence: &[SledDbOverlayStateDiff],
728 ) -> Result<SledDbOverlayStateDiff, sled::Error> {
729 let mut current = SledDbOverlayStateDiff::new(&self.state)?;
731
732 for diff in sequence {
734 current.remove_diff(diff);
735 }
736
737 Ok(current)
738 }
739
740 pub fn add_diff(&mut self, diff: &SledDbOverlayStateDiff) -> Result<(), sled::Error> {
742 self.state.add_diff(&self.db, diff)
743 }
744
745 pub fn remove_diff(&mut self, diff: &SledDbOverlayStateDiff) {
747 self.state.remove_diff(diff)
748 }
749
750 pub fn apply_diff(
758 &mut self,
759 diff: &SledDbOverlayStateDiff,
760 ) -> Result<(), TransactionError<sled::Error>> {
761 for tree in diff.dropped_trees.keys() {
763 if self.state.protected_tree_names.contains(tree) {
764 return Err(TransactionError::Storage(sled::Error::Unsupported(
765 "Protected tree can't be dropped".to_string(),
766 )));
767 }
768 }
769 for (tree_key, (_, drop)) in diff.caches.iter() {
770 if *drop && self.state.protected_tree_names.contains(tree_key) {
771 return Err(TransactionError::Storage(sled::Error::Unsupported(
772 "Protected tree can't be dropped".to_string(),
773 )));
774 }
775 }
776
777 let mut state_trees = self.get_state_trees();
779
780 for (tree_key, (_, drop)) in diff.caches.iter() {
782 if !self.state.initial_tree_names.contains(tree_key)
784 && !self.state.new_tree_names.contains(tree_key)
785 {
786 self.state.new_tree_names.push(tree_key.clone());
787 }
788
789 if *drop {
791 self.db.drop_tree(tree_key)?;
792 continue;
793 }
794
795 if !state_trees.contains_key(tree_key) {
796 let tree = self.db.open_tree(tree_key)?;
797 state_trees.insert(tree_key.clone(), tree);
798 }
799 }
800
801 for (tree_key, (_, restored)) in diff.dropped_trees.iter() {
803 if !restored {
804 state_trees.remove(tree_key);
805 self.db.drop_tree(tree_key)?;
806 continue;
807 }
808
809 if !self.state.initial_tree_names.contains(tree_key)
811 && !self.state.new_tree_names.contains(tree_key)
812 {
813 self.state.new_tree_names.push(tree_key.clone());
814 }
815
816 if !state_trees.contains_key(tree_key) {
817 let tree = self.db.open_tree(tree_key)?;
818 state_trees.insert(tree_key.clone(), tree);
819 }
820 }
821
822 let (trees, batches) = diff.aggregate(&state_trees)?;
824 if trees.is_empty() {
825 self.remove_diff(diff);
826 return Ok(());
827 }
828
829 trees.transaction(|trees| {
832 for (index, tree) in trees.iter().enumerate() {
833 tree.apply_batch(&batches[index])?;
834 }
835
836 Ok::<(), ConflictableTransactionError<sled::Error>>(())
837 })?;
838
839 self.remove_diff(diff);
841
842 Ok(())
843 }
844
845 pub fn iter(&self, tree_key: &[u8]) -> Result<SledTreeOverlayIter<'_>, sled::Error> {
847 let cache = self.get_cache(&tree_key.into())?;
848 Ok(cache.iter())
849 }
850}