ruvector_core/
storage.rs

1//! Storage layer with redb for metadata and memory-mapped vectors
2//!
3//! This module is only available when the "storage" feature is enabled.
4//! For WASM builds, use the in-memory storage backend instead.
5
6#[cfg(feature = "storage")]
7use crate::error::{Result, RuvectorError};
8#[cfg(feature = "storage")]
9use crate::types::{VectorEntry, VectorId};
10#[cfg(feature = "storage")]
11use bincode::config;
12#[cfg(feature = "storage")]
13use once_cell::sync::Lazy;
14#[cfg(feature = "storage")]
15use parking_lot::Mutex;
16#[cfg(feature = "storage")]
17use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
18#[cfg(feature = "storage")]
19use serde_json;
20#[cfg(feature = "storage")]
21use std::collections::HashMap;
22#[cfg(feature = "storage")]
23use std::path::{Path, PathBuf};
24#[cfg(feature = "storage")]
25use std::sync::Arc;
26
27#[cfg(feature = "storage")]
28
29const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
30const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
31
32// Global database connection pool to allow multiple VectorDB instances
33// to share the same underlying database file
34static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
35    Lazy::new(|| Mutex::new(HashMap::new()));
36
37/// Storage backend for vector database
38pub struct VectorStorage {
39    db: Arc<Database>,
40    dimensions: usize,
41}
42
43impl VectorStorage {
44    /// Create or open a vector storage at the given path
45    ///
46    /// This method uses a global connection pool to allow multiple VectorDB
47    /// instances to share the same underlying database file, fixing the
48    /// "Database already open. Cannot acquire lock" error.
49    pub fn new<P: AsRef<Path>>(path: P, dimensions: usize) -> Result<Self> {
50        // SECURITY: Validate path to prevent directory traversal attacks
51        let path_ref = path.as_ref();
52
53        // Create parent directories if they don't exist (needed for canonicalize)
54        if let Some(parent) = path_ref.parent() {
55            if !parent.as_os_str().is_empty() && !parent.exists() {
56                std::fs::create_dir_all(parent).map_err(|e| {
57                    RuvectorError::InvalidPath(format!("Failed to create directory: {}", e))
58                })?;
59            }
60        }
61
62        // Convert to absolute path first, then validate
63        let path_buf = if path_ref.is_absolute() {
64            path_ref.to_path_buf()
65        } else {
66            std::env::current_dir()
67                .map_err(|e| RuvectorError::InvalidPath(format!("Failed to get cwd: {}", e)))?
68                .join(path_ref)
69        };
70
71        // SECURITY: Check for path traversal attempts (e.g., "../../../etc/passwd")
72        // Only reject paths that contain ".." components trying to escape
73        let path_str = path_ref.to_string_lossy();
74        if path_str.contains("..") {
75            // Verify the resolved path doesn't escape intended boundaries
76            // For absolute paths, we allow them as-is (user explicitly specified)
77            // For relative paths with "..", check they don't escape cwd
78            if !path_ref.is_absolute() {
79                if let Ok(cwd) = std::env::current_dir() {
80                    // Normalize the path by resolving .. components
81                    let mut normalized = cwd.clone();
82                    for component in path_ref.components() {
83                        match component {
84                            std::path::Component::ParentDir => {
85                                if !normalized.pop() || !normalized.starts_with(&cwd) {
86                                    return Err(RuvectorError::InvalidPath(
87                                        "Path traversal attempt detected".to_string()
88                                    ));
89                                }
90                            }
91                            std::path::Component::Normal(c) => normalized.push(c),
92                            _ => {}
93                        }
94                    }
95                }
96            }
97        }
98
99        // Check if we already have a Database instance for this path
100        let db = {
101            let mut pool = DB_POOL.lock();
102
103            if let Some(existing_db) = pool.get(&path_buf) {
104                // Reuse existing database connection
105                Arc::clone(existing_db)
106            } else {
107                // Create new database and add to pool
108                let new_db = Arc::new(Database::create(&path_buf)?);
109
110                // Initialize tables
111                let write_txn = new_db.begin_write()?;
112                {
113                    let _ = write_txn.open_table(VECTORS_TABLE)?;
114                    let _ = write_txn.open_table(METADATA_TABLE)?;
115                }
116                write_txn.commit()?;
117
118                pool.insert(path_buf, Arc::clone(&new_db));
119                new_db
120            }
121        };
122
123        Ok(Self { db, dimensions })
124    }
125
126    /// Insert a vector entry
127    pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
128        if entry.vector.len() != self.dimensions {
129            return Err(RuvectorError::DimensionMismatch {
130                expected: self.dimensions,
131                actual: entry.vector.len(),
132            });
133        }
134
135        let id = entry
136            .id
137            .clone()
138            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
139
140        let write_txn = self.db.begin_write()?;
141        {
142            let mut table = write_txn.open_table(VECTORS_TABLE)?;
143
144            // Serialize vector data
145            let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
146                .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
147
148            table.insert(id.as_str(), vector_data.as_slice())?;
149
150            // Store metadata if present
151            if let Some(metadata) = &entry.metadata {
152                let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
153                let metadata_json = serde_json::to_string(metadata)
154                    .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
155                meta_table.insert(id.as_str(), metadata_json.as_str())?;
156            }
157        }
158        write_txn.commit()?;
159
160        Ok(id)
161    }
162
163    /// Insert multiple vectors in a batch
164    pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
165        let write_txn = self.db.begin_write()?;
166        let mut ids = Vec::with_capacity(entries.len());
167
168        {
169            let mut table = write_txn.open_table(VECTORS_TABLE)?;
170            let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
171
172            for entry in entries {
173                if entry.vector.len() != self.dimensions {
174                    return Err(RuvectorError::DimensionMismatch {
175                        expected: self.dimensions,
176                        actual: entry.vector.len(),
177                    });
178                }
179
180                let id = entry
181                    .id
182                    .clone()
183                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
184
185                // Serialize and insert vector
186                let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
187                    .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
188                table.insert(id.as_str(), vector_data.as_slice())?;
189
190                // Insert metadata if present
191                if let Some(metadata) = &entry.metadata {
192                    let metadata_json = serde_json::to_string(metadata)
193                        .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
194                    meta_table.insert(id.as_str(), metadata_json.as_str())?;
195                }
196
197                ids.push(id);
198            }
199        }
200
201        write_txn.commit()?;
202        Ok(ids)
203    }
204
205    /// Get a vector by ID
206    pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
207        let read_txn = self.db.begin_read()?;
208        let table = read_txn.open_table(VECTORS_TABLE)?;
209
210        let Some(vector_data) = table.get(id)? else {
211            return Ok(None);
212        };
213
214        let (vector, _): (Vec<f32>, usize) =
215            bincode::decode_from_slice(vector_data.value(), config::standard())
216                .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
217
218        // Try to get metadata
219        let meta_table = read_txn.open_table(METADATA_TABLE)?;
220        let metadata = if let Some(meta_data) = meta_table.get(id)? {
221            let meta_str = meta_data.value();
222            Some(
223                serde_json::from_str(meta_str)
224                    .map_err(|e| RuvectorError::SerializationError(e.to_string()))?,
225            )
226        } else {
227            None
228        };
229
230        Ok(Some(VectorEntry {
231            id: Some(id.to_string()),
232            vector,
233            metadata,
234        }))
235    }
236
237    /// Delete a vector by ID
238    pub fn delete(&self, id: &str) -> Result<bool> {
239        let write_txn = self.db.begin_write()?;
240        let mut deleted = false;
241
242        {
243            let mut table = write_txn.open_table(VECTORS_TABLE)?;
244            deleted = table.remove(id)?.is_some();
245
246            let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
247            let _ = meta_table.remove(id)?;
248        }
249
250        write_txn.commit()?;
251        Ok(deleted)
252    }
253
254    /// Get the number of vectors stored
255    pub fn len(&self) -> Result<usize> {
256        let read_txn = self.db.begin_read()?;
257        let table = read_txn.open_table(VECTORS_TABLE)?;
258        Ok(table.len()? as usize)
259    }
260
261    /// Check if storage is empty
262    pub fn is_empty(&self) -> Result<bool> {
263        Ok(self.len()? == 0)
264    }
265
266    /// Get all vector IDs
267    pub fn all_ids(&self) -> Result<Vec<VectorId>> {
268        let read_txn = self.db.begin_read()?;
269        let table = read_txn.open_table(VECTORS_TABLE)?;
270
271        let mut ids = Vec::new();
272        let iter = table.iter()?;
273        for item in iter {
274            let (key, _) = item?;
275            ids.push(key.value().to_string());
276        }
277
278        Ok(ids)
279    }
280}
281
282// Add uuid dependency
283use uuid;
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use tempfile::tempdir;
289
290    #[test]
291    fn test_insert_and_get() -> Result<()> {
292        let dir = tempdir().unwrap();
293        let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
294
295        let entry = VectorEntry {
296            id: Some("test1".to_string()),
297            vector: vec![1.0, 2.0, 3.0],
298            metadata: None,
299        };
300
301        let id = storage.insert(&entry)?;
302        assert_eq!(id, "test1");
303
304        let retrieved = storage.get("test1")?;
305        assert!(retrieved.is_some());
306        let retrieved = retrieved.unwrap();
307        assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
308
309        Ok(())
310    }
311
312    #[test]
313    fn test_batch_insert() -> Result<()> {
314        let dir = tempdir().unwrap();
315        let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
316
317        let entries = vec![
318            VectorEntry {
319                id: None,
320                vector: vec![1.0, 2.0, 3.0],
321                metadata: None,
322            },
323            VectorEntry {
324                id: None,
325                vector: vec![4.0, 5.0, 6.0],
326                metadata: None,
327            },
328        ];
329
330        let ids = storage.insert_batch(&entries)?;
331        assert_eq!(ids.len(), 2);
332        assert_eq!(storage.len()?, 2);
333
334        Ok(())
335    }
336
337    #[test]
338    fn test_delete() -> Result<()> {
339        let dir = tempdir().unwrap();
340        let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
341
342        let entry = VectorEntry {
343            id: Some("test1".to_string()),
344            vector: vec![1.0, 2.0, 3.0],
345            metadata: None,
346        };
347
348        storage.insert(&entry)?;
349        assert_eq!(storage.len()?, 1);
350
351        let deleted = storage.delete("test1")?;
352        assert!(deleted);
353        assert_eq!(storage.len()?, 0);
354
355        Ok(())
356    }
357
358    #[test]
359    fn test_multiple_instances_same_path() -> Result<()> {
360        // This test verifies the fix for the database locking bug
361        // Multiple VectorStorage instances should be able to share the same database file
362        let dir = tempdir().unwrap();
363        let db_path = dir.path().join("shared.db");
364
365        // Create first instance
366        let storage1 = VectorStorage::new(&db_path, 3)?;
367
368        // Insert data with first instance
369        storage1.insert(&VectorEntry {
370            id: Some("test1".to_string()),
371            vector: vec![1.0, 2.0, 3.0],
372            metadata: None,
373        })?;
374
375        // Create second instance with SAME path - this should NOT fail
376        let storage2 = VectorStorage::new(&db_path, 3)?;
377
378        // Both instances should see the same data
379        assert_eq!(storage1.len()?, 1);
380        assert_eq!(storage2.len()?, 1);
381
382        // Insert with second instance
383        storage2.insert(&VectorEntry {
384            id: Some("test2".to_string()),
385            vector: vec![4.0, 5.0, 6.0],
386            metadata: None,
387        })?;
388
389        // Both instances should see both records
390        assert_eq!(storage1.len()?, 2);
391        assert_eq!(storage2.len()?, 2);
392
393        // Verify data integrity
394        let retrieved1 = storage1.get("test1")?;
395        assert!(retrieved1.is_some());
396
397        let retrieved2 = storage2.get("test2")?;
398        assert!(retrieved2.is_some());
399
400        Ok(())
401    }
402}