wave_compiler/optimize/
mod.rs1pub mod constant_fold;
12pub mod cse;
13pub mod dce;
14pub mod licm;
15pub mod loop_unroll;
16pub mod mem2reg;
17pub mod pass;
18pub mod sccp;
19pub mod simplify_cfg;
20pub mod strength_reduce;
21
22use crate::driver::config::OptLevel;
23use crate::mir::function::MirFunction;
24use pass::Pass;
25
26pub fn optimize(func: &mut MirFunction, opt_level: OptLevel) {
28 let passes: Vec<Box<dyn Pass>> = match opt_level {
29 OptLevel::O0 => vec![],
30 OptLevel::O1 => vec![
31 Box::new(mem2reg::Mem2Reg),
32 Box::new(constant_fold::ConstantFold),
33 Box::new(dce::Dce),
34 Box::new(simplify_cfg::SimplifyCfg),
35 ],
36 OptLevel::O2 => vec![
37 Box::new(mem2reg::Mem2Reg),
38 Box::new(sccp::Sccp),
39 Box::new(dce::Dce),
40 Box::new(cse::Cse),
41 Box::new(licm::Licm),
42 Box::new(strength_reduce::StrengthReduce),
43 Box::new(simplify_cfg::SimplifyCfg),
44 Box::new(dce::Dce),
45 ],
46 OptLevel::O3 => vec![
47 Box::new(mem2reg::Mem2Reg),
48 Box::new(sccp::Sccp),
49 Box::new(dce::Dce),
50 Box::new(cse::Cse),
51 Box::new(licm::Licm),
52 Box::new(loop_unroll::LoopUnroll),
53 Box::new(strength_reduce::StrengthReduce),
54 Box::new(simplify_cfg::SimplifyCfg),
55 Box::new(dce::Dce),
56 ],
57 };
58
59 let mut changed = true;
60 let mut iterations = 0;
61 while changed && iterations < 10 {
62 changed = false;
63 for pass in &passes {
64 changed |= pass.run(func);
65 }
66 iterations += 1;
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use crate::mir::basic_block::{BasicBlock, Terminator};
74 use crate::mir::instruction::{ConstValue, MirInst};
75 use crate::mir::value::{BlockId, ValueId};
76
77 #[test]
78 fn test_optimize_o0_no_changes() {
79 let mut func = MirFunction::new("test".into(), BlockId(0));
80 let mut bb = BasicBlock::new(BlockId(0));
81 bb.instructions.push(MirInst::Const {
82 dest: ValueId(0),
83 value: ConstValue::I32(42),
84 });
85 bb.terminator = Terminator::Return;
86 func.blocks.push(bb);
87
88 optimize(&mut func, OptLevel::O0);
89 assert_eq!(func.blocks[0].instructions.len(), 1);
90 }
91
92 #[test]
93 fn test_optimize_o1_removes_dead_code() {
94 let mut func = MirFunction::new("test".into(), BlockId(0));
95 let mut bb = BasicBlock::new(BlockId(0));
96 bb.instructions.push(MirInst::Const {
97 dest: ValueId(0),
98 value: ConstValue::I32(42),
99 });
100 bb.instructions.push(MirInst::Const {
101 dest: ValueId(1),
102 value: ConstValue::I32(99),
103 });
104 bb.terminator = Terminator::Return;
105 func.blocks.push(bb);
106
107 optimize(&mut func, OptLevel::O1);
108 assert!(func.blocks[0].instructions.is_empty());
109 }
110}