ruvector_core/
storage.rs

1//! Storage layer with redb for metadata and memory-mapped vectors
2//!
3//! This module is only available when the "storage" feature is enabled.
4//! For WASM builds, use the in-memory storage backend instead.
5
6#[cfg(feature = "storage")]
7use crate::error::{Result, RuvectorError};
8#[cfg(feature = "storage")]
9use crate::types::{DbOptions, VectorEntry, VectorId};
10#[cfg(feature = "storage")]
11use bincode::config;
12#[cfg(feature = "storage")]
13use once_cell::sync::Lazy;
14#[cfg(feature = "storage")]
15use parking_lot::Mutex;
16#[cfg(feature = "storage")]
17use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
18#[cfg(feature = "storage")]
19use serde_json;
20#[cfg(feature = "storage")]
21use std::collections::HashMap;
22#[cfg(feature = "storage")]
23use std::path::{Path, PathBuf};
24#[cfg(feature = "storage")]
25use std::sync::Arc;
26
27#[cfg(feature = "storage")]
28
29const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
30const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
31const CONFIG_TABLE: TableDefinition<&str, &str> = TableDefinition::new("config");
32
33/// Key used to store database configuration in CONFIG_TABLE
34const DB_CONFIG_KEY: &str = "__ruvector_db_config__";
35
36// Global database connection pool to allow multiple VectorDB instances
37// to share the same underlying database file
38static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
39    Lazy::new(|| Mutex::new(HashMap::new()));
40
41/// Storage backend for vector database
42pub struct VectorStorage {
43    db: Arc<Database>,
44    dimensions: usize,
45}
46
47impl VectorStorage {
48    /// Create or open a vector storage at the given path
49    ///
50    /// This method uses a global connection pool to allow multiple VectorDB
51    /// instances to share the same underlying database file, fixing the
52    /// "Database already open. Cannot acquire lock" error.
53    pub fn new<P: AsRef<Path>>(path: P, dimensions: usize) -> Result<Self> {
54        // SECURITY: Validate path to prevent directory traversal attacks
55        let path_ref = path.as_ref();
56
57        // Create parent directories if they don't exist (needed for canonicalize)
58        if let Some(parent) = path_ref.parent() {
59            if !parent.as_os_str().is_empty() && !parent.exists() {
60                std::fs::create_dir_all(parent).map_err(|e| {
61                    RuvectorError::InvalidPath(format!("Failed to create directory: {}", e))
62                })?;
63            }
64        }
65
66        // Convert to absolute path first, then validate
67        let path_buf = if path_ref.is_absolute() {
68            path_ref.to_path_buf()
69        } else {
70            std::env::current_dir()
71                .map_err(|e| RuvectorError::InvalidPath(format!("Failed to get cwd: {}", e)))?
72                .join(path_ref)
73        };
74
75        // SECURITY: Check for path traversal attempts (e.g., "../../../etc/passwd")
76        // Only reject paths that contain ".." components trying to escape
77        let path_str = path_ref.to_string_lossy();
78        if path_str.contains("..") {
79            // Verify the resolved path doesn't escape intended boundaries
80            // For absolute paths, we allow them as-is (user explicitly specified)
81            // For relative paths with "..", check they don't escape cwd
82            if !path_ref.is_absolute() {
83                if let Ok(cwd) = std::env::current_dir() {
84                    // Normalize the path by resolving .. components
85                    let mut normalized = cwd.clone();
86                    for component in path_ref.components() {
87                        match component {
88                            std::path::Component::ParentDir => {
89                                if !normalized.pop() || !normalized.starts_with(&cwd) {
90                                    return Err(RuvectorError::InvalidPath(
91                                        "Path traversal attempt detected".to_string()
92                                    ));
93                                }
94                            }
95                            std::path::Component::Normal(c) => normalized.push(c),
96                            _ => {}
97                        }
98                    }
99                }
100            }
101        }
102
103        // Check if we already have a Database instance for this path
104        let db = {
105            let mut pool = DB_POOL.lock();
106
107            if let Some(existing_db) = pool.get(&path_buf) {
108                // Reuse existing database connection
109                Arc::clone(existing_db)
110            } else {
111                // Create new database and add to pool
112                let new_db = Arc::new(Database::create(&path_buf)?);
113
114                // Initialize tables
115                let write_txn = new_db.begin_write()?;
116                {
117                    let _ = write_txn.open_table(VECTORS_TABLE)?;
118                    let _ = write_txn.open_table(METADATA_TABLE)?;
119                    let _ = write_txn.open_table(CONFIG_TABLE)?;
120                }
121                write_txn.commit()?;
122
123                pool.insert(path_buf, Arc::clone(&new_db));
124                new_db
125            }
126        };
127
128        Ok(Self { db, dimensions })
129    }
130
131    /// Insert a vector entry
132    pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
133        if entry.vector.len() != self.dimensions {
134            return Err(RuvectorError::DimensionMismatch {
135                expected: self.dimensions,
136                actual: entry.vector.len(),
137            });
138        }
139
140        let id = entry
141            .id
142            .clone()
143            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
144
145        let write_txn = self.db.begin_write()?;
146        {
147            let mut table = write_txn.open_table(VECTORS_TABLE)?;
148
149            // Serialize vector data
150            let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
151                .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
152
153            table.insert(id.as_str(), vector_data.as_slice())?;
154
155            // Store metadata if present
156            if let Some(metadata) = &entry.metadata {
157                let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
158                let metadata_json = serde_json::to_string(metadata)
159                    .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
160                meta_table.insert(id.as_str(), metadata_json.as_str())?;
161            }
162        }
163        write_txn.commit()?;
164
165        Ok(id)
166    }
167
168    /// Insert multiple vectors in a batch
169    pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
170        let write_txn = self.db.begin_write()?;
171        let mut ids = Vec::with_capacity(entries.len());
172
173        {
174            let mut table = write_txn.open_table(VECTORS_TABLE)?;
175            let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
176
177            for entry in entries {
178                if entry.vector.len() != self.dimensions {
179                    return Err(RuvectorError::DimensionMismatch {
180                        expected: self.dimensions,
181                        actual: entry.vector.len(),
182                    });
183                }
184
185                let id = entry
186                    .id
187                    .clone()
188                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
189
190                // Serialize and insert vector
191                let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
192                    .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
193                table.insert(id.as_str(), vector_data.as_slice())?;
194
195                // Insert metadata if present
196                if let Some(metadata) = &entry.metadata {
197                    let metadata_json = serde_json::to_string(metadata)
198                        .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
199                    meta_table.insert(id.as_str(), metadata_json.as_str())?;
200                }
201
202                ids.push(id);
203            }
204        }
205
206        write_txn.commit()?;
207        Ok(ids)
208    }
209
210    /// Get a vector by ID
211    pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
212        let read_txn = self.db.begin_read()?;
213        let table = read_txn.open_table(VECTORS_TABLE)?;
214
215        let Some(vector_data) = table.get(id)? else {
216            return Ok(None);
217        };
218
219        let (vector, _): (Vec<f32>, usize) =
220            bincode::decode_from_slice(vector_data.value(), config::standard())
221                .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
222
223        // Try to get metadata
224        let meta_table = read_txn.open_table(METADATA_TABLE)?;
225        let metadata = if let Some(meta_data) = meta_table.get(id)? {
226            let meta_str = meta_data.value();
227            Some(
228                serde_json::from_str(meta_str)
229                    .map_err(|e| RuvectorError::SerializationError(e.to_string()))?,
230            )
231        } else {
232            None
233        };
234
235        Ok(Some(VectorEntry {
236            id: Some(id.to_string()),
237            vector,
238            metadata,
239        }))
240    }
241
242    /// Delete a vector by ID
243    pub fn delete(&self, id: &str) -> Result<bool> {
244        let write_txn = self.db.begin_write()?;
245        let mut deleted = false;
246
247        {
248            let mut table = write_txn.open_table(VECTORS_TABLE)?;
249            deleted = table.remove(id)?.is_some();
250
251            let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
252            let _ = meta_table.remove(id)?;
253        }
254
255        write_txn.commit()?;
256        Ok(deleted)
257    }
258
259    /// Get the number of vectors stored
260    pub fn len(&self) -> Result<usize> {
261        let read_txn = self.db.begin_read()?;
262        let table = read_txn.open_table(VECTORS_TABLE)?;
263        Ok(table.len()? as usize)
264    }
265
266    /// Check if storage is empty
267    pub fn is_empty(&self) -> Result<bool> {
268        Ok(self.len()? == 0)
269    }
270
271    /// Get all vector IDs
272    pub fn all_ids(&self) -> Result<Vec<VectorId>> {
273        let read_txn = self.db.begin_read()?;
274        let table = read_txn.open_table(VECTORS_TABLE)?;
275
276        let mut ids = Vec::new();
277        let iter = table.iter()?;
278        for item in iter {
279            let (key, _) = item?;
280            ids.push(key.value().to_string());
281        }
282
283        Ok(ids)
284    }
285
286    /// Save database configuration to persistent storage
287    pub fn save_config(&self, options: &DbOptions) -> Result<()> {
288        let config_json = serde_json::to_string(options)
289            .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
290
291        let write_txn = self.db.begin_write()?;
292        {
293            let mut table = write_txn.open_table(CONFIG_TABLE)?;
294            table.insert(DB_CONFIG_KEY, config_json.as_str())?;
295        }
296        write_txn.commit()?;
297
298        Ok(())
299    }
300
301    /// Load database configuration from persistent storage
302    pub fn load_config(&self) -> Result<Option<DbOptions>> {
303        let read_txn = self.db.begin_read()?;
304
305        // Try to open config table - may not exist in older databases
306        let table = match read_txn.open_table(CONFIG_TABLE) {
307            Ok(t) => t,
308            Err(_) => return Ok(None),
309        };
310
311        let Some(config_data) = table.get(DB_CONFIG_KEY)? else {
312            return Ok(None);
313        };
314
315        let config: DbOptions = serde_json::from_str(config_data.value())
316            .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
317
318        Ok(Some(config))
319    }
320
321    /// Get the stored dimensions
322    pub fn dimensions(&self) -> usize {
323        self.dimensions
324    }
325}
326
327// Add uuid dependency
328use uuid;
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use tempfile::tempdir;
334
335    #[test]
336    fn test_insert_and_get() -> Result<()> {
337        let dir = tempdir().unwrap();
338        let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
339
340        let entry = VectorEntry {
341            id: Some("test1".to_string()),
342            vector: vec![1.0, 2.0, 3.0],
343            metadata: None,
344        };
345
346        let id = storage.insert(&entry)?;
347        assert_eq!(id, "test1");
348
349        let retrieved = storage.get("test1")?;
350        assert!(retrieved.is_some());
351        let retrieved = retrieved.unwrap();
352        assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
353
354        Ok(())
355    }
356
357    #[test]
358    fn test_batch_insert() -> Result<()> {
359        let dir = tempdir().unwrap();
360        let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
361
362        let entries = vec![
363            VectorEntry {
364                id: None,
365                vector: vec![1.0, 2.0, 3.0],
366                metadata: None,
367            },
368            VectorEntry {
369                id: None,
370                vector: vec![4.0, 5.0, 6.0],
371                metadata: None,
372            },
373        ];
374
375        let ids = storage.insert_batch(&entries)?;
376        assert_eq!(ids.len(), 2);
377        assert_eq!(storage.len()?, 2);
378
379        Ok(())
380    }
381
382    #[test]
383    fn test_delete() -> Result<()> {
384        let dir = tempdir().unwrap();
385        let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
386
387        let entry = VectorEntry {
388            id: Some("test1".to_string()),
389            vector: vec![1.0, 2.0, 3.0],
390            metadata: None,
391        };
392
393        storage.insert(&entry)?;
394        assert_eq!(storage.len()?, 1);
395
396        let deleted = storage.delete("test1")?;
397        assert!(deleted);
398        assert_eq!(storage.len()?, 0);
399
400        Ok(())
401    }
402
403    #[test]
404    fn test_multiple_instances_same_path() -> Result<()> {
405        // This test verifies the fix for the database locking bug
406        // Multiple VectorStorage instances should be able to share the same database file
407        let dir = tempdir().unwrap();
408        let db_path = dir.path().join("shared.db");
409
410        // Create first instance
411        let storage1 = VectorStorage::new(&db_path, 3)?;
412
413        // Insert data with first instance
414        storage1.insert(&VectorEntry {
415            id: Some("test1".to_string()),
416            vector: vec![1.0, 2.0, 3.0],
417            metadata: None,
418        })?;
419
420        // Create second instance with SAME path - this should NOT fail
421        let storage2 = VectorStorage::new(&db_path, 3)?;
422
423        // Both instances should see the same data
424        assert_eq!(storage1.len()?, 1);
425        assert_eq!(storage2.len()?, 1);
426
427        // Insert with second instance
428        storage2.insert(&VectorEntry {
429            id: Some("test2".to_string()),
430            vector: vec![4.0, 5.0, 6.0],
431            metadata: None,
432        })?;
433
434        // Both instances should see both records
435        assert_eq!(storage1.len()?, 2);
436        assert_eq!(storage2.len()?, 2);
437
438        // Verify data integrity
439        let retrieved1 = storage1.get("test1")?;
440        assert!(retrieved1.is_some());
441
442        let retrieved2 = storage2.get("test2")?;
443        assert!(retrieved2.is_some());
444
445        Ok(())
446    }
447}