1use crate::error::Result;
4use crate::index::flat::FlatIndex;
5
6#[cfg(feature = "hnsw")]
7use crate::index::hnsw::HnswIndex;
8
9use crate::index::VectorIndex;
10use crate::types::*;
11use parking_lot::RwLock;
12use std::sync::Arc;
13
14#[cfg(feature = "storage")]
16use crate::storage::VectorStorage;
17
18#[cfg(not(feature = "storage"))]
19use crate::storage_memory::MemoryStorage as VectorStorage;
20
21pub struct VectorDB {
23 storage: Arc<VectorStorage>,
24 index: Arc<RwLock<Box<dyn VectorIndex>>>,
25 options: DbOptions,
26}
27
28impl VectorDB {
29 pub fn new(mut options: DbOptions) -> Result<Self> {
36 #[cfg(feature = "storage")]
37 let storage = {
38 let temp_storage = VectorStorage::new(&options.storage_path, options.dimensions)?;
41
42 let stored_config = temp_storage.load_config()?;
43
44 if let Some(config) = stored_config {
45 tracing::info!(
47 "Loading existing database with {} dimensions",
48 config.dimensions
49 );
50 options = DbOptions {
51 storage_path: options.storage_path.clone(),
53 dimensions: config.dimensions,
55 distance_metric: config.distance_metric,
56 hnsw_config: config.hnsw_config,
57 quantization: config.quantization,
58 };
59 Arc::new(VectorStorage::new(
61 &options.storage_path,
62 options.dimensions,
63 )?)
64 } else {
65 tracing::info!(
67 "Creating new database with {} dimensions",
68 options.dimensions
69 );
70 temp_storage.save_config(&options)?;
71 Arc::new(temp_storage)
72 }
73 };
74
75 #[cfg(not(feature = "storage"))]
76 let storage = Arc::new(VectorStorage::new(options.dimensions)?);
77
78 let mut index: Box<dyn VectorIndex> = if let Some(hnsw_config) = &options.hnsw_config {
80 #[cfg(feature = "hnsw")]
81 {
82 Box::new(HnswIndex::new(
83 options.dimensions,
84 options.distance_metric,
85 hnsw_config.clone(),
86 )?)
87 }
88 #[cfg(not(feature = "hnsw"))]
89 {
90 tracing::warn!("HNSW requested but not available (WASM build), using flat index");
92 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
93 }
94 } else {
95 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
96 };
97
98 #[cfg(feature = "storage")]
101 {
102 let stored_ids = storage.all_ids()?;
103 if !stored_ids.is_empty() {
104 tracing::info!(
105 "Rebuilding index from {} persisted vectors",
106 stored_ids.len()
107 );
108
109 let mut entries = Vec::with_capacity(stored_ids.len());
111 for id in stored_ids {
112 if let Some(entry) = storage.get(&id)? {
113 entries.push((id, entry.vector));
114 }
115 }
116
117 index.add_batch(entries)?;
119
120 tracing::info!("Index rebuilt successfully");
121 }
122 }
123
124 Ok(Self {
125 storage,
126 index: Arc::new(RwLock::new(index)),
127 options,
128 })
129 }
130
131 pub fn with_dimensions(dimensions: usize) -> Result<Self> {
133 let mut options = DbOptions::default();
134 options.dimensions = dimensions;
135 Self::new(options)
136 }
137
138 pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
140 let id = self.storage.insert(&entry)?;
141
142 let mut index = self.index.write();
144 index.add(id.clone(), entry.vector)?;
145
146 Ok(id)
147 }
148
149 pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
151 let ids = self.storage.insert_batch(&entries)?;
152
153 let mut index = self.index.write();
155 let index_entries: Vec<_> = ids
156 .iter()
157 .zip(entries.iter())
158 .map(|(id, entry)| (id.clone(), entry.vector.clone()))
159 .collect();
160
161 index.add_batch(index_entries)?;
162
163 Ok(ids)
164 }
165
166 pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
168 let index = self.index.read();
169 let mut results = index.search(&query.vector, query.k)?;
170
171 for result in &mut results {
173 if let Ok(Some(entry)) = self.storage.get(&result.id) {
174 result.vector = Some(entry.vector);
175 result.metadata = entry.metadata;
176 }
177 }
178
179 if let Some(filter) = &query.filter {
181 results.retain(|r| {
182 if let Some(metadata) = &r.metadata {
183 filter
184 .iter()
185 .all(|(key, value)| metadata.get(key).map_or(false, |v| v == value))
186 } else {
187 false
188 }
189 });
190 }
191
192 Ok(results)
193 }
194
195 pub fn delete(&self, id: &str) -> Result<bool> {
197 let deleted_storage = self.storage.delete(id)?;
198
199 if deleted_storage {
200 let mut index = self.index.write();
201 let _ = index.remove(&id.to_string())?;
202 }
203
204 Ok(deleted_storage)
205 }
206
207 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
209 self.storage.get(id)
210 }
211
212 pub fn len(&self) -> Result<usize> {
214 self.storage.len()
215 }
216
217 pub fn is_empty(&self) -> Result<bool> {
219 self.storage.is_empty()
220 }
221
222 pub fn options(&self) -> &DbOptions {
224 &self.options
225 }
226
227 pub fn keys(&self) -> Result<Vec<String>> {
229 self.storage.all_ids()
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use std::path::Path;
237 use tempfile::tempdir;
238
239 #[test]
240 fn test_vector_db_creation() -> Result<()> {
241 let dir = tempdir().unwrap();
242 let mut options = DbOptions::default();
243 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
244 options.dimensions = 3;
245
246 let db = VectorDB::new(options)?;
247 assert!(db.is_empty()?);
248
249 Ok(())
250 }
251
252 #[test]
253 fn test_insert_and_search() -> Result<()> {
254 let dir = tempdir().unwrap();
255 let mut options = DbOptions::default();
256 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
257 options.dimensions = 3;
258 options.distance_metric = DistanceMetric::Euclidean; options.hnsw_config = None; let db = VectorDB::new(options)?;
262
263 db.insert(VectorEntry {
265 id: Some("v1".to_string()),
266 vector: vec![1.0, 0.0, 0.0],
267 metadata: None,
268 })?;
269
270 db.insert(VectorEntry {
271 id: Some("v2".to_string()),
272 vector: vec![0.0, 1.0, 0.0],
273 metadata: None,
274 })?;
275
276 db.insert(VectorEntry {
277 id: Some("v3".to_string()),
278 vector: vec![0.0, 0.0, 1.0],
279 metadata: None,
280 })?;
281
282 let results = db.search(SearchQuery {
284 vector: vec![1.0, 0.0, 0.0],
285 k: 2,
286 filter: None,
287 ef_search: None,
288 })?;
289
290 assert!(results.len() >= 1);
291 assert_eq!(results[0].id, "v1", "First result should be exact match");
292 assert!(
293 results[0].score < 0.01,
294 "Exact match should have ~0 distance"
295 );
296
297 Ok(())
298 }
299
300 #[test]
303 #[cfg(feature = "storage")]
304 fn test_search_after_restart() -> Result<()> {
305 let dir = tempdir().unwrap();
306 let db_path = dir.path().join("persist.db").to_string_lossy().to_string();
307
308 {
310 let mut options = DbOptions::default();
311 options.storage_path = db_path.clone();
312 options.dimensions = 3;
313 options.distance_metric = DistanceMetric::Euclidean;
314 options.hnsw_config = None;
315
316 let db = VectorDB::new(options)?;
317
318 db.insert(VectorEntry {
319 id: Some("v1".to_string()),
320 vector: vec![1.0, 0.0, 0.0],
321 metadata: None,
322 })?;
323
324 db.insert(VectorEntry {
325 id: Some("v2".to_string()),
326 vector: vec![0.0, 1.0, 0.0],
327 metadata: None,
328 })?;
329
330 db.insert(VectorEntry {
331 id: Some("v3".to_string()),
332 vector: vec![0.7, 0.7, 0.0],
333 metadata: None,
334 })?;
335
336 let results = db.search(SearchQuery {
338 vector: vec![0.8, 0.6, 0.0],
339 k: 3,
340 filter: None,
341 ef_search: None,
342 })?;
343 assert_eq!(results.len(), 3, "Should find all 3 vectors before restart");
344 }
345 {
349 let mut options = DbOptions::default();
350 options.storage_path = db_path.clone();
351 options.dimensions = 3;
352 options.distance_metric = DistanceMetric::Euclidean;
353 options.hnsw_config = None;
354
355 let db = VectorDB::new(options)?;
356
357 assert_eq!(db.len()?, 3, "Should have 3 vectors after restart");
359
360 let v1 = db.get("v1")?;
362 assert!(v1.is_some(), "get() should work after restart");
363
364 let results = db.search(SearchQuery {
366 vector: vec![0.8, 0.6, 0.0],
367 k: 3,
368 filter: None,
369 ef_search: None,
370 })?;
371
372 assert_eq!(
373 results.len(),
374 3,
375 "search() should return results after restart (was returning 0 before fix)"
376 );
377
378 assert_eq!(
380 results[0].id, "v3",
381 "v3 [0.7, 0.7, 0.0] should be closest to query [0.8, 0.6, 0.0]"
382 );
383 }
384
385 Ok(())
386 }
387}