1use std::{
6 collections::HashMap,
7 fs::{File, OpenOptions},
8 hash::Hash,
9 io::{Read, Seek, SeekFrom, Write},
10 path::Path,
11 sync::Arc,
12};
13
14use serde::{de::DeserializeOwned, Deserialize, Serialize};
15use serde_bytes::ByteBuf;
16use snafu::ResultExt;
17use tracing::{debug, instrument, warn};
18
19use crate::{
20 crypto::{CipherText, ClearText, Key},
21 error::{BackendError, EntryIO},
22 file::segment::Segment,
23};
24mod segment;
25
26#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28struct Entry {
29 items: HashMap<ByteBuf, ByteBuf>,
31}
32
33pub struct LsmFile<T, K>
38where
39 T: Read + Write + Seek,
40 K: Key,
41{
42 file: T,
44 offsets: HashMap<Vec<u8>, u64>,
46 cache: Option<Entry>,
48 compression: Option<i32>,
50 key: Arc<K>,
52 max_cache_entries: usize,
54}
55
56impl<T, K> LsmFile<T, K>
57where
58 T: Read + Write + Seek,
59 K: Key,
60{
61 #[instrument(skip(file, key))]
71 pub fn new(
72 mut file: T,
73 compression: Option<i32>,
74 key: Arc<K>,
75 max_cache_entries: Option<usize>,
76 ) -> Result<Self, BackendError> {
77 file.seek(SeekFrom::Start(0)).context(EntryIO)?;
79 let mut offset: u64 = 0;
81 let mut offsets: HashMap<Vec<u8>, u64> = HashMap::new();
82 while let Ok(segment) = Segment::read_owned(&mut file) {
83 match CipherText::try_from(segment) {
84 Ok(x) => {
85 let cleartext = x.decrypt(&*key)?;
86 let entry: Entry = cleartext.deserialize()?;
87 for (k, _) in entry.items {
88 offsets.insert(k.to_vec(), offset);
89 }
90 }
91 Err(e) => {
92 warn!(?e, "Failed to decode a segment");
93 }
94 }
95 offset = file.stream_position().context(EntryIO)?;
96 }
97 Ok(Self {
98 file,
99 offsets,
100 cache: None,
101 compression,
102 key,
103 max_cache_entries: max_cache_entries.unwrap_or(100),
104 })
105 }
106
107 pub fn get<C, V>(&mut self, key: &C) -> Result<Option<V>, BackendError>
115 where
116 C: Serialize,
117 V: DeserializeOwned,
118 {
119 let key = match serde_cbor::to_vec(key) {
121 Ok(k) => ByteBuf::from(k),
122 Err(_) => return Err(BackendError::ItemSerialization),
125 };
126 if let Some(cache) = self.cache.as_ref() {
128 if let Some(bytes) = cache.items.get(&key) {
129 match serde_cbor::from_slice(bytes) {
130 Ok(v) => return Ok(Some(v)),
131 Err(_) => return Err(BackendError::ItemDeserialization),
134 }
135 }
136 }
137 if let Some(offset) = self.offsets.get(&*key) {
139 self.file.seek(SeekFrom::Start(*offset)).context(EntryIO)?;
141 let segment = Segment::read_owned(&mut self.file)?;
142 let ciphertext = CipherText::try_from(segment)?;
144 let cleartext = ciphertext.decrypt(&*self.key)?;
145 let entry: Entry = cleartext.deserialize()?;
146 if let Some(bytes) = entry.items.get(&key) {
147 match serde_cbor::from_slice(bytes) {
148 Ok(v) => Ok(Some(v)),
149 Err(_) => Err(BackendError::ItemDeserialization),
152 }
153 } else {
154 warn!("Offsets table contained a offset fora nonexistent value");
155 Ok(None)
156 }
157 } else {
158 Ok(None)
159 }
160 }
161
162 #[instrument(skip(self, key, value))]
171 pub fn insert<C, V>(&mut self, key: &C, value: &V) -> Result<(), BackendError>
172 where
173 C: Serialize,
174 V: Serialize,
175 {
176 let mut cache = self.cache.take().unwrap_or_else(|| Entry {
178 items: HashMap::new(),
179 });
180 let key = serde_cbor::to_vec(key).map_err(|_| BackendError::ItemSerialization)?;
183 let value = serde_cbor::to_vec(value).map_err(|_| BackendError::ItemSerialization)?;
184 cache.items.insert(ByteBuf::from(key), ByteBuf::from(value));
186 let length = cache.items.len();
187 self.cache = Some(cache);
188 if length >= self.max_cache_entries {
190 debug!("Flushing cache");
191 self.flush()?;
192 }
193 Ok(())
194 }
195
196 pub fn flush(&mut self) -> Result<(), BackendError> {
202 let offset = self.file.seek(SeekFrom::End(0)).context(EntryIO)?;
204 if let Some(cache) = self.cache.take() {
206 for k in cache.items.keys() {
208 self.offsets.insert((&*k).to_vec(), offset);
209 }
210 let plaintext = ClearText::new(&cache)?;
212 let ciphertext = plaintext.encrypt(&*self.key, self.compression)?;
213 let segment: Segment<'_> = ciphertext.into();
214 segment.write(&mut self.file)?;
215 self.file.flush().context(EntryIO)?;
216
217 Ok(())
218 } else {
219 Ok(())
220 }
221 }
222
223 pub fn to_hashmap<C, V>(&mut self) -> Result<HashMap<C, V>, BackendError>
232 where
233 C: DeserializeOwned + Serialize + Hash + Eq,
234 V: DeserializeOwned,
235 {
236 let mut ret = HashMap::new();
237 let keys: Vec<Vec<u8>> = self.offsets.keys().cloned().collect();
238 for key in keys {
240 let key: C =
242 serde_cbor::from_slice(&key).map_err(|_| BackendError::ItemDeserialization)?;
243 let value = self.get(&key)?.ok_or(BackendError::InvalidLsmState)?;
245 ret.insert(key, value);
246 }
247 if let Some(cache) = self.cache.as_ref() {
249 let keys: Vec<ByteBuf> = cache.items.keys().cloned().collect();
250 for key in keys {
251 let key_deser: C =
253 serde_cbor::from_slice(&key).map_err(|_| BackendError::ItemDeserialization)?;
254 let value = cache.items.get(&key).ok_or(BackendError::InvalidLsmState)?;
256 let value: V =
257 serde_cbor::from_slice(value).map_err(|_| BackendError::ItemDeserialization)?;
258 ret.insert(key_deser, value);
259 }
260 }
261 Ok(ret)
262 }
263 pub fn to_pairs<C, V>(&mut self) -> Result<Vec<(C, V)>, BackendError>
272 where
273 C: DeserializeOwned + Serialize + Eq + Clone,
274 V: DeserializeOwned + Clone,
275 {
276 let mut ret: Vec<(C, V)> = Vec::new();
277 let keys: Vec<Vec<u8>> = self.offsets.keys().cloned().collect();
278 for key in keys {
280 let key: C =
282 serde_cbor::from_slice(&key).map_err(|_| BackendError::ItemDeserialization)?;
283 let value = self.get(&key)?.ok_or(BackendError::InvalidLsmState)?;
285 if let Some(index) =
286 ret.iter()
287 .cloned()
288 .enumerate()
289 .find_map(|(x, (k, _))| if k == key { Some(x) } else { None })
290 {
291 ret[index] = (key, value);
292 } else {
293 ret.push((key, value));
294 }
295 }
296 if let Some(cache) = self.cache.as_ref() {
298 let keys: Vec<ByteBuf> = cache.items.keys().cloned().collect();
299 for key in keys {
300 let key_deser: C =
302 serde_cbor::from_slice(&key).map_err(|_| BackendError::ItemDeserialization)?;
303 let value = cache.items.get(&key).ok_or(BackendError::InvalidLsmState)?;
305 let value: V =
306 serde_cbor::from_slice(value).map_err(|_| BackendError::ItemDeserialization)?;
307 if let Some(index) = ret.iter().cloned().enumerate().find_map(|(x, (k, _))| {
308 if k == key_deser {
309 Some(x)
310 } else {
311 None
312 }
313 }) {
314 ret[index] = (key_deser, value);
315 } else {
316 ret.push((key_deser, value));
317 }
318 }
319 }
320 Ok(ret)
321 }
322 pub fn into_inner(self) -> T {
324 self.file
325 }
326}
327
328impl<K> LsmFile<File, K>
329where
330 K: Key,
331{
332 #[instrument(skip(key))]
341 pub fn open(
342 path: impl AsRef<Path> + std::fmt::Debug,
343 compression: Option<i32>,
344 key: Arc<K>,
345 max_cache_entries: Option<usize>,
346 ) -> Result<Self, BackendError> {
347 let file = OpenOptions::new()
349 .read(true)
350 .write(true)
351 .create(false)
352 .open(path.as_ref())
353 .context(EntryIO)?;
354 Self::new(file, compression, key, max_cache_entries)
355 }
356
357 #[instrument(skip(key))]
365 pub fn create(
366 path: impl AsRef<Path> + std::fmt::Debug,
367 compression: Option<i32>,
368 key: Arc<K>,
369 max_cache_entries: Option<usize>,
370 ) -> Result<Self, BackendError> {
371 let file = OpenOptions::new()
373 .read(true)
374 .write(true)
375 .create_new(true)
376 .open(path.as_ref())
377 .context(EntryIO)?;
378 Self::new(file, compression, key, max_cache_entries)
379 }
380}
381
382#[cfg(test)]
384mod tests {
385 use super::*;
386 use crate::crypto::RootKey;
387 use proptest::prelude::*;
388 use std::io::Cursor;
389 use tempfile::{tempdir, tempfile};
390 #[test]
392 fn file_create_open() -> Result<(), BackendError> {
393 let dir = tempdir().expect("Unable to make tempdir");
394 let file_path = dir.path().join("test_lsm");
395 println!("{:?}", file_path);
396 let hashmap: HashMap<u64, u64> = [(1, 2), (3, 4), (5, 6), (7, 8)].into_iter().collect();
398 let root_key = Arc::new(RootKey::random());
400 let mut lsm_file = LsmFile::create(&file_path, None, root_key.clone(), Some(3))?;
402 for (k, v) in &hashmap {
404 println!("k: {} v: {}", k, v);
405 lsm_file.insert(&k, &v).expect("Unable to insert k/v pair");
406 }
407 lsm_file.flush()?;
409 std::mem::drop(lsm_file);
410 let mut lsm_file = LsmFile::open(&file_path, None, root_key, Some(3))?;
412 let output: HashMap<u64, u64> = lsm_file.to_hashmap()?;
414 assert_eq!(output, hashmap);
415 Ok(())
416 }
417
418 proptest! {
419 #[test]
421 fn round_trip_hashmap_smoke(pairs: Vec<(u64,u64)>) {
422 let hashmap = pairs.iter().copied().collect::<HashMap<u64,u64>>();
424 let backing = Cursor::new(Vec::<u8>::new());
426 let root_key = Arc::new(RootKey::random());
427 let mut lsm_file: LsmFile<_, RootKey> =
428 LsmFile::new(backing, None, root_key, Some(10))
429 .expect("Failed to open file");
430 for (k,v) in pairs {
432 println!("k: {} v: {}", k, v);
433 lsm_file.insert(&k, &v).expect("Unable to insert k/v pair");
434 }
435 let output: HashMap<u64,u64> = lsm_file.to_hashmap().expect("Failed to covert to hashmap");
437 for (k,v) in &output {
438 if Some(v) != hashmap.get(k) {
439 panic!("Output hashmap contains k/v pair not in input: k: {} v: {}", k, v);
440 }
441 }
442 for (k,v) in &hashmap {
443 if Some(v) != output.get(k) {
444 panic!("Input hashmap contains k/v pair not in output: k: {} v: {}", k, v);
445 }
446 }
447 }
448 #[test]
451 fn round_trip_hashmap_flush(pairs: Vec<(u64,u64)>) {
452 let hashmap = pairs.iter().copied().collect::<HashMap<u64,u64>>();
454 let backing = Cursor::new(Vec::<u8>::new());
456 let root_key = Arc::new(RootKey::random());
457 let mut lsm_file: LsmFile<_, RootKey> =
458 LsmFile::new(backing, None, root_key.clone(), Some(10))
459 .expect("Failed to open file");
460 for (k,v) in pairs {
462 println!("k: {} v: {}", k, v);
463 lsm_file.insert(&k, &v).expect("Unable to insert k/v pair");
464 }
465 lsm_file.flush().expect("Failed to flush lsm file");
467 let backing = lsm_file.into_inner();
468 let mut lsm_file: LsmFile<_, RootKey> =
469 LsmFile::new(backing, None, root_key, Some(10))
470 .expect("Failed to open file");
471 let output: HashMap<u64,u64> = lsm_file.to_hashmap().expect("Failed to covert to hashmap");
473 for (k,v) in &output {
474 if Some(v) != hashmap.get(k) {
475 panic!("Output hashmap contains k/v pair not in input: k: {} v: {}", k, v);
476 }
477 }
478 for (k,v) in &hashmap {
479 if Some(v) != output.get(k) {
480 panic!("Input hashmap contains k/v pair not in output: k: {} v: {}", k, v);
481 }
482 }
483 }
484 #[test]
489 fn round_trip_hashmap_file(pairs: Vec<(u64,u64)>) {
490 let hashmap = pairs.iter().copied().collect::<HashMap<u64,u64>>();
492 let backing = tempfile().expect("Failed to open tempfile");
494 let root_key = Arc::new(RootKey::random());
495 let mut lsm_file: LsmFile<_, RootKey> =
496 LsmFile::new(backing, None, root_key.clone(), Some(10))
497 .expect("Failed to open file");
498 for (k,v) in pairs {
500 println!("k: {} v: {}", k, v);
501 lsm_file.insert(&k, &v).expect("Unable to insert k/v pair");
502 }
503 lsm_file.flush().expect("Failed to flush lsm file");
505 let backing = lsm_file.into_inner();
506 let mut lsm_file: LsmFile<_, RootKey> =
507 LsmFile::new(backing, None, root_key, Some(10))
508 .expect("Failed to open file");
509 let output: HashMap<u64,u64> = lsm_file.to_hashmap().expect("Failed to covert to hashmap");
511 for (k,v) in &output {
512 if Some(v) != hashmap.get(k) {
513 panic!("Output hashmap contains k/v pair not in input: k: {} v: {}", k, v);
514 }
515 }
516 for (k,v) in &hashmap {
517 if Some(v) != output.get(k) {
518 panic!("Input hashmap contains k/v pair not in output: k: {} v: {}", k, v);
519 }
520 }
521 }
522 }
523}