wave_compiler/analysis/
loop_analysis.rs1use std::collections::{HashMap, HashSet};
11
12use super::cfg::Cfg;
13use super::dominance::DomTree;
14use crate::mir::value::BlockId;
15
16#[derive(Debug, Clone)]
18pub struct NaturalLoop {
19 pub header: BlockId,
21 pub latch: BlockId,
23 pub body: HashSet<BlockId>,
25 pub depth: u32,
27}
28
29pub struct LoopInfo {
31 pub loops: Vec<NaturalLoop>,
33 pub back_edges: Vec<(BlockId, BlockId)>,
35 pub block_depth: HashMap<BlockId, u32>,
37}
38
39impl LoopInfo {
40 #[must_use]
42 pub fn compute(cfg: &Cfg, dom: &DomTree) -> Self {
43 let back_edges = detect_back_edges(cfg, dom);
44 let mut loops = Vec::new();
45
46 for &(latch, header) in &back_edges {
47 let body = compute_loop_body(cfg, header, latch);
48 loops.push(NaturalLoop {
49 header,
50 latch,
51 body,
52 depth: 0,
53 });
54 }
55
56 compute_nesting_depths(&mut loops);
57
58 let mut block_depth: HashMap<BlockId, u32> = HashMap::new();
59 for &bid in &cfg.blocks {
60 block_depth.insert(bid, 0);
61 }
62 for nl in &loops {
63 for &bid in &nl.body {
64 let current = block_depth.get(&bid).copied().unwrap_or(0);
65 if nl.depth + 1 > current {
66 block_depth.insert(bid, nl.depth + 1);
67 }
68 }
69 }
70
71 Self {
72 loops,
73 back_edges,
74 block_depth,
75 }
76 }
77
78 #[must_use]
80 pub fn depth(&self, block: BlockId) -> u32 {
81 self.block_depth.get(&block).copied().unwrap_or(0)
82 }
83
84 #[must_use]
86 pub fn containing_loop(&self, block: BlockId) -> Option<&NaturalLoop> {
87 self.loops
88 .iter()
89 .filter(|l| l.body.contains(&block))
90 .max_by_key(|l| l.depth)
91 }
92}
93
94fn detect_back_edges(cfg: &Cfg, dom: &DomTree) -> Vec<(BlockId, BlockId)> {
95 let mut back_edges = Vec::new();
96 for &bid in &cfg.blocks {
97 for &succ in cfg.succs(bid) {
98 if dom.dominates(succ, bid) {
99 back_edges.push((bid, succ));
100 }
101 }
102 }
103 back_edges
104}
105
106fn compute_loop_body(cfg: &Cfg, header: BlockId, latch: BlockId) -> HashSet<BlockId> {
107 let mut body = HashSet::new();
108 body.insert(header);
109 if header == latch {
110 return body;
111 }
112
113 let mut stack = vec![latch];
114 body.insert(latch);
115
116 while let Some(block) = stack.pop() {
117 for &pred in cfg.preds(block) {
118 if body.insert(pred) {
119 stack.push(pred);
120 }
121 }
122 }
123
124 body
125}
126
127fn compute_nesting_depths(loops: &mut [NaturalLoop]) {
128 let n = loops.len();
129 for i in 0..n {
130 let mut depth = 0u32;
131 for j in 0..n {
132 if i != j && loops[j].body.is_superset(&loops[i].body) {
133 depth += 1;
134 }
135 }
136 loops[i].depth = depth;
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::mir::basic_block::{BasicBlock, Terminator};
144 use crate::mir::function::MirFunction;
145 use crate::mir::value::ValueId;
146
147 fn make_simple_loop() -> MirFunction {
148 let mut func = MirFunction::new("test".into(), BlockId(0));
149 let mut bb0 = BasicBlock::new(BlockId(0));
150 bb0.terminator = Terminator::Branch { target: BlockId(1) };
151
152 let mut bb1 = BasicBlock::new(BlockId(1));
153 bb1.terminator = Terminator::CondBranch {
154 cond: ValueId(0),
155 true_target: BlockId(2),
156 false_target: BlockId(3),
157 };
158
159 let mut bb2 = BasicBlock::new(BlockId(2));
160 bb2.terminator = Terminator::Branch { target: BlockId(1) };
161
162 let bb3 = BasicBlock::new(BlockId(3));
163
164 func.blocks.push(bb0);
165 func.blocks.push(bb1);
166 func.blocks.push(bb2);
167 func.blocks.push(bb3);
168 func
169 }
170
171 #[test]
172 fn test_detect_loop() {
173 let func = make_simple_loop();
174 let cfg = Cfg::build(&func);
175 let dom = DomTree::compute(&cfg);
176 let loop_info = LoopInfo::compute(&cfg, &dom);
177
178 assert_eq!(loop_info.back_edges.len(), 1);
179 assert_eq!(loop_info.back_edges[0], (BlockId(2), BlockId(1)));
180 assert_eq!(loop_info.loops.len(), 1);
181 assert_eq!(loop_info.loops[0].header, BlockId(1));
182 assert!(loop_info.loops[0].body.contains(&BlockId(1)));
183 assert!(loop_info.loops[0].body.contains(&BlockId(2)));
184 }
185
186 #[test]
187 fn test_loop_depth() {
188 let func = make_simple_loop();
189 let cfg = Cfg::build(&func);
190 let dom = DomTree::compute(&cfg);
191 let loop_info = LoopInfo::compute(&cfg, &dom);
192
193 assert_eq!(loop_info.depth(BlockId(0)), 0);
194 assert_eq!(loop_info.depth(BlockId(1)), 1);
195 assert_eq!(loop_info.depth(BlockId(2)), 1);
196 assert_eq!(loop_info.depth(BlockId(3)), 0);
197 }
198
199 #[test]
200 fn test_no_loops() {
201 let mut func = MirFunction::new("test".into(), BlockId(0));
202 let mut bb0 = BasicBlock::new(BlockId(0));
203 bb0.terminator = Terminator::Branch { target: BlockId(1) };
204 let bb1 = BasicBlock::new(BlockId(1));
205 func.blocks.push(bb0);
206 func.blocks.push(bb1);
207
208 let cfg = Cfg::build(&func);
209 let dom = DomTree::compute(&cfg);
210 let loop_info = LoopInfo::compute(&cfg, &dom);
211
212 assert!(loop_info.loops.is_empty());
213 assert!(loop_info.back_edges.is_empty());
214 }
215}