swiftide_indexing/persist/
memory_storage.rs1use std::{
2 collections::HashMap,
3 sync::{
4 Arc,
5 atomic::{AtomicUsize, Ordering},
6 },
7};
8
9use anyhow::Result;
10use async_trait::async_trait;
11use derive_builder::Builder;
12use tokio::sync::RwLock;
13
14use swiftide_core::{
15 Persist,
16 indexing::{IndexingStream, Node},
17};
18
19#[derive(Debug, Default, Builder, Clone)]
20#[builder(pattern = "owned")]
21pub struct MemoryStorage {
28 data: Arc<RwLock<HashMap<String, Node>>>,
29 #[builder(default)]
30 batch_size: Option<usize>,
31 #[builder(default = Arc::new(AtomicUsize::new(0)))]
32 node_count: Arc<AtomicUsize>,
33}
34
35impl MemoryStorage {
36 fn key(&self) -> String {
37 self.node_count.fetch_add(1, Ordering::Relaxed).to_string()
38 }
39
40 pub async fn get(&self, key: impl AsRef<str>) -> Option<Node> {
42 self.data.read().await.get(key.as_ref()).cloned()
43 }
44
45 pub async fn get_all_values(&self) -> Vec<Node> {
47 self.data.read().await.values().cloned().collect()
48 }
49
50 pub async fn get_all(&self) -> Vec<(String, Node)> {
52 self.data
53 .read()
54 .await
55 .iter()
56 .map(|(k, v)| (k.clone(), v.clone()))
57 .collect()
58 }
59}
60
61#[async_trait]
62impl Persist for MemoryStorage {
63 async fn setup(&self) -> Result<()> {
64 Ok(())
65 }
66
67 async fn store(&self, node: Node) -> Result<Node> {
71 self.data.write().await.insert(self.key(), node.clone());
72
73 Ok(node)
74 }
75
76 async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
80 let mut lock = self.data.write().await;
81
82 for node in &nodes {
83 lock.insert(self.key(), node.clone());
84 }
85
86 IndexingStream::iter(nodes.into_iter().map(Ok))
87 }
88
89 fn batch_size(&self) -> Option<usize> {
90 self.batch_size
91 }
92}
93
94#[cfg(test)]
95mod test {
96 use super::*;
97 use futures_util::TryStreamExt;
98 use swiftide_core::indexing::Node;
99
100 #[tokio::test]
101 async fn test_memory_storage() {
102 let storage = MemoryStorage::default();
103 let node = Node::default();
104 let node = storage.store(node.clone()).await.unwrap();
105 assert_eq!(storage.get("0").await, Some(node));
106 }
107
108 #[tokio::test]
109 async fn test_inserting_multiple_nodes() {
110 let storage = MemoryStorage::default();
111 let node1 = Node::default();
112 let node2 = Node::default();
113
114 storage.store(node1.clone()).await.unwrap();
115 storage.store(node2.clone()).await.unwrap();
116
117 dbg!(storage.get_all().await);
118 assert_eq!(storage.get("0").await, Some(node1));
119 assert_eq!(storage.get("1").await, Some(node2));
120 }
121
122 #[tokio::test]
123 async fn test_batch_store() {
124 let storage = MemoryStorage::default();
125 let node1 = Node::default();
126 let node2 = Node::default();
127
128 let stream = storage
129 .batch_store(vec![node1.clone(), node2.clone()])
130 .await;
131
132 let result: Vec<Node> = stream.try_collect().await.unwrap();
133
134 assert_eq!(result.len(), 2);
135 assert_eq!(result[0], node1);
136 assert_eq!(result[1], node2);
137 }
138}