swiftide_indexing/persist/
memory_storage.rs1use std::{
2 collections::HashMap,
3 sync::{
4 Arc,
5 atomic::{AtomicUsize, Ordering},
6 },
7};
8
9use anyhow::Result;
10use async_trait::async_trait;
11use derive_builder::Builder;
12use tokio::sync::RwLock;
13
14use swiftide_core::{
15 Persist,
16 indexing::{Chunk, IndexingStream, Node},
17};
18
19#[derive(Debug, Default, Builder, Clone)]
20#[builder(pattern = "owned")]
21pub struct MemoryStorage<T: Chunk = String> {
28 data: Arc<RwLock<HashMap<String, Node<T>>>>,
29 #[builder(default)]
30 batch_size: Option<usize>,
31 #[builder(default = Arc::new(AtomicUsize::new(0)))]
32 node_count: Arc<AtomicUsize>,
33}
34
35impl<T: Chunk> MemoryStorage<T> {
36 fn key(&self) -> String {
37 self.node_count.fetch_add(1, Ordering::Relaxed).to_string()
38 }
39
40 pub async fn get(&self, key: impl AsRef<str>) -> Option<Node<T>> {
42 self.data.read().await.get(key.as_ref()).cloned()
43 }
44
45 pub async fn get_all_values(&self) -> Vec<Node<T>> {
47 self.data.read().await.values().cloned().collect()
48 }
49
50 pub async fn get_all(&self) -> Vec<(String, Node<T>)> {
52 self.data
53 .read()
54 .await
55 .iter()
56 .map(|(k, v)| (k.clone(), v.clone()))
57 .collect()
58 }
59}
60
61#[async_trait]
62impl<T: Chunk> Persist for MemoryStorage<T> {
63 type Input = T;
64 type Output = T;
65 async fn setup(&self) -> Result<()> {
66 Ok(())
67 }
68
69 async fn store(&self, node: Node<T>) -> Result<Node<T>> {
73 self.data.write().await.insert(self.key(), node.clone());
74
75 Ok(node)
76 }
77
78 async fn batch_store(&self, nodes: Vec<Node<T>>) -> IndexingStream<T> {
82 let mut lock = self.data.write().await;
83
84 for node in &nodes {
85 lock.insert(self.key(), node.clone());
86 }
87
88 IndexingStream::iter(nodes.into_iter().map(Ok))
89 }
90
91 fn batch_size(&self) -> Option<usize> {
92 self.batch_size
93 }
94}
95
96#[cfg(test)]
97mod test {
98 use super::*;
99 use futures_util::TryStreamExt;
100 use swiftide_core::indexing::TextNode;
101
102 #[tokio::test]
103 async fn test_memory_storage() {
104 let storage = MemoryStorage::default();
105 let node = TextNode::default();
106 let node = storage.store(node.clone()).await.unwrap();
107 assert_eq!(storage.get("0").await, Some(node));
108 }
109
110 #[tokio::test]
111 async fn test_inserting_multiple_nodes() {
112 let storage = MemoryStorage::default();
113 let node1 = TextNode::default();
114 let node2 = TextNode::default();
115
116 storage.store(node1.clone()).await.unwrap();
117 storage.store(node2.clone()).await.unwrap();
118
119 dbg!(storage.get_all().await);
120 assert_eq!(storage.get("0").await, Some(node1));
121 assert_eq!(storage.get("1").await, Some(node2));
122 }
123
124 #[tokio::test]
125 async fn test_batch_store() {
126 let storage = MemoryStorage::default();
127 let node1 = TextNode::default();
128 let node2 = TextNode::default();
129
130 let stream = storage
131 .batch_store(vec![node1.clone(), node2.clone()])
132 .await;
133
134 let result: Vec<TextNode> = stream.try_collect().await.unwrap();
135
136 assert_eq!(result.len(), 2);
137 assert_eq!(result[0], node1);
138 assert_eq!(result[1], node2);
139 }
140}