1use std::{collections::{HashMap, HashSet}, hash::Hash};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Default, Serialize, Deserialize)]
6pub struct DirectedGraph<Vertex: Eq + Hash + Clone, Edge: Eq + Hash> {
7 vertices: HashMap<Vertex, usize>,
8 idxs: HashMap<usize, Vertex>,
9
10 connections: HashMap<usize, HashSet<(usize, Edge)>>
11}
12
13impl<Vertex: Eq + Hash + Clone, Edge: Eq + Hash> DirectedGraph<Vertex, Edge> {
14 pub fn new() -> Self {
15 DirectedGraph {
16 vertices: HashMap::new(),
17 idxs: HashMap::new(),
18 connections: HashMap::new()
19 }
20 }
21
22 pub fn contains(&self, v: &Vertex) -> bool {
23 self.vertices.contains_key(v)
24 }
25
26 fn vertex_by_idx(&self, idx: usize) -> &Vertex {
27 return self.idxs.get(&idx).unwrap();
28 }
29
30 fn vertex_idx_unchecked(&self, v: &Vertex) -> usize {
31 return *self.vertices.get(v).unwrap();
32 }
33
34 fn vertex_idx(&mut self, v: Vertex) -> usize {
35 if let Some(idx) = self.vertices.get(&v) {
36 *idx
37
38 } else {
39 let idx = self.vertices.len();
40 self.vertices.insert(v.clone(), idx);
41 self.idxs.insert(idx, v);
42
43 idx
44 }
45 }
46
47 pub fn connections(&self, v: &Vertex) -> Result<HashSet<(&Vertex, &Edge)>, String> {
48 if self.contains(v) {
49 let idx = self.vertex_idx_unchecked(v);
50
51 if let Some(conn) = self.connections.get(&idx) {
52 return Ok(conn.iter().map(|(k, v)| (self.vertex_by_idx(*k), v)).collect());
53 }
54 }
55
56 Err("Vertex is not in the graph".into())
57 }
58
59 pub fn connect(&mut self, from: Vertex, to: Vertex, edge: Edge) {
60 let idx_from = self.vertex_idx(from);
61 let idx_to = self.vertex_idx(to);
62
63 self.connections.entry(idx_from).or_default().insert((idx_to, edge));
64 }
65
66 pub fn dfs<F: FnMut(&Vertex)>(&self, start: &Vertex, mut op: F) {
67 if self.vertices.contains_key(start) {
68 let mut seen: HashSet<usize> = HashSet::new();
69 let mut stack = vec!(self.vertex_idx_unchecked(start));
70
71 while let Some(elem) = stack.pop() {
72
73 seen.insert(elem);
74
75 op(self.vertex_by_idx(elem));
76
77 if self.connections.contains_key(&elem) {
78 for (i, _) in self.connections.get(&elem).unwrap() {
79 if !seen.contains(i) {
80 stack.push(*i);
81 }
82 }
83 }
84 }
85 }
86 }
87
88 pub fn bfs<F: FnMut(&Vertex)>(&self, start: &Vertex, mut op: F) {
89 if self.vertices.contains_key(start) {
90 let mut seen: HashSet<usize> = HashSet::new();
91 let mut layer = vec!(self.vertex_idx_unchecked(start));
92
93 while !layer.is_empty() {
94 let mut new_layer = vec!();
95
96 for elem in &layer {
97 seen.insert(*elem);
98 op(self.vertex_by_idx(*elem));
99
100 if self.connections.contains_key(elem) {
101 for (i, _) in self.connections.get(elem).unwrap() {
102 if !seen.contains(i) {
103 new_layer.push(*i);
104 }
105 }
106 }
107 }
108
109 layer = new_layer;
110 }
111 }
112 }
113
114 pub fn to_dot<F: Fn(&Vertex) -> String>(&self, repr: F) -> String {
115 let mut lines = vec!();
116
117 for (from, dest) in &self.connections {
118 let rf = repr(self.vertex_by_idx(*from));
119 for (to, _) in dest {
120 let rt = repr(self.vertex_by_idx(*to));
121 lines.push(format!("\"{}\" -> \"{}\"", rf, rt));
122 }
123 }
124
125 format!("digraph G {{\n{}\n}}", lines.join("\n"))
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use std::collections::HashSet;
132
133 use super::DirectedGraph;
134
135 #[test]
136 fn graph_construction() {
137 let mut g = DirectedGraph::new();
138
139 g.connect("A", "B", 1);
140 g.connect("A", "C", 2);
141 g.connect("A", "D", 3);
142
143 g.connect("B", "C", 2);
144 g.connect("B", "D", 2);
145 g.connect("B", "C", 3);
146
147 g.connect("D", "D", 4);
148
149 assert!(g.contains(&"A"));
150 assert!(g.contains(&"B"));
151 assert!(g.contains(&"C"));
152 assert!(g.contains(&"D"));
153 assert!(!g.contains(&"E"));
154
155 assert_eq!(
156 g.connections(&"A"),
157 Ok([(&"B", &1), (&"C", &2), (&"D", &3)].iter().cloned().collect())
158 );
159
160 assert_eq!(
161 g.connections(&"B"),
162 Ok([(&"C", &2), (&"D", &2), (&"C", &3)].iter().cloned().collect())
163 );
164
165 assert!(g.connections(&"C").is_err());
166
167 assert_eq!(
168 g.connections(&"D"),
169 Ok([(&"D", &4)].iter().cloned().collect())
170 );
171
172 let mut dfs_nodes = HashSet::new();
173 g.dfs(&"A", |i| { dfs_nodes.insert(*i); });
174
175 assert_eq!(dfs_nodes.len(), 4);
176
177 let mut bfs_nodes = HashSet::new();
178 g.bfs(&"A", |i| { bfs_nodes.insert(*i); });
179
180 assert_eq!(bfs_nodes.len(), 4);
181
182 assert_eq!(g.to_dot(|i| i.to_string()).len(), 90);
183 }
184}