1use std::collections::BTreeMap;
12use std::ops::Bound;
13
14use anyhow::Result;
15use async_trait::async_trait;
16use parking_lot::RwLock;
17
18pub const MAX_KV_KEY_BYTES: usize = 1024;
22
23#[async_trait]
24pub trait KVStore: Send + Sync {
25 async fn get(&self, key: String) -> Result<Option<surrealdb_types::Value>>;
26 async fn set(&self, key: String, value: surrealdb_types::Value) -> Result<()>;
27 async fn del(&self, key: String) -> Result<()>;
28 async fn exists(&self, key: String) -> Result<bool>;
29
30 async fn del_rng(&self, start: Bound<String>, end: Bound<String>) -> Result<()>;
31
32 async fn get_batch(&self, keys: Vec<String>) -> Result<Vec<Option<surrealdb_types::Value>>>;
33 async fn set_batch(&self, entries: Vec<(String, surrealdb_types::Value)>) -> Result<()>;
34 async fn del_batch(&self, keys: Vec<String>) -> Result<()>;
35
36 async fn keys(&self, start: Bound<String>, end: Bound<String>) -> Result<Vec<String>>;
37 async fn values(
38 &self,
39 start: Bound<String>,
40 end: Bound<String>,
41 ) -> Result<Vec<surrealdb_types::Value>>;
42 async fn entries(
43 &self,
44 start: Bound<String>,
45 end: Bound<String>,
46 ) -> Result<Vec<(String, surrealdb_types::Value)>>;
47 async fn count(&self, start: Bound<String>, end: Bound<String>) -> Result<u64>;
48}
49
50pub struct BTreeMapStore {
56 inner: RwLock<BTreeMap<String, surrealdb_types::Value>>,
57 max_entries: Option<usize>,
58 max_value_bytes: Option<usize>,
59}
60
61impl BTreeMapStore {
62 pub fn new() -> Self {
63 Self {
64 inner: RwLock::new(BTreeMap::new()),
65 max_entries: None,
66 max_value_bytes: None,
67 }
68 }
69
70 pub fn with_limits(max_entries: Option<usize>, max_value_bytes: Option<usize>) -> Self {
71 Self {
72 inner: RwLock::new(BTreeMap::new()),
73 max_entries,
74 max_value_bytes,
75 }
76 }
77
78 fn check_key_length(key: &str) -> Result<()> {
79 if key.len() > MAX_KV_KEY_BYTES {
80 anyhow::bail!(
81 "KV key length ({} bytes) exceeds limit ({MAX_KV_KEY_BYTES} bytes)",
82 key.len()
83 );
84 }
85 Ok(())
86 }
87
88 fn check_value_size(&self, value: &surrealdb_types::Value) -> Result<()> {
89 if let Some(max_bytes) = self.max_value_bytes {
90 let size = surrealdb_types::encode(value)?.len();
91 if size > max_bytes {
92 anyhow::bail!("KV value size ({size} bytes) exceeds limit ({max_bytes} bytes)");
93 }
94 }
95 Ok(())
96 }
97
98 fn check_entry_count(
99 &self,
100 map: &BTreeMap<String, surrealdb_types::Value>,
101 adding: usize,
102 ) -> Result<()> {
103 if let Some(max) = self.max_entries {
104 let new_count = map.len() + adding;
105 if new_count > max {
106 anyhow::bail!("KV store entry count ({new_count}) would exceed limit ({max})");
107 }
108 }
109 Ok(())
110 }
111}
112
113impl Default for BTreeMapStore {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119#[async_trait]
120impl KVStore for BTreeMapStore {
121 async fn get(&self, key: String) -> Result<Option<surrealdb_types::Value>> {
122 let map = self.inner.read();
123 Ok(map.get(&key).cloned())
124 }
125
126 async fn set(&self, key: String, value: surrealdb_types::Value) -> Result<()> {
127 Self::check_key_length(&key)?;
128 self.check_value_size(&value)?;
129 let mut map = self.inner.write();
130 if !map.contains_key(&key) {
131 self.check_entry_count(&map, 1)?;
132 }
133 map.insert(key, value);
134 Ok(())
135 }
136
137 async fn del(&self, key: String) -> Result<()> {
138 let mut map = self.inner.write();
139 map.remove(&key);
140 Ok(())
141 }
142
143 async fn exists(&self, key: String) -> Result<bool> {
144 let map = self.inner.read();
145 Ok(map.contains_key(&key))
146 }
147
148 async fn del_rng(&self, start: Bound<String>, end: Bound<String>) -> Result<()> {
149 let mut map = self.inner.write();
150 let keys: Vec<String> = map.range((start, end)).map(|(k, _)| k.clone()).collect();
151 for key in keys {
152 map.remove(&key);
153 }
154 Ok(())
155 }
156
157 async fn get_batch(&self, keys: Vec<String>) -> Result<Vec<Option<surrealdb_types::Value>>> {
158 let map = self.inner.read();
159 Ok(keys.iter().map(|key| map.get(key).cloned()).collect())
160 }
161
162 async fn set_batch(&self, entries: Vec<(String, surrealdb_types::Value)>) -> Result<()> {
163 for (key, value) in &entries {
164 Self::check_key_length(key)?;
165 self.check_value_size(value)?;
166 }
167 let mut map = self.inner.write();
168 let new_keys = entries
169 .iter()
170 .map(|(k, _)| k.as_str())
171 .collect::<std::collections::HashSet<_>>()
172 .into_iter()
173 .filter(|k| !map.contains_key(*k))
174 .count();
175 self.check_entry_count(&map, new_keys)?;
176 for (key, value) in entries {
177 map.insert(key, value);
178 }
179 Ok(())
180 }
181
182 async fn del_batch(&self, keys: Vec<String>) -> Result<()> {
183 let mut map = self.inner.write();
184 for key in keys {
185 map.remove(&key);
186 }
187 Ok(())
188 }
189
190 async fn keys(&self, start: Bound<String>, end: Bound<String>) -> Result<Vec<String>> {
191 let map = self.inner.read();
192 Ok(map.range((start, end)).map(|(k, _)| k.clone()).collect())
193 }
194
195 async fn values(
196 &self,
197 start: Bound<String>,
198 end: Bound<String>,
199 ) -> Result<Vec<surrealdb_types::Value>> {
200 let map = self.inner.read();
201 Ok(map.range((start, end)).map(|(_, v)| v.clone()).collect())
202 }
203
204 async fn entries(
205 &self,
206 start: Bound<String>,
207 end: Bound<String>,
208 ) -> Result<Vec<(String, surrealdb_types::Value)>> {
209 let map = self.inner.read();
210 Ok(map.range((start, end)).map(|(k, v)| (k.clone(), v.clone())).collect())
211 }
212
213 async fn count(&self, start: Bound<String>, end: Bound<String>) -> Result<u64> {
214 let map = self.inner.read();
215 Ok(map.range((start, end)).count() as u64)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use surrealdb_types::Value;
222
223 use super::*;
224
225 fn int_val(n: i64) -> Value {
226 Value::Number(surrealdb_types::Number::Int(n))
227 }
228
229 fn str_val(s: &str) -> Value {
230 Value::String(s.into())
231 }
232
233 #[tokio::test]
234 async fn get_set_del() {
235 let store = BTreeMapStore::new();
236 assert!(store.get("k".into()).await.unwrap().is_none());
237
238 store.set("k".into(), int_val(42)).await.unwrap();
239 assert_eq!(store.get("k".into()).await.unwrap(), Some(int_val(42)));
240
241 store.del("k".into()).await.unwrap();
242 assert!(store.get("k".into()).await.unwrap().is_none());
243 }
244
245 #[tokio::test]
246 async fn exists() {
247 let store = BTreeMapStore::new();
248 assert!(!store.exists("k".into()).await.unwrap());
249
250 store.set("k".into(), int_val(1)).await.unwrap();
251 assert!(store.exists("k".into()).await.unwrap());
252 }
253
254 #[tokio::test]
255 async fn overwrite() {
256 let store = BTreeMapStore::new();
257 store.set("k".into(), int_val(1)).await.unwrap();
258 store.set("k".into(), int_val(2)).await.unwrap();
259 assert_eq!(store.get("k".into()).await.unwrap(), Some(int_val(2)));
260 }
261
262 #[tokio::test]
263 async fn batch_ops() {
264 let store = BTreeMapStore::new();
265 store
266 .set_batch(vec![
267 ("a".into(), int_val(1)),
268 ("b".into(), int_val(2)),
269 ("c".into(), int_val(3)),
270 ])
271 .await
272 .unwrap();
273
274 let vals = store.get_batch(vec!["a".into(), "c".into(), "z".into()]).await.unwrap();
275 assert_eq!(vals, vec![Some(int_val(1)), Some(int_val(3)), None]);
276
277 store.del_batch(vec!["a".into(), "c".into()]).await.unwrap();
278 assert!(!store.exists("a".into()).await.unwrap());
279 assert!(store.exists("b".into()).await.unwrap());
280 assert!(!store.exists("c".into()).await.unwrap());
281 }
282
283 #[tokio::test]
284 async fn range_keys_values_entries() {
285 let store = BTreeMapStore::new();
286 for c in b'a'..=b'e' {
287 let key = String::from(c as char);
288 store.set(key, int_val(c as i64)).await.unwrap();
289 }
290
291 let keys =
292 store.keys(Bound::Included("b".into()), Bound::Excluded("d".into())).await.unwrap();
293 assert_eq!(keys, vec!["b".to_string(), "c".to_string()]);
294
295 let vals = store.values(Bound::Included("d".into()), Bound::Unbounded).await.unwrap();
296 assert_eq!(vals, vec![int_val(b'd' as i64), int_val(b'e' as i64)]);
297
298 let count = store.count(Bound::Unbounded, Bound::Unbounded).await.unwrap();
299 assert_eq!(count, 5);
300 }
301
302 #[tokio::test]
303 async fn del_rng() {
304 let store = BTreeMapStore::new();
305 for c in b'a'..=b'e' {
306 store.set(String::from(c as char), int_val(c as i64)).await.unwrap();
307 }
308
309 store.del_rng(Bound::Included("b".into()), Bound::Excluded("e".into())).await.unwrap();
310
311 assert!(store.exists("a".into()).await.unwrap());
312 assert!(!store.exists("b".into()).await.unwrap());
313 assert!(!store.exists("c".into()).await.unwrap());
314 assert!(!store.exists("d".into()).await.unwrap());
315 assert!(store.exists("e".into()).await.unwrap());
316 }
317
318 #[tokio::test]
319 async fn max_entries_limit() {
320 let store = BTreeMapStore::with_limits(Some(2), None);
321 store.set("a".into(), int_val(1)).await.unwrap();
322 store.set("b".into(), int_val(2)).await.unwrap();
323
324 let err = store.set("c".into(), int_val(3)).await;
325 assert!(err.is_err());
326 assert!(err.unwrap_err().to_string().contains("exceed limit"));
327
328 store.set("a".into(), int_val(10)).await.unwrap();
330 }
331
332 #[tokio::test]
333 async fn max_entries_batch_limit() {
334 let store = BTreeMapStore::with_limits(Some(2), None);
335 store.set("a".into(), int_val(1)).await.unwrap();
336
337 let err = store.set_batch(vec![("b".into(), int_val(2)), ("c".into(), int_val(3))]).await;
338 assert!(err.is_err());
339 }
340
341 #[tokio::test]
342 async fn max_value_bytes_limit() {
343 let store = BTreeMapStore::with_limits(None, Some(128));
344 store.set("k".into(), str_val("hi")).await.unwrap();
346
347 let big = "x".repeat(1024);
349 let err = store.set("k2".into(), str_val(&big)).await;
350 assert!(err.is_err());
351 assert!(err.unwrap_err().to_string().contains("exceeds limit"));
352 }
353
354 #[tokio::test]
355 async fn del_nonexistent_is_ok() {
356 let store = BTreeMapStore::new();
357 store.del("nope".into()).await.unwrap();
358 }
359
360 #[tokio::test]
361 async fn max_key_length_limit() {
362 let store = BTreeMapStore::new();
363 let ok_key = "k".repeat(MAX_KV_KEY_BYTES);
364 store.set(ok_key, int_val(1)).await.unwrap();
365
366 let bad_key = "k".repeat(MAX_KV_KEY_BYTES + 1);
367 let err = store.set(bad_key.clone(), int_val(2)).await;
368 assert!(err.is_err());
369 assert!(err.unwrap_err().to_string().contains("exceeds limit"));
370
371 let err = store.set_batch(vec![(bad_key, int_val(3))]).await;
372 assert!(err.is_err());
373 }
374}