wave_compiler/analysis/
escape.rs1use std::collections::HashSet;
11
12use crate::hir::types::AddressSpace;
13use crate::mir::function::MirFunction;
14use crate::mir::instruction::MirInst;
15use crate::mir::value::ValueId;
16
17pub struct EscapeInfo {
19 pub escaped: HashSet<ValueId>,
21}
22
23impl EscapeInfo {
24 #[must_use]
26 pub fn compute(func: &MirFunction) -> Self {
27 let mut escaped = HashSet::new();
28
29 for block in &func.blocks {
30 for inst in &block.instructions {
31 match inst {
32 MirInst::Store {
33 value,
34 space: AddressSpace::Device,
35 ..
36 }
37 | MirInst::AtomicRmw { value, .. } => {
38 escaped.insert(*value);
39 }
40 MirInst::Call { args, .. } => {
41 for arg in args {
42 escaped.insert(*arg);
43 }
44 }
45 _ => {}
46 }
47 }
48 }
49
50 Self { escaped }
51 }
52
53 #[must_use]
55 pub fn escapes(&self, value: ValueId) -> bool {
56 self.escaped.contains(&value)
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63 use crate::hir::expr::BinOp;
64 use crate::mir::basic_block::{BasicBlock, Terminator};
65 use crate::mir::instruction::{ConstValue, MirInst};
66 use crate::mir::types::MirType;
67 use crate::mir::value::BlockId;
68
69 #[test]
70 fn test_value_escapes_through_store() {
71 let mut func = MirFunction::new("test".into(), BlockId(0));
72 let mut bb = BasicBlock::new(BlockId(0));
73 bb.instructions.push(MirInst::Const {
74 dest: ValueId(0),
75 value: ConstValue::I32(42),
76 });
77 bb.instructions.push(MirInst::Store {
78 addr: ValueId(1),
79 value: ValueId(0),
80 space: AddressSpace::Device,
81 });
82 bb.terminator = Terminator::Return;
83 func.blocks.push(bb);
84
85 let info = EscapeInfo::compute(&func);
86 assert!(info.escapes(ValueId(0)));
87 assert!(!info.escapes(ValueId(1)));
88 }
89
90 #[test]
91 fn test_local_value_does_not_escape() {
92 let mut func = MirFunction::new("test".into(), BlockId(0));
93 let mut bb = BasicBlock::new(BlockId(0));
94 bb.instructions.push(MirInst::Const {
95 dest: ValueId(0),
96 value: ConstValue::I32(1),
97 });
98 bb.instructions.push(MirInst::Const {
99 dest: ValueId(1),
100 value: ConstValue::I32(2),
101 });
102 bb.instructions.push(MirInst::BinOp {
103 dest: ValueId(2),
104 op: BinOp::Add,
105 lhs: ValueId(0),
106 rhs: ValueId(1),
107 ty: MirType::I32,
108 });
109 bb.terminator = Terminator::Return;
110 func.blocks.push(bb);
111
112 let info = EscapeInfo::compute(&func);
113 assert!(!info.escapes(ValueId(0)));
114 assert!(!info.escapes(ValueId(1)));
115 assert!(!info.escapes(ValueId(2)));
116 }
117}