1use crate::error::Result;
4use crate::index::flat::FlatIndex;
5
6#[cfg(feature = "hnsw")]
7use crate::index::hnsw::HnswIndex;
8
9use crate::index::VectorIndex;
10use crate::types::*;
11use parking_lot::RwLock;
12use std::sync::Arc;
13
14#[cfg(feature = "storage")]
16use crate::storage::VectorStorage;
17
18#[cfg(not(feature = "storage"))]
19use crate::storage_memory::MemoryStorage as VectorStorage;
20
21pub struct VectorDB {
23 storage: Arc<VectorStorage>,
24 index: Arc<RwLock<Box<dyn VectorIndex>>>,
25 options: DbOptions,
26}
27
28impl VectorDB {
29 pub fn new(mut options: DbOptions) -> Result<Self> {
36 #[cfg(feature = "storage")]
37 let storage = {
38 let temp_storage = VectorStorage::new(
41 &options.storage_path,
42 options.dimensions,
43 )?;
44
45 let stored_config = temp_storage.load_config()?;
46
47 if let Some(config) = stored_config {
48 tracing::info!(
50 "Loading existing database with {} dimensions",
51 config.dimensions
52 );
53 options = DbOptions {
54 storage_path: options.storage_path.clone(),
56 dimensions: config.dimensions,
58 distance_metric: config.distance_metric,
59 hnsw_config: config.hnsw_config,
60 quantization: config.quantization,
61 };
62 Arc::new(VectorStorage::new(
64 &options.storage_path,
65 options.dimensions,
66 )?)
67 } else {
68 tracing::info!(
70 "Creating new database with {} dimensions",
71 options.dimensions
72 );
73 temp_storage.save_config(&options)?;
74 Arc::new(temp_storage)
75 }
76 };
77
78 #[cfg(not(feature = "storage"))]
79 let storage = Arc::new(VectorStorage::new(options.dimensions)?);
80
81 let mut index: Box<dyn VectorIndex> = if let Some(hnsw_config) = &options.hnsw_config {
83 #[cfg(feature = "hnsw")]
84 {
85 Box::new(HnswIndex::new(
86 options.dimensions,
87 options.distance_metric,
88 hnsw_config.clone(),
89 )?)
90 }
91 #[cfg(not(feature = "hnsw"))]
92 {
93 tracing::warn!("HNSW requested but not available (WASM build), using flat index");
95 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
96 }
97 } else {
98 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
99 };
100
101 #[cfg(feature = "storage")]
104 {
105 let stored_ids = storage.all_ids()?;
106 if !stored_ids.is_empty() {
107 tracing::info!(
108 "Rebuilding index from {} persisted vectors",
109 stored_ids.len()
110 );
111
112 let mut entries = Vec::with_capacity(stored_ids.len());
114 for id in stored_ids {
115 if let Some(entry) = storage.get(&id)? {
116 entries.push((id, entry.vector));
117 }
118 }
119
120 index.add_batch(entries)?;
122
123 tracing::info!("Index rebuilt successfully");
124 }
125 }
126
127 Ok(Self {
128 storage,
129 index: Arc::new(RwLock::new(index)),
130 options,
131 })
132 }
133
134 pub fn with_dimensions(dimensions: usize) -> Result<Self> {
136 let mut options = DbOptions::default();
137 options.dimensions = dimensions;
138 Self::new(options)
139 }
140
141 pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
143 let id = self.storage.insert(&entry)?;
144
145 let mut index = self.index.write();
147 index.add(id.clone(), entry.vector)?;
148
149 Ok(id)
150 }
151
152 pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
154 let ids = self.storage.insert_batch(&entries)?;
155
156 let mut index = self.index.write();
158 let index_entries: Vec<_> = ids
159 .iter()
160 .zip(entries.iter())
161 .map(|(id, entry)| (id.clone(), entry.vector.clone()))
162 .collect();
163
164 index.add_batch(index_entries)?;
165
166 Ok(ids)
167 }
168
169 pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
171 let index = self.index.read();
172 let mut results = index.search(&query.vector, query.k)?;
173
174 for result in &mut results {
176 if let Ok(Some(entry)) = self.storage.get(&result.id) {
177 result.vector = Some(entry.vector);
178 result.metadata = entry.metadata;
179 }
180 }
181
182 if let Some(filter) = &query.filter {
184 results.retain(|r| {
185 if let Some(metadata) = &r.metadata {
186 filter
187 .iter()
188 .all(|(key, value)| metadata.get(key).map_or(false, |v| v == value))
189 } else {
190 false
191 }
192 });
193 }
194
195 Ok(results)
196 }
197
198 pub fn delete(&self, id: &str) -> Result<bool> {
200 let deleted_storage = self.storage.delete(id)?;
201
202 if deleted_storage {
203 let mut index = self.index.write();
204 let _ = index.remove(&id.to_string())?;
205 }
206
207 Ok(deleted_storage)
208 }
209
210 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
212 self.storage.get(id)
213 }
214
215 pub fn len(&self) -> Result<usize> {
217 self.storage.len()
218 }
219
220 pub fn is_empty(&self) -> Result<bool> {
222 self.storage.is_empty()
223 }
224
225 pub fn options(&self) -> &DbOptions {
227 &self.options
228 }
229
230 pub fn keys(&self) -> Result<Vec<String>> {
232 self.storage.all_ids()
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use std::path::Path;
240 use tempfile::tempdir;
241
242 #[test]
243 fn test_vector_db_creation() -> Result<()> {
244 let dir = tempdir().unwrap();
245 let mut options = DbOptions::default();
246 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
247 options.dimensions = 3;
248
249 let db = VectorDB::new(options)?;
250 assert!(db.is_empty()?);
251
252 Ok(())
253 }
254
255 #[test]
256 fn test_insert_and_search() -> Result<()> {
257 let dir = tempdir().unwrap();
258 let mut options = DbOptions::default();
259 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
260 options.dimensions = 3;
261 options.distance_metric = DistanceMetric::Euclidean; options.hnsw_config = None; let db = VectorDB::new(options)?;
265
266 db.insert(VectorEntry {
268 id: Some("v1".to_string()),
269 vector: vec![1.0, 0.0, 0.0],
270 metadata: None,
271 })?;
272
273 db.insert(VectorEntry {
274 id: Some("v2".to_string()),
275 vector: vec![0.0, 1.0, 0.0],
276 metadata: None,
277 })?;
278
279 db.insert(VectorEntry {
280 id: Some("v3".to_string()),
281 vector: vec![0.0, 0.0, 1.0],
282 metadata: None,
283 })?;
284
285 let results = db.search(SearchQuery {
287 vector: vec![1.0, 0.0, 0.0],
288 k: 2,
289 filter: None,
290 ef_search: None,
291 })?;
292
293 assert!(results.len() >= 1);
294 assert_eq!(results[0].id, "v1", "First result should be exact match");
295 assert!(
296 results[0].score < 0.01,
297 "Exact match should have ~0 distance"
298 );
299
300 Ok(())
301 }
302
303 #[test]
306 #[cfg(feature = "storage")]
307 fn test_search_after_restart() -> Result<()> {
308 let dir = tempdir().unwrap();
309 let db_path = dir.path().join("persist.db").to_string_lossy().to_string();
310
311 {
313 let mut options = DbOptions::default();
314 options.storage_path = db_path.clone();
315 options.dimensions = 3;
316 options.distance_metric = DistanceMetric::Euclidean;
317 options.hnsw_config = None;
318
319 let db = VectorDB::new(options)?;
320
321 db.insert(VectorEntry {
322 id: Some("v1".to_string()),
323 vector: vec![1.0, 0.0, 0.0],
324 metadata: None,
325 })?;
326
327 db.insert(VectorEntry {
328 id: Some("v2".to_string()),
329 vector: vec![0.0, 1.0, 0.0],
330 metadata: None,
331 })?;
332
333 db.insert(VectorEntry {
334 id: Some("v3".to_string()),
335 vector: vec![0.7, 0.7, 0.0],
336 metadata: None,
337 })?;
338
339 let results = db.search(SearchQuery {
341 vector: vec![0.8, 0.6, 0.0],
342 k: 3,
343 filter: None,
344 ef_search: None,
345 })?;
346 assert_eq!(results.len(), 3, "Should find all 3 vectors before restart");
347 }
348 {
352 let mut options = DbOptions::default();
353 options.storage_path = db_path.clone();
354 options.dimensions = 3;
355 options.distance_metric = DistanceMetric::Euclidean;
356 options.hnsw_config = None;
357
358 let db = VectorDB::new(options)?;
359
360 assert_eq!(db.len()?, 3, "Should have 3 vectors after restart");
362
363 let v1 = db.get("v1")?;
365 assert!(v1.is_some(), "get() should work after restart");
366
367 let results = db.search(SearchQuery {
369 vector: vec![0.8, 0.6, 0.0],
370 k: 3,
371 filter: None,
372 ef_search: None,
373 })?;
374
375 assert_eq!(
376 results.len(),
377 3,
378 "search() should return results after restart (was returning 0 before fix)"
379 );
380
381 assert_eq!(
383 results[0].id, "v3",
384 "v3 [0.7, 0.7, 0.0] should be closest to query [0.8, 0.6, 0.0]"
385 );
386 }
387
388 Ok(())
389 }
390}