Skip to main content

wave_compiler/optimize/
simplify_cfg.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! CFG simplification pass.
5//!
6//! Merges basic blocks connected by unconditional branches when the
7//! target has a single predecessor. Removes empty blocks and unreachable
8//! blocks to simplify the control flow graph.
9
10use std::collections::HashSet;
11
12use super::pass::Pass;
13use crate::mir::basic_block::Terminator;
14use crate::mir::function::MirFunction;
15
16/// CFG simplification pass.
17pub struct SimplifyCfg;
18
19impl Pass for SimplifyCfg {
20    fn name(&self) -> &'static str {
21        "simplify_cfg"
22    }
23
24    fn run(&self, func: &mut MirFunction) -> bool {
25        let mut changed = false;
26        changed |= remove_unreachable_blocks(func);
27        changed |= merge_single_predecessor_blocks(func);
28        changed
29    }
30}
31
32fn remove_unreachable_blocks(func: &mut MirFunction) -> bool {
33    let mut reachable = HashSet::new();
34    let mut stack = vec![func.entry];
35    while let Some(bid) = stack.pop() {
36        if !reachable.insert(bid) {
37            continue;
38        }
39        if let Some(block) = func.block(bid) {
40            for succ in block.successors() {
41                stack.push(succ);
42            }
43        }
44    }
45
46    let original_count = func.blocks.len();
47    func.blocks.retain(|b| reachable.contains(&b.id));
48    func.blocks.len() != original_count
49}
50
51fn merge_single_predecessor_blocks(func: &mut MirFunction) -> bool {
52    let preds = func.predecessors();
53    let mut changed = false;
54
55    loop {
56        let mut merge_found = false;
57        for i in 0..func.blocks.len() {
58            let term = func.blocks[i].terminator.clone();
59            if let Terminator::Branch { target } = term {
60                if let Some(pred_list) = preds.get(&target) {
61                    if pred_list.len() == 1 && pred_list[0] == func.blocks[i].id {
62                        if let Some(target_idx) = func.blocks.iter().position(|b| b.id == target) {
63                            if target_idx != i {
64                                let target_block = func.blocks.remove(target_idx);
65                                let adjusted_i = if target_idx < i { i - 1 } else { i };
66                                func.blocks[adjusted_i]
67                                    .instructions
68                                    .extend(target_block.instructions);
69                                func.blocks[adjusted_i].terminator = target_block.terminator;
70                                merge_found = true;
71                                changed = true;
72                                break;
73                            }
74                        }
75                    }
76                }
77            }
78        }
79        if !merge_found {
80            break;
81        }
82    }
83
84    changed
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use crate::mir::basic_block::BasicBlock;
91    use crate::mir::instruction::{ConstValue, MirInst};
92    use crate::mir::value::BlockId;
93    use crate::mir::value::ValueId;
94
95    #[test]
96    fn test_remove_unreachable() {
97        let mut func = MirFunction::new("test".into(), BlockId(0));
98        let bb0 = BasicBlock::new(BlockId(0));
99        let bb1 = BasicBlock::new(BlockId(1));
100        func.blocks.push(bb0);
101        func.blocks.push(bb1);
102
103        let pass = SimplifyCfg;
104        assert!(pass.run(&mut func));
105        assert_eq!(func.blocks.len(), 1);
106        assert_eq!(func.blocks[0].id, BlockId(0));
107    }
108
109    #[test]
110    fn test_merge_blocks() {
111        let mut func = MirFunction::new("test".into(), BlockId(0));
112        let mut bb0 = BasicBlock::new(BlockId(0));
113        bb0.instructions.push(MirInst::Const {
114            dest: ValueId(0),
115            value: ConstValue::I32(1),
116        });
117        bb0.terminator = Terminator::Branch { target: BlockId(1) };
118
119        let mut bb1 = BasicBlock::new(BlockId(1));
120        bb1.instructions.push(MirInst::Const {
121            dest: ValueId(1),
122            value: ConstValue::I32(2),
123        });
124        bb1.terminator = Terminator::Return;
125
126        func.blocks.push(bb0);
127        func.blocks.push(bb1);
128
129        let pass = SimplifyCfg;
130        assert!(pass.run(&mut func));
131        assert_eq!(func.blocks.len(), 1);
132        assert_eq!(func.blocks[0].instructions.len(), 2);
133    }
134}