wave_compiler/analysis/
alias.rs1use std::collections::{HashMap, HashSet};
11
12use crate::hir::types::AddressSpace;
13use crate::mir::function::MirFunction;
14use crate::mir::instruction::MirInst;
15use crate::mir::value::ValueId;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AliasResult {
20 NoAlias,
22 MayAlias,
24 MustAlias,
26}
27
28#[derive(Debug, Clone)]
30pub struct MemOp {
31 pub addr: ValueId,
33 pub space: AddressSpace,
35}
36
37pub struct AliasInfo {
39 mem_ops: Vec<MemOp>,
41 addr_spaces: HashMap<ValueId, AddressSpace>,
43}
44
45impl AliasInfo {
46 #[must_use]
48 pub fn compute(func: &MirFunction) -> Self {
49 let mut mem_ops = Vec::new();
50 let mut addr_spaces: HashMap<ValueId, AddressSpace> = HashMap::new();
51
52 for block in &func.blocks {
53 for inst in &block.instructions {
54 match inst {
55 MirInst::Load { addr, space, .. } | MirInst::Store { addr, space, .. } => {
56 addr_spaces.insert(*addr, *space);
57 mem_ops.push(MemOp {
58 addr: *addr,
59 space: *space,
60 });
61 }
62 _ => {}
63 }
64 }
65 }
66
67 Self {
68 mem_ops,
69 addr_spaces,
70 }
71 }
72
73 #[must_use]
75 pub fn query(&self, a: &MemOp, b: &MemOp) -> AliasResult {
76 if a.space != b.space {
77 return AliasResult::NoAlias;
78 }
79 if a.addr == b.addr {
80 return AliasResult::MustAlias;
81 }
82 AliasResult::MayAlias
83 }
84
85 #[must_use]
87 pub fn mem_ops(&self) -> &[MemOp] {
88 &self.mem_ops
89 }
90
91 #[must_use]
93 pub fn addr_space_of(&self, value: ValueId) -> Option<AddressSpace> {
94 self.addr_spaces.get(&value).copied()
95 }
96
97 #[must_use]
99 pub fn may_alias_set(&self, op: &MemOp) -> HashSet<ValueId> {
100 let mut result = HashSet::new();
101 for mop in &self.mem_ops {
102 if self.query(op, mop) != AliasResult::NoAlias {
103 result.insert(mop.addr);
104 }
105 }
106 result
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn test_different_address_spaces_no_alias() {
116 let a = MemOp {
117 addr: ValueId(0),
118 space: AddressSpace::Device,
119 };
120 let b = MemOp {
121 addr: ValueId(1),
122 space: AddressSpace::Local,
123 };
124 let info = AliasInfo {
125 mem_ops: vec![a.clone(), b.clone()],
126 addr_spaces: HashMap::new(),
127 };
128 assert_eq!(info.query(&a, &b), AliasResult::NoAlias);
129 }
130
131 #[test]
132 fn test_same_addr_must_alias() {
133 let a = MemOp {
134 addr: ValueId(0),
135 space: AddressSpace::Device,
136 };
137 let b = MemOp {
138 addr: ValueId(0),
139 space: AddressSpace::Device,
140 };
141 let info = AliasInfo {
142 mem_ops: vec![a.clone(), b.clone()],
143 addr_spaces: HashMap::new(),
144 };
145 assert_eq!(info.query(&a, &b), AliasResult::MustAlias);
146 }
147
148 #[test]
149 fn test_same_space_different_addr_may_alias() {
150 let a = MemOp {
151 addr: ValueId(0),
152 space: AddressSpace::Device,
153 };
154 let b = MemOp {
155 addr: ValueId(1),
156 space: AddressSpace::Device,
157 };
158 let info = AliasInfo {
159 mem_ops: vec![a.clone(), b.clone()],
160 addr_spaces: HashMap::new(),
161 };
162 assert_eq!(info.query(&a, &b), AliasResult::MayAlias);
163 }
164
165 #[test]
166 fn test_may_alias_set() {
167 let ops = vec![
168 MemOp {
169 addr: ValueId(0),
170 space: AddressSpace::Device,
171 },
172 MemOp {
173 addr: ValueId(1),
174 space: AddressSpace::Device,
175 },
176 MemOp {
177 addr: ValueId(2),
178 space: AddressSpace::Local,
179 },
180 ];
181 let info = AliasInfo {
182 mem_ops: ops.clone(),
183 addr_spaces: HashMap::new(),
184 };
185 let aliases = info.may_alias_set(&ops[0]);
186 assert!(aliases.contains(&ValueId(0)));
187 assert!(aliases.contains(&ValueId(1)));
188 assert!(!aliases.contains(&ValueId(2)));
189 }
190}