Skip to main content

xz_embed/store/
memory.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::sync::RwLock;
5
6use crate::error::StoreError;
7use crate::traits::{StoreLifecycle, VectorStore};
8use crate::types::{MetadataFilter, SearchResult, StoreStats, VectorEntry};
9
10/// 内存向量存储(测试用)
11#[derive(Debug)]
12pub struct InMemoryVectorStore {
13    entries: RwLock<Vec<VectorEntry>>,
14    dimensions: usize,
15    closed: RwLock<bool>,
16}
17
18impl InMemoryVectorStore {
19    pub fn new(dimensions: usize) -> Self {
20        Self {
21            entries: RwLock::new(Vec::new()),
22            dimensions,
23            closed: RwLock::new(false),
24        }
25    }
26
27    fn check_closed(&self) -> Result<(), StoreError> {
28        if *self.closed.read().unwrap() {
29            return Err(StoreError::Closed);
30        }
31        Ok(())
32    }
33
34    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
35        let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
36        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
37        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
38        if norm_a == 0.0 || norm_b == 0.0 {
39            return 0.0;
40        }
41        dot / (norm_a * norm_b)
42    }
43
44    fn matches_filter(entry: &VectorEntry, filter: &MetadataFilter) -> bool {
45        match filter {
46            MetadataFilter::Eq { key, value } => {
47                entry.metadata.get(key).map(|v| v == value).unwrap_or(false)
48            }
49            MetadataFilter::Ne { key, value } => {
50                entry.metadata.get(key).map(|v| v != value).unwrap_or(true)
51            }
52            MetadataFilter::In { key, values } => {
53                entry.metadata.get(key).map(|v| values.contains(v)).unwrap_or(false)
54            }
55            MetadataFilter::NotIn { key, values } => {
56                entry.metadata.get(key).map(|v| !values.contains(v)).unwrap_or(true)
57            }
58            MetadataFilter::Exists { key } => entry.metadata.contains_key(key),
59            MetadataFilter::Contains { key, value } => {
60                entry.metadata.get(key).map(|v| v.contains(value)).unwrap_or(false)
61            }
62            MetadataFilter::Range { key, min, max } => {
63                if let Some(v) = entry.metadata.get(key) {
64                    if let Ok(num) = v.parse::<f64>() {
65                        return min.map_or(true, |m| num >= m) && max.map_or(true, |m| num <= m);
66                    }
67                }
68                false
69            }
70            MetadataFilter::And(filters) => filters.iter().all(|f| Self::matches_filter(entry, f)),
71            MetadataFilter::Or(filters) => filters.iter().any(|f| Self::matches_filter(entry, f)),
72            MetadataFilter::Not(filter) => !Self::matches_filter(entry, filter),
73        }
74    }
75}
76
77#[async_trait]
78impl VectorStore for InMemoryVectorStore {
79    async fn insert(&self, entry: VectorEntry) -> Result<(), StoreError> {
80        self.insert_batch(vec![entry]).await
81    }
82
83    async fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<(), StoreError> {
84        self.check_closed()?;
85        for entry in &entries {
86            if entry.vector.len() != self.dimensions {
87                return Err(StoreError::DimensionMismatch {
88                    expected: self.dimensions,
89                    actual: entry.vector.len(),
90                });
91            }
92        }
93        self.entries.write().unwrap().extend(entries);
94        Ok(())
95    }
96
97    async fn search(&self, query: &[f32], limit: usize) -> Result<Vec<SearchResult>, StoreError> {
98        self.check_closed()?;
99        if query.len() != self.dimensions {
100            return Err(StoreError::DimensionMismatch {
101                expected: self.dimensions,
102                actual: query.len(),
103            });
104        }
105
106        let entries = self.entries.read().unwrap();
107        let mut scored: Vec<(SearchResult, f32)> = entries
108            .iter()
109            .map(|entry| {
110                let similarity = Self::cosine_similarity(query, &entry.vector);
111                (
112                    SearchResult {
113                        id: entry.id.clone(),
114                        score: similarity,
115                        metadata: entry.metadata.clone(),
116                        content: entry.content.clone(),
117                        channel: entry.channel.clone(),
118                    },
119                    similarity,
120                )
121            })
122            .collect();
123
124        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
125        scored.truncate(limit);
126
127        Ok(scored.into_iter().map(|(r, _)| r).collect())
128    }
129
130    async fn search_with_filter(
131        &self,
132        query: &[f32],
133        filter: &MetadataFilter,
134        limit: usize,
135    ) -> Result<Vec<SearchResult>, StoreError> {
136        self.check_closed()?;
137        if query.len() != self.dimensions {
138            return Err(StoreError::DimensionMismatch {
139                expected: self.dimensions,
140                actual: query.len(),
141            });
142        }
143
144        let entries = self.entries.read().unwrap();
145        let mut scored: Vec<(SearchResult, f32)> = entries
146            .iter()
147            .filter(|entry| Self::matches_filter(entry, filter))
148            .map(|entry| {
149                let similarity = Self::cosine_similarity(query, &entry.vector);
150                (
151                    SearchResult {
152                        id: entry.id.clone(),
153                        score: similarity,
154                        metadata: entry.metadata.clone(),
155                        content: entry.content.clone(),
156                        channel: entry.channel.clone(),
157                    },
158                    similarity,
159                )
160            })
161            .collect();
162
163        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
164        scored.truncate(limit);
165
166        Ok(scored.into_iter().map(|(r, _)| r).collect())
167    }
168
169    async fn delete(&self, ids: &[String]) -> Result<usize, StoreError> {
170        self.check_closed()?;
171        let mut entries = self.entries.write().unwrap();
172        let before = entries.len();
173        entries.retain(|e| !ids.contains(&e.id));
174        Ok(before - entries.len())
175    }
176
177    async fn delete_by_filter(&self, filter: &MetadataFilter) -> Result<usize, StoreError> {
178        self.check_closed()?;
179        let mut entries = self.entries.write().unwrap();
180        let before = entries.len();
181        entries.retain(|e| !Self::matches_filter(e, filter));
182        Ok(before - entries.len())
183    }
184
185    async fn clear(&self) -> Result<(), StoreError> {
186        self.check_closed()?;
187        self.entries.write().unwrap().clear();
188        Ok(())
189    }
190
191    async fn count(&self) -> Result<usize, StoreError> {
192        self.check_closed()?;
193        Ok(self.entries.read().unwrap().len())
194    }
195
196    async fn rebuild_index(&self) -> Result<(), StoreError> {
197        Ok(())
198    }
199
200    async fn stats(&self) -> Result<StoreStats, StoreError> {
201        self.check_closed()?;
202        let count = self.entries.read().unwrap().len();
203        Ok(StoreStats {
204            total_vectors: count,
205            total_dimensions: self.dimensions,
206            index_size_bytes: 0,
207            data_size_bytes: 0,
208            last_indexed_at: None,
209        })
210    }
211}
212
213#[async_trait]
214impl StoreLifecycle for InMemoryVectorStore {
215    async fn initialize(&self) -> Result<(), StoreError> {
216        Ok(())
217    }
218
219    async fn close(&self) -> Result<(), StoreError> {
220        let mut closed = self.closed.write().unwrap();
221        *closed = true;
222        Ok(())
223    }
224
225    async fn checkpoint(&self) -> Result<(), StoreError> {
226        Ok(())
227    }
228
229    async fn health_check(&self) -> Result<bool, StoreError> {
230        Ok(!*self.closed.read().unwrap())
231    }
232}