1use super::types::{BasicBlockId, MirFunction, TerminatorKind};
4use std::collections::{HashMap, HashSet, VecDeque};
5
6#[derive(Debug)]
9pub struct ControlFlowGraph {
10 successors: HashMap<BasicBlockId, Vec<BasicBlockId>>,
12 predecessors: HashMap<BasicBlockId, Vec<BasicBlockId>>,
14 num_blocks: u32,
16}
17
18impl ControlFlowGraph {
19 pub fn build(mir: &MirFunction) -> Self {
21 let mut successors: HashMap<BasicBlockId, Vec<BasicBlockId>> = HashMap::new();
22 let mut predecessors: HashMap<BasicBlockId, Vec<BasicBlockId>> = HashMap::new();
23
24 for block in &mir.blocks {
25 let succs = Self::terminator_successors(&block.terminator.kind);
26 for &succ in &succs {
27 predecessors.entry(succ).or_default().push(block.id);
28 }
29 successors.insert(block.id, succs);
30 }
31
32 ControlFlowGraph {
33 successors,
34 predecessors,
35 num_blocks: mir.blocks.len() as u32,
36 }
37 }
38
39 pub fn successors(&self, block: BasicBlockId) -> &[BasicBlockId] {
41 self.successors.get(&block).map_or(&[], |v| v.as_slice())
42 }
43
44 pub fn predecessors(&self, block: BasicBlockId) -> &[BasicBlockId] {
46 self.predecessors.get(&block).map_or(&[], |v| v.as_slice())
47 }
48
49 pub fn reverse_postorder(&self) -> Vec<BasicBlockId> {
51 let mut visited = HashSet::new();
52 let mut postorder = Vec::new();
53 let entry = BasicBlockId(0);
54
55 self.dfs_postorder(entry, &mut visited, &mut postorder);
56 postorder.reverse();
57 postorder
58 }
59
60 fn dfs_postorder(
61 &self,
62 block: BasicBlockId,
63 visited: &mut HashSet<BasicBlockId>,
64 postorder: &mut Vec<BasicBlockId>,
65 ) {
66 if !visited.insert(block) {
67 return;
68 }
69 for &succ in self.successors(block) {
70 self.dfs_postorder(succ, visited, postorder);
71 }
72 postorder.push(block);
73 }
74
75 pub fn dominators(&self) -> HashMap<BasicBlockId, BasicBlockId> {
77 let rpo = self.reverse_postorder();
78 let entry = BasicBlockId(0);
79 let mut doms: HashMap<BasicBlockId, BasicBlockId> = HashMap::new();
80 doms.insert(entry, entry);
81
82 let mut changed = true;
83 while changed {
84 changed = false;
85 for &b in &rpo {
86 if b == entry {
87 continue;
88 }
89 let preds = self.predecessors(b);
90 let mut new_idom = None;
91 for &p in preds {
92 if doms.contains_key(&p) {
93 new_idom = Some(match new_idom {
94 None => p,
95 Some(current) => self.intersect(current, p, &doms, &rpo),
96 });
97 }
98 }
99 if let Some(new_idom) = new_idom {
100 if doms.get(&b) != Some(&new_idom) {
101 doms.insert(b, new_idom);
102 changed = true;
103 }
104 }
105 }
106 }
107
108 doms
109 }
110
111 fn intersect(
112 &self,
113 mut a: BasicBlockId,
114 mut b: BasicBlockId,
115 doms: &HashMap<BasicBlockId, BasicBlockId>,
116 rpo: &[BasicBlockId],
117 ) -> BasicBlockId {
118 let rpo_index: HashMap<BasicBlockId, usize> =
119 rpo.iter().enumerate().map(|(i, &bb)| (bb, i)).collect();
120 while a != b {
121 while rpo_index.get(&a).copied().unwrap_or(0) > rpo_index.get(&b).copied().unwrap_or(0)
122 {
123 a = *doms.get(&a).unwrap_or(&a);
124 }
125 while rpo_index.get(&b).copied().unwrap_or(0) > rpo_index.get(&a).copied().unwrap_or(0)
126 {
127 b = *doms.get(&b).unwrap_or(&b);
128 }
129 }
130 a
131 }
132
133 pub fn is_reachable(&self, target: BasicBlockId) -> bool {
135 let mut visited = HashSet::new();
136 let mut queue = VecDeque::new();
137 queue.push_back(BasicBlockId(0));
138 visited.insert(BasicBlockId(0));
139
140 while let Some(block) = queue.pop_front() {
141 if block == target {
142 return true;
143 }
144 for &succ in self.successors(block) {
145 if visited.insert(succ) {
146 queue.push_back(succ);
147 }
148 }
149 }
150 false
151 }
152
153 fn terminator_successors(kind: &TerminatorKind) -> Vec<BasicBlockId> {
154 match kind {
155 TerminatorKind::Goto(target) => vec![*target],
156 TerminatorKind::SwitchBool {
157 true_bb, false_bb, ..
158 } => vec![*true_bb, *false_bb],
159 TerminatorKind::Call { next, .. } => vec![*next],
160 TerminatorKind::Return | TerminatorKind::Unreachable => vec![],
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::mir::types::*;
169
170 fn span() -> shape_ast::ast::Span {
171 shape_ast::ast::Span { start: 0, end: 1 }
172 }
173
174 fn make_terminator(kind: TerminatorKind) -> super::super::types::Terminator {
175 super::super::types::Terminator { kind, span: span() }
176 }
177
178 #[test]
179 fn test_linear_cfg() {
180 let mir = MirFunction {
181 name: "test".to_string(),
182 blocks: vec![
183 BasicBlock {
184 id: BasicBlockId(0),
185 statements: vec![],
186 terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
187 },
188 BasicBlock {
189 id: BasicBlockId(1),
190 statements: vec![],
191 terminator: make_terminator(TerminatorKind::Return),
192 },
193 ],
194 num_locals: 0,
195 param_slots: vec![],
196 local_types: vec![],
197 span: span(),
198 };
199 let cfg = ControlFlowGraph::build(&mir);
200 assert_eq!(cfg.successors(BasicBlockId(0)), &[BasicBlockId(1)]);
201 assert_eq!(cfg.predecessors(BasicBlockId(1)), &[BasicBlockId(0)]);
202 }
203
204 #[test]
205 fn test_branch_cfg() {
206 let mir = MirFunction {
207 name: "test".to_string(),
208 blocks: vec![
209 BasicBlock {
210 id: BasicBlockId(0),
211 statements: vec![],
212 terminator: make_terminator(TerminatorKind::SwitchBool {
213 operand: Operand::Constant(MirConstant::Bool(true)),
214 true_bb: BasicBlockId(1),
215 false_bb: BasicBlockId(2),
216 }),
217 },
218 BasicBlock {
219 id: BasicBlockId(1),
220 statements: vec![],
221 terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(3))),
222 },
223 BasicBlock {
224 id: BasicBlockId(2),
225 statements: vec![],
226 terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(3))),
227 },
228 BasicBlock {
229 id: BasicBlockId(3),
230 statements: vec![],
231 terminator: make_terminator(TerminatorKind::Return),
232 },
233 ],
234 num_locals: 0,
235 param_slots: vec![],
236 local_types: vec![],
237 span: span(),
238 };
239 let cfg = ControlFlowGraph::build(&mir);
240 let rpo = cfg.reverse_postorder();
241 assert_eq!(rpo[0], BasicBlockId(0)); assert!(cfg.is_reachable(BasicBlockId(3)));
243 }
244
245 #[test]
246 fn test_loop_cfg() {
247 let mir = MirFunction {
248 name: "test".to_string(),
249 blocks: vec![
250 BasicBlock {
251 id: BasicBlockId(0),
252 statements: vec![],
253 terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
254 },
255 BasicBlock {
256 id: BasicBlockId(1),
257 statements: vec![],
258 terminator: make_terminator(TerminatorKind::SwitchBool {
259 operand: Operand::Constant(MirConstant::Bool(true)),
260 true_bb: BasicBlockId(2),
261 false_bb: BasicBlockId(3),
262 }),
263 },
264 BasicBlock {
265 id: BasicBlockId(2),
266 statements: vec![],
267 terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
268 },
269 BasicBlock {
270 id: BasicBlockId(3),
271 statements: vec![],
272 terminator: make_terminator(TerminatorKind::Return),
273 },
274 ],
275 num_locals: 0,
276 param_slots: vec![],
277 local_types: vec![],
278 span: span(),
279 };
280 let cfg = ControlFlowGraph::build(&mir);
281 let preds = cfg.predecessors(BasicBlockId(1));
283 assert_eq!(preds.len(), 2);
284 }
285}