Skip to main content

shape_vm/mir/
cfg.rs

1//! Control Flow Graph construction and traversal for MIR.
2
3use super::types::{BasicBlockId, MirFunction, TerminatorKind};
4use std::collections::{HashMap, HashSet, VecDeque};
5
6/// A control flow graph for a MIR function.
7/// Provides predecessor/successor queries and traversal orders.
8#[derive(Debug)]
9pub struct ControlFlowGraph {
10    /// Successors of each block.
11    successors: HashMap<BasicBlockId, Vec<BasicBlockId>>,
12    /// Predecessors of each block.
13    predecessors: HashMap<BasicBlockId, Vec<BasicBlockId>>,
14    /// Number of blocks.
15    num_blocks: u32,
16}
17
18impl ControlFlowGraph {
19    /// Build a CFG from a MIR function.
20    pub fn build(mir: &MirFunction) -> Self {
21        let mut successors: HashMap<BasicBlockId, Vec<BasicBlockId>> = HashMap::new();
22        let mut predecessors: HashMap<BasicBlockId, Vec<BasicBlockId>> = HashMap::new();
23
24        for block in &mir.blocks {
25            let succs = Self::terminator_successors(&block.terminator.kind);
26            for &succ in &succs {
27                predecessors.entry(succ).or_default().push(block.id);
28            }
29            successors.insert(block.id, succs);
30        }
31
32        ControlFlowGraph {
33            successors,
34            predecessors,
35            num_blocks: mir.blocks.len() as u32,
36        }
37    }
38
39    /// Get the successors of a block.
40    pub fn successors(&self, block: BasicBlockId) -> &[BasicBlockId] {
41        self.successors.get(&block).map_or(&[], |v| v.as_slice())
42    }
43
44    /// Get the predecessors of a block.
45    pub fn predecessors(&self, block: BasicBlockId) -> &[BasicBlockId] {
46        self.predecessors.get(&block).map_or(&[], |v| v.as_slice())
47    }
48
49    /// Reverse postorder traversal (useful for forward dataflow analysis).
50    pub fn reverse_postorder(&self) -> Vec<BasicBlockId> {
51        let mut visited = HashSet::new();
52        let mut postorder = Vec::new();
53        let entry = BasicBlockId(0);
54
55        self.dfs_postorder(entry, &mut visited, &mut postorder);
56        postorder.reverse();
57        postorder
58    }
59
60    fn dfs_postorder(
61        &self,
62        block: BasicBlockId,
63        visited: &mut HashSet<BasicBlockId>,
64        postorder: &mut Vec<BasicBlockId>,
65    ) {
66        if !visited.insert(block) {
67            return;
68        }
69        for &succ in self.successors(block) {
70            self.dfs_postorder(succ, visited, postorder);
71        }
72        postorder.push(block);
73    }
74
75    /// Compute dominators using the iterative dataflow algorithm.
76    pub fn dominators(&self) -> HashMap<BasicBlockId, BasicBlockId> {
77        let rpo = self.reverse_postorder();
78        let entry = BasicBlockId(0);
79        let mut doms: HashMap<BasicBlockId, BasicBlockId> = HashMap::new();
80        doms.insert(entry, entry);
81
82        let mut changed = true;
83        while changed {
84            changed = false;
85            for &b in &rpo {
86                if b == entry {
87                    continue;
88                }
89                let preds = self.predecessors(b);
90                let mut new_idom = None;
91                for &p in preds {
92                    if doms.contains_key(&p) {
93                        new_idom = Some(match new_idom {
94                            None => p,
95                            Some(current) => self.intersect(current, p, &doms, &rpo),
96                        });
97                    }
98                }
99                if let Some(new_idom) = new_idom {
100                    if doms.get(&b) != Some(&new_idom) {
101                        doms.insert(b, new_idom);
102                        changed = true;
103                    }
104                }
105            }
106        }
107
108        doms
109    }
110
111    fn intersect(
112        &self,
113        mut a: BasicBlockId,
114        mut b: BasicBlockId,
115        doms: &HashMap<BasicBlockId, BasicBlockId>,
116        rpo: &[BasicBlockId],
117    ) -> BasicBlockId {
118        let rpo_index: HashMap<BasicBlockId, usize> =
119            rpo.iter().enumerate().map(|(i, &bb)| (bb, i)).collect();
120        while a != b {
121            while rpo_index.get(&a).copied().unwrap_or(0) > rpo_index.get(&b).copied().unwrap_or(0)
122            {
123                a = *doms.get(&a).unwrap_or(&a);
124            }
125            while rpo_index.get(&b).copied().unwrap_or(0) > rpo_index.get(&a).copied().unwrap_or(0)
126            {
127                b = *doms.get(&b).unwrap_or(&b);
128            }
129        }
130        a
131    }
132
133    /// Check if a block is reachable from the entry.
134    pub fn is_reachable(&self, target: BasicBlockId) -> bool {
135        let mut visited = HashSet::new();
136        let mut queue = VecDeque::new();
137        queue.push_back(BasicBlockId(0));
138        visited.insert(BasicBlockId(0));
139
140        while let Some(block) = queue.pop_front() {
141            if block == target {
142                return true;
143            }
144            for &succ in self.successors(block) {
145                if visited.insert(succ) {
146                    queue.push_back(succ);
147                }
148            }
149        }
150        false
151    }
152
153    fn terminator_successors(kind: &TerminatorKind) -> Vec<BasicBlockId> {
154        match kind {
155            TerminatorKind::Goto(target) => vec![*target],
156            TerminatorKind::SwitchBool {
157                true_bb, false_bb, ..
158            } => vec![*true_bb, *false_bb],
159            TerminatorKind::Call { next, .. } => vec![*next],
160            TerminatorKind::Return | TerminatorKind::Unreachable => vec![],
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::mir::types::*;
169
170    fn span() -> shape_ast::ast::Span {
171        shape_ast::ast::Span { start: 0, end: 1 }
172    }
173
174    fn make_terminator(kind: TerminatorKind) -> super::super::types::Terminator {
175        super::super::types::Terminator { kind, span: span() }
176    }
177
178    #[test]
179    fn test_linear_cfg() {
180        let mir = MirFunction {
181            name: "test".to_string(),
182            blocks: vec![
183                BasicBlock {
184                    id: BasicBlockId(0),
185                    statements: vec![],
186                    terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
187                },
188                BasicBlock {
189                    id: BasicBlockId(1),
190                    statements: vec![],
191                    terminator: make_terminator(TerminatorKind::Return),
192                },
193            ],
194            num_locals: 0,
195            param_slots: vec![],
196            local_types: vec![],
197            span: span(),
198        };
199        let cfg = ControlFlowGraph::build(&mir);
200        assert_eq!(cfg.successors(BasicBlockId(0)), &[BasicBlockId(1)]);
201        assert_eq!(cfg.predecessors(BasicBlockId(1)), &[BasicBlockId(0)]);
202    }
203
204    #[test]
205    fn test_branch_cfg() {
206        let mir = MirFunction {
207            name: "test".to_string(),
208            blocks: vec![
209                BasicBlock {
210                    id: BasicBlockId(0),
211                    statements: vec![],
212                    terminator: make_terminator(TerminatorKind::SwitchBool {
213                        operand: Operand::Constant(MirConstant::Bool(true)),
214                        true_bb: BasicBlockId(1),
215                        false_bb: BasicBlockId(2),
216                    }),
217                },
218                BasicBlock {
219                    id: BasicBlockId(1),
220                    statements: vec![],
221                    terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(3))),
222                },
223                BasicBlock {
224                    id: BasicBlockId(2),
225                    statements: vec![],
226                    terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(3))),
227                },
228                BasicBlock {
229                    id: BasicBlockId(3),
230                    statements: vec![],
231                    terminator: make_terminator(TerminatorKind::Return),
232                },
233            ],
234            num_locals: 0,
235            param_slots: vec![],
236            local_types: vec![],
237            span: span(),
238        };
239        let cfg = ControlFlowGraph::build(&mir);
240        let rpo = cfg.reverse_postorder();
241        assert_eq!(rpo[0], BasicBlockId(0)); // entry first
242        assert!(cfg.is_reachable(BasicBlockId(3)));
243    }
244
245    #[test]
246    fn test_loop_cfg() {
247        let mir = MirFunction {
248            name: "test".to_string(),
249            blocks: vec![
250                BasicBlock {
251                    id: BasicBlockId(0),
252                    statements: vec![],
253                    terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
254                },
255                BasicBlock {
256                    id: BasicBlockId(1),
257                    statements: vec![],
258                    terminator: make_terminator(TerminatorKind::SwitchBool {
259                        operand: Operand::Constant(MirConstant::Bool(true)),
260                        true_bb: BasicBlockId(2),
261                        false_bb: BasicBlockId(3),
262                    }),
263                },
264                BasicBlock {
265                    id: BasicBlockId(2),
266                    statements: vec![],
267                    terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
268                },
269                BasicBlock {
270                    id: BasicBlockId(3),
271                    statements: vec![],
272                    terminator: make_terminator(TerminatorKind::Return),
273                },
274            ],
275            num_locals: 0,
276            param_slots: vec![],
277            local_types: vec![],
278            span: span(),
279        };
280        let cfg = ControlFlowGraph::build(&mir);
281        // Block 1 should have two predecessors: 0 (entry) and 2 (back edge)
282        let preds = cfg.predecessors(BasicBlockId(1));
283        assert_eq!(preds.len(), 2);
284    }
285}