ryna/structures/
graph.rs

1use std::{collections::{HashMap, HashSet}, hash::Hash};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Default, Serialize, Deserialize)]
6pub struct DirectedGraph<Vertex: Eq + Hash + Clone, Edge: Eq + Hash> {
7    vertices: HashMap<Vertex, usize>,
8    idxs: HashMap<usize, Vertex>,
9
10    connections: HashMap<usize, HashSet<(usize, Edge)>>
11}
12
13impl<Vertex: Eq + Hash + Clone, Edge: Eq + Hash> DirectedGraph<Vertex, Edge> {
14    pub fn new() -> Self {
15        DirectedGraph { 
16            vertices: HashMap::new(), 
17            idxs: HashMap::new(), 
18            connections: HashMap::new() 
19        }
20    }
21
22    pub fn contains(&self, v: &Vertex) -> bool {
23        self.vertices.contains_key(v)
24    }
25
26    fn vertex_by_idx(&self, idx: usize) -> &Vertex {
27        return self.idxs.get(&idx).unwrap();
28    }
29
30    fn vertex_idx_unchecked(&self, v: &Vertex) -> usize {
31        return *self.vertices.get(v).unwrap();
32    }
33
34    fn vertex_idx(&mut self, v: Vertex) -> usize {
35        if let Some(idx) = self.vertices.get(&v) {
36            *idx
37
38        } else {
39            let idx = self.vertices.len();
40            self.vertices.insert(v.clone(), idx);
41            self.idxs.insert(idx, v);
42
43            idx
44        }
45    }
46
47    pub fn connections(&self, v: &Vertex) -> Result<HashSet<(&Vertex, &Edge)>, String> {
48        if self.contains(v) {
49            let idx = self.vertex_idx_unchecked(v);
50
51            if let Some(conn) = self.connections.get(&idx) {
52                return Ok(conn.iter().map(|(k, v)| (self.vertex_by_idx(*k), v)).collect());
53            }
54        }
55        
56        Err("Vertex is not in the graph".into())
57    }
58
59    pub fn connect(&mut self, from: Vertex, to: Vertex, edge: Edge) {
60        let idx_from = self.vertex_idx(from);
61        let idx_to = self.vertex_idx(to);
62
63        self.connections.entry(idx_from).or_default().insert((idx_to, edge));
64    }
65
66    pub fn dfs<F: FnMut(&Vertex)>(&self, start: &Vertex, mut op: F) {
67        if self.vertices.contains_key(start) {
68            let mut seen: HashSet<usize> = HashSet::new();
69            let mut stack = vec!(self.vertex_idx_unchecked(start));
70    
71            while let Some(elem) = stack.pop() {
72                
73                seen.insert(elem);
74    
75                op(self.vertex_by_idx(elem));
76    
77                if self.connections.contains_key(&elem) {
78                    for (i, _) in self.connections.get(&elem).unwrap() {
79                        if !seen.contains(i) {
80                            stack.push(*i);
81                        }
82                    }
83                }
84            }
85        }
86    }
87
88    pub fn bfs<F: FnMut(&Vertex)>(&self, start: &Vertex, mut op: F) {
89        if self.vertices.contains_key(start) {
90            let mut seen: HashSet<usize> = HashSet::new();
91            let mut layer = vec!(self.vertex_idx_unchecked(start));
92
93            while !layer.is_empty() {
94                let mut new_layer = vec!();
95                
96                for elem in &layer {
97                    seen.insert(*elem);
98                    op(self.vertex_by_idx(*elem));
99
100                    if self.connections.contains_key(elem) {
101                        for (i, _) in self.connections.get(elem).unwrap() {
102                            if !seen.contains(i) {
103                                new_layer.push(*i);
104                            }
105                        }
106                    }
107                }
108
109                layer = new_layer;
110            }
111        }
112    }
113
114    pub fn to_dot<F: Fn(&Vertex) -> String>(&self, repr: F) -> String {
115        let mut lines = vec!();
116
117        for (from, dest) in &self.connections {
118            let rf = repr(self.vertex_by_idx(*from));
119            for (to, _) in dest {
120                let rt = repr(self.vertex_by_idx(*to));
121                lines.push(format!("\"{}\" -> \"{}\"", rf, rt));
122            }
123        }
124
125        format!("digraph G {{\n{}\n}}", lines.join("\n"))
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use std::collections::HashSet;
132
133    use super::DirectedGraph;
134
135    #[test]
136    fn graph_construction() {
137        let mut g = DirectedGraph::new();
138
139        g.connect("A", "B", 1);
140        g.connect("A", "C", 2);
141        g.connect("A", "D", 3);
142
143        g.connect("B", "C", 2);
144        g.connect("B", "D", 2);
145        g.connect("B", "C", 3);
146
147        g.connect("D", "D", 4);
148
149        assert!(g.contains(&"A"));
150        assert!(g.contains(&"B"));
151        assert!(g.contains(&"C"));
152        assert!(g.contains(&"D"));
153        assert!(!g.contains(&"E"));
154
155        assert_eq!(
156            g.connections(&"A"),
157            Ok([(&"B", &1), (&"C", &2), (&"D", &3)].iter().cloned().collect())
158        );
159
160        assert_eq!(
161            g.connections(&"B"),
162            Ok([(&"C", &2), (&"D", &2), (&"C", &3)].iter().cloned().collect())
163        );
164
165        assert!(g.connections(&"C").is_err());
166
167        assert_eq!(
168            g.connections(&"D"),
169            Ok([(&"D", &4)].iter().cloned().collect())
170        );
171
172        let mut dfs_nodes = HashSet::new();
173        g.dfs(&"A", |i| { dfs_nodes.insert(*i); });
174
175        assert_eq!(dfs_nodes.len(), 4);
176
177        let mut bfs_nodes = HashSet::new();
178        g.bfs(&"A", |i| { bfs_nodes.insert(*i); });
179
180        assert_eq!(bfs_nodes.len(), 4);
181
182        assert_eq!(g.to_dot(|i| i.to_string()).len(), 90);
183    }
184}