taskchampion_sync_server_core/
inmemory.rs

1use super::{Client, Snapshot, Storage, StorageTxn, Version};
2use std::collections::HashMap;
3use std::sync::{Mutex, MutexGuard};
4use uuid::Uuid;
5
6struct Inner {
7    /// Clients, indexed by client_id
8    clients: HashMap<Uuid, Client>,
9
10    /// Snapshot data, indexed by client id
11    snapshots: HashMap<Uuid, Vec<u8>>,
12
13    /// Versions, indexed by (client_id, version_id)
14    versions: HashMap<(Uuid, Uuid), Version>,
15
16    /// Child versions, indexed by (client_id, parent_version_id)
17    children: HashMap<(Uuid, Uuid), Uuid>,
18}
19
20/// In-memory storage for testing and experimentation.
21///
22/// This is not for production use, but supports testing of sync server implementations.
23///
24/// NOTE: this panics if changes were made in a transaction that is later dropped without being
25/// committed, as this likely represents a bug that should be exposed in tests.
26pub struct InMemoryStorage(Mutex<Inner>);
27
28impl InMemoryStorage {
29    #[allow(clippy::new_without_default)]
30    pub fn new() -> Self {
31        Self(Mutex::new(Inner {
32            clients: HashMap::new(),
33            snapshots: HashMap::new(),
34            versions: HashMap::new(),
35            children: HashMap::new(),
36        }))
37    }
38}
39
40struct InnerTxn<'a> {
41    client_id: Uuid,
42    guard: MutexGuard<'a, Inner>,
43    written: bool,
44    committed: bool,
45}
46
47#[async_trait::async_trait]
48impl Storage for InMemoryStorage {
49    async fn txn(&self, client_id: Uuid) -> anyhow::Result<Box<dyn StorageTxn + '_>> {
50        Ok(Box::new(InnerTxn {
51            client_id,
52            guard: self.0.lock().expect("poisoned lock"),
53            written: false,
54            committed: false,
55        }))
56    }
57}
58
59#[async_trait::async_trait(?Send)]
60impl StorageTxn for InnerTxn<'_> {
61    async fn get_client(&mut self) -> anyhow::Result<Option<Client>> {
62        Ok(self.guard.clients.get(&self.client_id).cloned())
63    }
64
65    async fn new_client(&mut self, latest_version_id: Uuid) -> anyhow::Result<()> {
66        if self.guard.clients.contains_key(&self.client_id) {
67            return Err(anyhow::anyhow!("Client {} already exists", self.client_id));
68        }
69        self.guard.clients.insert(
70            self.client_id,
71            Client {
72                latest_version_id,
73                snapshot: None,
74            },
75        );
76        self.written = true;
77        Ok(())
78    }
79
80    async fn set_snapshot(&mut self, snapshot: Snapshot, data: Vec<u8>) -> anyhow::Result<()> {
81        let client = self
82            .guard
83            .clients
84            .get_mut(&self.client_id)
85            .ok_or_else(|| anyhow::anyhow!("no such client"))?;
86        client.snapshot = Some(snapshot);
87        self.guard.snapshots.insert(self.client_id, data);
88        self.written = true;
89        Ok(())
90    }
91
92    async fn get_snapshot_data(&mut self, version_id: Uuid) -> anyhow::Result<Option<Vec<u8>>> {
93        // sanity check
94        let client = self.guard.clients.get(&self.client_id);
95        let client = client.ok_or_else(|| anyhow::anyhow!("no such client"))?;
96        if Some(&version_id) != client.snapshot.as_ref().map(|snap| &snap.version_id) {
97            return Err(anyhow::anyhow!("unexpected snapshot_version_id"));
98        }
99        Ok(self.guard.snapshots.get(&self.client_id).cloned())
100    }
101
102    async fn get_version_by_parent(
103        &mut self,
104        parent_version_id: Uuid,
105    ) -> anyhow::Result<Option<Version>> {
106        if let Some(parent_version_id) = self
107            .guard
108            .children
109            .get(&(self.client_id, parent_version_id))
110        {
111            Ok(self
112                .guard
113                .versions
114                .get(&(self.client_id, *parent_version_id))
115                .cloned())
116        } else {
117            Ok(None)
118        }
119    }
120
121    async fn get_version(&mut self, version_id: Uuid) -> anyhow::Result<Option<Version>> {
122        Ok(self
123            .guard
124            .versions
125            .get(&(self.client_id, version_id))
126            .cloned())
127    }
128
129    async fn add_version(
130        &mut self,
131        version_id: Uuid,
132        parent_version_id: Uuid,
133        history_segment: Vec<u8>,
134    ) -> anyhow::Result<()> {
135        let version = Version {
136            version_id,
137            parent_version_id,
138            history_segment,
139        };
140
141        if let Some(client) = self.guard.clients.get_mut(&self.client_id) {
142            client.latest_version_id = version_id;
143            if let Some(ref mut snap) = client.snapshot {
144                snap.versions_since += 1;
145            }
146        } else {
147            anyhow::bail!("Client {} does not exist", self.client_id);
148        }
149
150        if self
151            .guard
152            .children
153            .insert((self.client_id, parent_version_id), version_id)
154            .is_some()
155        {
156            anyhow::bail!(
157                "Client {} already has a child for {}",
158                self.client_id,
159                parent_version_id
160            );
161        }
162        if self
163            .guard
164            .versions
165            .insert((self.client_id, version_id), version)
166            .is_some()
167        {
168            anyhow::bail!(
169                "Client {} already has a version {}",
170                self.client_id,
171                version_id
172            );
173        }
174
175        self.written = true;
176        Ok(())
177    }
178
179    async fn commit(&mut self) -> anyhow::Result<()> {
180        self.committed = true;
181        Ok(())
182    }
183}
184
185impl Drop for InnerTxn<'_> {
186    fn drop(&mut self) {
187        if self.written && !self.committed {
188            panic!("Uncommitted InMemoryStorage transaction dropped without commiting");
189        }
190    }
191}
192
193#[cfg(test)]
194mod test {
195    use super::*;
196    use chrono::Utc;
197
198    #[tokio::test]
199    async fn test_get_client_empty() -> anyhow::Result<()> {
200        let storage = InMemoryStorage::new();
201        let mut txn = storage.txn(Uuid::new_v4()).await?;
202        let maybe_client = txn.get_client().await?;
203        assert!(maybe_client.is_none());
204        Ok(())
205    }
206
207    #[tokio::test]
208    async fn test_client_storage() -> anyhow::Result<()> {
209        let storage = InMemoryStorage::new();
210        let client_id = Uuid::new_v4();
211        let mut txn = storage.txn(client_id).await?;
212
213        let latest_version_id = Uuid::new_v4();
214        txn.new_client(latest_version_id).await?;
215
216        let client = txn.get_client().await?.unwrap();
217        assert_eq!(client.latest_version_id, latest_version_id);
218        assert!(client.snapshot.is_none());
219
220        let latest_version_id = Uuid::new_v4();
221        txn.add_version(latest_version_id, Uuid::new_v4(), vec![1, 1])
222            .await?;
223
224        let client = txn.get_client().await?.unwrap();
225        assert_eq!(client.latest_version_id, latest_version_id);
226        assert!(client.snapshot.is_none());
227
228        let snap = Snapshot {
229            version_id: Uuid::new_v4(),
230            timestamp: Utc::now(),
231            versions_since: 4,
232        };
233        txn.set_snapshot(snap.clone(), vec![1, 2, 3]).await?;
234
235        let client = txn.get_client().await?.unwrap();
236        assert_eq!(client.latest_version_id, latest_version_id);
237        assert_eq!(client.snapshot.unwrap(), snap);
238
239        txn.commit().await?;
240        Ok(())
241    }
242
243    #[tokio::test]
244    async fn test_gvbp_empty() -> anyhow::Result<()> {
245        let storage = InMemoryStorage::new();
246        let client_id = Uuid::new_v4();
247        let mut txn = storage.txn(client_id).await?;
248        let maybe_version = txn.get_version_by_parent(Uuid::new_v4()).await?;
249        assert!(maybe_version.is_none());
250        Ok(())
251    }
252
253    #[tokio::test]
254    async fn test_add_version_and_get_version() -> anyhow::Result<()> {
255        let storage = InMemoryStorage::new();
256        let client_id = Uuid::new_v4();
257        let mut txn = storage.txn(client_id).await?;
258
259        let version_id = Uuid::new_v4();
260        let parent_version_id = Uuid::new_v4();
261        let history_segment = b"abc".to_vec();
262
263        txn.new_client(parent_version_id).await?;
264        txn.add_version(version_id, parent_version_id, history_segment.clone())
265            .await?;
266
267        let expected = Version {
268            version_id,
269            parent_version_id,
270            history_segment,
271        };
272
273        let version = txn.get_version_by_parent(parent_version_id).await?.unwrap();
274        assert_eq!(version, expected);
275
276        let version = txn.get_version(version_id).await?.unwrap();
277        assert_eq!(version, expected);
278
279        txn.commit().await?;
280        Ok(())
281    }
282
283    #[tokio::test]
284    async fn test_add_version_exists() -> anyhow::Result<()> {
285        let storage = InMemoryStorage::new();
286        let client_id = Uuid::new_v4();
287        let mut txn = storage.txn(client_id).await?;
288
289        let version_id = Uuid::new_v4();
290        let parent_version_id = Uuid::new_v4();
291        let history_segment = b"abc".to_vec();
292
293        txn.new_client(parent_version_id).await?;
294        txn.add_version(version_id, parent_version_id, history_segment.clone())
295            .await?;
296        assert!(txn
297            .add_version(version_id, parent_version_id, history_segment.clone())
298            .await
299            .is_err());
300        txn.commit().await?;
301        Ok(())
302    }
303
304    #[tokio::test]
305    async fn test_snapshots() -> anyhow::Result<()> {
306        let storage = InMemoryStorage::new();
307        let client_id = Uuid::new_v4();
308        let mut txn = storage.txn(client_id).await?;
309
310        txn.new_client(Uuid::new_v4()).await?;
311        assert!(txn.get_client().await?.unwrap().snapshot.is_none());
312
313        let snap = Snapshot {
314            version_id: Uuid::new_v4(),
315            timestamp: Utc::now(),
316            versions_since: 3,
317        };
318        txn.set_snapshot(snap.clone(), vec![9, 8, 9]).await?;
319
320        assert_eq!(
321            txn.get_snapshot_data(snap.version_id).await?.unwrap(),
322            vec![9, 8, 9]
323        );
324        assert_eq!(txn.get_client().await?.unwrap().snapshot, Some(snap));
325
326        let snap2 = Snapshot {
327            version_id: Uuid::new_v4(),
328            timestamp: Utc::now(),
329            versions_since: 10,
330        };
331        txn.set_snapshot(snap2.clone(), vec![0, 2, 4, 6]).await?;
332
333        assert_eq!(
334            txn.get_snapshot_data(snap2.version_id).await?.unwrap(),
335            vec![0, 2, 4, 6]
336        );
337        assert_eq!(txn.get_client().await?.unwrap().snapshot, Some(snap2));
338
339        // check that mismatched version is detected
340        assert!(txn.get_snapshot_data(Uuid::new_v4()).await.is_err());
341
342        txn.commit().await?;
343        Ok(())
344    }
345}