Skip to main content

wave_compiler/analysis/
escape.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Escape analysis for stack-to-register promotion.
5//!
6//! Determines which locally-allocated values escape the current scope
7//! (e.g., through stores to device memory or function calls). Non-escaping
8//! values can be promoted from memory to registers.
9
10use std::collections::HashSet;
11
12use crate::hir::types::AddressSpace;
13use crate::mir::function::MirFunction;
14use crate::mir::instruction::MirInst;
15use crate::mir::value::ValueId;
16
17/// Result of escape analysis.
18pub struct EscapeInfo {
19    /// Values that escape the local scope.
20    pub escaped: HashSet<ValueId>,
21}
22
23impl EscapeInfo {
24    /// Compute escape information for a MIR function.
25    #[must_use]
26    pub fn compute(func: &MirFunction) -> Self {
27        let mut escaped = HashSet::new();
28
29        for block in &func.blocks {
30            for inst in &block.instructions {
31                match inst {
32                    MirInst::Store {
33                        value,
34                        space: AddressSpace::Device,
35                        ..
36                    }
37                    | MirInst::AtomicRmw { value, .. } => {
38                        escaped.insert(*value);
39                    }
40                    MirInst::Call { args, .. } => {
41                        for arg in args {
42                            escaped.insert(*arg);
43                        }
44                    }
45                    _ => {}
46                }
47            }
48        }
49
50        Self { escaped }
51    }
52
53    /// Returns true if a value escapes the local scope.
54    #[must_use]
55    pub fn escapes(&self, value: ValueId) -> bool {
56        self.escaped.contains(&value)
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63    use crate::hir::expr::BinOp;
64    use crate::mir::basic_block::{BasicBlock, Terminator};
65    use crate::mir::instruction::{ConstValue, MirInst};
66    use crate::mir::types::MirType;
67    use crate::mir::value::BlockId;
68
69    #[test]
70    fn test_value_escapes_through_store() {
71        let mut func = MirFunction::new("test".into(), BlockId(0));
72        let mut bb = BasicBlock::new(BlockId(0));
73        bb.instructions.push(MirInst::Const {
74            dest: ValueId(0),
75            value: ConstValue::I32(42),
76        });
77        bb.instructions.push(MirInst::Store {
78            addr: ValueId(1),
79            value: ValueId(0),
80            space: AddressSpace::Device,
81        });
82        bb.terminator = Terminator::Return;
83        func.blocks.push(bb);
84
85        let info = EscapeInfo::compute(&func);
86        assert!(info.escapes(ValueId(0)));
87        assert!(!info.escapes(ValueId(1)));
88    }
89
90    #[test]
91    fn test_local_value_does_not_escape() {
92        let mut func = MirFunction::new("test".into(), BlockId(0));
93        let mut bb = BasicBlock::new(BlockId(0));
94        bb.instructions.push(MirInst::Const {
95            dest: ValueId(0),
96            value: ConstValue::I32(1),
97        });
98        bb.instructions.push(MirInst::Const {
99            dest: ValueId(1),
100            value: ConstValue::I32(2),
101        });
102        bb.instructions.push(MirInst::BinOp {
103            dest: ValueId(2),
104            op: BinOp::Add,
105            lhs: ValueId(0),
106            rhs: ValueId(1),
107            ty: MirType::I32,
108        });
109        bb.terminator = Terminator::Return;
110        func.blocks.push(bb);
111
112        let info = EscapeInfo::compute(&func);
113        assert!(!info.escapes(ValueId(0)));
114        assert!(!info.escapes(ValueId(1)));
115        assert!(!info.escapes(ValueId(2)));
116    }
117}