Skip to main content

torsh_distributed/store/
memory.rs

1//! In-memory store implementation for testing and single-node coordination
2
3use super::{
4    store_trait::Store,
5    types::{StoreValue, DEFAULT_TIMEOUT},
6};
7use crate::{TorshDistributedError, TorshResult};
8use async_trait::async_trait;
9use dashmap::DashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13/// In-memory store implementation
14#[derive(Debug)]
15pub struct MemoryStore {
16    data: Arc<DashMap<String, StoreValue>>,
17}
18
19impl MemoryStore {
20    pub fn new() -> Self {
21        Self {
22            data: Arc::new(DashMap::new()),
23        }
24    }
25}
26
27impl Default for MemoryStore {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33#[async_trait]
34impl Store for MemoryStore {
35    async fn set(&self, key: &str, value: &[u8]) -> TorshResult<()> {
36        let store_value = StoreValue::new(value.to_vec());
37        self.data.insert(key.to_string(), store_value);
38        Ok(())
39    }
40
41    async fn get(&self, key: &str) -> TorshResult<Option<Vec<u8>>> {
42        Ok(self.data.get(key).map(|v| v.data().to_vec()))
43    }
44
45    async fn wait(&self, keys: &[String]) -> TorshResult<()> {
46        let start = Instant::now();
47
48        loop {
49            let all_present = keys.iter().all(|key| self.data.contains_key(key));
50
51            if all_present {
52                return Ok(());
53            }
54
55            if start.elapsed() > DEFAULT_TIMEOUT {
56                return Err(TorshDistributedError::communication_error(
57                    "Store wait",
58                    "Timeout waiting for keys",
59                ));
60            }
61
62            tokio::time::sleep(Duration::from_millis(10)).await;
63        }
64    }
65
66    async fn delete(&self, key: &str) -> TorshResult<()> {
67        self.data.remove(key);
68        Ok(())
69    }
70
71    async fn num_keys(&self) -> TorshResult<usize> {
72        Ok(self.data.len())
73    }
74
75    async fn contains(&self, key: &str) -> TorshResult<bool> {
76        Ok(self.data.contains_key(key))
77    }
78
79    async fn set_with_expiry(&self, key: &str, value: &[u8], _ttl: Duration) -> TorshResult<()> {
80        // Memory store doesn't support TTL, just set normally
81        self.set(key, value).await
82    }
83
84    async fn compare_and_swap(
85        &self,
86        key: &str,
87        expected: Option<&[u8]>,
88        value: &[u8],
89    ) -> TorshResult<bool> {
90        match expected {
91            Some(expected_val) => {
92                if let Some(current) = self.data.get(key) {
93                    if current.data() == expected_val {
94                        let store_value = StoreValue::new(value.to_vec());
95                        self.data.insert(key.to_string(), store_value);
96                        Ok(true)
97                    } else {
98                        Ok(false)
99                    }
100                } else {
101                    Ok(false)
102                }
103            }
104            None => {
105                // Expected value is None, so set only if key doesn't exist
106                if self.data.contains_key(key) {
107                    Ok(false)
108                } else {
109                    let store_value = StoreValue::new(value.to_vec());
110                    self.data.insert(key.to_string(), store_value);
111                    Ok(true)
112                }
113            }
114        }
115    }
116
117    async fn add(&self, key: &str, value: i64) -> TorshResult<i64> {
118        let new_value = if let Some(existing) = self.data.get(key) {
119            let current = i64::from_le_bytes(existing.data()[..8].try_into().map_err(|_| {
120                TorshDistributedError::invalid_argument(
121                    "value",
122                    "Failed to convert stored bytes to i64",
123                    "8 bytes representing a valid i64 value",
124                )
125            })?);
126            current + value
127        } else {
128            value
129        };
130
131        let store_value = StoreValue::new(new_value.to_le_bytes().to_vec());
132        self.data.insert(key.to_string(), store_value);
133        Ok(new_value)
134    }
135}