wave_compiler/optimize/
simplify_cfg.rs1use std::collections::HashSet;
11
12use super::pass::Pass;
13use crate::mir::basic_block::Terminator;
14use crate::mir::function::MirFunction;
15
16pub struct SimplifyCfg;
18
19impl Pass for SimplifyCfg {
20 fn name(&self) -> &'static str {
21 "simplify_cfg"
22 }
23
24 fn run(&self, func: &mut MirFunction) -> bool {
25 let mut changed = false;
26 changed |= remove_unreachable_blocks(func);
27 changed |= merge_single_predecessor_blocks(func);
28 changed
29 }
30}
31
32fn remove_unreachable_blocks(func: &mut MirFunction) -> bool {
33 let mut reachable = HashSet::new();
34 let mut stack = vec![func.entry];
35 while let Some(bid) = stack.pop() {
36 if !reachable.insert(bid) {
37 continue;
38 }
39 if let Some(block) = func.block(bid) {
40 for succ in block.successors() {
41 stack.push(succ);
42 }
43 }
44 }
45
46 let original_count = func.blocks.len();
47 func.blocks.retain(|b| reachable.contains(&b.id));
48 func.blocks.len() != original_count
49}
50
51fn merge_single_predecessor_blocks(func: &mut MirFunction) -> bool {
52 let preds = func.predecessors();
53 let mut changed = false;
54
55 loop {
56 let mut merge_found = false;
57 for i in 0..func.blocks.len() {
58 let term = func.blocks[i].terminator.clone();
59 if let Terminator::Branch { target } = term {
60 if let Some(pred_list) = preds.get(&target) {
61 if pred_list.len() == 1 && pred_list[0] == func.blocks[i].id {
62 if let Some(target_idx) = func.blocks.iter().position(|b| b.id == target) {
63 if target_idx != i {
64 let target_block = func.blocks.remove(target_idx);
65 let adjusted_i = if target_idx < i { i - 1 } else { i };
66 func.blocks[adjusted_i]
67 .instructions
68 .extend(target_block.instructions);
69 func.blocks[adjusted_i].terminator = target_block.terminator;
70 merge_found = true;
71 changed = true;
72 break;
73 }
74 }
75 }
76 }
77 }
78 }
79 if !merge_found {
80 break;
81 }
82 }
83
84 changed
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use crate::mir::basic_block::BasicBlock;
91 use crate::mir::instruction::{ConstValue, MirInst};
92 use crate::mir::value::BlockId;
93 use crate::mir::value::ValueId;
94
95 #[test]
96 fn test_remove_unreachable() {
97 let mut func = MirFunction::new("test".into(), BlockId(0));
98 let bb0 = BasicBlock::new(BlockId(0));
99 let bb1 = BasicBlock::new(BlockId(1));
100 func.blocks.push(bb0);
101 func.blocks.push(bb1);
102
103 let pass = SimplifyCfg;
104 assert!(pass.run(&mut func));
105 assert_eq!(func.blocks.len(), 1);
106 assert_eq!(func.blocks[0].id, BlockId(0));
107 }
108
109 #[test]
110 fn test_merge_blocks() {
111 let mut func = MirFunction::new("test".into(), BlockId(0));
112 let mut bb0 = BasicBlock::new(BlockId(0));
113 bb0.instructions.push(MirInst::Const {
114 dest: ValueId(0),
115 value: ConstValue::I32(1),
116 });
117 bb0.terminator = Terminator::Branch { target: BlockId(1) };
118
119 let mut bb1 = BasicBlock::new(BlockId(1));
120 bb1.instructions.push(MirInst::Const {
121 dest: ValueId(1),
122 value: ConstValue::I32(2),
123 });
124 bb1.terminator = Terminator::Return;
125
126 func.blocks.push(bb0);
127 func.blocks.push(bb1);
128
129 let pass = SimplifyCfg;
130 assert!(pass.run(&mut func));
131 assert_eq!(func.blocks.len(), 1);
132 assert_eq!(func.blocks[0].instructions.len(), 2);
133 }
134}