wave_compiler/optimize/
sccp.rs1use std::collections::{HashMap, HashSet, VecDeque};
11
12use super::pass::Pass;
13use crate::hir::expr::BinOp;
14use crate::mir::basic_block::Terminator;
15use crate::mir::function::MirFunction;
16use crate::mir::instruction::{ConstValue, MirInst};
17use crate::mir::value::{BlockId, ValueId};
18
19#[derive(Debug, Clone, PartialEq)]
21enum Lattice {
22 Top,
23 Constant(i32),
24 Bottom,
25}
26
27pub struct Sccp;
29
30impl Pass for Sccp {
31 fn name(&self) -> &'static str {
32 "sccp"
33 }
34
35 fn run(&self, func: &mut MirFunction) -> bool {
36 let mut lattice: HashMap<ValueId, Lattice> = HashMap::new();
37 let mut executable: HashSet<BlockId> = HashSet::new();
38 let mut worklist: VecDeque<BlockId> = VecDeque::new();
39
40 for param in &func.params {
41 lattice.insert(param.value, Lattice::Top);
42 }
43
44 executable.insert(func.entry);
45 worklist.push_back(func.entry);
46
47 while let Some(bid) = worklist.pop_front() {
48 let block = match func.block(bid) {
49 Some(b) => b.clone(),
50 None => continue,
51 };
52
53 for inst in &block.instructions {
54 evaluate_instruction(inst, &mut lattice);
55 }
56
57 match &block.terminator {
58 Terminator::Branch { target } => {
59 if executable.insert(*target) {
60 worklist.push_back(*target);
61 }
62 }
63 Terminator::CondBranch {
64 cond,
65 true_target,
66 false_target,
67 } => match lattice.get(cond) {
68 Some(Lattice::Constant(v)) if *v != 0 => {
69 if executable.insert(*true_target) {
70 worklist.push_back(*true_target);
71 }
72 }
73 Some(Lattice::Constant(0)) => {
74 if executable.insert(*false_target) {
75 worklist.push_back(*false_target);
76 }
77 }
78 _ => {
79 if executable.insert(*true_target) {
80 worklist.push_back(*true_target);
81 }
82 if executable.insert(*false_target) {
83 worklist.push_back(*false_target);
84 }
85 }
86 },
87 Terminator::Return => {}
88 }
89 }
90
91 let mut changed = false;
92
93 for block in &mut func.blocks {
94 for inst in &mut block.instructions {
95 if let Some(dest) = inst.dest() {
96 if let Some(Lattice::Constant(v)) = lattice.get(&dest) {
97 if !matches!(inst, MirInst::Const { .. }) {
98 *inst = MirInst::Const {
99 dest,
100 value: ConstValue::I32(*v),
101 };
102 changed = true;
103 }
104 }
105 }
106 }
107 }
108
109 let original_count = func.blocks.len();
110 func.blocks
111 .retain(|b| executable.contains(&b.id) || b.id == func.entry);
112 if func.blocks.len() != original_count {
113 changed = true;
114 }
115
116 changed
117 }
118}
119
120fn evaluate_instruction(inst: &MirInst, lattice: &mut HashMap<ValueId, Lattice>) {
121 match inst {
122 MirInst::Const { dest, value } => {
123 let v = match value {
124 ConstValue::I32(v) => *v,
125 ConstValue::U32(v) => i32::from_ne_bytes(v.to_ne_bytes()),
126 ConstValue::Bool(v) => i32::from(*v),
127 ConstValue::F32(_) => return,
128 };
129 lattice.insert(*dest, Lattice::Constant(v));
130 }
131 MirInst::BinOp {
132 dest, op, lhs, rhs, ..
133 } => {
134 let l = lattice.get(lhs).cloned().unwrap_or(Lattice::Bottom);
135 let r = lattice.get(rhs).cloned().unwrap_or(Lattice::Bottom);
136
137 let result = match (&l, &r) {
138 (Lattice::Constant(a), Lattice::Constant(b)) => match op {
139 BinOp::Add => Lattice::Constant(a.wrapping_add(*b)),
140 BinOp::Sub => Lattice::Constant(a.wrapping_sub(*b)),
141 BinOp::Mul => Lattice::Constant(a.wrapping_mul(*b)),
142 BinOp::Lt => Lattice::Constant(i32::from(*a < *b)),
143 BinOp::Eq => Lattice::Constant(i32::from(*a == *b)),
144 _ => Lattice::Bottom,
145 },
146 _ => Lattice::Bottom,
147 };
148 lattice.insert(*dest, result);
149 }
150 _ => {}
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::mir::basic_block::{BasicBlock, Terminator};
158 use crate::mir::types::MirType;
159
160 #[test]
161 fn test_sccp_folds_constants() {
162 let mut func = MirFunction::new("test".into(), BlockId(0));
163 let mut bb = BasicBlock::new(BlockId(0));
164 bb.instructions.push(MirInst::Const {
165 dest: ValueId(0),
166 value: ConstValue::I32(10),
167 });
168 bb.instructions.push(MirInst::Const {
169 dest: ValueId(1),
170 value: ConstValue::I32(20),
171 });
172 bb.instructions.push(MirInst::BinOp {
173 dest: ValueId(2),
174 op: BinOp::Add,
175 lhs: ValueId(0),
176 rhs: ValueId(1),
177 ty: MirType::I32,
178 });
179 bb.terminator = Terminator::Return;
180 func.blocks.push(bb);
181
182 let pass = Sccp;
183 assert!(pass.run(&mut func));
184 match &func.blocks[0].instructions[2] {
185 MirInst::Const { value, .. } => assert_eq!(*value, ConstValue::I32(30)),
186 other => panic!("expected Const, got {other:?}"),
187 }
188 }
189
190 #[test]
191 fn test_sccp_no_change_without_constants() {
192 let mut func = MirFunction::new("test".into(), BlockId(0));
193 let mut bb = BasicBlock::new(BlockId(0));
194 bb.instructions.push(MirInst::BinOp {
195 dest: ValueId(2),
196 op: BinOp::Add,
197 lhs: ValueId(0),
198 rhs: ValueId(1),
199 ty: MirType::I32,
200 });
201 bb.terminator = Terminator::Return;
202 func.blocks.push(bb);
203
204 let pass = Sccp;
205 assert!(!pass.run(&mut func));
206 }
207}