ss_graph_rs/
graph.rs

1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::hash::Hash;
4
5#[derive(Clone, PartialEq, Eq)]
6pub struct Graph<T: Eq + Hash + Clone> {
7    is_directed: bool,
8    adjacency_list: HashMap<T, HashSet<T>>,
9}
10
11impl<T> Graph<T>
12where
13    T: Eq + Hash + Clone,
14{
15    pub fn new(is_directed: Option<bool>) -> Self {
16        let is_directed = is_directed.unwrap_or(false); // default to undirected graph
17        Graph {
18            is_directed,
19            adjacency_list: HashMap::new(),
20        }
21    }
22
23    // Add a node to the graph.
24    pub fn add_edge(&mut self, vector_x: T, vector_y: T) {
25        self.adjacency_list
26            .entry(vector_x.clone())
27            .or_insert(HashSet::new())
28            .insert(vector_y.clone());
29
30        if !self.is_directed {
31            self.adjacency_list
32                .entry(vector_y.clone())
33                .or_insert(HashSet::new())
34                .insert(vector_x.clone());
35        }
36    }
37
38    // Use depth-first search to find all paths between two nodes
39    pub fn find_all_paths(&self, start: T, end: T) -> Vec<Vec<T>> {
40        let mut paths = Vec::new();
41        let mut visited = HashSet::new();
42        let mut path = Vec::new();
43
44        path.push(start.clone());
45        visited.insert(start.clone());
46
47        self.find_all_paths_helper(start, end, &mut visited, &mut path, &mut paths);
48
49        paths
50    }
51
52    // Helper function for find_all_paths. This function is recursive.
53    fn find_all_paths_helper(
54        &self,
55        start: T,
56        end: T,
57        visited: &mut HashSet<T>,
58        path: &mut Vec<T>,
59        paths: &mut Vec<Vec<T>>,
60    ) {
61        if start == end {
62            paths.push(path.clone());
63            return;
64        }
65
66        if let Some(neighbors) = self.adjacency_list.get(&start) {
67            for neighbor in neighbors {
68                if !visited.contains(neighbor) {
69                    visited.insert(neighbor.clone());
70                    path.push(neighbor.clone());
71                    self.find_all_paths_helper(neighbor.clone(), end.clone(), visited, path, paths);
72                    path.pop();
73                    visited.remove(neighbor);
74                }
75            }
76        }
77    }
78
79    // Use depth-first search to find all paths between two nodes with max steps limit.
80    pub fn find_paths_with_max_steps(&self, start: T, end: T, max_steps: usize) -> Vec<Vec<T>> {
81        let mut paths = Vec::new();
82        let mut visited = HashSet::new();
83        let mut path = Vec::new();
84
85        path.push(start.clone());
86        visited.insert(start.clone());
87
88        self.find_paths_with_max_steps_helper(
89            start,
90            end.clone(),
91            max_steps,
92            &mut visited,
93            &mut path,
94            &mut paths,
95        );
96
97        paths
98    }
99
100    // Helper function for find_paths_with_max_steps. This function is recursive.
101    fn find_paths_with_max_steps_helper(
102        &self,
103        start: T,
104        end: T,
105        max_steps: usize,
106        visited: &mut HashSet<T>,
107        path: &mut Vec<T>,
108        paths: &mut Vec<Vec<T>>,
109    ) {
110        if start == end {
111            paths.push(path.clone());
112            return;
113        }
114
115        if path.len() >= max_steps {
116            return;
117        }
118
119        if let Some(neighbors) = self.adjacency_list.get(&start) {
120            for neighbor in neighbors {
121                if !visited.contains(neighbor) {
122                    visited.insert(neighbor.clone());
123                    path.push(neighbor.clone());
124                    self.find_paths_with_max_steps_helper(
125                        neighbor.clone(),
126                        end.clone(),
127                        max_steps,
128                        visited,
129                        path,
130                        paths,
131                    );
132                    path.pop();
133                    visited.remove(neighbor);
134                }
135            }
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    // Test the Graph struct with integers.
145    #[test]
146    fn test_add_edge() {
147        let mut graph = Graph::new(Some(false));
148        graph.add_edge(1, 2);
149        assert!(graph.adjacency_list.get(&1).unwrap().contains(&2));
150        assert!(graph.adjacency_list.get(&2).is_some());
151
152        let mut graph = Graph::new(Some(true));
153        graph.add_edge(1, 2);
154        assert!(graph.adjacency_list.get(&1).unwrap().contains(&2));
155        assert!(!graph.adjacency_list.get(&2).is_some());
156    }
157
158    #[test]
159    fn test_find_all_paths() {
160        let mut graph = Graph::new(Some(false));
161        graph.add_edge(1, 2);
162        graph.add_edge(2, 3);
163        graph.add_edge(3, 4);
164        let paths = graph.find_all_paths(1, 4);
165        assert_eq!(paths, vec![vec![1, 2, 3, 4]]);
166    }
167
168    #[test]
169    fn test_find_paths_with_max_steps() {
170        let mut graph = Graph::new(None);
171        graph.add_edge(1, 2);
172        graph.add_edge(2, 3);
173        graph.add_edge(3, 4);
174        let paths = graph.find_paths_with_max_steps(1, 4, 3);
175        assert_eq!(paths, Vec::<Vec<i64>>::new());
176        let paths = graph.find_paths_with_max_steps(1, 4, 4);
177        assert_eq!(paths, vec![vec![1, 2, 3, 4]]);
178    }
179
180    // Test the Graph struct with strings.
181    #[test]
182    fn test_add_edge_string() {
183        let mut graph = Graph::new(Some(false));
184        graph.add_edge("Node1".to_string(), "Node2".to_string());
185        assert!(graph
186            .adjacency_list
187            .get(&"Node1".to_string())
188            .unwrap()
189            .contains(&"Node2".to_string()));
190        assert!(graph
191            .adjacency_list
192            .get(&"Node2".to_string())
193            .unwrap()
194            .contains(&"Node1".to_string()));
195
196        let mut graph = Graph::new(Some(true));
197        graph.add_edge("Node1".to_string(), "Node2".to_string());
198        assert!(graph
199            .adjacency_list
200            .get(&"Node1".to_string())
201            .unwrap()
202            .contains(&"Node2".to_string()));
203        assert!(!graph.adjacency_list.get(&"Node2".to_string()).is_some());
204    }
205
206    #[test]
207    fn test_find_all_paths_string() {
208        let mut graph = Graph::new(Some(false));
209        graph.add_edge("Node1".to_string(), "Node2".to_string());
210        graph.add_edge("Node2".to_string(), "Node3".to_string());
211        graph.add_edge("Node3".to_string(), "Node4".to_string());
212        let paths = graph.find_all_paths("Node1".to_string(), "Node4".to_string());
213        assert_eq!(
214            paths,
215            vec![vec!["Node1", "Node2", "Node3", "Node4"]
216                .into_iter()
217                .map(|s| s.to_string())
218                .collect::<Vec<_>>()]
219        );
220    }
221
222    #[test]
223    fn test_find_paths_with_max_steps_string() {
224        let mut graph = Graph::new(Some(false));
225        graph.add_edge("Node1".to_string(), "Node2".to_string());
226        graph.add_edge("Node2".to_string(), "Node3".to_string());
227        graph.add_edge("Node3".to_string(), "Node4".to_string());
228        let paths = graph.find_paths_with_max_steps("Node1".to_string(), "Node4".to_string(), 3);
229        assert_eq!(paths, Vec::<Vec<String>>::new());
230        let paths = graph.find_paths_with_max_steps("Node1".to_string(), "Node4".to_string(), 4);
231        assert_eq!(
232            paths,
233            vec![vec!["Node1", "Node2", "Node3", "Node4"]
234                .into_iter()
235                .map(|s| s.to_string())
236                .collect::<Vec<_>>()]
237        );
238    }
239}