1use crate::error::Result;
4use crate::index::flat::FlatIndex;
5
6#[cfg(feature = "hnsw")]
7use crate::index::hnsw::HnswIndex;
8
9use crate::index::VectorIndex;
10use crate::types::*;
11use parking_lot::RwLock;
12use std::sync::Arc;
13
14#[cfg(feature = "storage")]
16use crate::storage::VectorStorage;
17
18#[cfg(not(feature = "storage"))]
19use crate::storage_memory::MemoryStorage as VectorStorage;
20
21pub struct VectorDB {
23 storage: Arc<VectorStorage>,
24 index: Arc<RwLock<Box<dyn VectorIndex>>>,
25 options: DbOptions,
26}
27
28impl VectorDB {
29 #[allow(unused_mut)] pub fn new(mut options: DbOptions) -> Result<Self> {
37 #[cfg(feature = "storage")]
38 let storage = {
39 let temp_storage = VectorStorage::new(&options.storage_path, options.dimensions)?;
42
43 let stored_config = temp_storage.load_config()?;
44
45 if let Some(config) = stored_config {
46 tracing::info!(
48 "Loading existing database with {} dimensions",
49 config.dimensions
50 );
51 options = DbOptions {
52 storage_path: options.storage_path.clone(),
54 dimensions: config.dimensions,
56 distance_metric: config.distance_metric,
57 hnsw_config: config.hnsw_config,
58 quantization: config.quantization,
59 };
60 Arc::new(VectorStorage::new(
62 &options.storage_path,
63 options.dimensions,
64 )?)
65 } else {
66 tracing::info!(
68 "Creating new database with {} dimensions",
69 options.dimensions
70 );
71 temp_storage.save_config(&options)?;
72 Arc::new(temp_storage)
73 }
74 };
75
76 #[cfg(not(feature = "storage"))]
77 let storage = Arc::new(VectorStorage::new(options.dimensions)?);
78
79 #[allow(unused_mut)] let mut index: Box<dyn VectorIndex> = if let Some(hnsw_config) = &options.hnsw_config {
82 #[cfg(feature = "hnsw")]
83 {
84 Box::new(HnswIndex::new(
85 options.dimensions,
86 options.distance_metric,
87 hnsw_config.clone(),
88 )?)
89 }
90 #[cfg(not(feature = "hnsw"))]
91 {
92 tracing::warn!("HNSW requested but not available (WASM build), using flat index");
94 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
95 }
96 } else {
97 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
98 };
99
100 #[cfg(feature = "storage")]
103 {
104 let stored_ids = storage.all_ids()?;
105 if !stored_ids.is_empty() {
106 tracing::info!(
107 "Rebuilding index from {} persisted vectors",
108 stored_ids.len()
109 );
110
111 let mut entries = Vec::with_capacity(stored_ids.len());
113 for id in stored_ids {
114 if let Some(entry) = storage.get(&id)? {
115 entries.push((id, entry.vector));
116 }
117 }
118
119 index.add_batch(entries)?;
121
122 tracing::info!("Index rebuilt successfully");
123 }
124 }
125
126 Ok(Self {
127 storage,
128 index: Arc::new(RwLock::new(index)),
129 options,
130 })
131 }
132
133 pub fn with_dimensions(dimensions: usize) -> Result<Self> {
135 let options = DbOptions {
136 dimensions,
137 ..DbOptions::default()
138 };
139 Self::new(options)
140 }
141
142 pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
144 let id = self.storage.insert(&entry)?;
145
146 let mut index = self.index.write();
148 index.add(id.clone(), entry.vector)?;
149
150 Ok(id)
151 }
152
153 pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
155 let ids = self.storage.insert_batch(&entries)?;
156
157 let mut index = self.index.write();
159 let index_entries: Vec<_> = ids
160 .iter()
161 .zip(entries.iter())
162 .map(|(id, entry)| (id.clone(), entry.vector.clone()))
163 .collect();
164
165 index.add_batch(index_entries)?;
166
167 Ok(ids)
168 }
169
170 pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
172 let index = self.index.read();
173 let mut results = index.search(&query.vector, query.k)?;
174
175 for result in &mut results {
177 if let Ok(Some(entry)) = self.storage.get(&result.id) {
178 result.vector = Some(entry.vector);
179 result.metadata = entry.metadata;
180 }
181 }
182
183 if let Some(filter) = &query.filter {
185 results.retain(|r| {
186 if let Some(metadata) = &r.metadata {
187 filter
188 .iter()
189 .all(|(key, value)| metadata.get(key).is_some_and(|v| v == value))
190 } else {
191 false
192 }
193 });
194 }
195
196 Ok(results)
197 }
198
199 pub fn delete(&self, id: &str) -> Result<bool> {
201 let deleted_storage = self.storage.delete(id)?;
202
203 if deleted_storage {
204 let mut index = self.index.write();
205 let _ = index.remove(&id.to_string())?;
206 }
207
208 Ok(deleted_storage)
209 }
210
211 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
213 self.storage.get(id)
214 }
215
216 pub fn len(&self) -> Result<usize> {
218 self.storage.len()
219 }
220
221 pub fn is_empty(&self) -> Result<bool> {
223 self.storage.is_empty()
224 }
225
226 pub fn options(&self) -> &DbOptions {
228 &self.options
229 }
230
231 pub fn keys(&self) -> Result<Vec<String>> {
233 self.storage.all_ids()
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use std::path::Path;
241 use tempfile::tempdir;
242
243 #[test]
244 fn test_vector_db_creation() -> Result<()> {
245 let dir = tempdir().unwrap();
246 let mut options = DbOptions::default();
247 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
248 options.dimensions = 3;
249
250 let db = VectorDB::new(options)?;
251 assert!(db.is_empty()?);
252
253 Ok(())
254 }
255
256 #[test]
257 fn test_insert_and_search() -> Result<()> {
258 let dir = tempdir().unwrap();
259 let mut options = DbOptions::default();
260 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
261 options.dimensions = 3;
262 options.distance_metric = DistanceMetric::Euclidean; options.hnsw_config = None; let db = VectorDB::new(options)?;
266
267 db.insert(VectorEntry {
269 id: Some("v1".to_string()),
270 vector: vec![1.0, 0.0, 0.0],
271 metadata: None,
272 })?;
273
274 db.insert(VectorEntry {
275 id: Some("v2".to_string()),
276 vector: vec![0.0, 1.0, 0.0],
277 metadata: None,
278 })?;
279
280 db.insert(VectorEntry {
281 id: Some("v3".to_string()),
282 vector: vec![0.0, 0.0, 1.0],
283 metadata: None,
284 })?;
285
286 let results = db.search(SearchQuery {
288 vector: vec![1.0, 0.0, 0.0],
289 k: 2,
290 filter: None,
291 ef_search: None,
292 })?;
293
294 assert!(results.len() >= 1);
295 assert_eq!(results[0].id, "v1", "First result should be exact match");
296 assert!(
297 results[0].score < 0.01,
298 "Exact match should have ~0 distance"
299 );
300
301 Ok(())
302 }
303
304 #[test]
307 #[cfg(feature = "storage")]
308 fn test_search_after_restart() -> Result<()> {
309 let dir = tempdir().unwrap();
310 let db_path = dir.path().join("persist.db").to_string_lossy().to_string();
311
312 {
314 let mut options = DbOptions::default();
315 options.storage_path = db_path.clone();
316 options.dimensions = 3;
317 options.distance_metric = DistanceMetric::Euclidean;
318 options.hnsw_config = None;
319
320 let db = VectorDB::new(options)?;
321
322 db.insert(VectorEntry {
323 id: Some("v1".to_string()),
324 vector: vec![1.0, 0.0, 0.0],
325 metadata: None,
326 })?;
327
328 db.insert(VectorEntry {
329 id: Some("v2".to_string()),
330 vector: vec![0.0, 1.0, 0.0],
331 metadata: None,
332 })?;
333
334 db.insert(VectorEntry {
335 id: Some("v3".to_string()),
336 vector: vec![0.7, 0.7, 0.0],
337 metadata: None,
338 })?;
339
340 let results = db.search(SearchQuery {
342 vector: vec![0.8, 0.6, 0.0],
343 k: 3,
344 filter: None,
345 ef_search: None,
346 })?;
347 assert_eq!(results.len(), 3, "Should find all 3 vectors before restart");
348 }
349 {
353 let mut options = DbOptions::default();
354 options.storage_path = db_path.clone();
355 options.dimensions = 3;
356 options.distance_metric = DistanceMetric::Euclidean;
357 options.hnsw_config = None;
358
359 let db = VectorDB::new(options)?;
360
361 assert_eq!(db.len()?, 3, "Should have 3 vectors after restart");
363
364 let v1 = db.get("v1")?;
366 assert!(v1.is_some(), "get() should work after restart");
367
368 let results = db.search(SearchQuery {
370 vector: vec![0.8, 0.6, 0.0],
371 k: 3,
372 filter: None,
373 ef_search: None,
374 })?;
375
376 assert_eq!(
377 results.len(),
378 3,
379 "search() should return results after restart (was returning 0 before fix)"
380 );
381
382 assert_eq!(
384 results[0].id, "v3",
385 "v3 [0.7, 0.7, 0.0] should be closest to query [0.8, 0.6, 0.0]"
386 );
387 }
388
389 Ok(())
390 }
391}