wave_compiler/optimize/
licm.rs1use std::collections::HashSet;
10
11use super::pass::Pass;
12use crate::analysis::cfg::Cfg;
13use crate::analysis::dominance::DomTree;
14use crate::analysis::loop_analysis::LoopInfo;
15use crate::mir::function::MirFunction;
16use crate::mir::instruction::MirInst;
17use crate::mir::value::ValueId;
18
19pub struct Licm;
21
22impl Pass for Licm {
23 fn name(&self) -> &'static str {
24 "licm"
25 }
26
27 fn run(&self, func: &mut MirFunction) -> bool {
28 let cfg = Cfg::build(func);
29 let dom = DomTree::compute(&cfg);
30 let loop_info = LoopInfo::compute(&cfg, &dom);
31
32 let mut changed = false;
33
34 for natural_loop in &loop_info.loops {
35 let mut defs_in_loop: HashSet<ValueId> = HashSet::new();
36 for &bid in &natural_loop.body {
37 if let Some(block) = func.block(bid) {
38 for inst in &block.instructions {
39 if let Some(dest) = inst.dest() {
40 defs_in_loop.insert(dest);
41 }
42 }
43 }
44 }
45
46 let mut invariant_insts: Vec<(usize, MirInst)> = Vec::new();
47
48 for &bid in &natural_loop.body {
49 if let Some(block) = func.block(bid) {
50 for (idx, inst) in block.instructions.iter().enumerate() {
51 if inst.has_side_effects() {
52 continue;
53 }
54 let all_operands_invariant =
55 inst.operands().iter().all(|op| !defs_in_loop.contains(op));
56 if all_operands_invariant {
57 if let Some(dest) = inst.dest() {
58 invariant_insts.push((idx, inst.clone()));
59 defs_in_loop.remove(&dest);
60 }
61 }
62 }
63 }
64 }
65
66 if !invariant_insts.is_empty() {
67 let preds = cfg.preds(natural_loop.header);
68 let preheader = preds
69 .iter()
70 .find(|p| !natural_loop.body.contains(p))
71 .copied();
72
73 if let Some(pre_bid) = preheader {
74 if let Some(pre_block) = func.block_mut(pre_bid) {
75 for (_, inst) in &invariant_insts {
76 pre_block.instructions.push(inst.clone());
77 }
78 changed = true;
79 }
80 }
81 }
82 }
83
84 changed
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::hir::expr::BinOp;
92 use crate::mir::basic_block::{BasicBlock, Terminator};
93 use crate::mir::instruction::{ConstValue, MirInst};
94 use crate::mir::types::MirType;
95 use crate::mir::value::BlockId;
96
97 #[test]
98 fn test_licm_no_loops_no_change() {
99 let mut func = MirFunction::new("test".into(), BlockId(0));
100 let mut bb = BasicBlock::new(BlockId(0));
101 bb.instructions.push(MirInst::Const {
102 dest: ValueId(0),
103 value: ConstValue::I32(42),
104 });
105 bb.terminator = Terminator::Return;
106 func.blocks.push(bb);
107
108 let pass = Licm;
109 assert!(!pass.run(&mut func));
110 }
111
112 #[test]
113 fn test_licm_hoists_invariant() {
114 let mut func = MirFunction::new("test".into(), BlockId(0));
115
116 let mut bb0 = BasicBlock::new(BlockId(0));
117 bb0.terminator = Terminator::Branch { target: BlockId(1) };
118
119 let mut bb1 = BasicBlock::new(BlockId(1));
120 bb1.instructions.push(MirInst::BinOp {
121 dest: ValueId(2),
122 op: BinOp::Add,
123 lhs: ValueId(0),
124 rhs: ValueId(1),
125 ty: MirType::I32,
126 });
127 bb1.terminator = Terminator::CondBranch {
128 cond: ValueId(2),
129 true_target: BlockId(2),
130 false_target: BlockId(3),
131 };
132
133 let mut bb2 = BasicBlock::new(BlockId(2));
134 bb2.terminator = Terminator::Branch { target: BlockId(1) };
135
136 let bb3 = BasicBlock::new(BlockId(3));
137
138 func.blocks.push(bb0);
139 func.blocks.push(bb1);
140 func.blocks.push(bb2);
141 func.blocks.push(bb3);
142
143 let pass = Licm;
144 let changed = pass.run(&mut func);
145 assert!(changed);
146 assert!(!func.blocks[0].instructions.is_empty());
147 }
148}