swiftide_indexing/persist/
memory_storage.rs

1use std::{
2    collections::HashMap,
3    sync::{
4        Arc,
5        atomic::{AtomicUsize, Ordering},
6    },
7};
8
9use anyhow::Result;
10use async_trait::async_trait;
11use derive_builder::Builder;
12use tokio::sync::RwLock;
13
14use swiftide_core::{
15    Persist,
16    indexing::{Chunk, IndexingStream, Node},
17};
18
19#[derive(Debug, Default, Builder, Clone)]
20#[builder(pattern = "owned")]
21/// A simple in-memory storage implementation.
22///
23/// Great for experimentation and testing.
24///
25/// The storage will use a zero indexed, incremental counter as the key for each node if the node id
26/// is not set.
27pub struct MemoryStorage<T: Chunk = String> {
28    data: Arc<RwLock<HashMap<String, Node<T>>>>,
29    #[builder(default)]
30    batch_size: Option<usize>,
31    #[builder(default = Arc::new(AtomicUsize::new(0)))]
32    node_count: Arc<AtomicUsize>,
33}
34
35impl<T: Chunk> MemoryStorage<T> {
36    fn key(&self) -> String {
37        self.node_count.fetch_add(1, Ordering::Relaxed).to_string()
38    }
39
40    /// Retrieve a node by its key
41    pub async fn get(&self, key: impl AsRef<str>) -> Option<Node<T>> {
42        self.data.read().await.get(key.as_ref()).cloned()
43    }
44
45    /// Retrieve all nodes in the storage
46    pub async fn get_all_values(&self) -> Vec<Node<T>> {
47        self.data.read().await.values().cloned().collect()
48    }
49
50    /// Retrieve all nodes in the storage with their keys
51    pub async fn get_all(&self) -> Vec<(String, Node<T>)> {
52        self.data
53            .read()
54            .await
55            .iter()
56            .map(|(k, v)| (k.clone(), v.clone()))
57            .collect()
58    }
59}
60
61#[async_trait]
62impl<T: Chunk> Persist for MemoryStorage<T> {
63    type Input = T;
64    type Output = T;
65    async fn setup(&self) -> Result<()> {
66        Ok(())
67    }
68
69    /// Store a node by its id
70    ///
71    /// If the node does not have an id, a simple counter is used as the key.
72    async fn store(&self, node: Node<T>) -> Result<Node<T>> {
73        self.data.write().await.insert(self.key(), node.clone());
74
75        Ok(node)
76    }
77
78    /// Store multiple nodes at once
79    ///
80    /// If a node does not have an id, a simple counter is used as the key.
81    async fn batch_store(&self, nodes: Vec<Node<T>>) -> IndexingStream<T> {
82        let mut lock = self.data.write().await;
83
84        for node in &nodes {
85            lock.insert(self.key(), node.clone());
86        }
87
88        IndexingStream::iter(nodes.into_iter().map(Ok))
89    }
90
91    fn batch_size(&self) -> Option<usize> {
92        self.batch_size
93    }
94}
95
96#[cfg(test)]
97mod test {
98    use super::*;
99    use futures_util::TryStreamExt;
100    use swiftide_core::indexing::TextNode;
101
102    #[tokio::test]
103    async fn test_memory_storage() {
104        let storage = MemoryStorage::default();
105        let node = TextNode::default();
106        let node = storage.store(node.clone()).await.unwrap();
107        assert_eq!(storage.get("0").await, Some(node));
108    }
109
110    #[tokio::test]
111    async fn test_inserting_multiple_nodes() {
112        let storage = MemoryStorage::default();
113        let node1 = TextNode::default();
114        let node2 = TextNode::default();
115
116        storage.store(node1.clone()).await.unwrap();
117        storage.store(node2.clone()).await.unwrap();
118
119        dbg!(storage.get_all().await);
120        assert_eq!(storage.get("0").await, Some(node1));
121        assert_eq!(storage.get("1").await, Some(node2));
122    }
123
124    #[tokio::test]
125    async fn test_batch_store() {
126        let storage = MemoryStorage::default();
127        let node1 = TextNode::default();
128        let node2 = TextNode::default();
129
130        let stream = storage
131            .batch_store(vec![node1.clone(), node2.clone()])
132            .await;
133
134        let result: Vec<TextNode> = stream.try_collect().await.unwrap();
135
136        assert_eq!(result.len(), 2);
137        assert_eq!(result[0], node1);
138        assert_eq!(result[1], node2);
139    }
140}