Skip to main content

wave_compiler/optimize/
strength_reduce.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Strength reduction pass.
5//!
6//! Replaces expensive operations with cheaper equivalents:
7//! multiply by power-of-2 → shift left, unsigned divide by power-of-2
8//! → shift right, unsigned modulo by power-of-2 → bitwise AND.
9
10use std::collections::HashMap;
11
12use super::pass::Pass;
13use crate::hir::expr::BinOp;
14use crate::mir::function::MirFunction;
15use crate::mir::instruction::{ConstValue, MirInst};
16use crate::mir::value::ValueId;
17
18/// Strength reduction pass.
19pub struct StrengthReduce;
20
21impl Pass for StrengthReduce {
22    fn name(&self) -> &'static str {
23        "strength_reduce"
24    }
25
26    fn run(&self, func: &mut MirFunction) -> bool {
27        let mut constants: HashMap<ValueId, u32> = HashMap::new();
28        let mut changed = false;
29
30        for block in &mut func.blocks {
31            for inst in &block.instructions {
32                if let MirInst::Const { dest, value } = inst {
33                    match value {
34                        ConstValue::I32(v) => {
35                            constants.insert(*dest, u32::from_ne_bytes(v.to_ne_bytes()));
36                        }
37                        ConstValue::U32(v) => {
38                            constants.insert(*dest, *v);
39                        }
40                        _ => {}
41                    }
42                }
43            }
44
45            let mut replacements: Vec<(usize, MirInst, MirInst)> = Vec::new();
46
47            for (idx, inst) in block.instructions.iter().enumerate() {
48                if let MirInst::BinOp {
49                    dest,
50                    op,
51                    lhs,
52                    rhs,
53                    ty,
54                } = inst
55                {
56                    if let Some(&rhs_val) = constants.get(rhs) {
57                        if rhs_val.is_power_of_two() && rhs_val > 1 {
58                            let shift = rhs_val.trailing_zeros();
59                            let new_const_dest = ValueId(dest.0 + 10000);
60
61                            match op {
62                                BinOp::Mul => {
63                                    replacements.push((
64                                        idx,
65                                        MirInst::BinOp {
66                                            dest: *dest,
67                                            op: BinOp::Shl,
68                                            lhs: *lhs,
69                                            rhs: new_const_dest,
70                                            ty: *ty,
71                                        },
72                                        MirInst::Const {
73                                            dest: new_const_dest,
74                                            value: ConstValue::U32(shift),
75                                        },
76                                    ));
77                                }
78                                BinOp::Div | BinOp::FloorDiv => {
79                                    replacements.push((
80                                        idx,
81                                        MirInst::BinOp {
82                                            dest: *dest,
83                                            op: BinOp::Shr,
84                                            lhs: *lhs,
85                                            rhs: new_const_dest,
86                                            ty: *ty,
87                                        },
88                                        MirInst::Const {
89                                            dest: new_const_dest,
90                                            value: ConstValue::U32(shift),
91                                        },
92                                    ));
93                                }
94                                BinOp::Mod => {
95                                    replacements.push((
96                                        idx,
97                                        MirInst::BinOp {
98                                            dest: *dest,
99                                            op: BinOp::BitAnd,
100                                            lhs: *lhs,
101                                            rhs: new_const_dest,
102                                            ty: *ty,
103                                        },
104                                        MirInst::Const {
105                                            dest: new_const_dest,
106                                            value: ConstValue::U32(rhs_val - 1),
107                                        },
108                                    ));
109                                }
110                                _ => {}
111                            }
112                        }
113                    }
114                }
115            }
116
117            for (idx, replacement, new_const) in replacements.into_iter().rev() {
118                block.instructions[idx] = replacement;
119                block.instructions.insert(idx, new_const);
120                changed = true;
121            }
122        }
123
124        changed
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::mir::basic_block::{BasicBlock, Terminator};
132    use crate::mir::types::MirType;
133    use crate::mir::value::BlockId;
134
135    #[test]
136    fn test_strength_reduce_mul_by_power_of_two() {
137        let mut func = MirFunction::new("test".into(), BlockId(0));
138        let mut bb = BasicBlock::new(BlockId(0));
139        bb.instructions.push(MirInst::Const {
140            dest: ValueId(1),
141            value: ConstValue::U32(8),
142        });
143        bb.instructions.push(MirInst::BinOp {
144            dest: ValueId(2),
145            op: BinOp::Mul,
146            lhs: ValueId(0),
147            rhs: ValueId(1),
148            ty: MirType::I32,
149        });
150        bb.terminator = Terminator::Return;
151        func.blocks.push(bb);
152
153        let pass = StrengthReduce;
154        assert!(pass.run(&mut func));
155        let has_shl = func.blocks[0]
156            .instructions
157            .iter()
158            .any(|i| matches!(i, MirInst::BinOp { op: BinOp::Shl, .. }));
159        assert!(has_shl);
160    }
161
162    #[test]
163    fn test_no_reduce_non_power_of_two() {
164        let mut func = MirFunction::new("test".into(), BlockId(0));
165        let mut bb = BasicBlock::new(BlockId(0));
166        bb.instructions.push(MirInst::Const {
167            dest: ValueId(1),
168            value: ConstValue::U32(7),
169        });
170        bb.instructions.push(MirInst::BinOp {
171            dest: ValueId(2),
172            op: BinOp::Mul,
173            lhs: ValueId(0),
174            rhs: ValueId(1),
175            ty: MirType::I32,
176        });
177        bb.terminator = Terminator::Return;
178        func.blocks.push(bb);
179
180        let pass = StrengthReduce;
181        assert!(!pass.run(&mut func));
182    }
183}