1use core::panic;
5use itertools::Itertools;
6use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
7use slotmap::Key;
8use std::{
9 collections::hash_map,
10 fmt::Debug,
11 hash::{Hash, Hasher},
12};
13
14use crate::{
15 AnalysisResults, BinaryOpKind, Context, DebugWithContext, DomTree, Function, InstOp, IrError,
16 Pass, PassMutability, PostOrder, Predicate, ScopedPass, Type, UnaryOpKind, Value,
17 DOMINATORS_NAME, POSTORDER_NAME,
18};
19
20pub const CSE_NAME: &str = "cse";
21
22pub fn create_cse_pass() -> Pass {
23 Pass {
24 name: CSE_NAME,
25 descr: "Common subexpression elimination",
26 runner: ScopedPass::FunctionPass(PassMutability::Transform(cse)),
27 deps: vec![POSTORDER_NAME, DOMINATORS_NAME],
28 }
29}
30
31#[derive(Clone, Copy, Eq, PartialEq, Hash, DebugWithContext)]
32enum ValueNumber {
33 Top,
35 Number(Value),
37}
38
39impl Debug for ValueNumber {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 Self::Top => write!(f, "Top"),
43 Self::Number(arg0) => write!(f, "v{:?}", arg0.0.data()),
44 }
45 }
46}
47
48#[derive(Clone, Debug, Eq, PartialEq, Hash, DebugWithContext)]
49enum Expr {
50 Phi(Vec<ValueNumber>),
51 UnaryOp {
52 op: UnaryOpKind,
53 arg: ValueNumber,
54 },
55 BinaryOp {
56 op: BinaryOpKind,
57 arg1: ValueNumber,
58 arg2: ValueNumber,
59 },
60 BitCast(ValueNumber, Type),
61 CastPtr(ValueNumber, Type),
62 Cmp(Predicate, ValueNumber, ValueNumber),
63 GetElemPtr {
64 base: ValueNumber,
65 elem_ptr_ty: Type,
66 indices: Vec<ValueNumber>,
67 },
68 IntToPtr(ValueNumber, Type),
69 PtrToInt(ValueNumber, Type),
70}
71
72fn instr_to_expr(context: &Context, vntable: &VNTable, instr: Value) -> Option<Expr> {
75 match &instr.get_instruction(context).unwrap().op {
76 InstOp::AsmBlock(_, _) => None,
77 InstOp::UnaryOp { op, arg } => Some(Expr::UnaryOp {
78 op: *op,
79 arg: vntable.value_map.get(arg).cloned().unwrap(),
80 }),
81 InstOp::BinaryOp { op, arg1, arg2 } => Some(Expr::BinaryOp {
82 op: *op,
83 arg1: vntable.value_map.get(arg1).cloned().unwrap(),
84 arg2: vntable.value_map.get(arg2).cloned().unwrap(),
85 }),
86 InstOp::BitCast(val, ty) => Some(Expr::BitCast(
87 vntable.value_map.get(val).cloned().unwrap(),
88 *ty,
89 )),
90 InstOp::Branch(_) => None,
91 InstOp::Call(_, _) => None,
92 InstOp::CastPtr(val, ty) => Some(Expr::CastPtr(
93 vntable.value_map.get(val).cloned().unwrap(),
94 *ty,
95 )),
96 InstOp::Cmp(pred, val1, val2) => Some(Expr::Cmp(
97 *pred,
98 vntable.value_map.get(val1).cloned().unwrap(),
99 vntable.value_map.get(val2).cloned().unwrap(),
100 )),
101 InstOp::ConditionalBranch { .. } => None,
102 InstOp::ContractCall { .. } => None,
103 InstOp::FuelVm(_) => None,
104 InstOp::GetLocal(_) => None,
105 InstOp::GetGlobal(_) => None,
106 InstOp::GetConfig(_, _) => None,
107 InstOp::GetStorageKey(_) => None,
108 InstOp::GetElemPtr {
109 base,
110 elem_ptr_ty,
111 indices,
112 } => Some(Expr::GetElemPtr {
113 base: vntable.value_map.get(base).cloned().unwrap(),
114 elem_ptr_ty: *elem_ptr_ty,
115 indices: indices
116 .iter()
117 .map(|idx| vntable.value_map.get(idx).cloned().unwrap())
118 .collect(),
119 }),
120 InstOp::IntToPtr(val, ty) => Some(Expr::IntToPtr(
121 vntable.value_map.get(val).cloned().unwrap(),
122 *ty,
123 )),
124 InstOp::Load(_) => None,
125 InstOp::Alloc { .. } => None,
126 InstOp::MemCopyBytes { .. } => None,
127 InstOp::MemCopyVal { .. } => None,
128 InstOp::MemClearVal { .. } => None,
129 InstOp::Nop => None,
130 InstOp::PtrToInt(val, ty) => Some(Expr::PtrToInt(
131 vntable.value_map.get(val).cloned().unwrap(),
132 *ty,
133 )),
134 InstOp::Ret(_, _) => None,
135 InstOp::Store { .. } => None,
136 }
137}
138
139fn phi_to_expr(context: &Context, vntable: &VNTable, phi_arg: Value) -> Expr {
141 let phi_arg = phi_arg.get_argument(context).unwrap();
142 let phi_args = phi_arg
143 .block
144 .pred_iter(context)
145 .map(|pred| {
146 let incoming_val = phi_arg
147 .get_val_coming_from(context, pred)
148 .expect("No parameter from predecessor");
149 vntable.value_map.get(&incoming_val).cloned().unwrap()
150 })
151 .collect();
152 Expr::Phi(phi_args)
153}
154
155#[derive(Default)]
156struct VNTable {
157 value_map: FxHashMap<Value, ValueNumber>,
158 expr_map: FxHashMap<Expr, ValueNumber>,
159}
160
161impl Debug for VNTable {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 writeln!(f, "value_map:")?;
164 self.value_map.iter().for_each(|(key, value)| {
165 if format!("v{:?}", key.0.data()) == "v620v3" {
166 writeln!(f, "\tv{:?} -> {:?}", key.0.data(), value).expect("writeln! failed");
167 }
168 });
169 Ok(())
170 }
171}
172
173fn dominates(context: &Context, dom_tree: &DomTree, inst1: Value, inst2: Value) -> bool {
176 let block1 = match &context.values[inst1.0].value {
177 crate::ValueDatum::Argument(arg) => arg.block,
178 crate::ValueDatum::Constant(_) => {
179 panic!("Shouldn't be querying dominance info for constants")
180 }
181 crate::ValueDatum::Instruction(i) => i.parent,
182 };
183 let block2 = match &context.values[inst2.0].value {
184 crate::ValueDatum::Argument(arg) => arg.block,
185 crate::ValueDatum::Constant(_) => {
186 panic!("Shouldn't be querying dominance info for constants")
187 }
188 crate::ValueDatum::Instruction(i) => i.parent,
189 };
190
191 if block1 == block2 {
192 let inst1_idx = block1
193 .instruction_iter(context)
194 .position(|inst| inst == inst1)
195 .unwrap_or(0);
197 let inst2_idx = block1
198 .instruction_iter(context)
199 .position(|inst| inst == inst2)
200 .unwrap_or(0);
202 inst1_idx < inst2_idx
203 } else {
204 dom_tree.dominates(block1, block2)
205 }
206}
207
208pub fn cse(
209 context: &mut Context,
210 analyses: &AnalysisResults,
211 function: Function,
212) -> Result<bool, IrError> {
213 let mut vntable = VNTable::default();
214
215 for arg in function.args_iter(context) {
217 vntable.value_map.insert(arg.1, ValueNumber::Number(arg.1));
218 }
219
220 for block in function.block_iter(context).skip(1) {
222 for arg in block.arg_iter(context) {
223 vntable.value_map.insert(*arg, ValueNumber::Top);
224 }
225 }
226
227 let mut const_map = FxHashMap::<u64, Vec<Value>>::default();
231 for (_, inst) in function.instruction_iter(context) {
232 vntable.value_map.insert(inst, ValueNumber::Top);
233 for (const_opd_val, const_opd_const) in inst
234 .get_instruction(context)
235 .unwrap()
236 .op
237 .get_operands()
238 .iter()
239 .filter_map(|opd| opd.get_constant(context).map(|copd| (opd, copd)))
240 {
241 let mut state = FxHasher::default();
242 const_opd_const.hash(&mut state);
243 let hash = state.finish();
244 if let Some(existing_const) = const_map.get(&hash).and_then(|consts| {
245 consts.iter().find(|val| {
246 let c = val
247 .get_constant(context)
248 .expect("const_map can only contain consts");
249 const_opd_const == c
250 })
251 }) {
252 vntable
253 .value_map
254 .insert(*const_opd_val, ValueNumber::Number(*existing_const));
255 } else {
256 const_map
257 .entry(hash)
258 .and_modify(|consts| consts.push(*const_opd_val))
259 .or_insert_with(|| vec![*const_opd_val]);
260 vntable
261 .value_map
262 .insert(*const_opd_val, ValueNumber::Number(*const_opd_val));
263 }
264 }
265 }
266
267 let post_order: &PostOrder = analyses.get_analysis_result(function);
269
270 let mut changed = true;
272 while changed {
273 changed = false;
274 for (block_idx, block) in post_order.po_to_block.iter().rev().enumerate() {
276 if block_idx != 0 {
278 for (phi, expr_opt) in block
280 .arg_iter(context)
281 .map(|arg| (*arg, Some(phi_to_expr(context, &vntable, *arg))))
282 .collect_vec()
283 {
284 let expr = expr_opt.expect("PHIs must always translate to a valid Expr");
285 let vn = {
287 let Expr::Phi(ref phi_args) = expr else {
288 panic!("Expr must be a PHI")
289 };
290 phi_args
291 .iter()
292 .map(|vn| Some(*vn))
293 .reduce(|vn1, vn2| {
294 if let (Some(vn1), Some(vn2)) = (vn1, vn2) {
296 match (vn1, vn2) {
297 (ValueNumber::Top, ValueNumber::Top) => {
298 Some(ValueNumber::Top)
299 }
300 (ValueNumber::Top, ValueNumber::Number(vn))
301 | (ValueNumber::Number(vn), ValueNumber::Top) => {
302 Some(ValueNumber::Number(vn))
303 }
304 (ValueNumber::Number(vn1), ValueNumber::Number(vn2)) => {
305 (vn1 == vn2).then_some(ValueNumber::Number(vn1))
306 }
307 }
308 } else {
309 None
310 }
311 })
312 .flatten()
313 .unwrap_or(ValueNumber::Number(phi))
315 };
316
317 match vntable.value_map.entry(phi) {
318 hash_map::Entry::Occupied(occ) if *occ.get() == vn => {}
319 _ => {
320 changed = true;
321 vntable.value_map.insert(phi, vn);
322 }
323 }
324 }
325 }
326
327 for (inst, expr_opt) in block
328 .instruction_iter(context)
329 .map(|instr| (instr, instr_to_expr(context, &vntable, instr)))
330 .collect_vec()
331 {
332 let vn = if let Some(expr) = expr_opt {
334 match vntable.expr_map.entry(expr) {
335 hash_map::Entry::Occupied(occ) => *occ.get(),
336 hash_map::Entry::Vacant(vac) => *(vac.insert(ValueNumber::Number(inst))),
337 }
338 } else {
339 ValueNumber::Number(inst)
342 };
343 match vntable.value_map.entry(inst) {
344 hash_map::Entry::Occupied(occ) if *occ.get() == vn => {}
345 _ => {
346 changed = true;
347 vntable.value_map.insert(inst, vn);
348 }
349 }
350 }
351 }
352 vntable.expr_map.clear();
353 }
354
355 let mut partition = FxHashMap::<ValueNumber, FxHashSet<Value>>::default();
357 vntable.value_map.iter().for_each(|(v, vn)| {
358 if v.is_constant(context)
361 || matches!(vn, ValueNumber::Top)
362 || matches!(vn, ValueNumber::Number(v2) if (v == v2 || v2.is_constant(context)))
363 {
364 return;
365 }
366 partition
367 .entry(*vn)
368 .and_modify(|part| {
369 part.insert(*v);
370 })
371 .or_insert(vec![*v].into_iter().collect());
372 });
373
374 partition.iter_mut().for_each(|(vn, v_part)| {
376 let ValueNumber::Number(v) = vn else {
377 panic!("We cannot have Top at this point");
378 };
379 v_part.insert(*v);
380 assert!(
381 v_part.len() > 1,
382 "We've only created partitions with size greater than 1"
383 );
384 });
385
386 let dom_tree: &DomTree = analyses.get_analysis_result(function);
392 let mut replace_map = FxHashMap::<Value, Value>::default();
393 let mut modified = false;
394 partition.iter().for_each(|(_leader, vals)| {
396 for v_pair in vals.iter().combinations(2) {
398 let (v1, v2) = (*v_pair[0], *v_pair[1]);
399 if dominates(context, dom_tree, v1, v2) {
400 modified = true;
401 replace_map.insert(v2, v1);
402 } else if dominates(context, dom_tree, v2, v1) {
403 modified = true;
404 replace_map.insert(v1, v2);
405 }
406 }
407 });
408
409 function.replace_values(context, &replace_map, None);
410
411 Ok(modified)
412}