Skip to main content

ruvector_graph/
storage.rs

1//! Persistent storage layer with redb and memory-mapped vectors
2//!
3//! Provides ACID-compliant storage for graph nodes, edges, and hyperedges
4
5#[cfg(feature = "storage")]
6use crate::edge::Edge;
7#[cfg(feature = "storage")]
8use crate::hyperedge::{Hyperedge, HyperedgeId};
9#[cfg(feature = "storage")]
10use crate::node::Node;
11#[cfg(feature = "storage")]
12use crate::types::{EdgeId, NodeId};
13#[cfg(feature = "storage")]
14use anyhow::Result;
15#[cfg(feature = "storage")]
16use bincode::config;
17#[cfg(feature = "storage")]
18use once_cell::sync::Lazy;
19#[cfg(feature = "storage")]
20use parking_lot::Mutex;
21#[cfg(feature = "storage")]
22use redb::{Database, ReadableTable, TableDefinition};
23#[cfg(feature = "storage")]
24use std::collections::HashMap;
25#[cfg(feature = "storage")]
26use std::path::{Path, PathBuf};
27#[cfg(feature = "storage")]
28use std::sync::Arc;
29
30#[cfg(feature = "storage")]
31// Table definitions
32const NODES_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("nodes");
33#[cfg(feature = "storage")]
34const EDGES_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("edges");
35#[cfg(feature = "storage")]
36const HYPEREDGES_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("hyperedges");
37#[cfg(feature = "storage")]
38const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
39
40#[cfg(feature = "storage")]
41// Global database connection pool to allow multiple GraphStorage instances
42// to share the same underlying database file
43static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
44    Lazy::new(|| Mutex::new(HashMap::new()));
45
46#[cfg(feature = "storage")]
47/// Storage backend for graph database
48pub struct GraphStorage {
49    db: Arc<Database>,
50}
51
52#[cfg(feature = "storage")]
53impl GraphStorage {
54    /// Create or open a graph storage at the given path
55    ///
56    /// Uses a global connection pool to allow multiple GraphStorage
57    /// instances to share the same underlying database file
58    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
59        let path_ref = path.as_ref();
60
61        // Create parent directories if they don't exist
62        if let Some(parent) = path_ref.parent() {
63            if !parent.as_os_str().is_empty() && !parent.exists() {
64                std::fs::create_dir_all(parent)?;
65            }
66        }
67
68        // Convert to absolute path
69        let path_buf = if path_ref.is_absolute() {
70            path_ref.to_path_buf()
71        } else {
72            std::env::current_dir()?.join(path_ref)
73        };
74
75        // SECURITY: Check for path traversal attempts
76        let path_str = path_ref.to_string_lossy();
77        if path_str.contains("..") && !path_ref.is_absolute() {
78            if let Ok(cwd) = std::env::current_dir() {
79                let mut normalized = cwd.clone();
80                for component in path_ref.components() {
81                    match component {
82                        std::path::Component::ParentDir => {
83                            if !normalized.pop() || !normalized.starts_with(&cwd) {
84                                anyhow::bail!("Path traversal attempt detected");
85                            }
86                        }
87                        std::path::Component::Normal(c) => normalized.push(c),
88                        _ => {}
89                    }
90                }
91            }
92        }
93
94        // Check if we already have a Database instance for this path
95        let db = {
96            let mut pool = DB_POOL.lock();
97
98            if let Some(existing_db) = pool.get(&path_buf) {
99                // Reuse existing database connection
100                Arc::clone(existing_db)
101            } else {
102                // Create new database and add to pool
103                let new_db = Arc::new(Database::create(&path_buf)?);
104
105                // Initialize tables
106                let write_txn = new_db.begin_write()?;
107                {
108                    let _ = write_txn.open_table(NODES_TABLE)?;
109                    let _ = write_txn.open_table(EDGES_TABLE)?;
110                    let _ = write_txn.open_table(HYPEREDGES_TABLE)?;
111                    let _ = write_txn.open_table(METADATA_TABLE)?;
112                }
113                write_txn.commit()?;
114
115                pool.insert(path_buf, Arc::clone(&new_db));
116                new_db
117            }
118        };
119
120        Ok(Self { db })
121    }
122
123    // Node operations
124
125    /// Insert a node
126    pub fn insert_node(&self, node: &Node) -> Result<NodeId> {
127        let write_txn = self.db.begin_write()?;
128        {
129            let mut table = write_txn.open_table(NODES_TABLE)?;
130
131            // Serialize node data
132            let node_data = bincode::encode_to_vec(node, config::standard())?;
133            table.insert(node.id.as_str(), node_data.as_slice())?;
134        }
135        write_txn.commit()?;
136
137        Ok(node.id.clone())
138    }
139
140    /// Insert multiple nodes in a batch
141    pub fn insert_nodes_batch(&self, nodes: &[Node]) -> Result<Vec<NodeId>> {
142        let write_txn = self.db.begin_write()?;
143        let mut ids = Vec::with_capacity(nodes.len());
144
145        {
146            let mut table = write_txn.open_table(NODES_TABLE)?;
147
148            for node in nodes {
149                let node_data = bincode::encode_to_vec(node, config::standard())?;
150                table.insert(node.id.as_str(), node_data.as_slice())?;
151                ids.push(node.id.clone());
152            }
153        }
154
155        write_txn.commit()?;
156        Ok(ids)
157    }
158
159    /// Get a node by ID
160    pub fn get_node(&self, id: &str) -> Result<Option<Node>> {
161        let read_txn = self.db.begin_read()?;
162        let table = read_txn.open_table(NODES_TABLE)?;
163
164        let Some(node_data) = table.get(id)? else {
165            return Ok(None);
166        };
167
168        let (node, _): (Node, usize) =
169            bincode::decode_from_slice(node_data.value(), config::standard())?;
170        Ok(Some(node))
171    }
172
173    /// Delete a node by ID
174    pub fn delete_node(&self, id: &str) -> Result<bool> {
175        let write_txn = self.db.begin_write()?;
176        let deleted;
177        {
178            let mut table = write_txn.open_table(NODES_TABLE)?;
179            let result = table.remove(id)?;
180            deleted = result.is_some();
181        }
182        write_txn.commit()?;
183        Ok(deleted)
184    }
185
186    /// Get all node IDs
187    pub fn all_node_ids(&self) -> Result<Vec<NodeId>> {
188        let read_txn = self.db.begin_read()?;
189        let table = read_txn.open_table(NODES_TABLE)?;
190
191        let mut ids = Vec::new();
192        let iter = table.iter()?;
193        for item in iter {
194            let (key, _) = item?;
195            ids.push(key.value().to_string());
196        }
197
198        Ok(ids)
199    }
200
201    // Edge operations
202
203    /// Insert an edge
204    pub fn insert_edge(&self, edge: &Edge) -> Result<EdgeId> {
205        let write_txn = self.db.begin_write()?;
206        {
207            let mut table = write_txn.open_table(EDGES_TABLE)?;
208
209            // Serialize edge data
210            let edge_data = bincode::encode_to_vec(edge, config::standard())?;
211            table.insert(edge.id.as_str(), edge_data.as_slice())?;
212        }
213        write_txn.commit()?;
214
215        Ok(edge.id.clone())
216    }
217
218    /// Insert multiple edges in a batch
219    pub fn insert_edges_batch(&self, edges: &[Edge]) -> Result<Vec<EdgeId>> {
220        let write_txn = self.db.begin_write()?;
221        let mut ids = Vec::with_capacity(edges.len());
222
223        {
224            let mut table = write_txn.open_table(EDGES_TABLE)?;
225
226            for edge in edges {
227                let edge_data = bincode::encode_to_vec(edge, config::standard())?;
228                table.insert(edge.id.as_str(), edge_data.as_slice())?;
229                ids.push(edge.id.clone());
230            }
231        }
232
233        write_txn.commit()?;
234        Ok(ids)
235    }
236
237    /// Get an edge by ID
238    pub fn get_edge(&self, id: &str) -> Result<Option<Edge>> {
239        let read_txn = self.db.begin_read()?;
240        let table = read_txn.open_table(EDGES_TABLE)?;
241
242        let Some(edge_data) = table.get(id)? else {
243            return Ok(None);
244        };
245
246        let (edge, _): (Edge, usize) =
247            bincode::decode_from_slice(edge_data.value(), config::standard())?;
248        Ok(Some(edge))
249    }
250
251    /// Delete an edge by ID
252    pub fn delete_edge(&self, id: &str) -> Result<bool> {
253        let write_txn = self.db.begin_write()?;
254        let deleted;
255        {
256            let mut table = write_txn.open_table(EDGES_TABLE)?;
257            let result = table.remove(id)?;
258            deleted = result.is_some();
259        }
260        write_txn.commit()?;
261        Ok(deleted)
262    }
263
264    pub fn delete_edges_batch(&self, ids: &[impl AsRef<str>]) -> Result<usize> {
265        let write_txn = self.db.begin_write()?;
266        let mut deleted = 0;
267        {
268            let mut table = write_txn.open_table(EDGES_TABLE)?;
269            for id in ids {
270                if table.remove(id.as_ref())?.is_some() {
271                    deleted += 1;
272                }
273            }
274        }
275
276        write_txn.commit()?;
277        Ok(deleted)
278    }
279
280    /// Get all edge IDs
281    pub fn all_edge_ids(&self) -> Result<Vec<EdgeId>> {
282        let read_txn = self.db.begin_read()?;
283        let table = read_txn.open_table(EDGES_TABLE)?;
284
285        let mut ids = Vec::new();
286        let iter = table.iter()?;
287        for item in iter {
288            let (key, _) = item?;
289            ids.push(key.value().to_string());
290        }
291
292        Ok(ids)
293    }
294
295    // Hyperedge operations
296
297    /// Insert a hyperedge
298    pub fn insert_hyperedge(&self, hyperedge: &Hyperedge) -> Result<HyperedgeId> {
299        let write_txn = self.db.begin_write()?;
300        {
301            let mut table = write_txn.open_table(HYPEREDGES_TABLE)?;
302
303            // Serialize hyperedge data
304            let hyperedge_data = bincode::encode_to_vec(hyperedge, config::standard())?;
305            table.insert(hyperedge.id.as_str(), hyperedge_data.as_slice())?;
306        }
307        write_txn.commit()?;
308
309        Ok(hyperedge.id.clone())
310    }
311
312    /// Insert multiple hyperedges in a batch
313    pub fn insert_hyperedges_batch(&self, hyperedges: &[Hyperedge]) -> Result<Vec<HyperedgeId>> {
314        let write_txn = self.db.begin_write()?;
315        let mut ids = Vec::with_capacity(hyperedges.len());
316
317        {
318            let mut table = write_txn.open_table(HYPEREDGES_TABLE)?;
319
320            for hyperedge in hyperedges {
321                let hyperedge_data = bincode::encode_to_vec(hyperedge, config::standard())?;
322                table.insert(hyperedge.id.as_str(), hyperedge_data.as_slice())?;
323                ids.push(hyperedge.id.clone());
324            }
325        }
326
327        write_txn.commit()?;
328        Ok(ids)
329    }
330
331    /// Get a hyperedge by ID
332    pub fn get_hyperedge(&self, id: &str) -> Result<Option<Hyperedge>> {
333        let read_txn = self.db.begin_read()?;
334        let table = read_txn.open_table(HYPEREDGES_TABLE)?;
335
336        let Some(hyperedge_data) = table.get(id)? else {
337            return Ok(None);
338        };
339
340        let (hyperedge, _): (Hyperedge, usize) =
341            bincode::decode_from_slice(hyperedge_data.value(), config::standard())?;
342        Ok(Some(hyperedge))
343    }
344
345    /// Delete a hyperedge by ID
346    pub fn delete_hyperedge(&self, id: &str) -> Result<bool> {
347        let write_txn = self.db.begin_write()?;
348        let deleted;
349        {
350            let mut table = write_txn.open_table(HYPEREDGES_TABLE)?;
351            let result = table.remove(id)?;
352            deleted = result.is_some();
353        }
354        write_txn.commit()?;
355        Ok(deleted)
356    }
357
358    /// Get all hyperedge IDs
359    pub fn all_hyperedge_ids(&self) -> Result<Vec<HyperedgeId>> {
360        let read_txn = self.db.begin_read()?;
361        let table = read_txn.open_table(HYPEREDGES_TABLE)?;
362
363        let mut ids = Vec::new();
364        let iter = table.iter()?;
365        for item in iter {
366            let (key, _) = item?;
367            ids.push(key.value().to_string());
368        }
369
370        Ok(ids)
371    }
372
373    // Metadata operations
374
375    /// Set metadata
376    pub fn set_metadata(&self, key: &str, value: &str) -> Result<()> {
377        let write_txn = self.db.begin_write()?;
378        {
379            let mut table = write_txn.open_table(METADATA_TABLE)?;
380            table.insert(key, value)?;
381        }
382        write_txn.commit()?;
383        Ok(())
384    }
385
386    /// Get metadata
387    pub fn get_metadata(&self, key: &str) -> Result<Option<String>> {
388        let read_txn = self.db.begin_read()?;
389        let table = read_txn.open_table(METADATA_TABLE)?;
390
391        let value = table.get(key)?.map(|v| v.value().to_string());
392        Ok(value)
393    }
394
395    // Statistics
396
397    /// Get the number of nodes
398    pub fn node_count(&self) -> Result<usize> {
399        let read_txn = self.db.begin_read()?;
400        let table = read_txn.open_table(NODES_TABLE)?;
401        Ok(table.iter()?.count())
402    }
403
404    /// Get the number of edges
405    pub fn edge_count(&self) -> Result<usize> {
406        let read_txn = self.db.begin_read()?;
407        let table = read_txn.open_table(EDGES_TABLE)?;
408        Ok(table.iter()?.count())
409    }
410
411    /// Get the number of hyperedges
412    pub fn hyperedge_count(&self) -> Result<usize> {
413        let read_txn = self.db.begin_read()?;
414        let table = read_txn.open_table(HYPEREDGES_TABLE)?;
415        Ok(table.iter()?.count())
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use crate::edge::EdgeBuilder;
423    use crate::hyperedge::HyperedgeBuilder;
424    use crate::node::NodeBuilder;
425    use tempfile::tempdir;
426
427    #[test]
428    fn test_node_storage() -> Result<()> {
429        let dir = tempdir()?;
430        let storage = GraphStorage::new(dir.path().join("test.db"))?;
431
432        let node = NodeBuilder::new()
433            .label("Person")
434            .property("name", "Alice")
435            .build();
436
437        let id = storage.insert_node(&node)?;
438        assert_eq!(id, node.id);
439
440        let retrieved = storage.get_node(&id)?;
441        assert!(retrieved.is_some());
442        let retrieved = retrieved.unwrap();
443        assert_eq!(retrieved.id, node.id);
444        assert!(retrieved.has_label("Person"));
445
446        Ok(())
447    }
448
449    #[test]
450    fn test_edge_storage() -> Result<()> {
451        let dir = tempdir()?;
452        let storage = GraphStorage::new(dir.path().join("test.db"))?;
453
454        let edge = EdgeBuilder::new("n1".to_string(), "n2".to_string(), "KNOWS")
455            .property("since", 2020i64)
456            .build();
457
458        let id = storage.insert_edge(&edge)?;
459        assert_eq!(id, edge.id);
460
461        let retrieved = storage.get_edge(&id)?;
462        assert!(retrieved.is_some());
463
464        Ok(())
465    }
466
467    #[test]
468    fn test_batch_insert() -> Result<()> {
469        let dir = tempdir()?;
470        let storage = GraphStorage::new(dir.path().join("test.db"))?;
471
472        let nodes = vec![
473            NodeBuilder::new().label("Person").build(),
474            NodeBuilder::new().label("Person").build(),
475        ];
476
477        let ids = storage.insert_nodes_batch(&nodes)?;
478        assert_eq!(ids.len(), 2);
479        assert_eq!(storage.node_count()?, 2);
480
481        Ok(())
482    }
483
484    #[test]
485    fn test_hyperedge_storage() -> Result<()> {
486        let dir = tempdir()?;
487        let storage = GraphStorage::new(dir.path().join("test.db"))?;
488
489        let hyperedge = HyperedgeBuilder::new(
490            vec!["n1".to_string(), "n2".to_string(), "n3".to_string()],
491            "MEETING",
492        )
493        .description("Team meeting")
494        .build();
495
496        let id = storage.insert_hyperedge(&hyperedge)?;
497        assert_eq!(id, hyperedge.id);
498
499        let retrieved = storage.get_hyperedge(&id)?;
500        assert!(retrieved.is_some());
501
502        Ok(())
503    }
504}