1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
/* This file is part of sled-overlay
 *
 * Copyright (C) 2023-2024 Dyne.org foundation
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

use std::collections::{BTreeMap, BTreeSet};

use sled::transaction::{ConflictableTransactionError, TransactionError};
use sled::{IVec, Transactional};

use crate::SledTreeOverlay;

/// Struct representing [`SledDbOverlay`] cache state
#[derive(Debug, Clone)]
pub struct SledDbOverlayState {
    /// Existing trees in `db` at the time of instantiation, so we can track newly opened trees.
    pub initial_tree_names: Vec<IVec>,
    /// New trees that have been opened, but didn't exist in `db` before.
    pub new_tree_names: Vec<IVec>,
    /// Pointers to sled trees that we have opened.
    pub trees: BTreeMap<IVec, sled::Tree>,
    /// Pointers to [`SledTreeOverlay`] instances that have been created.
    pub caches: BTreeMap<IVec, SledTreeOverlay>,
    /// Trees that were dropped.
    pub dropped_tree_names: Vec<IVec>,
    /// Protected trees, that we don't allow their removal,
    /// and don't drop their references if they become stale.
    pub protected_tree_names: Vec<IVec>,
}

impl SledDbOverlayState {
    /// Instantiate a new [`SledDbOverlayState`].
    pub fn new(initial_tree_names: Vec<IVec>, protected_tree_names: Vec<IVec>) -> Self {
        Self {
            initial_tree_names,
            new_tree_names: vec![],
            trees: BTreeMap::new(),
            caches: BTreeMap::new(),
            dropped_tree_names: vec![],
            protected_tree_names,
        }
    }

    /// Aggregate all the current overlay changes into [`sled::Batch`] instances and
    /// return vectors of [`sled::Tree`] and their respective [`sled::Batch`] that can
    /// be used for further operations. If there are no changes, both vectors will be empty.
    fn aggregate(&self) -> Result<(Vec<sled::Tree>, Vec<sled::Batch>), sled::Error> {
        let mut trees = vec![];
        let mut batches = vec![];

        for (key, tree) in &self.trees {
            if self.dropped_tree_names.contains(key) {
                return Err(sled::Error::CollectionNotFound(key.into()));
            }

            let Some(cache) = self.caches.get(key) else {
                return Err(sled::Error::CollectionNotFound(key.into()));
            };

            if let Some(batch) = cache.aggregate() {
                trees.push(tree.clone());
                batches.push(batch);
            }
        }

        Ok((trees, batches))
    }

    /// Add provided `db` overlay state changes from our own.
    pub fn add_diff(&mut self, other: &Self) {
        self.initial_tree_names
            .retain(|x| other.initial_tree_names.contains(x));

        for new_tree_name in &other.new_tree_names {
            if self.new_tree_names.contains(new_tree_name) {
                continue;
            }
            self.new_tree_names.push(new_tree_name.clone());
        }

        for (k, v) in other.trees.iter() {
            if self.trees.contains_key(k) {
                continue;
            };
            self.trees.insert(k.clone(), v.clone());
        }

        for (k, v) in other.caches.iter() {
            let Some(tree_overlay) = self.caches.get_mut(k) else {
                self.caches.insert(k.clone(), v.clone());
                continue;
            };

            // If the state is unchanged, we skip it
            if tree_overlay.state == v.state {
                continue;
            }

            // Add the diff from our tree overlay state
            tree_overlay.add_diff(&v.state);
        }

        for dropped_tree_name in &other.dropped_tree_names {
            if self.dropped_tree_names.contains(dropped_tree_name) {
                continue;
            }
            self.new_tree_names.retain(|x| x != dropped_tree_name);
            self.trees.remove(dropped_tree_name);
            self.caches.remove(dropped_tree_name);
            self.dropped_tree_names.push(dropped_tree_name.clone());
        }
    }

    /// Remove provided `db` overlay state changes from our own.
    pub fn remove_diff(&mut self, other: &Self) {
        // We have some assertions here to catch catastrophic
        // logic bugs here, as all our fields are depending on each
        // other when checking for differences.
        for initial_tree_name in &other.initial_tree_names {
            assert!(self.initial_tree_names.contains(initial_tree_name));
        }

        for new_tree_name in &other.new_tree_names {
            self.new_tree_names.retain(|x| x != new_tree_name);
            self.initial_tree_names.push(new_tree_name.clone());
        }

        for tree in other.trees.keys() {
            // If the tree pointer is not in our trees,
            // it must exist in our dropped tree names
            if !self.trees.contains_key(tree) {
                assert!(self.dropped_tree_names.contains(tree));
            };
        }

        for (k, v) in other.caches.iter() {
            // If the key is not in the cache, it must
            // be in the dropped tree names
            let Some(tree_overlay) = self.caches.get_mut(k) else {
                assert!(self.dropped_tree_names.contains(k));
                continue;
            };

            // If the state is unchanged, handle the stale tree
            if tree_overlay.state == v.state {
                // If tree is protected, we simply reset its cache
                if self.protected_tree_names.contains(k) {
                    tree_overlay.state.cache = BTreeMap::new();
                    tree_overlay.state.removed = BTreeSet::new();
                    tree_overlay.checkpoint();
                    continue;
                }

                // Drop the stale reference
                self.trees.remove(k);
                self.caches.remove(k);
                continue;
            }

            // Remove the diff from our tree overlay state
            tree_overlay.remove_diff(&v.state);
        }

        // Since we don't allow reopenning dropped trees, we must
        // have all the dropped tree names.
        for dropped_tree_name in &other.dropped_tree_names {
            assert!(!self.trees.contains_key(dropped_tree_name));
            assert!(!self.caches.contains_key(dropped_tree_name));
            assert!(self.dropped_tree_names.contains(dropped_tree_name));
            self.dropped_tree_names.retain(|x| x != dropped_tree_name);
            self.initial_tree_names.retain(|x| x != dropped_tree_name);
        }

        assert_eq!(
            self.initial_tree_names.len(),
            other.initial_tree_names.len() + other.new_tree_names.len()
                - other.dropped_tree_names.len()
        );
    }
}

impl Default for SledDbOverlayState {
    fn default() -> Self {
        Self::new(vec![], vec![])
    }
}

/// An overlay on top of an entire [`sled::Db`] which can span multiple trees
#[derive(Clone)]
pub struct SledDbOverlay {
    /// The [`sled::Db`] that is being overlayed.
    db: sled::Db,
    /// Current overlay cache state
    pub state: SledDbOverlayState,
    /// Checkpointed cache state to revert to
    checkpoint: SledDbOverlayState,
}

impl SledDbOverlay {
    /// Instantiate a new [`SledDbOverlay`] on top of a given [`sled::Db`].
    /// Note: Provided protected trees don't have to be opened as protected,
    /// as they are setup as protected here.
    pub fn new(db: &sled::Db, protected_tree_names: Vec<&[u8]>) -> Self {
        let initial_tree_names = db.tree_names();
        let protected_tree_names: Vec<IVec> = protected_tree_names
            .into_iter()
            .map(|tree_name| tree_name.into())
            .collect();
        Self {
            db: db.clone(),
            state: SledDbOverlayState::new(
                initial_tree_names.clone(),
                protected_tree_names.clone(),
            ),
            checkpoint: SledDbOverlayState::new(initial_tree_names, protected_tree_names),
        }
    }

    /// Create a new [`SledTreeOverlay`] on top of a given `tree_name`.
    /// This function will also open a new tree inside `db` regardless of if it has
    /// existed before, so for convenience, we also provide [`SledDbOverlay::purge_new_trees`]
    /// in case we decide we don't want to write the batches, and drop the new trees.
    /// Additionally, a boolean flag is passed to mark the oppened tree as protected,
    /// meanning that it can't be removed and its references will never be dropped.
    pub fn open_tree(&mut self, tree_name: &[u8], protected: bool) -> Result<(), sled::Error> {
        let tree_key: IVec = tree_name.into();

        // We don't allow reopening a dropped tree.
        if self.state.dropped_tree_names.contains(&tree_key) {
            return Err(sled::Error::CollectionNotFound(tree_key));
        }

        if self.state.trees.contains_key(&tree_key) {
            // We have already opened this tree.
            return Ok(());
        }

        // Open this tree in sled. In case it hasn't existed before, we also need
        // to track it in `self.new_tree_names`.
        let tree = self.db.open_tree(&tree_key)?;
        let cache = SledTreeOverlay::new(&tree);

        if !self.state.initial_tree_names.contains(&tree_key) {
            self.state.new_tree_names.push(tree_key.clone());
        }

        self.state.trees.insert(tree_key.clone(), tree);
        self.state.caches.insert(tree_key.clone(), cache);

        // Mark tree as protected if requested
        if protected && !self.state.protected_tree_names.contains(&tree_key) {
            self.state.protected_tree_names.push(tree_key);
        }

        Ok(())
    }

    /// Drop a sled tree from the overlay.
    pub fn drop_tree(&mut self, tree_name: &[u8]) -> Result<(), sled::Error> {
        let tree_key: IVec = tree_name.into();

        // Check if tree is protected
        if self.state.protected_tree_names.contains(&tree_key) {
            return Err(sled::Error::Unsupported(
                "Protected tree can't be dropped".to_string(),
            ));
        }

        // Check if already removed
        if self.state.dropped_tree_names.contains(&tree_key) {
            return Err(sled::Error::CollectionNotFound(tree_key));
        }

        // Check if its a new tree we created
        if self.state.new_tree_names.contains(&tree_key) {
            self.state.new_tree_names.retain(|x| *x != tree_key);
            self.state.trees.remove(&tree_key);
            self.state.caches.remove(&tree_key);
            self.state.dropped_tree_names.push(tree_key);

            return Ok(());
        }

        // Check if tree existed in the database
        if !self.state.initial_tree_names.contains(&tree_key) {
            return Err(sled::Error::CollectionNotFound(tree_key));
        }

        self.state.trees.remove(&tree_key);
        self.state.caches.remove(&tree_key);
        self.state.dropped_tree_names.push(tree_key);

        Ok(())
    }

    /// Drop newly created trees from the sled database. This is a convenience
    /// function that should be used when we decide that we don't want to apply
    /// any cache changes, and we want to revert back to the initial state.
    pub fn purge_new_trees(&self) -> Result<(), sled::Error> {
        for i in &self.state.new_tree_names {
            self.db.drop_tree(i)?;
        }

        Ok(())
    }

    /// Fetch the cache for a given tree.
    fn get_cache(&self, tree_key: &IVec) -> Result<&SledTreeOverlay, sled::Error> {
        if self.state.dropped_tree_names.contains(tree_key) {
            return Err(sled::Error::CollectionNotFound(tree_key.into()));
        }

        if let Some(v) = self.state.caches.get(tree_key) {
            return Ok(v);
        }

        Err(sled::Error::CollectionNotFound(tree_key.into()))
    }

    /// Fetch a mutable reference to the cache for a given tree.
    fn get_cache_mut(&mut self, tree_key: &IVec) -> Result<&mut SledTreeOverlay, sled::Error> {
        if self.state.dropped_tree_names.contains(tree_key) {
            return Err(sled::Error::CollectionNotFound(tree_key.into()));
        }

        if let Some(v) = self.state.caches.get_mut(tree_key) {
            return Ok(v);
        }
        Err(sled::Error::CollectionNotFound(tree_key.clone()))
    }

    /// Returns `true` if the overlay contains a value for a specified key in the specified
    /// tree cache.
    pub fn contains_key(&self, tree_key: &[u8], key: &[u8]) -> Result<bool, sled::Error> {
        let cache = self.get_cache(&tree_key.into())?;
        cache.contains_key(key)
    }

    /// Retrieve a value from the overlay if it exists in the specified tree cache.
    pub fn get(&self, tree_key: &[u8], key: &[u8]) -> Result<Option<IVec>, sled::Error> {
        let cache = self.get_cache(&tree_key.into())?;
        cache.get(key)
    }

    /// Returns `true` if specified tree cache is empty.
    pub fn is_empty(&self, tree_key: &[u8]) -> Result<bool, sled::Error> {
        let cache = self.get_cache(&tree_key.into())?;
        Ok(cache.is_empty())
    }

    /// Returns last value from the overlay if the specified tree cache is not empty.
    pub fn last(&self, tree_key: &[u8]) -> Result<Option<(IVec, IVec)>, sled::Error> {
        let cache = self.get_cache(&tree_key.into())?;
        cache.last()
    }

    /// Insert a key to a new value in the specified tree cache, returning the last value
    /// if it was set.
    pub fn insert(
        &mut self,
        tree_key: &[u8],
        key: &[u8],
        value: &[u8],
    ) -> Result<Option<IVec>, sled::Error> {
        let cache = self.get_cache_mut(&tree_key.into())?;
        cache.insert(key, value)
    }

    /// Delete a value in the specified tree cache, returning the old value if it existed.
    pub fn remove(&mut self, tree_key: &[u8], key: &[u8]) -> Result<Option<IVec>, sled::Error> {
        let cache = self.get_cache_mut(&tree_key.into())?;
        cache.remove(key)
    }

    /// Aggregate all the current overlay changes into [`sled::Batch`] instances and
    /// return vectors of [`sled::Tree`] and their respective [`sled::Batch`] that can
    /// be used for further operations. If there are no changes, both vectors will be empty.
    fn aggregate(&self) -> Result<(Vec<sled::Tree>, Vec<sled::Batch>), sled::Error> {
        self.state.aggregate()
    }

    /// Ensure all new trees that have been opened exist in sled by reopening them,
    /// atomically apply all batches on all trees as a transaction, and drop dropped
    /// trees from sled.
    /// This function **does not** perform a db flush. This should be done externally,
    /// since then there is a choice to perform either blocking or async IO.
    /// After execution is successful, caller should *NOT* use the overlay again.
    pub fn apply(&mut self) -> Result<(), TransactionError<sled::Error>> {
        // Ensure new trees exist
        for tree_key in &self.state.new_tree_names {
            let tree = self.db.open_tree(tree_key)?;
            self.state.trees.insert(tree_key.clone(), tree);
        }

        // Drop removed trees
        for tree in &self.state.dropped_tree_names {
            self.db.drop_tree(tree)?;
        }

        // Aggregate batches
        let (trees, batches) = self.aggregate()?;
        if trees.is_empty() {
            return Ok(());
        }

        // Perform an atomic transaction over all the collected trees and
        // apply the batches.
        trees.transaction(|trees| {
            for (index, tree) in trees.iter().enumerate() {
                tree.apply_batch(&batches[index])?;
            }

            Ok::<(), ConflictableTransactionError<sled::Error>>(())
        })?;

        Ok(())
    }

    /// Checkpoint current cache state so we can revert to it, if needed.
    pub fn checkpoint(&mut self) {
        self.checkpoint = self.state.clone();
    }

    /// Revert to current cache state checkpoint.
    pub fn revert_to_checkpoint(&mut self) -> Result<(), sled::Error> {
        // We first check if any new trees were opened, so we can remove them.
        let new_trees: Vec<_> = self
            .state
            .new_tree_names
            .iter()
            .filter(|tree| !self.checkpoint.new_tree_names.contains(tree))
            .collect();
        for tree in &new_trees {
            self.db.drop_tree(tree)?;
        }

        self.state = self.checkpoint.clone();

        Ok(())
    }

    /// Calculate differences from provided overlay state changes
    /// sequence. This can be used when we want to keep track of
    /// consecutive individual changes performed over the current
    /// overlay state. If the sequence is empty, current state
    /// is returned as the diff.
    pub fn diff(&self, sequence: &[SledDbOverlayState]) -> SledDbOverlayState {
        // Grab current state
        let mut current = self.state.clone();

        // Remove provided diffs sequence
        for diff in sequence {
            current.remove_diff(diff);
        }

        current
    }

    /// Add provided `db` overlay state changes from our own.
    pub fn add_diff(&mut self, other: &SledDbOverlayState) {
        self.state.add_diff(other)
    }

    /// Remove provided `db` overlay state changes from our own.
    pub fn remove_diff(&mut self, other: &SledDbOverlayState) {
        self.state.remove_diff(other)
    }

    /// For a provided `SledDbOverlayState`, ensure all new trees that have been
    /// opened exist in sled by reopening them, atomically apply all batches on
    /// all trees as a transaction, and drop dropped trees from sled.
    /// After that, remove the state changes from our own. This is will also mutate
    /// the initial trees, based on what was oppened and/or dropped.
    /// This function **does not** perform a db flush. This should be done externally,
    /// since then there is a choice to perform either blocking or async IO.
    pub fn apply_diff(
        &mut self,
        other: &mut SledDbOverlayState,
    ) -> Result<(), TransactionError<sled::Error>> {
        // Ensure new trees exist
        for tree_key in &other.new_tree_names {
            let tree = self.db.open_tree(tree_key)?;
            other.trees.insert(tree_key.clone(), tree);
        }

        // Drop removed trees
        for tree in &other.dropped_tree_names {
            self.db.drop_tree(tree)?;
        }

        // Aggregate batches
        let (trees, batches) = other.aggregate()?;
        if trees.is_empty() {
            return Ok(());
        }

        // Perform an atomic transaction over all the collected trees and
        // apply the batches.
        trees.transaction(|trees| {
            for (index, tree) in trees.iter().enumerate() {
                tree.apply_batch(&batches[index])?;
            }

            Ok::<(), ConflictableTransactionError<sled::Error>>(())
        })?;

        // Remove changes from our current state
        self.remove_diff(other);

        Ok(())
    }
}