snapper_box/
file.rs

1//! Middle level abstraction logic.
2//!
3//! This module models a file as an LSM, with encrypted updates
4
5use std::{
6    collections::HashMap,
7    fs::{File, OpenOptions},
8    hash::Hash,
9    io::{Read, Seek, SeekFrom, Write},
10    path::Path,
11    sync::Arc,
12};
13
14use serde::{de::DeserializeOwned, Deserialize, Serialize};
15use serde_bytes::ByteBuf;
16use snafu::ResultExt;
17use tracing::{debug, instrument, warn};
18
19use crate::{
20    crypto::{CipherText, ClearText, Key},
21    error::{BackendError, EntryIO},
22    file::segment::Segment,
23};
24mod segment;
25
26/// Internal, data containing struct used for serializing log entries
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28struct Entry {
29    /// Items being stored
30    items: HashMap<ByteBuf, ByteBuf>,
31}
32
33/// LSM Abstraction over something [`File`] like (i.e. implementing [`Read`], [`Write`], and [`Seek`]).
34///
35/// This type maintains a write cache in memory, and you _must_ call the `flush` method for the changes
36/// to be persisted. This type _will not_ automatically flush.l
37pub struct LsmFile<T, K>
38where
39    T: Read + Write + Seek,
40    K: Key,
41{
42    /// Underlying [`File`](std::fs::File) like object
43    file: T,
44    /// Segment offset table
45    offsets: HashMap<Vec<u8>, u64>,
46    /// Cache containing yet-to-be-written items
47    cache: Option<Entry>,
48    /// compression level for the file
49    compression: Option<i32>,
50    /// Key for the file
51    key: Arc<K>,
52    /// Maximum number of cache entries
53    max_cache_entries: usize,
54}
55
56impl<T, K> LsmFile<T, K>
57where
58    T: Read + Write + Seek,
59    K: Key,
60{
61    /// Creates a new `LsmFile` from the provided, already existing, [`File`] like object.
62    ///
63    /// Accepts values for the compression level for the file, the key to be used for encryption/decryption,
64    /// as well as the maximum number of cache entries before a mandatory flush occurs, which defaults to 100.
65    ///
66    /// # Errors
67    ///
68    ///   * `Error::EntryIO` if an error occurs while reading the entries
69    ///   * Will bubble up any segment errors`
70    #[instrument(skip(file, key))]
71    pub fn new(
72        mut file: T,
73        compression: Option<i32>,
74        key: Arc<K>,
75        max_cache_entries: Option<usize>,
76    ) -> Result<Self, BackendError> {
77        // Seek to start of file
78        file.seek(SeekFrom::Start(0)).context(EntryIO)?;
79        // Start trying to read segments
80        let mut offset: u64 = 0;
81        let mut offsets: HashMap<Vec<u8>, u64> = HashMap::new();
82        while let Ok(segment) = Segment::read_owned(&mut file) {
83            match CipherText::try_from(segment) {
84                Ok(x) => {
85                    let cleartext = x.decrypt(&*key)?;
86                    let entry: Entry = cleartext.deserialize()?;
87                    for (k, _) in entry.items {
88                        offsets.insert(k.to_vec(), offset);
89                    }
90                }
91                Err(e) => {
92                    warn!(?e, "Failed to decode a segment");
93                }
94            }
95            offset = file.stream_position().context(EntryIO)?;
96        }
97        Ok(Self {
98            file,
99            offsets,
100            cache: None,
101            compression,
102            key,
103            max_cache_entries: max_cache_entries.unwrap_or(100),
104        })
105    }
106
107    /// Retries a value from the store given a key, if the value exists in the store.
108    ///
109    /// The most recently inserted version of the value will be returned
110    ///
111    /// # Errors
112    ///
113    /// Will return an error if any IO error occurs, or if the value fails to deserialize.
114    pub fn get<C, V>(&mut self, key: &C) -> Result<Option<V>, BackendError>
115    where
116        C: Serialize,
117        V: DeserializeOwned,
118    {
119        // Attempt to serialize the key
120        let key = match serde_cbor::to_vec(key) {
121            Ok(k) => ByteBuf::from(k),
122            // Intentionally hide the underlying serde error, to avoid leaking sensitive data into
123            // the logs
124            Err(_) => return Err(BackendError::ItemSerialization),
125        };
126        // Look up the key in the cache, return from it if possible
127        if let Some(cache) = self.cache.as_ref() {
128            if let Some(bytes) = cache.items.get(&key) {
129                match serde_cbor::from_slice(bytes) {
130                    Ok(v) => return Ok(Some(v)),
131                    // Intentionally hide the underlying serde error, to avoid leaking sensitive
132                    // data into the logs
133                    Err(_) => return Err(BackendError::ItemDeserialization),
134                }
135            }
136        }
137        // Look up the key in the offsets table
138        if let Some(offset) = self.offsets.get(&*key) {
139            // Find the segment
140            self.file.seek(SeekFrom::Start(*offset)).context(EntryIO)?;
141            let segment = Segment::read_owned(&mut self.file)?;
142            // Decrypt it
143            let ciphertext = CipherText::try_from(segment)?;
144            let cleartext = ciphertext.decrypt(&*self.key)?;
145            let entry: Entry = cleartext.deserialize()?;
146            if let Some(bytes) = entry.items.get(&key) {
147                match serde_cbor::from_slice(bytes) {
148                    Ok(v) => Ok(Some(v)),
149                    // Intentionally hide the underlying serde error, to avoid leaking sensitive
150                    // data into the logs
151                    Err(_) => Err(BackendError::ItemDeserialization),
152                }
153            } else {
154                warn!("Offsets table contained a offset fora nonexistent value");
155                Ok(None)
156            }
157        } else {
158            Ok(None)
159        }
160    }
161
162    /// Inserts an object into the store
163    ///
164    /// This method will write to the cache, however, if this insertion makes the cache size greater than or
165    /// equal to the maximum number of entries, the current cache will be flushed.
166    ///
167    /// # Errors
168    ///   * Will return an error if either the key or value fail to serialize
169    ///   * Will bubble up any IO errors that occur, if a flush occurs
170    #[instrument(skip(self, key, value))]
171    pub fn insert<C, V>(&mut self, key: &C, value: &V) -> Result<(), BackendError>
172    where
173        C: Serialize,
174        V: Serialize,
175    {
176        // Get the cache
177        let mut cache = self.cache.take().unwrap_or_else(|| Entry {
178            items: HashMap::new(),
179        });
180        // Serialize the values
181        // Intentionally hide the serde errors, to avoid leaking sensitive information
182        let key = serde_cbor::to_vec(key).map_err(|_| BackendError::ItemSerialization)?;
183        let value = serde_cbor::to_vec(value).map_err(|_| BackendError::ItemSerialization)?;
184        // Insert it into the cache
185        cache.items.insert(ByteBuf::from(key), ByteBuf::from(value));
186        let length = cache.items.len();
187        self.cache = Some(cache);
188        // Flush if needed
189        if length >= self.max_cache_entries {
190            debug!("Flushing cache");
191            self.flush()?;
192        }
193        Ok(())
194    }
195
196    /// Flushes the cache to disk
197    ///
198    /// # Errors
199    ///
200    /// Will bubble up any IO errors
201    pub fn flush(&mut self) -> Result<(), BackendError> {
202        // Get the end of the file
203        let offset = self.file.seek(SeekFrom::End(0)).context(EntryIO)?;
204        // Take the cache
205        if let Some(cache) = self.cache.take() {
206            // Update the offset table
207            for k in cache.items.keys() {
208                self.offsets.insert((&*k).to_vec(), offset);
209            }
210            // Serialize it and encrypt it
211            let plaintext = ClearText::new(&cache)?;
212            let ciphertext = plaintext.encrypt(&*self.key, self.compression)?;
213            let segment: Segment<'_> = ciphertext.into();
214            segment.write(&mut self.file)?;
215            self.file.flush().context(EntryIO)?;
216
217            Ok(())
218        } else {
219            Ok(())
220        }
221    }
222
223    /// Converts the current contents of the store to a [`HashMap`].
224    ///
225    /// Iterates through this `LsmFIle`'s keys, and insert them all into the returned [`HashMap`]. The most
226    /// up to date value will be used for each key.
227    ///
228    /// # Errors
229    ///
230    /// Will bubble up any error that occurs during a `get` operation
231    pub fn to_hashmap<C, V>(&mut self) -> Result<HashMap<C, V>, BackendError>
232    where
233        C: DeserializeOwned + Serialize + Hash + Eq,
234        V: DeserializeOwned,
235    {
236        let mut ret = HashMap::new();
237        let keys: Vec<Vec<u8>> = self.offsets.keys().cloned().collect();
238        // Insert from storage
239        for key in keys {
240            // Deserialize the key
241            let key: C =
242                serde_cbor::from_slice(&key).map_err(|_| BackendError::ItemDeserialization)?;
243            // get the value
244            let value = self.get(&key)?.ok_or(BackendError::InvalidLsmState)?;
245            ret.insert(key, value);
246        }
247        // Insert from cache
248        if let Some(cache) = self.cache.as_ref() {
249            let keys: Vec<ByteBuf> = cache.items.keys().cloned().collect();
250            for key in keys {
251                // Deserialize the key
252                let key_deser: C =
253                    serde_cbor::from_slice(&key).map_err(|_| BackendError::ItemDeserialization)?;
254                // get the value
255                let value = cache.items.get(&key).ok_or(BackendError::InvalidLsmState)?;
256                let value: V =
257                    serde_cbor::from_slice(value).map_err(|_| BackendError::ItemDeserialization)?;
258                ret.insert(key_deser, value);
259            }
260        }
261        Ok(ret)
262    }
263    /// Converts the current contents of the store to a [`Vec`] of pairs.
264    ///
265    /// Iterates through this `LsmFIle`'s keys, and insert them all into the returned [`Vec`]]. The most
266    /// up to date value will be used for each key.
267    ///
268    /// # Errors
269    ///
270    /// Will bubble up any error that occurs during a `get` operation
271    pub fn to_pairs<C, V>(&mut self) -> Result<Vec<(C, V)>, BackendError>
272    where
273        C: DeserializeOwned + Serialize + Eq + Clone,
274        V: DeserializeOwned + Clone,
275    {
276        let mut ret: Vec<(C, V)> = Vec::new();
277        let keys: Vec<Vec<u8>> = self.offsets.keys().cloned().collect();
278        // Insert from storage
279        for key in keys {
280            // Deserialize the key
281            let key: C =
282                serde_cbor::from_slice(&key).map_err(|_| BackendError::ItemDeserialization)?;
283            // get the value
284            let value = self.get(&key)?.ok_or(BackendError::InvalidLsmState)?;
285            if let Some(index) =
286                ret.iter()
287                    .cloned()
288                    .enumerate()
289                    .find_map(|(x, (k, _))| if k == key { Some(x) } else { None })
290            {
291                ret[index] = (key, value);
292            } else {
293                ret.push((key, value));
294            }
295        }
296        // Insert from cache
297        if let Some(cache) = self.cache.as_ref() {
298            let keys: Vec<ByteBuf> = cache.items.keys().cloned().collect();
299            for key in keys {
300                // Deserialize the key
301                let key_deser: C =
302                    serde_cbor::from_slice(&key).map_err(|_| BackendError::ItemDeserialization)?;
303                // get the value
304                let value = cache.items.get(&key).ok_or(BackendError::InvalidLsmState)?;
305                let value: V =
306                    serde_cbor::from_slice(value).map_err(|_| BackendError::ItemDeserialization)?;
307                if let Some(index) = ret.iter().cloned().enumerate().find_map(|(x, (k, _))| {
308                    if k == key_deser {
309                        Some(x)
310                    } else {
311                        None
312                    }
313                }) {
314                    ret[index] = (key_deser, value);
315                } else {
316                    ret.push((key_deser, value));
317                }
318            }
319        }
320        Ok(ret)
321    }
322    /// Consumes self and returns the inner [`File`] like object
323    pub fn into_inner(self) -> T {
324        self.file
325    }
326}
327
328impl<K> LsmFile<File, K>
329where
330    K: Key,
331{
332    /// Opens an existing `LsmFile` from the provided path.
333    ///
334    /// Will open the file in read/write mode. Will fail if the file does not exists.
335    ///
336    /// # Errors
337    ///   * If the file does not exist
338    ///   * If any other IO occurs
339    ///   * If any of the decryption or deserialization of control structures fails
340    #[instrument(skip(key))]
341    pub fn open(
342        path: impl AsRef<Path> + std::fmt::Debug,
343        compression: Option<i32>,
344        key: Arc<K>,
345        max_cache_entries: Option<usize>,
346    ) -> Result<Self, BackendError> {
347        // open the file read/write
348        let file = OpenOptions::new()
349            .read(true)
350            .write(true)
351            .create(false)
352            .open(path.as_ref())
353            .context(EntryIO)?;
354        Self::new(file, compression, key, max_cache_entries)
355    }
356
357    /// Opens an existing `LsmFile` from the provided path.
358    ///
359    /// Will create the file in read/write mode, failing if the file exists
360    ///
361    /// # Errors
362    ///   * If the file already exists
363    ///   * If any other IO occurs
364    #[instrument(skip(key))]
365    pub fn create(
366        path: impl AsRef<Path> + std::fmt::Debug,
367        compression: Option<i32>,
368        key: Arc<K>,
369        max_cache_entries: Option<usize>,
370    ) -> Result<Self, BackendError> {
371        // open the file read/write
372        let file = OpenOptions::new()
373            .read(true)
374            .write(true)
375            .create_new(true)
376            .open(path.as_ref())
377            .context(EntryIO)?;
378        Self::new(file, compression, key, max_cache_entries)
379    }
380}
381
382/// Unit tests for the module
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use crate::crypto::RootKey;
387    use proptest::prelude::*;
388    use std::io::Cursor;
389    use tempfile::{tempdir, tempfile};
390    /// Make sure that `LsmFile` can create and open files
391    #[test]
392    fn file_create_open() -> Result<(), BackendError> {
393        let dir = tempdir().expect("Unable to make tempdir");
394        let file_path = dir.path().join("test_lsm");
395        println!("{:?}", file_path);
396        // Create the comparison hashmap
397        let hashmap: HashMap<u64, u64> = [(1, 2), (3, 4), (5, 6), (7, 8)].into_iter().collect();
398        // Get a key
399        let root_key = Arc::new(RootKey::random());
400        // Create the LsmFile
401        let mut lsm_file = LsmFile::create(&file_path, None, root_key.clone(), Some(3))?;
402        // Insert all the k/v pairs into the LsmFile
403        for (k, v) in &hashmap {
404            println!("k: {} v: {}", k, v);
405            lsm_file.insert(&k, &v).expect("Unable to insert k/v pair");
406        }
407        // Flush the file and drop it
408        lsm_file.flush()?;
409        std::mem::drop(lsm_file);
410        // Reopen it
411        let mut lsm_file = LsmFile::open(&file_path, None, root_key, Some(3))?;
412        // Get the output hashmap
413        let output: HashMap<u64, u64> = lsm_file.to_hashmap()?;
414        assert_eq!(output, hashmap);
415        Ok(())
416    }
417
418    proptest! {
419        /// Use a random list of pairs to insert, and test the behavior compared to a hash map
420        #[test]
421        fn round_trip_hashmap_smoke(pairs: Vec<(u64,u64)>) {
422            // Get the comparison HashMap
423            let hashmap = pairs.iter().copied().collect::<HashMap<u64,u64>>();
424            // Open up the `LsmFile`
425            let backing = Cursor::new(Vec::<u8>::new());
426            let root_key = Arc::new(RootKey::random());
427            let mut lsm_file: LsmFile<_, RootKey> =
428                LsmFile::new(backing, None, root_key, Some(10))
429                .expect("Failed to open file");
430            // Insert all the k/v pairs into the LsmFile
431            for (k,v) in pairs {
432                println!("k: {} v: {}", k, v);
433                lsm_file.insert(&k, &v).expect("Unable to insert k/v pair");
434            }
435            // Get the output hashmap
436            let output: HashMap<u64,u64> = lsm_file.to_hashmap().expect("Failed to covert to hashmap");
437            for (k,v) in &output {
438                if Some(v) != hashmap.get(k) {
439                    panic!("Output hashmap contains k/v pair not in input: k: {} v: {}", k, v);
440                }
441            }
442            for (k,v) in &hashmap {
443                if Some(v) != output.get(k) {
444                    panic!("Input hashmap contains k/v pair not in output: k: {} v: {}", k, v);
445                }
446            }
447        }
448        /// Use a random list of pairs to insert, and test the behavior compared to a hash map, deconstructing
449        /// and reconstructing the `LsmFile` before reading the output hashmap.
450        #[test]
451        fn round_trip_hashmap_flush(pairs: Vec<(u64,u64)>) {
452            // Get the comparison HashMap
453            let hashmap = pairs.iter().copied().collect::<HashMap<u64,u64>>();
454            // Open up the `LsmFile`
455            let backing = Cursor::new(Vec::<u8>::new());
456            let root_key = Arc::new(RootKey::random());
457            let mut lsm_file: LsmFile<_, RootKey> =
458                LsmFile::new(backing, None, root_key.clone(), Some(10))
459                .expect("Failed to open file");
460            // Insert all the k/v pairs into the LsmFile
461            for (k,v) in pairs {
462                println!("k: {} v: {}", k, v);
463                lsm_file.insert(&k, &v).expect("Unable to insert k/v pair");
464            }
465            // Flush and reconstruct the lsm_file
466            lsm_file.flush().expect("Failed to flush lsm file");
467            let backing = lsm_file.into_inner();
468            let mut lsm_file: LsmFile<_, RootKey> =
469                LsmFile::new(backing, None, root_key, Some(10))
470                .expect("Failed to open file");
471            // Get the output hashmap
472            let output: HashMap<u64,u64> = lsm_file.to_hashmap().expect("Failed to covert to hashmap");
473            for (k,v) in &output {
474                if Some(v) != hashmap.get(k) {
475                    panic!("Output hashmap contains k/v pair not in input: k: {} v: {}", k, v);
476                }
477            }
478            for (k,v) in &hashmap {
479                if Some(v) != output.get(k) {
480                    panic!("Input hashmap contains k/v pair not in output: k: {} v: {}", k, v);
481                }
482            }
483        }
484        /// Use a random list of pairs to insert, and test the behavior compared to a hash map, deconstructing
485        /// and reconstructing the `LsmFile` before reading the output hashmap.
486        ///
487        /// This variant of the test uses an actual temporary file
488        #[test]
489        fn round_trip_hashmap_file(pairs: Vec<(u64,u64)>) {
490            // Get the comparison HashMap
491            let hashmap = pairs.iter().copied().collect::<HashMap<u64,u64>>();
492            // Open up the `LsmFile`
493            let backing = tempfile().expect("Failed to open tempfile");
494            let root_key = Arc::new(RootKey::random());
495            let mut lsm_file: LsmFile<_, RootKey> =
496                LsmFile::new(backing, None, root_key.clone(), Some(10))
497                .expect("Failed to open file");
498            // Insert all the k/v pairs into the LsmFile
499            for (k,v) in pairs {
500                println!("k: {} v: {}", k, v);
501                lsm_file.insert(&k, &v).expect("Unable to insert k/v pair");
502            }
503            // Flush and reconstruct the lsm_file
504            lsm_file.flush().expect("Failed to flush lsm file");
505            let backing = lsm_file.into_inner();
506            let mut lsm_file: LsmFile<_, RootKey> =
507                LsmFile::new(backing, None, root_key, Some(10))
508                .expect("Failed to open file");
509            // Get the output hashmap
510            let output: HashMap<u64,u64> = lsm_file.to_hashmap().expect("Failed to covert to hashmap");
511            for (k,v) in &output {
512                if Some(v) != hashmap.get(k) {
513                    panic!("Output hashmap contains k/v pair not in input: k: {} v: {}", k, v);
514                }
515            }
516            for (k,v) in &hashmap {
517                if Some(v) != output.get(k) {
518                    panic!("Input hashmap contains k/v pair not in output: k: {} v: {}", k, v);
519                }
520            }
521        }
522    }
523}