wave_compiler/optimize/
loop_unroll.rs1use super::pass::Pass;
10use crate::analysis::cfg::Cfg;
11use crate::analysis::dominance::DomTree;
12use crate::analysis::loop_analysis::LoopInfo;
13use crate::mir::function::MirFunction;
14use crate::mir::instruction::MirInst;
15
16const MAX_UNROLL_FACTOR: u32 = 4;
17const MAX_UNROLLED_BODY_SIZE: usize = 128;
18
19pub struct LoopUnroll;
21
22impl Pass for LoopUnroll {
23 fn name(&self) -> &'static str {
24 "loop_unroll"
25 }
26
27 fn run(&self, func: &mut MirFunction) -> bool {
28 let cfg = Cfg::build(func);
29 let dom = DomTree::compute(&cfg);
30 let loop_info = LoopInfo::compute(&cfg, &dom);
31
32 let mut changed = false;
33
34 for natural_loop in &loop_info.loops {
35 let body_size: usize = natural_loop
36 .body
37 .iter()
38 .filter_map(|bid| func.block(*bid))
39 .map(|b| b.instructions.len())
40 .sum();
41
42 if body_size == 0 || body_size > MAX_UNROLLED_BODY_SIZE / MAX_UNROLL_FACTOR as usize {
43 continue;
44 }
45
46 let header_insts: Vec<MirInst> = func
47 .block(natural_loop.header)
48 .map(|b| b.instructions.clone())
49 .unwrap_or_default();
50
51 if header_insts.len() <= 2 {
52 if let Some(header_block) = func.block_mut(natural_loop.header) {
53 let original = header_block.instructions.clone();
54 for inst in &original {
55 if !inst.has_side_effects() && inst.dest().is_some() {
56 let cloned = inst.clone();
57 header_block.instructions.push(cloned);
58 changed = true;
59 }
60 }
61 }
62 }
63 }
64
65 changed
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use crate::mir::basic_block::{BasicBlock, Terminator};
73 use crate::mir::instruction::{ConstValue, MirInst};
74 use crate::mir::value::{BlockId, ValueId};
75
76 #[test]
77 fn test_loop_unroll_no_loops() {
78 let mut func = MirFunction::new("test".into(), BlockId(0));
79 let bb = BasicBlock::new(BlockId(0));
80 func.blocks.push(bb);
81
82 let pass = LoopUnroll;
83 assert!(!pass.run(&mut func));
84 }
85
86 #[test]
87 fn test_loop_unroll_simple_loop() {
88 let mut func = MirFunction::new("test".into(), BlockId(0));
89
90 let mut bb0 = BasicBlock::new(BlockId(0));
91 bb0.terminator = Terminator::Branch { target: BlockId(1) };
92
93 let mut bb1 = BasicBlock::new(BlockId(1));
94 bb1.instructions.push(MirInst::Const {
95 dest: ValueId(0),
96 value: ConstValue::I32(1),
97 });
98 bb1.terminator = Terminator::CondBranch {
99 cond: ValueId(0),
100 true_target: BlockId(2),
101 false_target: BlockId(3),
102 };
103
104 let mut bb2 = BasicBlock::new(BlockId(2));
105 bb2.terminator = Terminator::Branch { target: BlockId(1) };
106
107 let bb3 = BasicBlock::new(BlockId(3));
108
109 func.blocks.push(bb0);
110 func.blocks.push(bb1);
111 func.blocks.push(bb2);
112 func.blocks.push(bb3);
113
114 let pass = LoopUnroll;
115 pass.run(&mut func);
116 assert!(func.blocks.len() >= 4);
117 }
118}