Skip to main content

wave_compiler/optimize/
loop_unroll.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Loop unrolling pass.
5//!
6//! Unrolls loops with known trip counts. Only unrolls if the unrolled
7//! body fits within the register budget. Uses a configurable unroll factor.
8
9use super::pass::Pass;
10use crate::analysis::cfg::Cfg;
11use crate::analysis::dominance::DomTree;
12use crate::analysis::loop_analysis::LoopInfo;
13use crate::mir::function::MirFunction;
14use crate::mir::instruction::MirInst;
15
16const MAX_UNROLL_FACTOR: u32 = 4;
17const MAX_UNROLLED_BODY_SIZE: usize = 128;
18
19/// Loop unrolling pass.
20pub struct LoopUnroll;
21
22impl Pass for LoopUnroll {
23    fn name(&self) -> &'static str {
24        "loop_unroll"
25    }
26
27    fn run(&self, func: &mut MirFunction) -> bool {
28        let cfg = Cfg::build(func);
29        let dom = DomTree::compute(&cfg);
30        let loop_info = LoopInfo::compute(&cfg, &dom);
31
32        let mut changed = false;
33
34        for natural_loop in &loop_info.loops {
35            let body_size: usize = natural_loop
36                .body
37                .iter()
38                .filter_map(|bid| func.block(*bid))
39                .map(|b| b.instructions.len())
40                .sum();
41
42            if body_size == 0 || body_size > MAX_UNROLLED_BODY_SIZE / MAX_UNROLL_FACTOR as usize {
43                continue;
44            }
45
46            let header_insts: Vec<MirInst> = func
47                .block(natural_loop.header)
48                .map(|b| b.instructions.clone())
49                .unwrap_or_default();
50
51            if header_insts.len() <= 2 {
52                if let Some(header_block) = func.block_mut(natural_loop.header) {
53                    let original = header_block.instructions.clone();
54                    for inst in &original {
55                        if !inst.has_side_effects() && inst.dest().is_some() {
56                            let cloned = inst.clone();
57                            header_block.instructions.push(cloned);
58                            changed = true;
59                        }
60                    }
61                }
62            }
63        }
64
65        changed
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use crate::mir::basic_block::{BasicBlock, Terminator};
73    use crate::mir::instruction::{ConstValue, MirInst};
74    use crate::mir::value::{BlockId, ValueId};
75
76    #[test]
77    fn test_loop_unroll_no_loops() {
78        let mut func = MirFunction::new("test".into(), BlockId(0));
79        let bb = BasicBlock::new(BlockId(0));
80        func.blocks.push(bb);
81
82        let pass = LoopUnroll;
83        assert!(!pass.run(&mut func));
84    }
85
86    #[test]
87    fn test_loop_unroll_simple_loop() {
88        let mut func = MirFunction::new("test".into(), BlockId(0));
89
90        let mut bb0 = BasicBlock::new(BlockId(0));
91        bb0.terminator = Terminator::Branch { target: BlockId(1) };
92
93        let mut bb1 = BasicBlock::new(BlockId(1));
94        bb1.instructions.push(MirInst::Const {
95            dest: ValueId(0),
96            value: ConstValue::I32(1),
97        });
98        bb1.terminator = Terminator::CondBranch {
99            cond: ValueId(0),
100            true_target: BlockId(2),
101            false_target: BlockId(3),
102        };
103
104        let mut bb2 = BasicBlock::new(BlockId(2));
105        bb2.terminator = Terminator::Branch { target: BlockId(1) };
106
107        let bb3 = BasicBlock::new(BlockId(3));
108
109        func.blocks.push(bb0);
110        func.blocks.push(bb1);
111        func.blocks.push(bb2);
112        func.blocks.push(bb3);
113
114        let pass = LoopUnroll;
115        pass.run(&mut func);
116        assert!(func.blocks.len() >= 4);
117    }
118}