Skip to main content

velesdb_core/collection/core/
graph_api.rs

1//! Graph API methods for Collection (EPIC-015 US-001).
2//!
3//! Exposes Knowledge Graph operations on Collection for use by
4//! Tauri plugin, REST API, and other consumers.
5
6use crate::collection::graph::{GraphEdge, GraphSchema, TraversalConfig, TraversalResult};
7use crate::collection::types::Collection;
8use crate::error::{Error, Result};
9use crate::index::VectorIndex;
10use crate::point::{Point, SearchResult};
11use crate::storage::{PayloadStorage, VectorStorage};
12
13impl Collection {
14    /// Adds an edge to the collection's knowledge graph.
15    ///
16    /// # Arguments
17    ///
18    /// * `edge` - The edge to add (id, source, target, label, properties)
19    ///
20    /// # Errors
21    ///
22    /// Returns `Error::EdgeExists` if an edge with the same ID already exists.
23    ///
24    /// # Example
25    ///
26    /// ```rust,ignore
27    /// use velesdb_core::collection::graph::GraphEdge;
28    ///
29    /// let edge = GraphEdge::new(1, 100, 200, "KNOWS")?;
30    /// collection.add_edge(edge)?;
31    /// ```
32    pub fn add_edge(&self, edge: GraphEdge) -> Result<()> {
33        self.edge_store.write().add_edge(edge)?;
34        // Bump write generation so any cached plan for this collection is
35        // invalidated on the next query (CACHE-01).
36        self.write_generation
37            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
38        Ok(())
39    }
40
41    /// Gets all edges from the collection's knowledge graph.
42    ///
43    /// Note: This iterates through all stored edges. For large graphs,
44    /// consider using `get_edges_by_label` or `get_outgoing_edges` for
45    /// more targeted queries.
46    ///
47    /// # Returns
48    ///
49    /// Vector of all edges in the graph (cloned).
50    #[must_use]
51    pub fn get_all_edges(&self) -> Vec<GraphEdge> {
52        let store = self.edge_store.read();
53        store.all_edges().into_iter().cloned().collect()
54    }
55
56    /// Gets edges filtered by label.
57    ///
58    /// # Arguments
59    ///
60    /// * `label` - The edge label (relationship type) to filter by
61    ///
62    /// # Returns
63    ///
64    /// Vector of edges with the specified label (cloned).
65    #[must_use]
66    pub fn get_edges_by_label(&self, label: &str) -> Vec<GraphEdge> {
67        self.edge_store
68            .read()
69            .get_edges_by_label(label)
70            .into_iter()
71            .cloned()
72            .collect()
73    }
74
75    /// Gets outgoing edges from a specific node.
76    ///
77    /// # Arguments
78    ///
79    /// * `node_id` - The source node ID
80    ///
81    /// # Returns
82    ///
83    /// Vector of edges originating from the specified node (cloned).
84    #[must_use]
85    pub fn get_outgoing_edges(&self, node_id: u64) -> Vec<GraphEdge> {
86        self.edge_store
87            .read()
88            .get_outgoing(node_id)
89            .into_iter()
90            .cloned()
91            .collect()
92    }
93
94    /// Gets incoming edges to a specific node.
95    ///
96    /// # Arguments
97    ///
98    /// * `node_id` - The target node ID
99    ///
100    /// # Returns
101    ///
102    /// Vector of edges pointing to the specified node (cloned).
103    #[must_use]
104    pub fn get_incoming_edges(&self, node_id: u64) -> Vec<GraphEdge> {
105        self.edge_store
106            .read()
107            .get_incoming(node_id)
108            .into_iter()
109            .cloned()
110            .collect()
111    }
112
113    /// Traverses the graph using BFS from a source node.
114    ///
115    /// # Arguments
116    ///
117    /// * `source` - Starting node ID
118    /// * `max_depth` - Maximum traversal depth
119    /// * `rel_types` - Optional filter by relationship types
120    /// * `limit` - Maximum number of results
121    ///
122    /// # Returns
123    ///
124    /// Vector of traversal results with target nodes and paths.
125    ///
126    /// # Errors
127    ///
128    /// Returns an error if traversal fails.
129    pub fn traverse_bfs(
130        &self,
131        source: u64,
132        max_depth: u32,
133        rel_types: Option<&[&str]>,
134        limit: usize,
135    ) -> Result<Vec<TraversalResult>> {
136        use std::collections::{HashSet, VecDeque};
137
138        let store = self.edge_store.read();
139        let mut visited = HashSet::new();
140        let mut queue = VecDeque::new();
141        let mut results = Vec::new();
142
143        visited.insert(source);
144        queue.push_back((source, 0u32, Vec::new()));
145
146        while let Some((node, depth, path)) = queue.pop_front() {
147            if results.len() >= limit {
148                break;
149            }
150
151            if depth >= max_depth {
152                continue;
153            }
154
155            for edge in store.get_outgoing(node) {
156                // Filter by relationship type if specified
157                if let Some(types) = rel_types {
158                    if !types.contains(&edge.label()) {
159                        continue;
160                    }
161                }
162
163                let target = edge.target();
164                if !visited.contains(&target) {
165                    visited.insert(target);
166                    let mut new_path = path.clone();
167                    new_path.push(edge.id());
168
169                    results.push(TraversalResult {
170                        target_id: target,
171                        depth: depth + 1,
172                        path: new_path.clone(),
173                    });
174
175                    if results.len() < limit {
176                        queue.push_back((target, depth + 1, new_path));
177                    }
178                }
179            }
180        }
181
182        Ok(results)
183    }
184
185    /// Traverses the graph using DFS from a source node.
186    ///
187    /// # Arguments
188    ///
189    /// * `source` - Starting node ID
190    /// * `max_depth` - Maximum traversal depth
191    /// * `rel_types` - Optional filter by relationship types
192    /// * `limit` - Maximum number of results
193    ///
194    /// # Returns
195    ///
196    /// Vector of traversal results with target nodes and paths.
197    ///
198    /// # Errors
199    ///
200    /// Returns an error if traversal fails.
201    pub fn traverse_dfs(
202        &self,
203        source: u64,
204        max_depth: u32,
205        rel_types: Option<&[&str]>,
206        limit: usize,
207    ) -> Result<Vec<TraversalResult>> {
208        use std::collections::HashSet;
209
210        let store = self.edge_store.read();
211        let mut visited = HashSet::new();
212        let mut stack = Vec::new();
213        let mut results = Vec::new();
214
215        visited.insert(source);
216        stack.push((source, 0u32, Vec::new()));
217
218        while let Some((node, depth, path)) = stack.pop() {
219            if results.len() >= limit {
220                break;
221            }
222
223            if depth >= max_depth {
224                continue;
225            }
226
227            for edge in store.get_outgoing(node) {
228                // Filter by relationship type if specified
229                if let Some(types) = rel_types {
230                    if !types.contains(&edge.label()) {
231                        continue;
232                    }
233                }
234
235                let target = edge.target();
236                if !visited.contains(&target) {
237                    visited.insert(target);
238                    let mut new_path = path.clone();
239                    new_path.push(edge.id());
240
241                    results.push(TraversalResult {
242                        target_id: target,
243                        depth: depth + 1,
244                        path: new_path.clone(),
245                    });
246
247                    if results.len() < limit {
248                        stack.push((target, depth + 1, new_path));
249                    }
250                }
251            }
252        }
253
254        Ok(results)
255    }
256
257    /// Gets the in-degree and out-degree of a node.
258    ///
259    /// # Arguments
260    ///
261    /// * `node_id` - The node ID
262    ///
263    /// # Returns
264    ///
265    /// Tuple of (`in_degree`, `out_degree`).
266    #[must_use]
267    pub fn get_node_degree(&self, node_id: u64) -> (usize, usize) {
268        let store = self.edge_store.read();
269        let in_degree = store.get_incoming(node_id).len();
270        let out_degree = store.get_outgoing(node_id).len();
271        (in_degree, out_degree)
272    }
273
274    /// Removes an edge from the graph by ID.
275    ///
276    /// # Arguments
277    ///
278    /// * `edge_id` - The edge ID to remove
279    ///
280    /// # Returns
281    ///
282    /// `true` if the edge existed and was removed, `false` if it didn't exist.
283    #[must_use]
284    pub fn remove_edge(&self, edge_id: u64) -> bool {
285        let mut store = self.edge_store.write();
286        if store.contains_edge(edge_id) {
287            store.remove_edge(edge_id);
288            // Bump only when a mutation actually occurred (CACHE-01).
289            // Releasing the write lock before the atomic bump is intentional:
290            // the bump is a best-effort cache invalidation hint, not part of
291            // the edge-store transaction.
292            drop(store);
293            self.write_generation
294                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
295            true
296        } else {
297            false
298        }
299    }
300
301    /// Returns the total number of edges in the graph.
302    #[must_use]
303    pub fn edge_count(&self) -> usize {
304        self.edge_store.read().len()
305    }
306
307    // -------------------------------------------------------------------------
308    // Graph schema
309    // -------------------------------------------------------------------------
310
311    /// Returns the graph schema stored in the collection config, if any.
312    #[must_use]
313    pub fn graph_schema(&self) -> Option<GraphSchema> {
314        self.config.read().graph_schema.clone()
315    }
316
317    /// Returns `true` if this collection was created as a graph collection.
318    #[must_use]
319    pub fn is_graph(&self) -> bool {
320        self.config.read().graph_schema.is_some()
321    }
322
323    /// Returns `true` if this graph collection stores node embeddings.
324    #[must_use]
325    pub fn has_embeddings(&self) -> bool {
326        self.config.read().embedding_dimension.is_some()
327    }
328
329    // -------------------------------------------------------------------------
330    // Node payload (graph node properties)
331    // -------------------------------------------------------------------------
332
333    /// Stores a JSON payload for a graph node.
334    ///
335    /// # Errors
336    ///
337    /// Returns an error if storage fails.
338    pub fn store_node_payload(&self, node_id: u64, payload: &serde_json::Value) -> Result<()> {
339        let mut storage = self.payload_storage.write();
340        storage.store(node_id, payload).map_err(Error::Io)?;
341        // Bump write generation so any cached plan for this collection is
342        // invalidated on the next query (CACHE-01).
343        self.write_generation
344            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
345        Ok(())
346    }
347
348    /// Retrieves the JSON payload for a graph node.
349    ///
350    /// # Errors
351    ///
352    /// Returns an error if retrieval fails.
353    pub fn get_node_payload(&self, node_id: u64) -> Result<Option<serde_json::Value>> {
354        self.payload_storage
355            .read()
356            .retrieve(node_id)
357            .map_err(Error::Io)
358    }
359
360    // -------------------------------------------------------------------------
361    // Graph traversal with TraversalConfig
362    // -------------------------------------------------------------------------
363
364    /// BFS traversal using the core `bfs_stream` iterator.
365    #[must_use]
366    pub fn traverse_bfs_config(
367        &self,
368        source_id: u64,
369        config: &TraversalConfig,
370    ) -> Vec<TraversalResult> {
371        use crate::collection::graph::{bfs_stream, StreamingConfig};
372        let store = self.edge_store.read();
373        let streaming = StreamingConfig {
374            max_depth: config.max_depth,
375            rel_types: config.rel_types.clone(),
376            limit: Some(config.limit),
377            max_visited_size: 100_000,
378        };
379        bfs_stream(&store, source_id, streaming)
380            .filter(|result| result.depth >= config.min_depth)
381            .take(config.limit)
382            .collect()
383    }
384
385    /// DFS traversal (iterative) using `TraversalConfig`.
386    #[must_use]
387    pub fn traverse_dfs_config(
388        &self,
389        source_id: u64,
390        config: &TraversalConfig,
391    ) -> Vec<TraversalResult> {
392        use std::collections::HashSet;
393        let store = self.edge_store.read();
394        let rel_filter: HashSet<&str> = config.rel_types.iter().map(String::as_str).collect();
395
396        let mut results = Vec::new();
397        let mut visited: HashSet<u64> = HashSet::new();
398        let mut stack: Vec<(u64, u32, Vec<u64>)> = vec![(source_id, 0, Vec::new())];
399
400        while let Some((node_id, depth, path)) = stack.pop() {
401            if results.len() >= config.limit {
402                break;
403            }
404            if !visited.insert(node_id) {
405                continue;
406            }
407            if depth >= config.min_depth && depth > 0 {
408                results.push(TraversalResult::new(node_id, path.clone(), depth));
409                if results.len() >= config.limit {
410                    break;
411                }
412            }
413            if depth < config.max_depth {
414                for edge in store.get_outgoing(node_id).into_iter().rev() {
415                    if !rel_filter.is_empty() && !rel_filter.contains(edge.label()) {
416                        continue;
417                    }
418                    if visited.contains(&edge.target()) {
419                        continue;
420                    }
421                    let mut new_path = path.clone();
422                    // Use edge IDs in path, consistent with bfs_traverse/bfs_stream.
423                    new_path.push(edge.id());
424                    stack.push((edge.target(), depth + 1, new_path));
425                }
426            }
427        }
428        results
429    }
430
431    // -------------------------------------------------------------------------
432    // Embedding search on graph nodes
433    // -------------------------------------------------------------------------
434
435    /// Searches for similar graph nodes by embedding vector.
436    ///
437    /// Only available if `has_embeddings()` returns `true`.
438    ///
439    /// # Errors
440    ///
441    /// Returns `Error::VectorNotAllowed` if no embeddings are configured,
442    /// or `Error::DimensionMismatch` if the query dimension is wrong.
443    pub fn search_by_embedding(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
444        let config = self.config.read();
445        let emb_dim = config
446            .embedding_dimension
447            .ok_or_else(|| Error::VectorNotAllowed(config.name.clone()))?;
448        drop(config);
449
450        if query.len() != emb_dim {
451            return Err(Error::DimensionMismatch {
452                expected: emb_dim,
453                actual: query.len(),
454            });
455        }
456
457        // Reason: we reuse the existing HNSW index (dimension == emb_dim when created
458        // via create_graph_collection_with_embeddings). For graph-without-embeddings
459        // the HNSW has dimension 0 and the guard above already rejected the call.
460        let metric = self.config.read().metric;
461        let ids = self.index.search(query, k);
462        let ids = self.merge_delta(ids, query, k, metric);
463
464        // Acquire each lock once: collect vector data, then collect payload data.
465        // This avoids holding vector_storage while locking payload_storage per item.
466        let vectors: Vec<(u64, f32, Option<Vec<f32>>)> = {
467            let vector_storage = self.vector_storage.read();
468            ids.into_iter()
469                .map(|(id, score)| {
470                    let vec = vector_storage.retrieve(id).ok().flatten();
471                    (id, score, vec)
472                })
473                .collect()
474        };
475        let results = {
476            let payload_storage = self.payload_storage.read();
477            vectors
478                .into_iter()
479                .filter_map(|(id, score, vector)| {
480                    let vector = vector?;
481                    let payload = payload_storage.retrieve(id).ok().flatten();
482                    Some(SearchResult::new(
483                        Point {
484                            id,
485                            vector,
486                            payload,
487                            sparse_vectors: None,
488                        },
489                        score,
490                    ))
491                })
492                .collect()
493        };
494        Ok(results)
495    }
496}