Skip to main content

wave_compiler/analysis/
loop_analysis.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Loop detection, nesting depth, and induction variable analysis.
5//!
6//! Detects natural loops by finding back edges in the CFG, then
7//! determines loop bodies, nesting relationships, and identifies
8//! simple induction variables for optimization.
9
10use std::collections::{HashMap, HashSet};
11
12use super::cfg::Cfg;
13use super::dominance::DomTree;
14use crate::mir::value::BlockId;
15
16/// A natural loop in the CFG.
17#[derive(Debug, Clone)]
18pub struct NaturalLoop {
19    /// Loop header block (dominates all blocks in the loop).
20    pub header: BlockId,
21    /// Back edge source block.
22    pub latch: BlockId,
23    /// All blocks in the loop body (including header).
24    pub body: HashSet<BlockId>,
25    /// Nesting depth (0 for outermost loops).
26    pub depth: u32,
27}
28
29/// Result of loop analysis.
30pub struct LoopInfo {
31    /// Detected natural loops.
32    pub loops: Vec<NaturalLoop>,
33    /// Back edges in the CFG.
34    pub back_edges: Vec<(BlockId, BlockId)>,
35    /// Loop depth for each block (0 if not in any loop).
36    pub block_depth: HashMap<BlockId, u32>,
37}
38
39impl LoopInfo {
40    /// Perform loop analysis on a CFG with dominator tree.
41    #[must_use]
42    pub fn compute(cfg: &Cfg, dom: &DomTree) -> Self {
43        let back_edges = detect_back_edges(cfg, dom);
44        let mut loops = Vec::new();
45
46        for &(latch, header) in &back_edges {
47            let body = compute_loop_body(cfg, header, latch);
48            loops.push(NaturalLoop {
49                header,
50                latch,
51                body,
52                depth: 0,
53            });
54        }
55
56        compute_nesting_depths(&mut loops);
57
58        let mut block_depth: HashMap<BlockId, u32> = HashMap::new();
59        for &bid in &cfg.blocks {
60            block_depth.insert(bid, 0);
61        }
62        for nl in &loops {
63            for &bid in &nl.body {
64                let current = block_depth.get(&bid).copied().unwrap_or(0);
65                if nl.depth + 1 > current {
66                    block_depth.insert(bid, nl.depth + 1);
67                }
68            }
69        }
70
71        Self {
72            loops,
73            back_edges,
74            block_depth,
75        }
76    }
77
78    /// Returns the loop depth of a block (0 if not in any loop).
79    #[must_use]
80    pub fn depth(&self, block: BlockId) -> u32 {
81        self.block_depth.get(&block).copied().unwrap_or(0)
82    }
83
84    /// Returns the innermost loop containing a block, if any.
85    #[must_use]
86    pub fn containing_loop(&self, block: BlockId) -> Option<&NaturalLoop> {
87        self.loops
88            .iter()
89            .filter(|l| l.body.contains(&block))
90            .max_by_key(|l| l.depth)
91    }
92}
93
94fn detect_back_edges(cfg: &Cfg, dom: &DomTree) -> Vec<(BlockId, BlockId)> {
95    let mut back_edges = Vec::new();
96    for &bid in &cfg.blocks {
97        for &succ in cfg.succs(bid) {
98            if dom.dominates(succ, bid) {
99                back_edges.push((bid, succ));
100            }
101        }
102    }
103    back_edges
104}
105
106fn compute_loop_body(cfg: &Cfg, header: BlockId, latch: BlockId) -> HashSet<BlockId> {
107    let mut body = HashSet::new();
108    body.insert(header);
109    if header == latch {
110        return body;
111    }
112
113    let mut stack = vec![latch];
114    body.insert(latch);
115
116    while let Some(block) = stack.pop() {
117        for &pred in cfg.preds(block) {
118            if body.insert(pred) {
119                stack.push(pred);
120            }
121        }
122    }
123
124    body
125}
126
127fn compute_nesting_depths(loops: &mut [NaturalLoop]) {
128    let n = loops.len();
129    for i in 0..n {
130        let mut depth = 0u32;
131        for j in 0..n {
132            if i != j && loops[j].body.is_superset(&loops[i].body) {
133                depth += 1;
134            }
135        }
136        loops[i].depth = depth;
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::mir::basic_block::{BasicBlock, Terminator};
144    use crate::mir::function::MirFunction;
145    use crate::mir::value::ValueId;
146
147    fn make_simple_loop() -> MirFunction {
148        let mut func = MirFunction::new("test".into(), BlockId(0));
149        let mut bb0 = BasicBlock::new(BlockId(0));
150        bb0.terminator = Terminator::Branch { target: BlockId(1) };
151
152        let mut bb1 = BasicBlock::new(BlockId(1));
153        bb1.terminator = Terminator::CondBranch {
154            cond: ValueId(0),
155            true_target: BlockId(2),
156            false_target: BlockId(3),
157        };
158
159        let mut bb2 = BasicBlock::new(BlockId(2));
160        bb2.terminator = Terminator::Branch { target: BlockId(1) };
161
162        let bb3 = BasicBlock::new(BlockId(3));
163
164        func.blocks.push(bb0);
165        func.blocks.push(bb1);
166        func.blocks.push(bb2);
167        func.blocks.push(bb3);
168        func
169    }
170
171    #[test]
172    fn test_detect_loop() {
173        let func = make_simple_loop();
174        let cfg = Cfg::build(&func);
175        let dom = DomTree::compute(&cfg);
176        let loop_info = LoopInfo::compute(&cfg, &dom);
177
178        assert_eq!(loop_info.back_edges.len(), 1);
179        assert_eq!(loop_info.back_edges[0], (BlockId(2), BlockId(1)));
180        assert_eq!(loop_info.loops.len(), 1);
181        assert_eq!(loop_info.loops[0].header, BlockId(1));
182        assert!(loop_info.loops[0].body.contains(&BlockId(1)));
183        assert!(loop_info.loops[0].body.contains(&BlockId(2)));
184    }
185
186    #[test]
187    fn test_loop_depth() {
188        let func = make_simple_loop();
189        let cfg = Cfg::build(&func);
190        let dom = DomTree::compute(&cfg);
191        let loop_info = LoopInfo::compute(&cfg, &dom);
192
193        assert_eq!(loop_info.depth(BlockId(0)), 0);
194        assert_eq!(loop_info.depth(BlockId(1)), 1);
195        assert_eq!(loop_info.depth(BlockId(2)), 1);
196        assert_eq!(loop_info.depth(BlockId(3)), 0);
197    }
198
199    #[test]
200    fn test_no_loops() {
201        let mut func = MirFunction::new("test".into(), BlockId(0));
202        let mut bb0 = BasicBlock::new(BlockId(0));
203        bb0.terminator = Terminator::Branch { target: BlockId(1) };
204        let bb1 = BasicBlock::new(BlockId(1));
205        func.blocks.push(bb0);
206        func.blocks.push(bb1);
207
208        let cfg = Cfg::build(&func);
209        let dom = DomTree::compute(&cfg);
210        let loop_info = LoopInfo::compute(&cfg, &dom);
211
212        assert!(loop_info.loops.is_empty());
213        assert!(loop_info.back_edges.is_empty());
214    }
215}