typed_store/
lib.rs

1// Copyright (c) 2021, Facebook, Inc. and its affiliates
2// Copyright (c) 2022, Mysten Labs, Inc.
3// SPDX-License-Identifier: Apache-2.0
4#![warn(
5    future_incompatible,
6    nonstandard_style,
7    rust_2018_idioms,
8    rust_2021_compatibility
9)]
10
11use eyre::Result;
12use rocksdb::MultiThreaded;
13use serde::{de::DeserializeOwned, Serialize};
14use std::{
15    cmp::Eq,
16    collections::{HashMap, VecDeque},
17    hash::Hash,
18    sync::Arc,
19};
20use tokio::sync::{
21    mpsc::{channel, Sender},
22    oneshot,
23};
24
25pub mod traits;
26pub use traits::Map;
27pub mod metrics;
28pub mod rocks;
29pub use metrics::DBMetrics;
30
31#[cfg(test)]
32#[path = "tests/store_tests.rs"]
33pub mod store_tests;
34
35pub type StoreError = rocks::TypedStoreError;
36
37type StoreResult<T> = Result<T, StoreError>;
38
39pub enum StoreCommand<Key, Value> {
40    Write(Key, Value, Option<oneshot::Sender<StoreResult<()>>>),
41    WriteAll(Vec<(Key, Value)>, oneshot::Sender<StoreResult<()>>),
42    Delete(Key),
43    DeleteAll(Vec<Key>, oneshot::Sender<StoreResult<()>>),
44    Read(Key, oneshot::Sender<StoreResult<Option<Value>>>),
45    ReadRawBytes(Key, oneshot::Sender<StoreResult<Option<Vec<u8>>>>),
46    ReadAll(Vec<Key>, oneshot::Sender<StoreResult<Vec<Option<Value>>>>),
47    NotifyRead(Key, oneshot::Sender<StoreResult<Option<Value>>>),
48    Iter(
49        Option<Box<dyn Fn(&(Key, Value)) -> bool + Send>>,
50        oneshot::Sender<HashMap<Key, Value>>,
51    ),
52}
53
54#[derive(Clone)]
55pub struct Store<K, V> {
56    channel: Sender<StoreCommand<K, V>>,
57    pub rocksdb: Arc<rocksdb::DBWithThreadMode<MultiThreaded>>,
58}
59
60impl<Key, Value> Store<Key, Value>
61where
62    Key: Hash + Eq + Serialize + DeserializeOwned + Send + 'static,
63    Value: Serialize + DeserializeOwned + Send + Clone + 'static,
64{
65    pub fn new(keyed_db: rocks::DBMap<Key, Value>) -> Self {
66        let mut obligations = HashMap::<Key, VecDeque<oneshot::Sender<_>>>::new();
67        let clone_db = keyed_db.rocksdb.clone();
68        let (tx, mut rx) = channel(100);
69        tokio::spawn(async move {
70            while let Some(command) = rx.recv().await {
71                match command {
72                    StoreCommand::Write(key, value, sender) => {
73                        let response = keyed_db.insert(&key, &value);
74                        if response.is_ok() {
75                            if let Some(mut senders) = obligations.remove(&key) {
76                                while let Some(s) = senders.pop_front() {
77                                    let _ = s.send(Ok(Some(value.clone())));
78                                }
79                            }
80                        }
81                        if let Some(replier) = sender {
82                            let _ = replier.send(response);
83                        }
84                    }
85                    StoreCommand::WriteAll(key_values, sender) => {
86                        let response =
87                            keyed_db.multi_insert(key_values.iter().map(|(k, v)| (k, v)));
88
89                        if response.is_ok() {
90                            for (key, _) in key_values {
91                                if let Some(mut senders) = obligations.remove(&key) {
92                                    while let Some(s) = senders.pop_front() {
93                                        let _ = s.send(Ok(None));
94                                    }
95                                }
96                            }
97                        }
98                        let _ = sender.send(response);
99                    }
100                    StoreCommand::Delete(key) => {
101                        let _ = keyed_db.remove(&key);
102                        if let Some(mut senders) = obligations.remove(&key) {
103                            while let Some(s) = senders.pop_front() {
104                                let _ = s.send(Ok(None));
105                            }
106                        }
107                    }
108                    StoreCommand::DeleteAll(keys, sender) => {
109                        let response = keyed_db.multi_remove(keys.iter());
110                        // notify the obligations only when the delete was successful
111                        if response.is_ok() {
112                            for key in keys {
113                                if let Some(mut senders) = obligations.remove(&key) {
114                                    while let Some(s) = senders.pop_front() {
115                                        let _ = s.send(Ok(None));
116                                    }
117                                }
118                            }
119                        }
120                        let _ = sender.send(response);
121                    }
122                    StoreCommand::Read(key, sender) => {
123                        let response = keyed_db.get(&key);
124                        let _ = sender.send(response);
125                    }
126                    StoreCommand::ReadAll(keys, sender) => {
127                        let response = keyed_db.multi_get(keys.as_slice());
128                        let _ = sender.send(response);
129                    }
130                    StoreCommand::NotifyRead(key, sender) => {
131                        let response = keyed_db.get(&key);
132                        if let Ok(Some(_)) = response {
133                            let _ = sender.send(response);
134                        } else {
135                            obligations
136                                .entry(key)
137                                .or_insert_with(VecDeque::new)
138                                .push_back(sender)
139                        }
140                    }
141                    StoreCommand::Iter(predicate, sender) => {
142                        let response = if let Some(func) = predicate {
143                            keyed_db.iter().filter(func).collect()
144                        } else {
145                            // Beware, we may overload the memory with a large table!
146                            keyed_db.iter().collect()
147                        };
148
149                        let _ = sender.send(response);
150                    }
151                    StoreCommand::ReadRawBytes(key, sender) => {
152                        let response = keyed_db.get_raw_bytes(&key);
153                        let _ = sender.send(response);
154                    }
155                }
156            }
157        });
158        Self {
159            channel: tx,
160            rocksdb: clone_db,
161        }
162    }
163}
164
165impl<Key, Value> Store<Key, Value>
166where
167    Key: Serialize + DeserializeOwned + Send,
168    Value: Serialize + DeserializeOwned + Send,
169{
170    pub async fn async_write(&self, key: Key, value: Value) {
171        if let Err(e) = self
172            .channel
173            .send(StoreCommand::Write(key, value, None))
174            .await
175        {
176            panic!("Failed to send Write command to store: {e}");
177        }
178    }
179
180    pub async fn sync_write(&self, key: Key, value: Value) -> StoreResult<()> {
181        let (sender, receiver) = oneshot::channel();
182        if let Err(e) = self
183            .channel
184            .send(StoreCommand::Write(key, value, Some(sender)))
185            .await
186        {
187            panic!("Failed to send Write command to store: {e}");
188        }
189        receiver
190            .await
191            .expect("Failed to receive reply to Write command from store")
192    }
193
194    /// Atomically writes all the key-value pairs in storage.
195    /// If the operation is successful, then the result will be a non
196    /// error empty result. Otherwise the error is returned.
197    pub async fn sync_write_all(
198        &self,
199        key_value_pairs: impl IntoIterator<Item = (Key, Value)>,
200    ) -> StoreResult<()> {
201        let (sender, receiver) = oneshot::channel();
202        if let Err(e) = self
203            .channel
204            .send(StoreCommand::WriteAll(
205                key_value_pairs.into_iter().collect(),
206                sender,
207            ))
208            .await
209        {
210            panic!("Failed to send WriteAll command to store: {e}");
211        }
212        receiver
213            .await
214            .expect("Failed to receive reply to WriteAll command from store")
215    }
216
217    pub async fn remove(&self, key: Key) {
218        if let Err(e) = self.channel.send(StoreCommand::Delete(key)).await {
219            panic!("Failed to send Delete command to store: {e}");
220        }
221    }
222
223    /// Atomically removes all the data referenced by the provided keys.
224    /// If the operation is successful, then the result will be a non
225    /// error empty result. Otherwise the error is returned.
226    pub async fn remove_all(&self, keys: impl IntoIterator<Item = Key>) -> StoreResult<()> {
227        let (sender, receiver) = oneshot::channel();
228        if let Err(e) = self
229            .channel
230            .send(StoreCommand::DeleteAll(keys.into_iter().collect(), sender))
231            .await
232        {
233            panic!("Failed to send DeleteAll command to store: {e}");
234        }
235        receiver
236            .await
237            .expect("Failed to receive reply to RemoveAll command from store")
238    }
239
240    /// Returns the read value in raw bincode bytes
241    pub async fn read_raw_bytes(&self, key: Key) -> StoreResult<Option<Vec<u8>>> {
242        let (sender, receiver) = oneshot::channel();
243        if let Err(e) = self
244            .channel
245            .send(StoreCommand::ReadRawBytes(key, sender))
246            .await
247        {
248            panic!("Failed to send ReadRawBytes command to store: {e}");
249        }
250        receiver
251            .await
252            .expect("Failed to receive reply to ReadRawBytes command from store")
253    }
254
255    pub async fn read(&self, key: Key) -> StoreResult<Option<Value>> {
256        let (sender, receiver) = oneshot::channel();
257        if let Err(e) = self.channel.send(StoreCommand::Read(key, sender)).await {
258            panic!("Failed to send Read command to store: {e}");
259        }
260        receiver
261            .await
262            .expect("Failed to receive reply to Read command from store")
263    }
264
265    /// Fetches all the values for the provided keys.
266    pub async fn read_all(
267        &self,
268        keys: impl IntoIterator<Item = Key>,
269    ) -> StoreResult<Vec<Option<Value>>> {
270        let (sender, receiver) = oneshot::channel();
271        if let Err(e) = self
272            .channel
273            .send(StoreCommand::ReadAll(keys.into_iter().collect(), sender))
274            .await
275        {
276            panic!("Failed to send ReadAll command to store: {e}");
277        }
278        receiver
279            .await
280            .expect("Failed to receive reply to ReadAll command from store")
281    }
282
283    pub async fn notify_read(&self, key: Key) -> StoreResult<Option<Value>> {
284        let (sender, receiver) = oneshot::channel();
285        if let Err(e) = self
286            .channel
287            .send(StoreCommand::NotifyRead(key, sender))
288            .await
289        {
290            panic!("Failed to send NotifyRead command to store: {e}");
291        }
292        receiver
293            .await
294            .expect("Failed to receive reply to NotifyRead command from store")
295    }
296
297    pub async fn iter(
298        &self,
299        predicate: Option<Box<dyn Fn(&(Key, Value)) -> bool + Send>>,
300    ) -> HashMap<Key, Value> {
301        let (sender, receiver) = oneshot::channel();
302        if let Err(e) = self
303            .channel
304            .send(StoreCommand::Iter(predicate, sender))
305            .await
306        {
307            panic!("Failed to send Iter command to store: {e}");
308        }
309        receiver
310            .await
311            .expect("Failed to receive reply to Iter command from store")
312    }
313}