Skip to main content

sakuhiki_memdb/
lib.rs

1use std::{
2    collections::BTreeMap,
3    future::{self, Ready, ready},
4    ops::{Bound, RangeBounds},
5    sync::Mutex,
6};
7
8use async_lock::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard};
9use futures_util::stream;
10use sakuhiki_core::{
11    Backend, CfError,
12    backend::{BackendBuilder, BackendCf, Builder, BuilderConfig},
13};
14
15#[derive(Debug, thiserror::Error)]
16pub enum Error {
17    #[error("Column family does not exist in memory database")]
18    NonExistentColumnFamily,
19}
20
21type ColumnFamily = BTreeMap<Vec<u8>, Vec<u8>>;
22
23pub struct TransactionCf<'t> {
24    cf: Mutex<AsyncMutexGuard<'t, ColumnFamily>>,
25    name: &'static str,
26}
27
28impl BackendCf for TransactionCf<'_> {
29    fn name(&self) -> &'static str {
30        self.name
31    }
32}
33
34pub struct MemDb {
35    db: BTreeMap<String, AsyncMutex<ColumnFamily>>,
36}
37
38impl MemDb {
39    pub fn builder() -> Builder<MemDb> {
40        Builder::new(MemDbBuilder { _private: () })
41    }
42}
43
44#[warn(clippy::missing_trait_methods)]
45impl Backend for MemDb {
46    type Error = Error;
47
48    type Builder = MemDbBuilder;
49
50    type Cf<'db> = &'static str;
51
52    type CfHandleFuture<'op> = Ready<Result<Self::Cf<'op>, Self::Error>>;
53
54    fn cf_handle<'db>(&'db self, name: &'static str) -> Self::CfHandleFuture<'db> {
55        ready(Ok(name))
56    }
57
58    type Transaction<'t> = Transaction;
59    type TransactionCf<'t> = TransactionCf<'t>;
60
61    fn ro_transaction<'fut, 'db, F, Ret>(
62        &'fut self,
63        cfs: &'fut [&'fut Self::Cf<'db>],
64        actions: F,
65    ) -> waaa::BoxFuture<'fut, Result<Ret, CfError<Self::Error>>>
66    where
67        F: 'fut
68            + waaa::Send
69            + for<'t> FnOnce(&'t (), Transaction, Vec<TransactionCf<'t>>) -> waaa::BoxFuture<'t, Ret>,
70    {
71        Box::pin(async move {
72            let t = Transaction { _private: () };
73            let mut cfs = cfs.iter().enumerate().collect::<Vec<_>>();
74            cfs.sort_by_key(|e| e.1);
75            let mut transaction_cfs = Vec::with_capacity(cfs.len());
76            for (i, &name) in cfs {
77                let cf = self
78                    .db
79                    .get(*name)
80                    .ok_or_else(|| CfError::cf(name, Error::NonExistentColumnFamily))?;
81                let cf = Mutex::new(cf.lock().await);
82                transaction_cfs.push((i, TransactionCf { name, cf }));
83            }
84            transaction_cfs.sort_by_key(|e| e.0);
85            let transaction_cfs = transaction_cfs
86                .into_iter()
87                .map(|(_, cf)| cf)
88                .collect::<Vec<_>>();
89            Ok(actions(&(), t, transaction_cfs).await)
90        })
91    }
92
93    fn rw_transaction<'fut, 'db, F, Ret>(
94        &'fut self,
95        cfs: &'fut [&'fut Self::Cf<'db>],
96        actions: F,
97    ) -> waaa::BoxFuture<'fut, Result<Ret, CfError<Self::Error>>>
98    where
99        F: 'fut
100            + waaa::Send
101            + for<'t> FnOnce(
102                &'t (),
103                Self::Transaction<'t>,
104                Vec<Self::TransactionCf<'t>>,
105            ) -> waaa::BoxFuture<'t, Ret>,
106    {
107        self.ro_transaction(cfs, actions)
108    }
109
110    type Key<'op> = Vec<u8>;
111    type Value<'op> = Vec<u8>;
112}
113
114pub struct Transaction {
115    _private: (),
116}
117
118// #[warn(clippy::missing_trait_methods)] // MemDb is used only for tests, we can use default impls
119impl<'t> sakuhiki_core::backend::Transaction<'t, MemDb> for Transaction {
120    type ExclusiveLock<'op>
121        = ()
122    where
123        't: 'op;
124
125    fn take_exclusive_lock<'op>(
126        &'op self,
127        _cf: &'op <MemDb as Backend>::TransactionCf<'t>,
128    ) -> waaa::BoxFuture<'op, Result<Self::ExclusiveLock<'op>, <MemDb as Backend>::Error>>
129    where
130        't: 'op,
131    {
132        // MemDb already locks literally all the CFs when starting the transaction anyway
133        Box::pin(future::ready(Ok(())))
134    }
135
136    fn get<'op, 'key>(
137        &'op self,
138        cf: &'op TransactionCf<'t>,
139        key: &'key [u8],
140    ) -> waaa::BoxFuture<'key, Result<Option<Vec<u8>>, Error>>
141    where
142        'op: 'key,
143    {
144        Box::pin(ready(Ok(cf
145            .cf
146            .lock()
147            .unwrap()
148            .get(key)
149            .map(|v| v.to_owned()))))
150    }
151
152    fn scan<'op, 'keys, R>(
153        &'op self,
154        cf: &'op TransactionCf<'t>,
155        keys: impl 'keys + RangeBounds<R>,
156    ) -> waaa::BoxStream<'keys, Result<(Vec<u8>, Vec<u8>), Error>>
157    where
158        't: 'op,
159        'op: 'keys,
160        R: ?Sized + AsRef<[u8]>,
161    {
162        let start: Bound<&[u8]> = keys.start_bound().map(|k| k.as_ref());
163        let end: Bound<&[u8]> = keys.end_bound().map(|k| k.as_ref());
164        Box::pin(stream::iter(
165            cf.cf
166                .lock()
167                .unwrap()
168                .range::<[u8], _>((start, end))
169                .map(|(k, v)| Ok((k.to_owned(), v.to_owned())))
170                .collect::<Vec<_>>(),
171        ))
172    }
173
174    fn put<'op, 'kv>(
175        &'op self,
176        cf: &'op TransactionCf<'t>,
177        key: &'kv [u8],
178        value: &'kv [u8],
179    ) -> waaa::BoxFuture<'kv, Result<Option<Vec<u8>>, Error>>
180    where
181        't: 'op,
182        'op: 'kv,
183    {
184        let data = cf.cf.lock().unwrap().insert(key.to_vec(), value.to_vec());
185        Box::pin(ready(Ok(data)))
186    }
187
188    fn delete<'op, 'key>(
189        &'op self,
190        cf: &'op TransactionCf<'t>,
191        key: &'key [u8],
192    ) -> waaa::BoxFuture<'key, Result<Option<Vec<u8>>, Error>>
193    where
194        't: 'op,
195        'op: 'key,
196    {
197        let data = cf.cf.lock().unwrap().remove(key);
198        Box::pin(ready(Ok(data)))
199    }
200
201    fn clear<'op>(
202        &'op self,
203        cf: &'op <MemDb as Backend>::TransactionCf<'t>,
204    ) -> waaa::BoxFuture<'op, Result<(), <MemDb as Backend>::Error>> {
205        cf.cf.lock().unwrap().clear();
206        Box::pin(ready(Ok(())))
207    }
208}
209
210pub struct MemDbBuilder {
211    _private: (),
212}
213
214impl BackendBuilder for MemDbBuilder {
215    type Target = MemDb;
216    type CfOptions = (); // TODO(blocked): should be !
217
218    type BuildFuture = waaa::BoxFuture<'static, anyhow::Result<Self::Target>>;
219
220    fn build(self, config: BuilderConfig<MemDb>) -> Self::BuildFuture {
221        Box::pin(async move {
222            let mut db = MemDb {
223                db: BTreeMap::new(),
224            };
225            for (cf, _opts) in config.cfs {
226                db.db
227                    .insert(cf.to_string(), AsyncMutex::new(ColumnFamily::new()));
228            }
229            // Note: drop_unknown_cfs currently has no impact as we're always starting from scratch, though it could be useful in tests to check db recovery
230            for i in config.index_rebuilders {
231                let mut index_cfs = Vec::with_capacity(i.index_cfs.len());
232                for cf in i.index_cfs {
233                    index_cfs.push(TransactionCf {
234                        name: cf,
235                        cf: Mutex::new(db.db.get(*cf).unwrap().lock().await),
236                    });
237                }
238                let datum_cf = TransactionCf {
239                    name: i.datum_cf,
240                    cf: Mutex::new(db.db.get(i.datum_cf).unwrap().lock().await),
241                };
242                let t = Transaction { _private: () };
243                (i.rebuilder)(&t, &index_cfs, &datum_cf).await?;
244            }
245            Ok(db)
246        })
247    }
248}