xgraph/graph/algorithms/
search.rs

1//! Graph search algorithms and operations
2//!
3//! This module provides functionality for performing various search operations on graphs,
4//! including path finding, cycle detection, and node existence checks. It is designed to work
5//! with both directed and undirected graphs, offering robust implementations of Depth-First
6//! Search (DFS) and Breadth-First Search (BFS) algorithms.
7//!
8//! # Features
9//! - Path existence checking using DFS
10//! - Shortest path finding using BFS
11//! - Cycle detection tailored for directed and undirected graphs
12//! - Node existence verification
13//!
14//! # Examples
15//!
16//! Basic usage with a directed graph:
17//! ```rust
18//! use xgraph::graph::graph::Graph;
19//! use xgraph::graph::algorithms::search::Search;
20//!
21//! let mut graph = Graph::<u32, &str, &str>::new(true);
22//! let a = graph.add_node("A");
23//! let b = graph.add_node("B");
24//! graph.add_edge(a, b, 1, "link").unwrap();
25//!
26//! assert!(graph.has_path(a, b).unwrap());
27//! assert_eq!(graph.bfs_path(a, b).unwrap(), Some(vec![a, b]));
28//! ```
29
30use crate::graph::graph::Graph;
31use std::collections::{HashMap, HashSet, VecDeque};
32use std::fmt::Debug;
33use std::hash::Hash;
34
35/// Error type for search operation failures.
36///
37/// Represents errors that may occur during graph search operations.
38#[derive(Debug)]
39pub enum SearchError {
40    /// Indicates that a node referenced in the search does not exist in the graph.
41    InvalidNode(usize),
42}
43
44impl std::fmt::Display for SearchError {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            SearchError::InvalidNode(id) => {
48                write!(f, "Invalid node reference: node ID {} not found", id)
49            }
50        }
51    }
52}
53
54impl std::error::Error for SearchError {}
55
56/// Trait for graph search operations.
57///
58/// Provides fundamental algorithms for graph traversal and analysis with error handling.
59///
60/// # Type Parameters
61/// - `W`: Edge weight type, must be copyable, have a default value, and support partial equality.
62/// - `N`: Node data type, must be clonable, equatable, hashable, and debuggable.
63/// - `E`: Edge data type, must be clonable and debuggable.
64///
65/// # Examples
66/// ```rust
67/// use xgraph::graph::graph::Graph;
68/// use xgraph::graph::algorithms::search::Search;
69///
70/// let mut graph = Graph::<u32, u32, ()>::new(true);
71/// let a = graph.add_node(0);
72/// let b = graph.add_node(1);
73/// graph.add_edge(a, b, 1, ()).unwrap();
74///
75/// assert!(graph.has_path(a, b).unwrap());
76/// assert_eq!(graph.bfs_path(a, b).unwrap(), Some(vec![a, b]));
77/// ```
78pub trait Search<W, N, E>
79where
80    W: Copy + Default + PartialEq,
81    N: Clone + Eq + Hash + Debug,
82    E: Clone + Default + Debug,
83{
84    /// Checks if a path exists between two nodes using DFS.
85    ///
86    /// Traverses the graph depth-first to determine if there is a valid path from `start` to `target`.
87    ///
88    /// # Arguments
89    /// - `start`: The ID of the starting node.
90    /// - `target`: The ID of the target node.
91    ///
92    /// # Returns
93    /// - `Ok(bool)`: `true` if a path exists, `false` otherwise.
94    /// - `Err(SearchError)`: If either `start` or `target` is not a valid node.
95    ///
96    /// # Examples
97    /// ```rust
98    /// use xgraph::graph::graph::Graph;
99    /// use xgraph::graph::algorithms::search::Search;
100    ///
101    /// let mut graph = Graph::<u32, u32, ()>::new(true);
102    /// let a = graph.add_node(0);
103    /// let b = graph.add_node(1);
104    /// let c = graph.add_node(2);
105    /// graph.add_edge(a, b, 1, ()).unwrap();
106    /// graph.add_edge(b, c, 1, ()).unwrap();
107    ///
108    /// assert!(graph.has_path(a, c).unwrap());
109    /// assert!(!graph.has_path(c, a).unwrap()); // Directed graph
110    /// ```
111    fn has_path(&self, start: usize, target: usize) -> Result<bool, SearchError>;
112
113    /// Finds the shortest path between nodes using BFS.
114    ///
115    /// Uses breadth-first search to find the shortest unweighted path from `start` to `target`.
116    ///
117    /// # Arguments
118    /// - `start`: The ID of the starting node.
119    /// - `target`: The ID of the target node.
120    ///
121    /// # Returns
122    /// - `Ok(Option<Vec<usize>>)`: A vector of node IDs representing the shortest path, or `None` if no path exists.
123    /// - `Err(SearchError)`: If either `start` or `target` is not a valid node.
124    ///
125    /// # Examples
126    /// ```rust
127    /// use xgraph::graph::graph::Graph;
128    /// use xgraph::graph::algorithms::search::Search;
129    ///
130    /// let mut graph = Graph::<u32, (), ()>::new(true);
131    /// let nodes = (0..4).map(|i| graph.add_node(())).collect::<Vec<_>>();
132    /// graph.add_edge(nodes[0], nodes[1], 1, ()).unwrap();
133    /// graph.add_edge(nodes[1], nodes[2], 1, ()).unwrap();
134    /// graph.add_edge(nodes[0], nodes[3], 1, ()).unwrap();
135    ///
136    /// assert_eq!(
137    ///     graph.bfs_path(nodes[0], nodes[2]).unwrap(),
138    ///     Some(vec![nodes[0], nodes[1], nodes[2]])
139    /// );
140    /// ```
141    fn bfs_path(&self, start: usize, target: usize) -> Result<Option<Vec<usize>>, SearchError>;
142
143    /// Performs a recursive DFS to check for a path between nodes.
144    ///
145    /// Internal helper method for path checking, not typically called directly by users.
146    ///
147    /// # Arguments
148    /// - `current`: The current node ID in the recursion.
149    /// - `target`: The target node ID to find.
150    /// - `visited`: A mutable reference to a set tracking visited nodes.
151    ///
152    /// # Returns
153    /// - `Ok(bool)`: `true` if a path to `target` is found, `false` otherwise.
154    /// - `Err(SearchError)`: If `current` or `target` is not a valid node.
155    fn dfs(
156        &self,
157        current: usize,
158        target: usize,
159        visited: &mut HashSet<usize>,
160    ) -> Result<bool, SearchError>;
161
162    /// Checks if a node exists in the graph.
163    ///
164    /// Verifies whether a node with the given ID is present in the graph.
165    ///
166    /// # Arguments
167    /// - `node`: The ID of the node to check.
168    ///
169    /// # Returns
170    /// - `Ok(bool)`: `true` if the node exists, `false` otherwise.
171    /// - `Err(SearchError)`: Never returned in this implementation, but included for consistency.
172    ///
173    /// # Examples
174    /// ```rust
175    /// use xgraph::graph::graph::Graph;
176    /// use xgraph::graph::algorithms::search::Search;
177    ///
178    /// let mut graph = Graph::<i32, (), ()>::new(false);
179    /// let n0 = graph.add_node(());
180    /// assert!(graph.has_node(n0).unwrap());
181    /// assert!(!graph.has_node(999).unwrap());
182    /// ```
183    fn has_node(&self, node: usize) -> Result<bool, SearchError>;
184
185    /// Detects whether the graph contains any cycles.
186    ///
187    /// Uses DFS-based algorithms tailored to the graph's directionality:
188    /// - Directed graphs: Uses a recursion stack to detect back edges.
189    /// - Undirected graphs: Uses parent tracking to avoid false positives from bidirectional edges.
190    ///
191    /// # Returns
192    /// - `Ok(bool)`: `true` if a cycle is detected, `false` otherwise.
193    /// - `Err(SearchError)`: If an invalid node is encountered during traversal.
194    ///
195    /// # Examples
196    /// ```rust
197    /// use xgraph::graph::graph::Graph;
198    /// use xgraph::graph::algorithms::search::Search;
199    ///
200    /// let mut graph = Graph::<u32, (), ()>::new(true);
201    /// let nodes = (0..3).map(|i| graph.add_node(())).collect::<Vec<_>>();
202    /// graph.add_edge(nodes[0], nodes[1], 1, ()).unwrap();
203    /// graph.add_edge(nodes[1], nodes[2], 1, ()).unwrap();
204    /// graph.add_edge(nodes[2], nodes[0], 1, ()).unwrap();
205    /// assert!(graph.has_cycle().unwrap());
206    /// ```
207    fn has_cycle(&self) -> Result<bool, SearchError>;
208
209    /// Helper method for cycle detection in directed graphs.
210    ///
211    /// Uses DFS with a recursion stack to detect cycles in directed graphs.
212    ///
213    /// # Arguments
214    /// - `node`: The current node ID.
215    /// - `visited`: A mutable reference to a set of visited nodes.
216    /// - `recursion_stack`: A mutable reference to a set tracking the current recursion path.
217    ///
218    /// # Returns
219    /// - `Ok(bool)`: `true` if a cycle is detected, `false` otherwise.
220    /// - `Err(SearchError)`: If `node` is not a valid node.
221    fn has_cycle_directed(
222        &self,
223        node: usize,
224        visited: &mut HashSet<usize>,
225        recursion_stack: &mut HashSet<usize>,
226    ) -> Result<bool, SearchError>;
227
228    /// Helper method for cycle detection in undirected graphs.
229    ///
230    /// Uses DFS with parent tracking to detect cycles in undirected graphs.
231    ///
232    /// # Arguments
233    /// - `node`: The current node ID.
234    /// - `parent`: The ID of the parent node in the DFS tree, or `None` if root.
235    /// - `visited`: A mutable reference to a set of visited nodes.
236    ///
237    /// # Returns
238    /// - `Ok(bool)`: `true` if a cycle is detected, `false` otherwise.
239    /// - `Err(SearchError)`: If `node` is not a valid node.
240    fn has_cycle_undirected(
241        &self,
242        node: usize,
243        parent: Option<usize>,
244        visited: &mut HashSet<usize>,
245    ) -> Result<bool, SearchError>;
246}
247
248impl<W, N, E> Search<W, N, E> for Graph<W, N, E>
249where
250    W: Copy + Default + PartialEq,
251    N: Clone + Eq + Hash + Debug,
252    E: Clone + Default + Debug,
253{
254    fn has_path(&self, start: usize, target: usize) -> Result<bool, SearchError> {
255        if !self.nodes.contains(start) {
256            return Err(SearchError::InvalidNode(start));
257        }
258        if !self.nodes.contains(target) {
259            return Err(SearchError::InvalidNode(target));
260        }
261        let mut visited = HashSet::new();
262        self.dfs(start, target, &mut visited)
263    }
264
265    fn bfs_path(&self, start: usize, target: usize) -> Result<Option<Vec<usize>>, SearchError> {
266        if !self.nodes.contains(start) {
267            return Err(SearchError::InvalidNode(start));
268        }
269        if !self.nodes.contains(target) {
270            return Err(SearchError::InvalidNode(target));
271        }
272
273        let mut visited = HashSet::new();
274        let mut queue = VecDeque::new();
275        let mut parent = HashMap::new();
276
277        queue.push_back(start);
278        visited.insert(start);
279
280        while let Some(current) = queue.pop_front() {
281            if current == target {
282                let mut path = vec![current];
283                let mut node = current;
284                while let Some(&p) = parent.get(&node) {
285                    path.push(p);
286                    node = p;
287                }
288                path.reverse();
289                return Ok(Some(path));
290            }
291
292            for &(neighbor, _) in &self.nodes[current].neighbors {
293                if !visited.contains(&neighbor) {
294                    if !self.nodes.contains(neighbor) {
295                        return Err(SearchError::InvalidNode(neighbor));
296                    }
297                    visited.insert(neighbor);
298                    parent.insert(neighbor, current);
299                    queue.push_back(neighbor);
300                }
301            }
302        }
303        Ok(None)
304    }
305
306    fn dfs(
307        &self,
308        current: usize,
309        target: usize,
310        visited: &mut HashSet<usize>,
311    ) -> Result<bool, SearchError> {
312        if !self.nodes.contains(current) {
313            return Err(SearchError::InvalidNode(current));
314        }
315        if !self.nodes.contains(target) {
316            return Err(SearchError::InvalidNode(target));
317        }
318
319        if current == target {
320            return Ok(true);
321        }
322
323        visited.insert(current);
324
325        for &(neighbor, _) in &self.nodes[current].neighbors {
326            if !visited.contains(&neighbor) {
327                if !self.nodes.contains(neighbor) {
328                    return Err(SearchError::InvalidNode(neighbor));
329                }
330                if self.dfs(neighbor, target, visited)? {
331                    return Ok(true);
332                }
333            }
334        }
335        Ok(false)
336    }
337
338    fn has_node(&self, node: usize) -> Result<bool, SearchError> {
339        Ok(self.nodes.contains(node))
340    }
341
342    fn has_cycle(&self) -> Result<bool, SearchError> {
343        if self.directed {
344            let mut visited = HashSet::new();
345            let mut recursion_stack = HashSet::new();
346
347            for (node_id, _) in self.nodes.iter() {
348                if !visited.contains(&node_id)
349                    && self.has_cycle_directed(node_id, &mut visited, &mut recursion_stack)?
350                {
351                    return Ok(true);
352                }
353            }
354            Ok(false)
355        } else {
356            let mut visited = HashSet::new();
357
358            for (node_id, _) in self.nodes.iter() {
359                if !visited.contains(&node_id)
360                    && self.has_cycle_undirected(node_id, None, &mut visited)?
361                {
362                    return Ok(true);
363                }
364            }
365            Ok(false)
366        }
367    }
368
369    fn has_cycle_directed(
370        &self,
371        node: usize,
372        visited: &mut HashSet<usize>,
373        recursion_stack: &mut HashSet<usize>,
374    ) -> Result<bool, SearchError> {
375        if !self.nodes.contains(node) {
376            return Err(SearchError::InvalidNode(node));
377        }
378
379        if recursion_stack.contains(&node) {
380            return Ok(true);
381        }
382
383        if visited.contains(&node) {
384            return Ok(false);
385        }
386
387        visited.insert(node);
388        recursion_stack.insert(node);
389
390        if let Some(neighbors) = self.nodes.get(node).map(|n| &n.neighbors) {
391            for &(neighbor, _) in neighbors {
392                if !self.nodes.contains(neighbor) {
393                    return Err(SearchError::InvalidNode(neighbor));
394                }
395                if self.has_cycle_directed(neighbor, visited, recursion_stack)? {
396                    return Ok(true);
397                }
398            }
399        }
400
401        recursion_stack.remove(&node);
402        Ok(false)
403    }
404
405    fn has_cycle_undirected(
406        &self,
407        node: usize,
408        parent: Option<usize>,
409        visited: &mut HashSet<usize>,
410    ) -> Result<bool, SearchError> {
411        if !self.nodes.contains(node) {
412            return Err(SearchError::InvalidNode(node));
413        }
414
415        if visited.contains(&node) {
416            return Ok(true);
417        }
418
419        visited.insert(node);
420
421        if let Some(neighbors) = self.nodes.get(node).map(|n| &n.neighbors) {
422            for &(neighbor, _) in neighbors {
423                if Some(neighbor) == parent {
424                    continue;
425                }
426                if !self.nodes.contains(neighbor) {
427                    return Err(SearchError::InvalidNode(neighbor));
428                }
429                if self.has_cycle_undirected(neighbor, Some(node), visited)? {
430                    return Ok(true);
431                }
432            }
433        }
434
435        Ok(false)
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    #[test]
444    fn test_has_path() {
445        let mut graph = Graph::<u32, (), ()>::new(false);
446        let n0 = graph.add_node(());
447        let n1 = graph.add_node(());
448        graph.add_edge(n0, n1, 1, ()).unwrap();
449
450        assert!(graph.has_path(n0, n1).unwrap());
451        assert!(graph.has_path(n1, n0).unwrap());
452        assert!(matches!(
453            graph.has_path(999, n0),
454            Err(SearchError::InvalidNode(999))
455        ));
456    }
457
458    #[test]
459    fn test_bfs_path() {
460        let mut graph = Graph::<u32, (), ()>::new(true);
461        let n0 = graph.add_node(());
462        let n1 = graph.add_node(());
463        let n2 = graph.add_node(());
464        graph.add_edge(n0, n1, 1, ()).unwrap();
465        graph.add_edge(n1, n2, 1, ()).unwrap();
466
467        assert_eq!(graph.bfs_path(n0, n2).unwrap(), Some(vec![n0, n1, n2]));
468        assert!(matches!(
469            graph.bfs_path(999, n0),
470            Err(SearchError::InvalidNode(999))
471        ));
472    }
473
474    #[test]
475    fn test_dfs() {
476        let mut graph = Graph::<u32, (), ()>::new(false);
477        let n0 = graph.add_node(());
478        let n1 = graph.add_node(());
479        let n2 = graph.add_node(());
480        graph.add_edge(n0, n1, 1, ()).unwrap();
481        graph.add_edge(n1, n2, 1, ()).unwrap();
482
483        let mut visited = HashSet::new();
484        assert!(graph.dfs(n0, n2, &mut visited).unwrap());
485        assert!(matches!(
486            graph.dfs(999, n0, &mut visited),
487            Err(SearchError::InvalidNode(999))
488        ));
489    }
490
491    #[test]
492    fn test_invalid_nodes() {
493        let graph = Graph::<u32, (), ()>::new(false);
494        assert!(!graph.has_node(0).unwrap());
495        assert!(matches!(
496            graph.has_path(0, 1),
497            Err(SearchError::InvalidNode(0))
498        ));
499        assert!(matches!(
500            graph.bfs_path(0, 1),
501            Err(SearchError::InvalidNode(0))
502        ));
503    }
504
505    #[test]
506    fn test_cycle_detection_directed() {
507        let mut graph = Graph::<u32, (), ()>::new(true);
508        let n0 = graph.add_node(());
509        let n1 = graph.add_node(());
510        let n2 = graph.add_node(());
511
512        graph.add_edge(n0, n1, 1, ()).unwrap();
513        graph.add_edge(n1, n2, 1, ()).unwrap();
514        graph.add_edge(n2, n0, 1, ()).unwrap();
515
516        assert!(graph.has_cycle().unwrap());
517    }
518
519    #[test]
520    fn test_cycle_detection_undirected() {
521        let mut graph = Graph::<u32, (), ()>::new(false);
522        let n0 = graph.add_node(());
523        let n1 = graph.add_node(());
524        let n2 = graph.add_node(());
525
526        graph.add_edge(n0, n1, 1, ()).unwrap();
527        graph.add_edge(n1, n2, 1, ()).unwrap();
528        graph.add_edge(n2, n0, 1, ()).unwrap();
529
530        assert!(graph.has_cycle().unwrap());
531    }
532
533    #[test]
534    fn test_no_cycle_directed() {
535        let mut graph = Graph::<u32, (), ()>::new(true);
536        let n0 = graph.add_node(());
537        let n1 = graph.add_node(());
538        let n2 = graph.add_node(());
539
540        graph.add_edge(n0, n1, 1, ()).unwrap();
541        graph.add_edge(n1, n2, 1, ()).unwrap();
542
543        assert!(!graph.has_cycle().unwrap());
544    }
545
546    #[test]
547    fn test_no_cycle_undirected() {
548        let mut graph = Graph::<u32, (), ()>::new(false);
549        let n0 = graph.add_node(());
550        let n1 = graph.add_node(());
551        let n2 = graph.add_node(());
552
553        graph.add_edge(n0, n1, 1, ()).unwrap();
554        graph.add_edge(n1, n2, 1, ()).unwrap();
555
556        assert!(!graph.has_cycle().unwrap());
557    }
558}