Skip to main content

ruvector_core/
storage.rs

1//! Storage layer with redb for metadata and memory-mapped vectors
2//!
3//! This module is only available when the "storage" feature is enabled.
4//! For WASM builds, use the in-memory storage backend instead.
5
6#[cfg(feature = "storage")]
7use crate::error::{Result, RuvectorError};
8#[cfg(feature = "storage")]
9use crate::types::{DbOptions, VectorEntry, VectorId};
10#[cfg(feature = "storage")]
11use bincode::config;
12#[cfg(feature = "storage")]
13use once_cell::sync::Lazy;
14#[cfg(feature = "storage")]
15use parking_lot::Mutex;
16#[cfg(feature = "storage")]
17use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
18#[cfg(feature = "storage")]
19use serde_json;
20#[cfg(feature = "storage")]
21use std::collections::HashMap;
22#[cfg(feature = "storage")]
23use std::path::{Path, PathBuf};
24#[cfg(feature = "storage")]
25use std::sync::Arc;
26
27#[cfg(feature = "storage")]
28const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
29const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
30const CONFIG_TABLE: TableDefinition<&str, &str> = TableDefinition::new("config");
31
32/// Key used to store database configuration in CONFIG_TABLE
33const DB_CONFIG_KEY: &str = "__ruvector_db_config__";
34
35// Global database connection pool to allow multiple VectorDB instances
36// to share the same underlying database file
37static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
38    Lazy::new(|| Mutex::new(HashMap::new()));
39
40/// Storage backend for vector database
41pub struct VectorStorage {
42    db: Arc<Database>,
43    dimensions: usize,
44}
45
46impl VectorStorage {
47    /// Create or open a vector storage at the given path
48    ///
49    /// This method uses a global connection pool to allow multiple VectorDB
50    /// instances to share the same underlying database file, fixing the
51    /// "Database already open. Cannot acquire lock" error.
52    pub fn new<P: AsRef<Path>>(path: P, dimensions: usize) -> Result<Self> {
53        // SECURITY: Validate path to prevent directory traversal attacks
54        let path_ref = path.as_ref();
55
56        // Create parent directories if they don't exist (needed for canonicalize)
57        if let Some(parent) = path_ref.parent() {
58            if !parent.as_os_str().is_empty() && !parent.exists() {
59                std::fs::create_dir_all(parent).map_err(|e| {
60                    RuvectorError::InvalidPath(format!("Failed to create directory: {}", e))
61                })?;
62            }
63        }
64
65        // Convert to absolute path first, then validate
66        let path_buf = if path_ref.is_absolute() {
67            path_ref.to_path_buf()
68        } else {
69            std::env::current_dir()
70                .map_err(|e| RuvectorError::InvalidPath(format!("Failed to get cwd: {}", e)))?
71                .join(path_ref)
72        };
73
74        // SECURITY: Check for path traversal attempts (e.g., "../../../etc/passwd")
75        // Only reject paths that contain ".." components trying to escape
76        let path_str = path_ref.to_string_lossy();
77        if path_str.contains("..") {
78            // Verify the resolved path doesn't escape intended boundaries
79            // For absolute paths, we allow them as-is (user explicitly specified)
80            // For relative paths with "..", check they don't escape cwd
81            if !path_ref.is_absolute() {
82                if let Ok(cwd) = std::env::current_dir() {
83                    // Normalize the path by resolving .. components
84                    let mut normalized = cwd.clone();
85                    for component in path_ref.components() {
86                        match component {
87                            std::path::Component::ParentDir => {
88                                if !normalized.pop() || !normalized.starts_with(&cwd) {
89                                    return Err(RuvectorError::InvalidPath(
90                                        "Path traversal attempt detected".to_string(),
91                                    ));
92                                }
93                            }
94                            std::path::Component::Normal(c) => normalized.push(c),
95                            _ => {}
96                        }
97                    }
98                }
99            }
100        }
101
102        // Check if we already have a Database instance for this path
103        let db = {
104            let mut pool = DB_POOL.lock();
105
106            if let Some(existing_db) = pool.get(&path_buf) {
107                // Reuse existing database connection
108                Arc::clone(existing_db)
109            } else {
110                // Create new database and add to pool
111                let new_db = Arc::new(Database::create(&path_buf)?);
112
113                // Initialize tables
114                let write_txn = new_db.begin_write()?;
115                {
116                    let _ = write_txn.open_table(VECTORS_TABLE)?;
117                    let _ = write_txn.open_table(METADATA_TABLE)?;
118                    let _ = write_txn.open_table(CONFIG_TABLE)?;
119                }
120                write_txn.commit()?;
121
122                pool.insert(path_buf, Arc::clone(&new_db));
123                new_db
124            }
125        };
126
127        Ok(Self { db, dimensions })
128    }
129
130    /// Insert a vector entry
131    pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
132        if entry.vector.len() != self.dimensions {
133            return Err(RuvectorError::DimensionMismatch {
134                expected: self.dimensions,
135                actual: entry.vector.len(),
136            });
137        }
138
139        let id = entry
140            .id
141            .clone()
142            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
143
144        let write_txn = self.db.begin_write()?;
145        {
146            let mut table = write_txn.open_table(VECTORS_TABLE)?;
147
148            // Serialize vector data
149            let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
150                .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
151
152            table.insert(id.as_str(), vector_data.as_slice())?;
153
154            // Store metadata if present
155            if let Some(metadata) = &entry.metadata {
156                let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
157                let metadata_json = serde_json::to_string(metadata)
158                    .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
159                meta_table.insert(id.as_str(), metadata_json.as_str())?;
160            }
161        }
162        write_txn.commit()?;
163
164        Ok(id)
165    }
166
167    /// Insert multiple vectors in a batch
168    pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
169        let write_txn = self.db.begin_write()?;
170        let mut ids = Vec::with_capacity(entries.len());
171
172        {
173            let mut table = write_txn.open_table(VECTORS_TABLE)?;
174            let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
175
176            for entry in entries {
177                if entry.vector.len() != self.dimensions {
178                    return Err(RuvectorError::DimensionMismatch {
179                        expected: self.dimensions,
180                        actual: entry.vector.len(),
181                    });
182                }
183
184                let id = entry
185                    .id
186                    .clone()
187                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
188
189                // Serialize and insert vector
190                let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
191                    .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
192                table.insert(id.as_str(), vector_data.as_slice())?;
193
194                // Insert metadata if present
195                if let Some(metadata) = &entry.metadata {
196                    let metadata_json = serde_json::to_string(metadata)
197                        .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
198                    meta_table.insert(id.as_str(), metadata_json.as_str())?;
199                }
200
201                ids.push(id);
202            }
203        }
204
205        write_txn.commit()?;
206        Ok(ids)
207    }
208
209    /// Get a vector by ID
210    pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
211        let read_txn = self.db.begin_read()?;
212        let table = read_txn.open_table(VECTORS_TABLE)?;
213
214        let Some(vector_data) = table.get(id)? else {
215            return Ok(None);
216        };
217
218        let (vector, _): (Vec<f32>, usize) =
219            bincode::decode_from_slice(vector_data.value(), config::standard())
220                .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
221
222        // Try to get metadata
223        let meta_table = read_txn.open_table(METADATA_TABLE)?;
224        let metadata = if let Some(meta_data) = meta_table.get(id)? {
225            let meta_str = meta_data.value();
226            Some(
227                serde_json::from_str(meta_str)
228                    .map_err(|e| RuvectorError::SerializationError(e.to_string()))?,
229            )
230        } else {
231            None
232        };
233
234        Ok(Some(VectorEntry {
235            id: Some(id.to_string()),
236            vector,
237            metadata,
238        }))
239    }
240
241    /// Delete a vector by ID
242    pub fn delete(&self, id: &str) -> Result<bool> {
243        let write_txn = self.db.begin_write()?;
244        let deleted;
245
246        {
247            let mut table = write_txn.open_table(VECTORS_TABLE)?;
248            deleted = table.remove(id)?.is_some();
249
250            let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
251            let _ = meta_table.remove(id)?;
252        }
253
254        write_txn.commit()?;
255        Ok(deleted)
256    }
257
258    /// Get the number of vectors stored
259    pub fn len(&self) -> Result<usize> {
260        let read_txn = self.db.begin_read()?;
261        let table = read_txn.open_table(VECTORS_TABLE)?;
262        Ok(table.len()? as usize)
263    }
264
265    /// Check if storage is empty
266    pub fn is_empty(&self) -> Result<bool> {
267        Ok(self.len()? == 0)
268    }
269
270    /// Get all vector IDs
271    pub fn all_ids(&self) -> Result<Vec<VectorId>> {
272        let read_txn = self.db.begin_read()?;
273        let table = read_txn.open_table(VECTORS_TABLE)?;
274
275        let mut ids = Vec::new();
276        let iter = table.iter()?;
277        for item in iter {
278            let (key, _) = item?;
279            ids.push(key.value().to_string());
280        }
281
282        Ok(ids)
283    }
284
285    /// Save database configuration to persistent storage
286    pub fn save_config(&self, options: &DbOptions) -> Result<()> {
287        let config_json = serde_json::to_string(options)
288            .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
289
290        let write_txn = self.db.begin_write()?;
291        {
292            let mut table = write_txn.open_table(CONFIG_TABLE)?;
293            table.insert(DB_CONFIG_KEY, config_json.as_str())?;
294        }
295        write_txn.commit()?;
296
297        Ok(())
298    }
299
300    /// Load database configuration from persistent storage
301    pub fn load_config(&self) -> Result<Option<DbOptions>> {
302        let read_txn = self.db.begin_read()?;
303
304        // Try to open config table - may not exist in older databases
305        let table = match read_txn.open_table(CONFIG_TABLE) {
306            Ok(t) => t,
307            Err(_) => return Ok(None),
308        };
309
310        let Some(config_data) = table.get(DB_CONFIG_KEY)? else {
311            return Ok(None);
312        };
313
314        let config: DbOptions = serde_json::from_str(config_data.value())
315            .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
316
317        Ok(Some(config))
318    }
319
320    /// Get the stored dimensions
321    pub fn dimensions(&self) -> usize {
322        self.dimensions
323    }
324}
325
326// Add uuid dependency
327use uuid;
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use tempfile::tempdir;
333
334    #[test]
335    fn test_insert_and_get() -> Result<()> {
336        let dir = tempdir().unwrap();
337        let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
338
339        let entry = VectorEntry {
340            id: Some("test1".to_string()),
341            vector: vec![1.0, 2.0, 3.0],
342            metadata: None,
343        };
344
345        let id = storage.insert(&entry)?;
346        assert_eq!(id, "test1");
347
348        let retrieved = storage.get("test1")?;
349        assert!(retrieved.is_some());
350        let retrieved = retrieved.unwrap();
351        assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
352
353        Ok(())
354    }
355
356    #[test]
357    fn test_batch_insert() -> Result<()> {
358        let dir = tempdir().unwrap();
359        let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
360
361        let entries = vec![
362            VectorEntry {
363                id: None,
364                vector: vec![1.0, 2.0, 3.0],
365                metadata: None,
366            },
367            VectorEntry {
368                id: None,
369                vector: vec![4.0, 5.0, 6.0],
370                metadata: None,
371            },
372        ];
373
374        let ids = storage.insert_batch(&entries)?;
375        assert_eq!(ids.len(), 2);
376        assert_eq!(storage.len()?, 2);
377
378        Ok(())
379    }
380
381    #[test]
382    fn test_delete() -> Result<()> {
383        let dir = tempdir().unwrap();
384        let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
385
386        let entry = VectorEntry {
387            id: Some("test1".to_string()),
388            vector: vec![1.0, 2.0, 3.0],
389            metadata: None,
390        };
391
392        storage.insert(&entry)?;
393        assert_eq!(storage.len()?, 1);
394
395        let deleted = storage.delete("test1")?;
396        assert!(deleted);
397        assert_eq!(storage.len()?, 0);
398
399        Ok(())
400    }
401
402    #[test]
403    fn test_multiple_instances_same_path() -> Result<()> {
404        // This test verifies the fix for the database locking bug
405        // Multiple VectorStorage instances should be able to share the same database file
406        let dir = tempdir().unwrap();
407        let db_path = dir.path().join("shared.db");
408
409        // Create first instance
410        let storage1 = VectorStorage::new(&db_path, 3)?;
411
412        // Insert data with first instance
413        storage1.insert(&VectorEntry {
414            id: Some("test1".to_string()),
415            vector: vec![1.0, 2.0, 3.0],
416            metadata: None,
417        })?;
418
419        // Create second instance with SAME path - this should NOT fail
420        let storage2 = VectorStorage::new(&db_path, 3)?;
421
422        // Both instances should see the same data
423        assert_eq!(storage1.len()?, 1);
424        assert_eq!(storage2.len()?, 1);
425
426        // Insert with second instance
427        storage2.insert(&VectorEntry {
428            id: Some("test2".to_string()),
429            vector: vec![4.0, 5.0, 6.0],
430            metadata: None,
431        })?;
432
433        // Both instances should see both records
434        assert_eq!(storage1.len()?, 2);
435        assert_eq!(storage2.len()?, 2);
436
437        // Verify data integrity
438        let retrieved1 = storage1.get("test1")?;
439        assert!(retrieved1.is_some());
440
441        let retrieved2 = storage2.get("test2")?;
442        assert!(retrieved2.is_some());
443
444        Ok(())
445    }
446}