swiftide_indexing/persist/
memory_storage.rs

1use std::{
2    collections::HashMap,
3    sync::{
4        Arc,
5        atomic::{AtomicUsize, Ordering},
6    },
7};
8
9use anyhow::Result;
10use async_trait::async_trait;
11use derive_builder::Builder;
12use tokio::sync::RwLock;
13
14use swiftide_core::{
15    Persist,
16    indexing::{IndexingStream, Node},
17};
18
19#[derive(Debug, Default, Builder, Clone)]
20#[builder(pattern = "owned")]
21/// A simple in-memory storage implementation.
22///
23/// Great for experimentation and testing.
24///
25/// The storage will use a zero indexed, incremental counter as the key for each node if the node id
26/// is not set.
27pub struct MemoryStorage {
28    data: Arc<RwLock<HashMap<String, Node>>>,
29    #[builder(default)]
30    batch_size: Option<usize>,
31    #[builder(default = Arc::new(AtomicUsize::new(0)))]
32    node_count: Arc<AtomicUsize>,
33}
34
35impl MemoryStorage {
36    fn key(&self) -> String {
37        self.node_count.fetch_add(1, Ordering::Relaxed).to_string()
38    }
39
40    /// Retrieve a node by its key
41    pub async fn get(&self, key: impl AsRef<str>) -> Option<Node> {
42        self.data.read().await.get(key.as_ref()).cloned()
43    }
44
45    /// Retrieve all nodes in the storage
46    pub async fn get_all_values(&self) -> Vec<Node> {
47        self.data.read().await.values().cloned().collect()
48    }
49
50    /// Retrieve all nodes in the storage with their keys
51    pub async fn get_all(&self) -> Vec<(String, Node)> {
52        self.data
53            .read()
54            .await
55            .iter()
56            .map(|(k, v)| (k.clone(), v.clone()))
57            .collect()
58    }
59}
60
61#[async_trait]
62impl Persist for MemoryStorage {
63    async fn setup(&self) -> Result<()> {
64        Ok(())
65    }
66
67    /// Store a node by its id
68    ///
69    /// If the node does not have an id, a simple counter is used as the key.
70    async fn store(&self, node: Node) -> Result<Node> {
71        self.data.write().await.insert(self.key(), node.clone());
72
73        Ok(node)
74    }
75
76    /// Store multiple nodes at once
77    ///
78    /// If a node does not have an id, a simple counter is used as the key.
79    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
80        let mut lock = self.data.write().await;
81
82        for node in &nodes {
83            lock.insert(self.key(), node.clone());
84        }
85
86        IndexingStream::iter(nodes.into_iter().map(Ok))
87    }
88
89    fn batch_size(&self) -> Option<usize> {
90        self.batch_size
91    }
92}
93
94#[cfg(test)]
95mod test {
96    use super::*;
97    use futures_util::TryStreamExt;
98    use swiftide_core::indexing::Node;
99
100    #[tokio::test]
101    async fn test_memory_storage() {
102        let storage = MemoryStorage::default();
103        let node = Node::default();
104        let node = storage.store(node.clone()).await.unwrap();
105        assert_eq!(storage.get("0").await, Some(node));
106    }
107
108    #[tokio::test]
109    async fn test_inserting_multiple_nodes() {
110        let storage = MemoryStorage::default();
111        let node1 = Node::default();
112        let node2 = Node::default();
113
114        storage.store(node1.clone()).await.unwrap();
115        storage.store(node2.clone()).await.unwrap();
116
117        dbg!(storage.get_all().await);
118        assert_eq!(storage.get("0").await, Some(node1));
119        assert_eq!(storage.get("1").await, Some(node2));
120    }
121
122    #[tokio::test]
123    async fn test_batch_store() {
124        let storage = MemoryStorage::default();
125        let node1 = Node::default();
126        let node2 = Node::default();
127
128        let stream = storage
129            .batch_store(vec![node1.clone(), node2.clone()])
130            .await;
131
132        let result: Vec<Node> = stream.try_collect().await.unwrap();
133
134        assert_eq!(result.len(), 2);
135        assert_eq!(result[0], node1);
136        assert_eq!(result[1], node2);
137    }
138}