1use crate::error::Result;
4use crate::index::flat::FlatIndex;
5
6#[cfg(feature = "hnsw")]
7use crate::index::hnsw::HnswIndex;
8
9use crate::index::VectorIndex;
10use crate::types::*;
11use parking_lot::RwLock;
12use std::sync::Arc;
13
14#[cfg(feature = "storage")]
16use crate::storage::VectorStorage;
17
18#[cfg(not(feature = "storage"))]
19use crate::storage_memory::MemoryStorage as VectorStorage;
20
21pub struct VectorDB {
23 storage: Arc<VectorStorage>,
24 index: Arc<RwLock<Box<dyn VectorIndex>>>,
25 options: DbOptions,
26}
27
28impl VectorDB {
29 #[allow(unused_mut)] pub fn new(mut options: DbOptions) -> Result<Self> {
37 #[cfg(feature = "storage")]
38 let storage = {
39 let temp_storage = VectorStorage::new(&options.storage_path, options.dimensions)?;
42
43 let stored_config = temp_storage.load_config()?;
44
45 if let Some(config) = stored_config {
46 tracing::info!(
48 "Loading existing database with {} dimensions",
49 config.dimensions
50 );
51 options = DbOptions {
52 storage_path: options.storage_path.clone(),
54 dimensions: config.dimensions,
56 distance_metric: config.distance_metric,
57 hnsw_config: config.hnsw_config,
58 quantization: config.quantization,
59 };
60 Arc::new(VectorStorage::new(
62 &options.storage_path,
63 options.dimensions,
64 )?)
65 } else {
66 tracing::info!(
68 "Creating new database with {} dimensions",
69 options.dimensions
70 );
71 temp_storage.save_config(&options)?;
72 Arc::new(temp_storage)
73 }
74 };
75
76 #[cfg(not(feature = "storage"))]
77 let storage = Arc::new(VectorStorage::new(options.dimensions)?);
78
79 #[allow(unused_mut)] let mut index: Box<dyn VectorIndex> = if let Some(hnsw_config) = &options.hnsw_config {
82 #[cfg(feature = "hnsw")]
83 {
84 Box::new(HnswIndex::new(
85 options.dimensions,
86 options.distance_metric,
87 hnsw_config.clone(),
88 )?)
89 }
90 #[cfg(not(feature = "hnsw"))]
91 {
92 tracing::warn!("HNSW requested but not available (WASM build), using flat index");
94 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
95 }
96 } else {
97 Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
98 };
99
100 #[cfg(feature = "storage")]
103 {
104 let stored_ids = storage.all_ids()?;
105 if !stored_ids.is_empty() {
106 tracing::info!(
107 "Rebuilding index from {} persisted vectors",
108 stored_ids.len()
109 );
110
111 let mut entries = Vec::with_capacity(stored_ids.len());
113 for id in stored_ids {
114 if let Some(entry) = storage.get(&id)? {
115 entries.push((id, entry.vector));
116 }
117 }
118
119 index.add_batch(entries)?;
121
122 tracing::info!("Index rebuilt successfully");
123 }
124 }
125
126 Ok(Self {
127 storage,
128 index: Arc::new(RwLock::new(index)),
129 options,
130 })
131 }
132
133 pub fn with_dimensions(dimensions: usize) -> Result<Self> {
135 let options = DbOptions {
136 dimensions,
137 ..DbOptions::default()
138 };
139 Self::new(options)
140 }
141
142 pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
144 let id = self.storage.insert(&entry)?;
145
146 let mut index = self.index.write();
148 index.add(id.clone(), entry.vector)?;
149
150 Ok(id)
151 }
152
153 pub fn insert_batch(&self, entries: impl AsRef<[VectorEntry]>) -> Result<Vec<VectorId>> {
155 let entries = entries.as_ref();
156 let ids = self.storage.insert_batch(entries)?;
157
158 let mut index = self.index.write();
160 let index_entries: Vec<_> = ids
161 .iter()
162 .zip(entries.iter())
163 .map(|(id, entry)| (id.clone(), entry.vector.clone()))
164 .collect();
165
166 index.add_batch(index_entries)?;
167
168 Ok(ids)
169 }
170
171 pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
173 let index = self.index.read();
174 let mut results = index.search(&query.vector, query.k)?;
175
176 for result in &mut results {
178 if let Ok(Some(entry)) = self.storage.get(&result.id) {
179 result.vector = Some(entry.vector);
180 result.metadata = entry.metadata;
181 }
182 }
183
184 if let Some(filter) = &query.filter {
186 results.retain(|r| {
187 if let Some(metadata) = &r.metadata {
188 filter
189 .iter()
190 .all(|(key, value)| metadata.get(key).is_some_and(|v| v == value))
191 } else {
192 false
193 }
194 });
195 }
196
197 Ok(results)
198 }
199
200 pub fn delete(&self, id: &str) -> Result<bool> {
202 let deleted_storage = self.storage.delete(id)?;
203
204 if deleted_storage {
205 let mut index = self.index.write();
206 let _ = index.remove(&id.to_string())?;
207 }
208
209 Ok(deleted_storage)
210 }
211
212 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
214 self.storage.get(id)
215 }
216
217 pub fn len(&self) -> Result<usize> {
219 self.storage.len()
220 }
221
222 pub fn is_empty(&self) -> Result<bool> {
224 self.storage.is_empty()
225 }
226
227 pub fn options(&self) -> &DbOptions {
229 &self.options
230 }
231
232 pub fn keys(&self) -> Result<Vec<String>> {
234 self.storage.all_ids()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use std::path::Path;
242 use tempfile::tempdir;
243
244 #[test]
245 fn test_vector_db_creation() -> Result<()> {
246 let dir = tempdir().unwrap();
247 let mut options = DbOptions::default();
248 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
249 options.dimensions = 3;
250
251 let db = VectorDB::new(options)?;
252 assert!(db.is_empty()?);
253
254 Ok(())
255 }
256
257 #[test]
258 fn test_insert_and_search() -> Result<()> {
259 let dir = tempdir().unwrap();
260 let mut options = DbOptions::default();
261 options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
262 options.dimensions = 3;
263 options.distance_metric = DistanceMetric::Euclidean; options.hnsw_config = None; let db = VectorDB::new(options)?;
267
268 db.insert(VectorEntry {
270 id: Some("v1".to_string()),
271 vector: vec![1.0, 0.0, 0.0],
272 metadata: None,
273 })?;
274
275 db.insert(VectorEntry {
276 id: Some("v2".to_string()),
277 vector: vec![0.0, 1.0, 0.0],
278 metadata: None,
279 })?;
280
281 db.insert(VectorEntry {
282 id: Some("v3".to_string()),
283 vector: vec![0.0, 0.0, 1.0],
284 metadata: None,
285 })?;
286
287 let results = db.search(SearchQuery {
289 vector: vec![1.0, 0.0, 0.0],
290 k: 2,
291 filter: None,
292 ef_search: None,
293 })?;
294
295 assert!(results.len() >= 1);
296 assert_eq!(results[0].id, "v1", "First result should be exact match");
297 assert!(
298 results[0].score < 0.01,
299 "Exact match should have ~0 distance"
300 );
301
302 Ok(())
303 }
304
305 #[test]
308 #[cfg(feature = "storage")]
309 fn test_search_after_restart() -> Result<()> {
310 let dir = tempdir().unwrap();
311 let db_path = dir.path().join("persist.db").to_string_lossy().to_string();
312
313 {
315 let mut options = DbOptions::default();
316 options.storage_path = db_path.clone();
317 options.dimensions = 3;
318 options.distance_metric = DistanceMetric::Euclidean;
319 options.hnsw_config = None;
320
321 let db = VectorDB::new(options)?;
322
323 db.insert(VectorEntry {
324 id: Some("v1".to_string()),
325 vector: vec![1.0, 0.0, 0.0],
326 metadata: None,
327 })?;
328
329 db.insert(VectorEntry {
330 id: Some("v2".to_string()),
331 vector: vec![0.0, 1.0, 0.0],
332 metadata: None,
333 })?;
334
335 db.insert(VectorEntry {
336 id: Some("v3".to_string()),
337 vector: vec![0.7, 0.7, 0.0],
338 metadata: None,
339 })?;
340
341 let results = db.search(SearchQuery {
343 vector: vec![0.8, 0.6, 0.0],
344 k: 3,
345 filter: None,
346 ef_search: None,
347 })?;
348 assert_eq!(results.len(), 3, "Should find all 3 vectors before restart");
349 }
350 {
354 let mut options = DbOptions::default();
355 options.storage_path = db_path.clone();
356 options.dimensions = 3;
357 options.distance_metric = DistanceMetric::Euclidean;
358 options.hnsw_config = None;
359
360 let db = VectorDB::new(options)?;
361
362 assert_eq!(db.len()?, 3, "Should have 3 vectors after restart");
364
365 let v1 = db.get("v1")?;
367 assert!(v1.is_some(), "get() should work after restart");
368
369 let results = db.search(SearchQuery {
371 vector: vec![0.8, 0.6, 0.0],
372 k: 3,
373 filter: None,
374 ef_search: None,
375 })?;
376
377 assert_eq!(
378 results.len(),
379 3,
380 "search() should return results after restart (was returning 0 before fix)"
381 );
382
383 assert_eq!(
385 results[0].id, "v3",
386 "v3 [0.7, 0.7, 0.0] should be closest to query [0.8, 0.6, 0.0]"
387 );
388 }
389
390 Ok(())
391 }
392}