wave_compiler/optimize/
cse.rs1use std::collections::HashMap;
11
12use super::pass::Pass;
13use crate::mir::function::MirFunction;
14use crate::mir::instruction::MirInst;
15use crate::mir::value::ValueId;
16
17pub struct Cse;
19
20#[derive(Hash, PartialEq, Eq, Clone)]
21struct ExprKey {
22 op: u8,
23 lhs: ValueId,
24 rhs: ValueId,
25}
26
27impl Pass for Cse {
28 fn name(&self) -> &'static str {
29 "cse"
30 }
31
32 fn run(&self, func: &mut MirFunction) -> bool {
33 let mut available: HashMap<ExprKey, ValueId> = HashMap::new();
34 let mut replacements: HashMap<ValueId, ValueId> = HashMap::new();
35 let mut changed = false;
36
37 for block in &mut func.blocks {
38 for inst in &mut block.instructions {
39 if let MirInst::BinOp {
40 dest, op, lhs, rhs, ..
41 } = inst
42 {
43 let actual_lhs = *replacements.get(lhs).unwrap_or(lhs);
44 let actual_rhs = *replacements.get(rhs).unwrap_or(rhs);
45 *lhs = actual_lhs;
46 *rhs = actual_rhs;
47
48 let key = ExprKey {
49 op: *op as u8,
50 lhs: actual_lhs,
51 rhs: actual_rhs,
52 };
53
54 if let Some(&existing) = available.get(&key) {
55 replacements.insert(*dest, existing);
56 *inst = MirInst::Const {
57 dest: *dest,
58 value: crate::mir::instruction::ConstValue::I32(0),
59 };
60 changed = true;
61 } else {
62 available.insert(key, *dest);
63 }
64 }
65 }
66 }
67
68 changed
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75 use crate::hir::expr::BinOp;
76 use crate::mir::basic_block::{BasicBlock, Terminator};
77 use crate::mir::types::MirType;
78 use crate::mir::value::BlockId;
79
80 #[test]
81 fn test_cse_eliminates_duplicate() {
82 let mut func = MirFunction::new("test".into(), BlockId(0));
83 let mut bb = BasicBlock::new(BlockId(0));
84 bb.instructions.push(MirInst::BinOp {
85 dest: ValueId(2),
86 op: BinOp::Add,
87 lhs: ValueId(0),
88 rhs: ValueId(1),
89 ty: MirType::I32,
90 });
91 bb.instructions.push(MirInst::BinOp {
92 dest: ValueId(3),
93 op: BinOp::Add,
94 lhs: ValueId(0),
95 rhs: ValueId(1),
96 ty: MirType::I32,
97 });
98 bb.terminator = Terminator::Return;
99 func.blocks.push(bb);
100
101 let pass = Cse;
102 assert!(pass.run(&mut func));
103 }
104
105 #[test]
106 fn test_cse_no_change_different_ops() {
107 let mut func = MirFunction::new("test".into(), BlockId(0));
108 let mut bb = BasicBlock::new(BlockId(0));
109 bb.instructions.push(MirInst::BinOp {
110 dest: ValueId(2),
111 op: BinOp::Add,
112 lhs: ValueId(0),
113 rhs: ValueId(1),
114 ty: MirType::I32,
115 });
116 bb.instructions.push(MirInst::BinOp {
117 dest: ValueId(3),
118 op: BinOp::Sub,
119 lhs: ValueId(0),
120 rhs: ValueId(1),
121 ty: MirType::I32,
122 });
123 bb.terminator = Terminator::Return;
124 func.blocks.push(bb);
125
126 let pass = Cse;
127 assert!(!pass.run(&mut func));
128 }
129}