ruvector_router_core/
storage.rs

1//! Storage layer with redb and memory-mapped files
2
3use crate::error::{Result, VectorDbError};
4use crate::types::VectorEntry;
5use parking_lot::RwLock;
6use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
7use std::collections::HashMap;
8use std::path::Path;
9use std::sync::Arc;
10
11// Table definitions
12const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
13const METADATA_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("metadata");
14const INDEX_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("index");
15
16/// Storage backend for vector database
17pub struct Storage {
18    db: Arc<Database>,
19    vector_cache: Arc<RwLock<HashMap<String, Vec<f32>>>>,
20}
21
22impl Storage {
23    /// Create a new storage instance
24    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
25        // SECURITY: Validate path to prevent directory traversal attacks
26        let path_ref = path.as_ref();
27
28        // Create parent directories if they don't exist
29        if let Some(parent) = path_ref.parent() {
30            if !parent.as_os_str().is_empty() && !parent.exists() {
31                std::fs::create_dir_all(parent).map_err(|e| {
32                    VectorDbError::InvalidPath(format!("Failed to create directory: {}", e))
33                })?;
34            }
35        }
36
37        // Convert to absolute path
38        let canonical_path = if path_ref.is_absolute() {
39            path_ref.to_path_buf()
40        } else {
41            std::env::current_dir()
42                .map_err(|e| VectorDbError::InvalidPath(format!("Failed to get cwd: {}", e)))?
43                .join(path_ref)
44        };
45
46        // SECURITY: Check for path traversal attempts
47        let path_str = path_ref.to_string_lossy();
48        if path_str.contains("..") && !path_ref.is_absolute() {
49            if let Ok(cwd) = std::env::current_dir() {
50                let mut normalized = cwd.clone();
51                for component in path_ref.components() {
52                    match component {
53                        std::path::Component::ParentDir => {
54                            if !normalized.pop() || !normalized.starts_with(&cwd) {
55                                return Err(VectorDbError::InvalidPath(
56                                    "Path traversal attempt detected".to_string()
57                                ));
58                            }
59                        }
60                        std::path::Component::Normal(c) => normalized.push(c),
61                        _ => {}
62                    }
63                }
64            }
65        }
66
67        let db = Database::create(canonical_path)?;
68
69        Ok(Self {
70            db: Arc::new(db),
71            vector_cache: Arc::new(RwLock::new(HashMap::new())),
72        })
73    }
74
75    /// Open an existing storage instance
76    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
77        // SECURITY: Validate path to prevent directory traversal attacks
78        let path_ref = path.as_ref();
79
80        // Convert to absolute path - file must exist for open
81        let canonical_path = path_ref.canonicalize().map_err(|e| {
82            VectorDbError::InvalidPath(format!("Path does not exist or cannot be resolved: {}", e))
83        })?;
84
85        // SECURITY: Check for path traversal attempts
86        let path_str = path_ref.to_string_lossy();
87        if path_str.contains("..") && !path_ref.is_absolute() {
88            if let Ok(cwd) = std::env::current_dir() {
89                if !canonical_path.starts_with(&cwd) {
90                    return Err(VectorDbError::InvalidPath(
91                        "Path traversal attempt detected".to_string()
92                    ));
93                }
94            }
95        }
96
97        let db = Database::open(canonical_path)?;
98
99        Ok(Self {
100            db: Arc::new(db),
101            vector_cache: Arc::new(RwLock::new(HashMap::new())),
102        })
103    }
104
105    /// Insert a vector entry
106    pub fn insert(&self, entry: &VectorEntry) -> Result<()> {
107        let write_txn = self.db.begin_write()?;
108
109        {
110            let mut table = write_txn.open_table(VECTORS_TABLE)?;
111
112            // Serialize vector as bytes
113            let vector_bytes = bincode::encode_to_vec(&entry.vector, bincode::config::standard())
114                .map_err(|e| VectorDbError::Serialization(e.to_string()))?;
115
116            table.insert(entry.id.as_str(), vector_bytes.as_slice())?;
117        }
118
119        {
120            let mut table = write_txn.open_table(METADATA_TABLE)?;
121
122            // Serialize metadata (use JSON for serde_json::Value compatibility)
123            let metadata_bytes = serde_json::to_vec(&entry.metadata)
124                .map_err(|e| VectorDbError::Serialization(e.to_string()))?;
125
126            table.insert(entry.id.as_str(), metadata_bytes.as_slice())?;
127        }
128
129        write_txn.commit()?;
130
131        // Update cache
132        self.vector_cache
133            .write()
134            .insert(entry.id.clone(), entry.vector.clone());
135
136        Ok(())
137    }
138
139    /// Insert multiple vector entries in a batch
140    pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<()> {
141        let write_txn = self.db.begin_write()?;
142
143        {
144            let mut vectors_table = write_txn.open_table(VECTORS_TABLE)?;
145            let mut metadata_table = write_txn.open_table(METADATA_TABLE)?;
146
147            for entry in entries {
148                // Serialize vector
149                let vector_bytes =
150                    bincode::encode_to_vec(&entry.vector, bincode::config::standard())
151                        .map_err(|e| VectorDbError::Serialization(e.to_string()))?;
152
153                vectors_table.insert(entry.id.as_str(), vector_bytes.as_slice())?;
154
155                // Serialize metadata (use JSON for serde_json::Value compatibility)
156                let metadata_bytes = serde_json::to_vec(&entry.metadata)
157                    .map_err(|e| VectorDbError::Serialization(e.to_string()))?;
158
159                metadata_table.insert(entry.id.as_str(), metadata_bytes.as_slice())?;
160            }
161        }
162
163        write_txn.commit()?;
164
165        // Update cache
166        let mut cache = self.vector_cache.write();
167        for entry in entries {
168            cache.insert(entry.id.clone(), entry.vector.clone());
169        }
170
171        Ok(())
172    }
173
174    /// Get a vector by ID
175    pub fn get(&self, id: &str) -> Result<Option<Vec<f32>>> {
176        // Check cache first
177        if let Some(vector) = self.vector_cache.read().get(id) {
178            return Ok(Some(vector.clone()));
179        }
180
181        // Read from database
182        let read_txn = self.db.begin_read()?;
183        let table = read_txn.open_table(VECTORS_TABLE)?;
184
185        if let Some(bytes) = table.get(id)? {
186            let (vector, _): (Vec<f32>, usize) =
187                bincode::decode_from_slice(bytes.value(), bincode::config::standard())
188                    .map_err(|e| VectorDbError::Serialization(e.to_string()))?;
189
190            // Update cache
191            self.vector_cache
192                .write()
193                .insert(id.to_string(), vector.clone());
194
195            Ok(Some(vector))
196        } else {
197            Ok(None)
198        }
199    }
200
201    /// Get metadata for a vector
202    pub fn get_metadata(&self, id: &str) -> Result<Option<HashMap<String, serde_json::Value>>> {
203        let read_txn = self.db.begin_read()?;
204        let table = read_txn.open_table(METADATA_TABLE)?;
205
206        if let Some(bytes) = table.get(id)? {
207            let metadata: HashMap<String, serde_json::Value> =
208                serde_json::from_slice(bytes.value())
209                    .map_err(|e| VectorDbError::Serialization(e.to_string()))?;
210
211            Ok(Some(metadata))
212        } else {
213            Ok(None)
214        }
215    }
216
217    /// Delete a vector by ID
218    pub fn delete(&self, id: &str) -> Result<bool> {
219        let write_txn = self.db.begin_write()?;
220
221        let deleted;
222
223        {
224            let mut table = write_txn.open_table(VECTORS_TABLE)?;
225            deleted = table.remove(id)?.is_some();
226        }
227
228        {
229            let mut table = write_txn.open_table(METADATA_TABLE)?;
230            table.remove(id)?;
231        }
232
233        write_txn.commit()?;
234
235        // Remove from cache
236        self.vector_cache.write().remove(id);
237
238        Ok(deleted)
239    }
240
241    /// Get all vector IDs
242    pub fn get_all_ids(&self) -> Result<Vec<String>> {
243        let read_txn = self.db.begin_read()?;
244        let table = read_txn.open_table(VECTORS_TABLE)?;
245
246        let mut ids = Vec::new();
247        let iter = table.iter()?;
248        for item in iter {
249            let (key, _) = item?;
250            ids.push(key.value().to_string());
251        }
252
253        Ok(ids)
254    }
255
256    /// Count total vectors
257    pub fn count(&self) -> Result<usize> {
258        let read_txn = self.db.begin_read()?;
259        let table = read_txn.open_table(VECTORS_TABLE)?;
260        Ok(table.len()? as usize)
261    }
262
263    /// Store index data
264    pub fn store_index(&self, key: &str, data: &[u8]) -> Result<()> {
265        let write_txn = self.db.begin_write()?;
266
267        {
268            let mut table = write_txn.open_table(INDEX_TABLE)?;
269            table.insert(key, data)?;
270        }
271
272        write_txn.commit()?;
273        Ok(())
274    }
275
276    /// Load index data
277    pub fn load_index(&self, key: &str) -> Result<Option<Vec<u8>>> {
278        let read_txn = self.db.begin_read()?;
279        let table = read_txn.open_table(INDEX_TABLE)?;
280
281        if let Some(bytes) = table.get(key)? {
282            Ok(Some(bytes.value().to_vec()))
283        } else {
284            Ok(None)
285        }
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use tempfile::tempdir;
293
294    #[test]
295    fn test_storage_insert_and_get() {
296        let dir = tempdir().unwrap();
297        let path = dir.path().join("test.db");
298        let storage = Storage::new(&path).unwrap();
299
300        let entry = VectorEntry {
301            id: "test1".to_string(),
302            vector: vec![1.0, 2.0, 3.0],
303            metadata: HashMap::new(),
304            timestamp: 0,
305        };
306
307        storage.insert(&entry).unwrap();
308
309        let retrieved = storage.get("test1").unwrap();
310        assert!(retrieved.is_some());
311        assert_eq!(retrieved.unwrap(), vec![1.0, 2.0, 3.0]);
312    }
313
314    #[test]
315    fn test_storage_delete() {
316        let dir = tempdir().unwrap();
317        let path = dir.path().join("test.db");
318        let storage = Storage::new(&path).unwrap();
319
320        let entry = VectorEntry {
321            id: "test1".to_string(),
322            vector: vec![1.0, 2.0, 3.0],
323            metadata: HashMap::new(),
324            timestamp: 0,
325        };
326
327        storage.insert(&entry).unwrap();
328        assert!(storage.delete("test1").unwrap());
329        assert!(storage.get("test1").unwrap().is_none());
330    }
331}