1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::sync::RwLock;
5
6use crate::error::StoreError;
7use crate::traits::{StoreLifecycle, VectorStore};
8use crate::types::{MetadataFilter, SearchResult, StoreStats, VectorEntry};
9
10#[derive(Debug)]
12pub struct InMemoryVectorStore {
13 entries: RwLock<Vec<VectorEntry>>,
14 dimensions: usize,
15 closed: RwLock<bool>,
16}
17
18impl InMemoryVectorStore {
19 pub fn new(dimensions: usize) -> Self {
20 Self {
21 entries: RwLock::new(Vec::new()),
22 dimensions,
23 closed: RwLock::new(false),
24 }
25 }
26
27 fn check_closed(&self) -> Result<(), StoreError> {
28 if *self.closed.read().unwrap() {
29 return Err(StoreError::Closed);
30 }
31 Ok(())
32 }
33
34 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
35 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
36 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
37 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
38 if norm_a == 0.0 || norm_b == 0.0 {
39 return 0.0;
40 }
41 dot / (norm_a * norm_b)
42 }
43
44 fn matches_filter(entry: &VectorEntry, filter: &MetadataFilter) -> bool {
45 match filter {
46 MetadataFilter::Eq { key, value } => {
47 entry.metadata.get(key).map(|v| v == value).unwrap_or(false)
48 }
49 MetadataFilter::Ne { key, value } => {
50 entry.metadata.get(key).map(|v| v != value).unwrap_or(true)
51 }
52 MetadataFilter::In { key, values } => {
53 entry.metadata.get(key).map(|v| values.contains(v)).unwrap_or(false)
54 }
55 MetadataFilter::NotIn { key, values } => {
56 entry.metadata.get(key).map(|v| !values.contains(v)).unwrap_or(true)
57 }
58 MetadataFilter::Exists { key } => entry.metadata.contains_key(key),
59 MetadataFilter::Contains { key, value } => {
60 entry.metadata.get(key).map(|v| v.contains(value)).unwrap_or(false)
61 }
62 MetadataFilter::Range { key, min, max } => {
63 if let Some(v) = entry.metadata.get(key) {
64 if let Ok(num) = v.parse::<f64>() {
65 return min.map_or(true, |m| num >= m) && max.map_or(true, |m| num <= m);
66 }
67 }
68 false
69 }
70 MetadataFilter::And(filters) => filters.iter().all(|f| Self::matches_filter(entry, f)),
71 MetadataFilter::Or(filters) => filters.iter().any(|f| Self::matches_filter(entry, f)),
72 MetadataFilter::Not(filter) => !Self::matches_filter(entry, filter),
73 }
74 }
75}
76
77#[async_trait]
78impl VectorStore for InMemoryVectorStore {
79 async fn insert(&self, entry: VectorEntry) -> Result<(), StoreError> {
80 self.insert_batch(vec![entry]).await
81 }
82
83 async fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<(), StoreError> {
84 self.check_closed()?;
85 for entry in &entries {
86 if entry.vector.len() != self.dimensions {
87 return Err(StoreError::DimensionMismatch {
88 expected: self.dimensions,
89 actual: entry.vector.len(),
90 });
91 }
92 }
93 self.entries.write().unwrap().extend(entries);
94 Ok(())
95 }
96
97 async fn search(&self, query: &[f32], limit: usize) -> Result<Vec<SearchResult>, StoreError> {
98 self.check_closed()?;
99 if query.len() != self.dimensions {
100 return Err(StoreError::DimensionMismatch {
101 expected: self.dimensions,
102 actual: query.len(),
103 });
104 }
105
106 let entries = self.entries.read().unwrap();
107 let mut scored: Vec<(SearchResult, f32)> = entries
108 .iter()
109 .map(|entry| {
110 let similarity = Self::cosine_similarity(query, &entry.vector);
111 (
112 SearchResult {
113 id: entry.id.clone(),
114 score: similarity,
115 metadata: entry.metadata.clone(),
116 content: entry.content.clone(),
117 channel: entry.channel.clone(),
118 },
119 similarity,
120 )
121 })
122 .collect();
123
124 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
125 scored.truncate(limit);
126
127 Ok(scored.into_iter().map(|(r, _)| r).collect())
128 }
129
130 async fn search_with_filter(
131 &self,
132 query: &[f32],
133 filter: &MetadataFilter,
134 limit: usize,
135 ) -> Result<Vec<SearchResult>, StoreError> {
136 self.check_closed()?;
137 if query.len() != self.dimensions {
138 return Err(StoreError::DimensionMismatch {
139 expected: self.dimensions,
140 actual: query.len(),
141 });
142 }
143
144 let entries = self.entries.read().unwrap();
145 let mut scored: Vec<(SearchResult, f32)> = entries
146 .iter()
147 .filter(|entry| Self::matches_filter(entry, filter))
148 .map(|entry| {
149 let similarity = Self::cosine_similarity(query, &entry.vector);
150 (
151 SearchResult {
152 id: entry.id.clone(),
153 score: similarity,
154 metadata: entry.metadata.clone(),
155 content: entry.content.clone(),
156 channel: entry.channel.clone(),
157 },
158 similarity,
159 )
160 })
161 .collect();
162
163 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
164 scored.truncate(limit);
165
166 Ok(scored.into_iter().map(|(r, _)| r).collect())
167 }
168
169 async fn delete(&self, ids: &[String]) -> Result<usize, StoreError> {
170 self.check_closed()?;
171 let mut entries = self.entries.write().unwrap();
172 let before = entries.len();
173 entries.retain(|e| !ids.contains(&e.id));
174 Ok(before - entries.len())
175 }
176
177 async fn delete_by_filter(&self, filter: &MetadataFilter) -> Result<usize, StoreError> {
178 self.check_closed()?;
179 let mut entries = self.entries.write().unwrap();
180 let before = entries.len();
181 entries.retain(|e| !Self::matches_filter(e, filter));
182 Ok(before - entries.len())
183 }
184
185 async fn clear(&self) -> Result<(), StoreError> {
186 self.check_closed()?;
187 self.entries.write().unwrap().clear();
188 Ok(())
189 }
190
191 async fn count(&self) -> Result<usize, StoreError> {
192 self.check_closed()?;
193 Ok(self.entries.read().unwrap().len())
194 }
195
196 async fn rebuild_index(&self) -> Result<(), StoreError> {
197 Ok(())
198 }
199
200 async fn stats(&self) -> Result<StoreStats, StoreError> {
201 self.check_closed()?;
202 let count = self.entries.read().unwrap().len();
203 Ok(StoreStats {
204 total_vectors: count,
205 total_dimensions: self.dimensions,
206 index_size_bytes: 0,
207 data_size_bytes: 0,
208 last_indexed_at: None,
209 })
210 }
211}
212
213#[async_trait]
214impl StoreLifecycle for InMemoryVectorStore {
215 async fn initialize(&self) -> Result<(), StoreError> {
216 Ok(())
217 }
218
219 async fn close(&self) -> Result<(), StoreError> {
220 let mut closed = self.closed.write().unwrap();
221 *closed = true;
222 Ok(())
223 }
224
225 async fn checkpoint(&self) -> Result<(), StoreError> {
226 Ok(())
227 }
228
229 async fn health_check(&self) -> Result<bool, StoreError> {
230 Ok(!*self.closed.read().unwrap())
231 }
232}