1use crate::error::Result;
4use crate::index::flat::FlatIndex;
5
6#[cfg(feature = "hnsw")]
7use crate::index::hnsw::HnswIndex;
8
9use crate::index::VectorIndex;
10use crate::types::*;
11use parking_lot::RwLock;
12use std::sync::Arc;
13
14#[cfg(feature = "storage")]
16use crate::storage::VectorStorage;
17
18#[cfg(not(feature = "storage"))]
19use crate::storage_memory::MemoryStorage as VectorStorage;
20
21pub struct VectorDB {
23 storage: Arc<VectorStorage>,
24 index: Arc<RwLock<Box<dyn VectorIndex>>>,
25 options: DbOptions,
26}
27
28impl VectorDB {
29 pub fn new(options: DbOptions) -> Result<Self> {
34 #[cfg(feature = "storage")]
35 let storage = Arc::new(VectorStorage::new(
36 &options.storage_path,
37 options.dimensions,
38 )?);
39
40 #[cfg(not(feature = "storage"))]
41 let storage = Arc::new(VectorStorage::new(options.dimensions)?);
42
43 let mut index: Box<dyn VectorIndex> = if let Some(hnsw_config) = &options.hnsw_config {
45 #[cfg(feature = "hnsw")]
46 {
47 Box::new(HnswIndex::new(
48 options.dimensions,
49 options.distance_metric,
50 hnsw_config.clone(),
51 )?)
52 }
53 #[cfg(not(feature = "hnsw"))]
54 {
55 tracing::warn!("HNSW requested but not available (WASM build), using flat index");
57 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
58 }
59 } else {
60 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
61 };
62
63 #[cfg(feature = "storage")]
66 {
67 let stored_ids = storage.all_ids()?;
68 if !stored_ids.is_empty() {
69 tracing::info!(
70 "Rebuilding index from {} persisted vectors",
71 stored_ids.len()
72 );
73
74 let mut entries = Vec::with_capacity(stored_ids.len());
76 for id in stored_ids {
77 if let Some(entry) = storage.get(&id)? {
78 entries.push((id, entry.vector));
79 }
80 }
81
82 index.add_batch(entries)?;
84
85 tracing::info!("Index rebuilt successfully");
86 }
87 }
88
89 Ok(Self {
90 storage,
91 index: Arc::new(RwLock::new(index)),
92 options,
93 })
94 }
95
96 pub fn with_dimensions(dimensions: usize) -> Result<Self> {
98 let mut options = DbOptions::default();
99 options.dimensions = dimensions;
100 Self::new(options)
101 }
102
103 pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
105 let id = self.storage.insert(&entry)?;
106
107 let mut index = self.index.write();
109 index.add(id.clone(), entry.vector)?;
110
111 Ok(id)
112 }
113
114 pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
116 let ids = self.storage.insert_batch(&entries)?;
117
118 let mut index = self.index.write();
120 let index_entries: Vec<_> = ids
121 .iter()
122 .zip(entries.iter())
123 .map(|(id, entry)| (id.clone(), entry.vector.clone()))
124 .collect();
125
126 index.add_batch(index_entries)?;
127
128 Ok(ids)
129 }
130
131 pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
133 let index = self.index.read();
134 let mut results = index.search(&query.vector, query.k)?;
135
136 for result in &mut results {
138 if let Ok(Some(entry)) = self.storage.get(&result.id) {
139 result.vector = Some(entry.vector);
140 result.metadata = entry.metadata;
141 }
142 }
143
144 if let Some(filter) = &query.filter {
146 results.retain(|r| {
147 if let Some(metadata) = &r.metadata {
148 filter
149 .iter()
150 .all(|(key, value)| metadata.get(key).map_or(false, |v| v == value))
151 } else {
152 false
153 }
154 });
155 }
156
157 Ok(results)
158 }
159
160 pub fn delete(&self, id: &str) -> Result<bool> {
162 let deleted_storage = self.storage.delete(id)?;
163
164 if deleted_storage {
165 let mut index = self.index.write();
166 let _ = index.remove(&id.to_string())?;
167 }
168
169 Ok(deleted_storage)
170 }
171
172 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
174 self.storage.get(id)
175 }
176
177 pub fn len(&self) -> Result<usize> {
179 self.storage.len()
180 }
181
182 pub fn is_empty(&self) -> Result<bool> {
184 self.storage.is_empty()
185 }
186
187 pub fn options(&self) -> &DbOptions {
189 &self.options
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use std::path::Path;
197 use tempfile::tempdir;
198
199 #[test]
200 fn test_vector_db_creation() -> Result<()> {
201 let dir = tempdir().unwrap();
202 let mut options = DbOptions::default();
203 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
204 options.dimensions = 3;
205
206 let db = VectorDB::new(options)?;
207 assert!(db.is_empty()?);
208
209 Ok(())
210 }
211
212 #[test]
213 fn test_insert_and_search() -> Result<()> {
214 let dir = tempdir().unwrap();
215 let mut options = DbOptions::default();
216 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
217 options.dimensions = 3;
218 options.distance_metric = DistanceMetric::Euclidean; options.hnsw_config = None; let db = VectorDB::new(options)?;
222
223 db.insert(VectorEntry {
225 id: Some("v1".to_string()),
226 vector: vec![1.0, 0.0, 0.0],
227 metadata: None,
228 })?;
229
230 db.insert(VectorEntry {
231 id: Some("v2".to_string()),
232 vector: vec![0.0, 1.0, 0.0],
233 metadata: None,
234 })?;
235
236 db.insert(VectorEntry {
237 id: Some("v3".to_string()),
238 vector: vec![0.0, 0.0, 1.0],
239 metadata: None,
240 })?;
241
242 let results = db.search(SearchQuery {
244 vector: vec![1.0, 0.0, 0.0],
245 k: 2,
246 filter: None,
247 ef_search: None,
248 })?;
249
250 assert!(results.len() >= 1);
251 assert_eq!(results[0].id, "v1", "First result should be exact match");
252 assert!(
253 results[0].score < 0.01,
254 "Exact match should have ~0 distance"
255 );
256
257 Ok(())
258 }
259
260 #[test]
263 #[cfg(feature = "storage")]
264 fn test_search_after_restart() -> Result<()> {
265 let dir = tempdir().unwrap();
266 let db_path = dir.path().join("persist.db").to_string_lossy().to_string();
267
268 {
270 let mut options = DbOptions::default();
271 options.storage_path = db_path.clone();
272 options.dimensions = 3;
273 options.distance_metric = DistanceMetric::Euclidean;
274 options.hnsw_config = None;
275
276 let db = VectorDB::new(options)?;
277
278 db.insert(VectorEntry {
279 id: Some("v1".to_string()),
280 vector: vec![1.0, 0.0, 0.0],
281 metadata: None,
282 })?;
283
284 db.insert(VectorEntry {
285 id: Some("v2".to_string()),
286 vector: vec![0.0, 1.0, 0.0],
287 metadata: None,
288 })?;
289
290 db.insert(VectorEntry {
291 id: Some("v3".to_string()),
292 vector: vec![0.7, 0.7, 0.0],
293 metadata: None,
294 })?;
295
296 let results = db.search(SearchQuery {
298 vector: vec![0.8, 0.6, 0.0],
299 k: 3,
300 filter: None,
301 ef_search: None,
302 })?;
303 assert_eq!(results.len(), 3, "Should find all 3 vectors before restart");
304 }
305 {
309 let mut options = DbOptions::default();
310 options.storage_path = db_path.clone();
311 options.dimensions = 3;
312 options.distance_metric = DistanceMetric::Euclidean;
313 options.hnsw_config = None;
314
315 let db = VectorDB::new(options)?;
316
317 assert_eq!(db.len()?, 3, "Should have 3 vectors after restart");
319
320 let v1 = db.get("v1")?;
322 assert!(v1.is_some(), "get() should work after restart");
323
324 let results = db.search(SearchQuery {
326 vector: vec![0.8, 0.6, 0.0],
327 k: 3,
328 filter: None,
329 ef_search: None,
330 })?;
331
332 assert_eq!(
333 results.len(),
334 3,
335 "search() should return results after restart (was returning 0 before fix)"
336 );
337
338 assert_eq!(
340 results[0].id, "v3",
341 "v3 [0.7, 0.7, 0.0] should be closest to query [0.8, 0.6, 0.0]"
342 );
343 }
344
345 Ok(())
346 }
347}