1use crate::error::Result;
4use crate::index::flat::FlatIndex;
5
6#[cfg(feature = "hnsw")]
7use crate::index::hnsw::HnswIndex;
8
9use crate::index::VectorIndex;
10use crate::types::*;
11use parking_lot::RwLock;
12use std::sync::Arc;
13
14#[cfg(feature = "storage")]
16use crate::storage::VectorStorage;
17
18#[cfg(not(feature = "storage"))]
19use crate::storage_memory::MemoryStorage as VectorStorage;
20
21pub struct VectorDB {
23 storage: Arc<VectorStorage>,
24 index: Arc<RwLock<Box<dyn VectorIndex>>>,
25 options: DbOptions,
26}
27
28impl VectorDB {
29 pub fn new(mut options: DbOptions) -> Result<Self> {
36 #[cfg(feature = "storage")]
37 let storage = {
38 let temp_storage = VectorStorage::new(
41 &options.storage_path,
42 options.dimensions,
43 )?;
44
45 let stored_config = temp_storage.load_config()?;
46
47 if let Some(config) = stored_config {
48 tracing::info!(
50 "Loading existing database with {} dimensions",
51 config.dimensions
52 );
53 options = DbOptions {
54 storage_path: options.storage_path.clone(),
56 dimensions: config.dimensions,
58 distance_metric: config.distance_metric,
59 hnsw_config: config.hnsw_config,
60 quantization: config.quantization,
61 };
62 Arc::new(VectorStorage::new(
64 &options.storage_path,
65 options.dimensions,
66 )?)
67 } else {
68 tracing::info!(
70 "Creating new database with {} dimensions",
71 options.dimensions
72 );
73 temp_storage.save_config(&options)?;
74 Arc::new(temp_storage)
75 }
76 };
77
78 #[cfg(not(feature = "storage"))]
79 let storage = Arc::new(VectorStorage::new(options.dimensions)?);
80
81 let mut index: Box<dyn VectorIndex> = if let Some(hnsw_config) = &options.hnsw_config {
83 #[cfg(feature = "hnsw")]
84 {
85 Box::new(HnswIndex::new(
86 options.dimensions,
87 options.distance_metric,
88 hnsw_config.clone(),
89 )?)
90 }
91 #[cfg(not(feature = "hnsw"))]
92 {
93 tracing::warn!("HNSW requested but not available (WASM build), using flat index");
95 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
96 }
97 } else {
98 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
99 };
100
101 #[cfg(feature = "storage")]
104 {
105 let stored_ids = storage.all_ids()?;
106 if !stored_ids.is_empty() {
107 tracing::info!(
108 "Rebuilding index from {} persisted vectors",
109 stored_ids.len()
110 );
111
112 let mut entries = Vec::with_capacity(stored_ids.len());
114 for id in stored_ids {
115 if let Some(entry) = storage.get(&id)? {
116 entries.push((id, entry.vector));
117 }
118 }
119
120 index.add_batch(entries)?;
122
123 tracing::info!("Index rebuilt successfully");
124 }
125 }
126
127 Ok(Self {
128 storage,
129 index: Arc::new(RwLock::new(index)),
130 options,
131 })
132 }
133
134 pub fn with_dimensions(dimensions: usize) -> Result<Self> {
136 let mut options = DbOptions::default();
137 options.dimensions = dimensions;
138 Self::new(options)
139 }
140
141 pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
143 let id = self.storage.insert(&entry)?;
144
145 let mut index = self.index.write();
147 index.add(id.clone(), entry.vector)?;
148
149 Ok(id)
150 }
151
152 pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
154 let ids = self.storage.insert_batch(&entries)?;
155
156 let mut index = self.index.write();
158 let index_entries: Vec<_> = ids
159 .iter()
160 .zip(entries.iter())
161 .map(|(id, entry)| (id.clone(), entry.vector.clone()))
162 .collect();
163
164 index.add_batch(index_entries)?;
165
166 Ok(ids)
167 }
168
169 pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
171 let index = self.index.read();
172 let mut results = index.search(&query.vector, query.k)?;
173
174 for result in &mut results {
176 if let Ok(Some(entry)) = self.storage.get(&result.id) {
177 result.vector = Some(entry.vector);
178 result.metadata = entry.metadata;
179 }
180 }
181
182 if let Some(filter) = &query.filter {
184 results.retain(|r| {
185 if let Some(metadata) = &r.metadata {
186 filter
187 .iter()
188 .all(|(key, value)| metadata.get(key).map_or(false, |v| v == value))
189 } else {
190 false
191 }
192 });
193 }
194
195 Ok(results)
196 }
197
198 pub fn delete(&self, id: &str) -> Result<bool> {
200 let deleted_storage = self.storage.delete(id)?;
201
202 if deleted_storage {
203 let mut index = self.index.write();
204 let _ = index.remove(&id.to_string())?;
205 }
206
207 Ok(deleted_storage)
208 }
209
210 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
212 self.storage.get(id)
213 }
214
215 pub fn len(&self) -> Result<usize> {
217 self.storage.len()
218 }
219
220 pub fn is_empty(&self) -> Result<bool> {
222 self.storage.is_empty()
223 }
224
225 pub fn options(&self) -> &DbOptions {
227 &self.options
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use std::path::Path;
235 use tempfile::tempdir;
236
237 #[test]
238 fn test_vector_db_creation() -> Result<()> {
239 let dir = tempdir().unwrap();
240 let mut options = DbOptions::default();
241 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
242 options.dimensions = 3;
243
244 let db = VectorDB::new(options)?;
245 assert!(db.is_empty()?);
246
247 Ok(())
248 }
249
250 #[test]
251 fn test_insert_and_search() -> Result<()> {
252 let dir = tempdir().unwrap();
253 let mut options = DbOptions::default();
254 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
255 options.dimensions = 3;
256 options.distance_metric = DistanceMetric::Euclidean; options.hnsw_config = None; let db = VectorDB::new(options)?;
260
261 db.insert(VectorEntry {
263 id: Some("v1".to_string()),
264 vector: vec![1.0, 0.0, 0.0],
265 metadata: None,
266 })?;
267
268 db.insert(VectorEntry {
269 id: Some("v2".to_string()),
270 vector: vec![0.0, 1.0, 0.0],
271 metadata: None,
272 })?;
273
274 db.insert(VectorEntry {
275 id: Some("v3".to_string()),
276 vector: vec![0.0, 0.0, 1.0],
277 metadata: None,
278 })?;
279
280 let results = db.search(SearchQuery {
282 vector: vec![1.0, 0.0, 0.0],
283 k: 2,
284 filter: None,
285 ef_search: None,
286 })?;
287
288 assert!(results.len() >= 1);
289 assert_eq!(results[0].id, "v1", "First result should be exact match");
290 assert!(
291 results[0].score < 0.01,
292 "Exact match should have ~0 distance"
293 );
294
295 Ok(())
296 }
297
298 #[test]
301 #[cfg(feature = "storage")]
302 fn test_search_after_restart() -> Result<()> {
303 let dir = tempdir().unwrap();
304 let db_path = dir.path().join("persist.db").to_string_lossy().to_string();
305
306 {
308 let mut options = DbOptions::default();
309 options.storage_path = db_path.clone();
310 options.dimensions = 3;
311 options.distance_metric = DistanceMetric::Euclidean;
312 options.hnsw_config = None;
313
314 let db = VectorDB::new(options)?;
315
316 db.insert(VectorEntry {
317 id: Some("v1".to_string()),
318 vector: vec![1.0, 0.0, 0.0],
319 metadata: None,
320 })?;
321
322 db.insert(VectorEntry {
323 id: Some("v2".to_string()),
324 vector: vec![0.0, 1.0, 0.0],
325 metadata: None,
326 })?;
327
328 db.insert(VectorEntry {
329 id: Some("v3".to_string()),
330 vector: vec![0.7, 0.7, 0.0],
331 metadata: None,
332 })?;
333
334 let results = db.search(SearchQuery {
336 vector: vec![0.8, 0.6, 0.0],
337 k: 3,
338 filter: None,
339 ef_search: None,
340 })?;
341 assert_eq!(results.len(), 3, "Should find all 3 vectors before restart");
342 }
343 {
347 let mut options = DbOptions::default();
348 options.storage_path = db_path.clone();
349 options.dimensions = 3;
350 options.distance_metric = DistanceMetric::Euclidean;
351 options.hnsw_config = None;
352
353 let db = VectorDB::new(options)?;
354
355 assert_eq!(db.len()?, 3, "Should have 3 vectors after restart");
357
358 let v1 = db.get("v1")?;
360 assert!(v1.is_some(), "get() should work after restart");
361
362 let results = db.search(SearchQuery {
364 vector: vec![0.8, 0.6, 0.0],
365 k: 3,
366 filter: None,
367 ef_search: None,
368 })?;
369
370 assert_eq!(
371 results.len(),
372 3,
373 "search() should return results after restart (was returning 0 before fix)"
374 );
375
376 assert_eq!(
378 results[0].id, "v3",
379 "v3 [0.7, 0.7, 0.0] should be closest to query [0.8, 0.6, 0.0]"
380 );
381 }
382
383 Ok(())
384 }
385}