Skip to main content

wave_compiler/optimize/
cse.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Common subexpression elimination pass.
5//!
6//! For each instruction, checks if an equivalent computation with the
7//! same operands already exists. If so, replaces the redundant computation
8//! with a reference to the existing value.
9
10use std::collections::HashMap;
11
12use super::pass::Pass;
13use crate::mir::function::MirFunction;
14use crate::mir::instruction::MirInst;
15use crate::mir::value::ValueId;
16
17/// Common subexpression elimination pass.
18pub struct Cse;
19
20#[derive(Hash, PartialEq, Eq, Clone)]
21struct ExprKey {
22    op: u8,
23    lhs: ValueId,
24    rhs: ValueId,
25}
26
27impl Pass for Cse {
28    fn name(&self) -> &'static str {
29        "cse"
30    }
31
32    fn run(&self, func: &mut MirFunction) -> bool {
33        let mut available: HashMap<ExprKey, ValueId> = HashMap::new();
34        let mut replacements: HashMap<ValueId, ValueId> = HashMap::new();
35        let mut changed = false;
36
37        for block in &mut func.blocks {
38            for inst in &mut block.instructions {
39                if let MirInst::BinOp {
40                    dest, op, lhs, rhs, ..
41                } = inst
42                {
43                    let actual_lhs = *replacements.get(lhs).unwrap_or(lhs);
44                    let actual_rhs = *replacements.get(rhs).unwrap_or(rhs);
45                    *lhs = actual_lhs;
46                    *rhs = actual_rhs;
47
48                    let key = ExprKey {
49                        op: *op as u8,
50                        lhs: actual_lhs,
51                        rhs: actual_rhs,
52                    };
53
54                    if let Some(&existing) = available.get(&key) {
55                        replacements.insert(*dest, existing);
56                        *inst = MirInst::Const {
57                            dest: *dest,
58                            value: crate::mir::instruction::ConstValue::I32(0),
59                        };
60                        changed = true;
61                    } else {
62                        available.insert(key, *dest);
63                    }
64                }
65            }
66        }
67
68        changed
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use crate::hir::expr::BinOp;
76    use crate::mir::basic_block::{BasicBlock, Terminator};
77    use crate::mir::types::MirType;
78    use crate::mir::value::BlockId;
79
80    #[test]
81    fn test_cse_eliminates_duplicate() {
82        let mut func = MirFunction::new("test".into(), BlockId(0));
83        let mut bb = BasicBlock::new(BlockId(0));
84        bb.instructions.push(MirInst::BinOp {
85            dest: ValueId(2),
86            op: BinOp::Add,
87            lhs: ValueId(0),
88            rhs: ValueId(1),
89            ty: MirType::I32,
90        });
91        bb.instructions.push(MirInst::BinOp {
92            dest: ValueId(3),
93            op: BinOp::Add,
94            lhs: ValueId(0),
95            rhs: ValueId(1),
96            ty: MirType::I32,
97        });
98        bb.terminator = Terminator::Return;
99        func.blocks.push(bb);
100
101        let pass = Cse;
102        assert!(pass.run(&mut func));
103    }
104
105    #[test]
106    fn test_cse_no_change_different_ops() {
107        let mut func = MirFunction::new("test".into(), BlockId(0));
108        let mut bb = BasicBlock::new(BlockId(0));
109        bb.instructions.push(MirInst::BinOp {
110            dest: ValueId(2),
111            op: BinOp::Add,
112            lhs: ValueId(0),
113            rhs: ValueId(1),
114            ty: MirType::I32,
115        });
116        bb.instructions.push(MirInst::BinOp {
117            dest: ValueId(3),
118            op: BinOp::Sub,
119            lhs: ValueId(0),
120            rhs: ValueId(1),
121            ty: MirType::I32,
122        });
123        bb.terminator = Terminator::Return;
124        func.blocks.push(bb);
125
126        let pass = Cse;
127        assert!(!pass.run(&mut func));
128    }
129}