torsh_distributed/store/
memory.rs1use super::{
4 store_trait::Store,
5 types::{StoreValue, DEFAULT_TIMEOUT},
6};
7use crate::{TorshDistributedError, TorshResult};
8use async_trait::async_trait;
9use dashmap::DashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13#[derive(Debug)]
15pub struct MemoryStore {
16 data: Arc<DashMap<String, StoreValue>>,
17}
18
19impl MemoryStore {
20 pub fn new() -> Self {
21 Self {
22 data: Arc::new(DashMap::new()),
23 }
24 }
25}
26
27impl Default for MemoryStore {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33#[async_trait]
34impl Store for MemoryStore {
35 async fn set(&self, key: &str, value: &[u8]) -> TorshResult<()> {
36 let store_value = StoreValue::new(value.to_vec());
37 self.data.insert(key.to_string(), store_value);
38 Ok(())
39 }
40
41 async fn get(&self, key: &str) -> TorshResult<Option<Vec<u8>>> {
42 Ok(self.data.get(key).map(|v| v.data().to_vec()))
43 }
44
45 async fn wait(&self, keys: &[String]) -> TorshResult<()> {
46 let start = Instant::now();
47
48 loop {
49 let all_present = keys.iter().all(|key| self.data.contains_key(key));
50
51 if all_present {
52 return Ok(());
53 }
54
55 if start.elapsed() > DEFAULT_TIMEOUT {
56 return Err(TorshDistributedError::communication_error(
57 "Store wait",
58 "Timeout waiting for keys",
59 ));
60 }
61
62 tokio::time::sleep(Duration::from_millis(10)).await;
63 }
64 }
65
66 async fn delete(&self, key: &str) -> TorshResult<()> {
67 self.data.remove(key);
68 Ok(())
69 }
70
71 async fn num_keys(&self) -> TorshResult<usize> {
72 Ok(self.data.len())
73 }
74
75 async fn contains(&self, key: &str) -> TorshResult<bool> {
76 Ok(self.data.contains_key(key))
77 }
78
79 async fn set_with_expiry(&self, key: &str, value: &[u8], _ttl: Duration) -> TorshResult<()> {
80 self.set(key, value).await
82 }
83
84 async fn compare_and_swap(
85 &self,
86 key: &str,
87 expected: Option<&[u8]>,
88 value: &[u8],
89 ) -> TorshResult<bool> {
90 match expected {
91 Some(expected_val) => {
92 if let Some(current) = self.data.get(key) {
93 if current.data() == expected_val {
94 let store_value = StoreValue::new(value.to_vec());
95 self.data.insert(key.to_string(), store_value);
96 Ok(true)
97 } else {
98 Ok(false)
99 }
100 } else {
101 Ok(false)
102 }
103 }
104 None => {
105 if self.data.contains_key(key) {
107 Ok(false)
108 } else {
109 let store_value = StoreValue::new(value.to_vec());
110 self.data.insert(key.to_string(), store_value);
111 Ok(true)
112 }
113 }
114 }
115 }
116
117 async fn add(&self, key: &str, value: i64) -> TorshResult<i64> {
118 let new_value = if let Some(existing) = self.data.get(key) {
119 let current = i64::from_le_bytes(existing.data()[..8].try_into().map_err(|_| {
120 TorshDistributedError::invalid_argument(
121 "value",
122 "Failed to convert stored bytes to i64",
123 "8 bytes representing a valid i64 value",
124 )
125 })?);
126 current + value
127 } else {
128 value
129 };
130
131 let store_value = StoreValue::new(new_value.to_le_bytes().to_vec());
132 self.data.insert(key.to_string(), store_value);
133 Ok(new_value)
134 }
135}