wave_compiler/optimize/
strength_reduce.rs1use std::collections::HashMap;
11
12use super::pass::Pass;
13use crate::hir::expr::BinOp;
14use crate::mir::function::MirFunction;
15use crate::mir::instruction::{ConstValue, MirInst};
16use crate::mir::value::ValueId;
17
18pub struct StrengthReduce;
20
21impl Pass for StrengthReduce {
22 fn name(&self) -> &'static str {
23 "strength_reduce"
24 }
25
26 fn run(&self, func: &mut MirFunction) -> bool {
27 let mut constants: HashMap<ValueId, u32> = HashMap::new();
28 let mut changed = false;
29
30 for block in &mut func.blocks {
31 for inst in &block.instructions {
32 if let MirInst::Const { dest, value } = inst {
33 match value {
34 ConstValue::I32(v) => {
35 constants.insert(*dest, u32::from_ne_bytes(v.to_ne_bytes()));
36 }
37 ConstValue::U32(v) => {
38 constants.insert(*dest, *v);
39 }
40 _ => {}
41 }
42 }
43 }
44
45 let mut replacements: Vec<(usize, MirInst, MirInst)> = Vec::new();
46
47 for (idx, inst) in block.instructions.iter().enumerate() {
48 if let MirInst::BinOp {
49 dest,
50 op,
51 lhs,
52 rhs,
53 ty,
54 } = inst
55 {
56 if let Some(&rhs_val) = constants.get(rhs) {
57 if rhs_val.is_power_of_two() && rhs_val > 1 {
58 let shift = rhs_val.trailing_zeros();
59 let new_const_dest = ValueId(dest.0 + 10000);
60
61 match op {
62 BinOp::Mul => {
63 replacements.push((
64 idx,
65 MirInst::BinOp {
66 dest: *dest,
67 op: BinOp::Shl,
68 lhs: *lhs,
69 rhs: new_const_dest,
70 ty: *ty,
71 },
72 MirInst::Const {
73 dest: new_const_dest,
74 value: ConstValue::U32(shift),
75 },
76 ));
77 }
78 BinOp::Div | BinOp::FloorDiv => {
79 replacements.push((
80 idx,
81 MirInst::BinOp {
82 dest: *dest,
83 op: BinOp::Shr,
84 lhs: *lhs,
85 rhs: new_const_dest,
86 ty: *ty,
87 },
88 MirInst::Const {
89 dest: new_const_dest,
90 value: ConstValue::U32(shift),
91 },
92 ));
93 }
94 BinOp::Mod => {
95 replacements.push((
96 idx,
97 MirInst::BinOp {
98 dest: *dest,
99 op: BinOp::BitAnd,
100 lhs: *lhs,
101 rhs: new_const_dest,
102 ty: *ty,
103 },
104 MirInst::Const {
105 dest: new_const_dest,
106 value: ConstValue::U32(rhs_val - 1),
107 },
108 ));
109 }
110 _ => {}
111 }
112 }
113 }
114 }
115 }
116
117 for (idx, replacement, new_const) in replacements.into_iter().rev() {
118 block.instructions[idx] = replacement;
119 block.instructions.insert(idx, new_const);
120 changed = true;
121 }
122 }
123
124 changed
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::mir::basic_block::{BasicBlock, Terminator};
132 use crate::mir::types::MirType;
133 use crate::mir::value::BlockId;
134
135 #[test]
136 fn test_strength_reduce_mul_by_power_of_two() {
137 let mut func = MirFunction::new("test".into(), BlockId(0));
138 let mut bb = BasicBlock::new(BlockId(0));
139 bb.instructions.push(MirInst::Const {
140 dest: ValueId(1),
141 value: ConstValue::U32(8),
142 });
143 bb.instructions.push(MirInst::BinOp {
144 dest: ValueId(2),
145 op: BinOp::Mul,
146 lhs: ValueId(0),
147 rhs: ValueId(1),
148 ty: MirType::I32,
149 });
150 bb.terminator = Terminator::Return;
151 func.blocks.push(bb);
152
153 let pass = StrengthReduce;
154 assert!(pass.run(&mut func));
155 let has_shl = func.blocks[0]
156 .instructions
157 .iter()
158 .any(|i| matches!(i, MirInst::BinOp { op: BinOp::Shl, .. }));
159 assert!(has_shl);
160 }
161
162 #[test]
163 fn test_no_reduce_non_power_of_two() {
164 let mut func = MirFunction::new("test".into(), BlockId(0));
165 let mut bb = BasicBlock::new(BlockId(0));
166 bb.instructions.push(MirInst::Const {
167 dest: ValueId(1),
168 value: ConstValue::U32(7),
169 });
170 bb.instructions.push(MirInst::BinOp {
171 dest: ValueId(2),
172 op: BinOp::Mul,
173 lhs: ValueId(0),
174 rhs: ValueId(1),
175 ty: MirType::I32,
176 });
177 bb.terminator = Terminator::Return;
178 func.blocks.push(bb);
179
180 let pass = StrengthReduce;
181 assert!(!pass.run(&mut func));
182 }
183}