Skip to main content

shape_vm/mir/
cfg.rs

1//! Control Flow Graph construction and traversal for MIR.
2
3use super::types::{BasicBlockId, MirFunction, TerminatorKind};
4use std::collections::{HashMap, HashSet, VecDeque};
5
6/// A control flow graph for a MIR function.
7/// Provides predecessor/successor queries and traversal orders.
8#[derive(Debug)]
9pub struct ControlFlowGraph {
10    /// Successors of each block.
11    successors: HashMap<BasicBlockId, Vec<BasicBlockId>>,
12    /// Predecessors of each block.
13    predecessors: HashMap<BasicBlockId, Vec<BasicBlockId>>,
14}
15
16impl ControlFlowGraph {
17    /// Build a CFG from a MIR function.
18    pub fn build(mir: &MirFunction) -> Self {
19        let mut successors: HashMap<BasicBlockId, Vec<BasicBlockId>> = HashMap::new();
20        let mut predecessors: HashMap<BasicBlockId, Vec<BasicBlockId>> = HashMap::new();
21
22        for block in &mir.blocks {
23            let succs = Self::terminator_successors(&block.terminator.kind);
24            for &succ in &succs {
25                predecessors.entry(succ).or_default().push(block.id);
26            }
27            successors.insert(block.id, succs);
28        }
29
30        ControlFlowGraph {
31            successors,
32            predecessors,
33        }
34    }
35
36    /// Get the successors of a block.
37    pub fn successors(&self, block: BasicBlockId) -> &[BasicBlockId] {
38        self.successors.get(&block).map_or(&[], |v| v.as_slice())
39    }
40
41    /// Get the predecessors of a block.
42    pub fn predecessors(&self, block: BasicBlockId) -> &[BasicBlockId] {
43        self.predecessors.get(&block).map_or(&[], |v| v.as_slice())
44    }
45
46    /// Reverse postorder traversal (useful for forward dataflow analysis).
47    pub fn reverse_postorder(&self) -> Vec<BasicBlockId> {
48        let mut visited = HashSet::new();
49        let mut postorder = Vec::new();
50        let entry = BasicBlockId(0);
51
52        self.dfs_postorder(entry, &mut visited, &mut postorder);
53        postorder.reverse();
54        postorder
55    }
56
57    fn dfs_postorder(
58        &self,
59        block: BasicBlockId,
60        visited: &mut HashSet<BasicBlockId>,
61        postorder: &mut Vec<BasicBlockId>,
62    ) {
63        if !visited.insert(block) {
64            return;
65        }
66        for &succ in self.successors(block) {
67            self.dfs_postorder(succ, visited, postorder);
68        }
69        postorder.push(block);
70    }
71
72    /// Compute dominators using the iterative dataflow algorithm.
73    pub fn dominators(&self) -> HashMap<BasicBlockId, BasicBlockId> {
74        let rpo = self.reverse_postorder();
75        let entry = BasicBlockId(0);
76        let mut doms: HashMap<BasicBlockId, BasicBlockId> = HashMap::new();
77        doms.insert(entry, entry);
78
79        let mut changed = true;
80        while changed {
81            changed = false;
82            for &b in &rpo {
83                if b == entry {
84                    continue;
85                }
86                let preds = self.predecessors(b);
87                let mut new_idom = None;
88                for &p in preds {
89                    if doms.contains_key(&p) {
90                        new_idom = Some(match new_idom {
91                            None => p,
92                            Some(current) => self.intersect(current, p, &doms, &rpo),
93                        });
94                    }
95                }
96                if let Some(new_idom) = new_idom {
97                    if doms.get(&b) != Some(&new_idom) {
98                        doms.insert(b, new_idom);
99                        changed = true;
100                    }
101                }
102            }
103        }
104
105        doms
106    }
107
108    fn intersect(
109        &self,
110        mut a: BasicBlockId,
111        mut b: BasicBlockId,
112        doms: &HashMap<BasicBlockId, BasicBlockId>,
113        rpo: &[BasicBlockId],
114    ) -> BasicBlockId {
115        let rpo_index: HashMap<BasicBlockId, usize> =
116            rpo.iter().enumerate().map(|(i, &bb)| (bb, i)).collect();
117        while a != b {
118            while rpo_index.get(&a).copied().unwrap_or(0) > rpo_index.get(&b).copied().unwrap_or(0)
119            {
120                a = *doms.get(&a).unwrap_or(&a);
121            }
122            while rpo_index.get(&b).copied().unwrap_or(0) > rpo_index.get(&a).copied().unwrap_or(0)
123            {
124                b = *doms.get(&b).unwrap_or(&b);
125            }
126        }
127        a
128    }
129
130    /// Check if a block is reachable from the entry.
131    pub fn is_reachable(&self, target: BasicBlockId) -> bool {
132        let mut visited = HashSet::new();
133        let mut queue = VecDeque::new();
134        queue.push_back(BasicBlockId(0));
135        visited.insert(BasicBlockId(0));
136
137        while let Some(block) = queue.pop_front() {
138            if block == target {
139                return true;
140            }
141            for &succ in self.successors(block) {
142                if visited.insert(succ) {
143                    queue.push_back(succ);
144                }
145            }
146        }
147        false
148    }
149
150    fn terminator_successors(kind: &TerminatorKind) -> Vec<BasicBlockId> {
151        match kind {
152            TerminatorKind::Goto(target) => vec![*target],
153            TerminatorKind::SwitchBool {
154                true_bb, false_bb, ..
155            } => vec![*true_bb, *false_bb],
156            TerminatorKind::Call { next, .. } => vec![*next],
157            TerminatorKind::Return | TerminatorKind::Unreachable => vec![],
158        }
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::mir::types::*;
166
167    fn span() -> shape_ast::ast::Span {
168        shape_ast::ast::Span { start: 0, end: 1 }
169    }
170
171    fn make_terminator(kind: TerminatorKind) -> super::super::types::Terminator {
172        super::super::types::Terminator { kind, span: span() }
173    }
174
175    #[test]
176    fn test_linear_cfg() {
177        let mir = MirFunction {
178            name: "test".to_string(),
179            blocks: vec![
180                BasicBlock {
181                    id: BasicBlockId(0),
182                    statements: vec![],
183                    terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
184                },
185                BasicBlock {
186                    id: BasicBlockId(1),
187                    statements: vec![],
188                    terminator: make_terminator(TerminatorKind::Return),
189                },
190            ],
191            num_locals: 0,
192            param_slots: vec![],
193            param_reference_kinds: vec![],
194            local_types: vec![],
195            span: span(),
196        };
197        let cfg = ControlFlowGraph::build(&mir);
198        assert_eq!(cfg.successors(BasicBlockId(0)), &[BasicBlockId(1)]);
199        assert_eq!(cfg.predecessors(BasicBlockId(1)), &[BasicBlockId(0)]);
200    }
201
202    #[test]
203    fn test_branch_cfg() {
204        let mir = MirFunction {
205            name: "test".to_string(),
206            blocks: vec![
207                BasicBlock {
208                    id: BasicBlockId(0),
209                    statements: vec![],
210                    terminator: make_terminator(TerminatorKind::SwitchBool {
211                        operand: Operand::Constant(MirConstant::Bool(true)),
212                        true_bb: BasicBlockId(1),
213                        false_bb: BasicBlockId(2),
214                    }),
215                },
216                BasicBlock {
217                    id: BasicBlockId(1),
218                    statements: vec![],
219                    terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(3))),
220                },
221                BasicBlock {
222                    id: BasicBlockId(2),
223                    statements: vec![],
224                    terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(3))),
225                },
226                BasicBlock {
227                    id: BasicBlockId(3),
228                    statements: vec![],
229                    terminator: make_terminator(TerminatorKind::Return),
230                },
231            ],
232            num_locals: 0,
233            param_slots: vec![],
234            param_reference_kinds: vec![],
235            local_types: vec![],
236            span: span(),
237        };
238        let cfg = ControlFlowGraph::build(&mir);
239        let rpo = cfg.reverse_postorder();
240        assert_eq!(rpo[0], BasicBlockId(0)); // entry first
241        assert!(cfg.is_reachable(BasicBlockId(3)));
242    }
243
244    #[test]
245    fn test_loop_cfg() {
246        let mir = MirFunction {
247            name: "test".to_string(),
248            blocks: vec![
249                BasicBlock {
250                    id: BasicBlockId(0),
251                    statements: vec![],
252                    terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
253                },
254                BasicBlock {
255                    id: BasicBlockId(1),
256                    statements: vec![],
257                    terminator: make_terminator(TerminatorKind::SwitchBool {
258                        operand: Operand::Constant(MirConstant::Bool(true)),
259                        true_bb: BasicBlockId(2),
260                        false_bb: BasicBlockId(3),
261                    }),
262                },
263                BasicBlock {
264                    id: BasicBlockId(2),
265                    statements: vec![],
266                    terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
267                },
268                BasicBlock {
269                    id: BasicBlockId(3),
270                    statements: vec![],
271                    terminator: make_terminator(TerminatorKind::Return),
272                },
273            ],
274            num_locals: 0,
275            param_slots: vec![],
276            param_reference_kinds: vec![],
277            local_types: vec![],
278            span: span(),
279        };
280        let cfg = ControlFlowGraph::build(&mir);
281        // Block 1 should have two predecessors: 0 (entry) and 2 (back edge)
282        let preds = cfg.predecessors(BasicBlockId(1));
283        assert_eq!(preds.len(), 2);
284    }
285}