1use crate::error::Result;
2use crate::ir_inner::model::expr::{Expr, GeneratorRef, Ident};
3use crate::ir_inner::model::generated::Node;
4use crate::ir_inner::model::node::NodeExtension;
5use crate::visit::VisitOrder;
6use smallvec::SmallVec;
7use std::ops::ControlFlow;
8
9pub trait Lowerable<Ctx: ?Sized> {
19 fn lower(&self, ctx: &mut Ctx) -> Result<()>;
26}
27
28pub trait Evaluatable<Env: ?Sized> {
35 type Value;
38
39 fn evaluate(&self, env: &mut Env) -> Result<Self::Value>;
46}
47
48pub trait NodeVisitor {
64 type Break;
66
67 fn visit_let(&mut self, node: &Node, name: &Ident, value: &Expr) -> ControlFlow<Self::Break>;
69 fn visit_assign(&mut self, node: &Node, name: &Ident, value: &Expr)
71 -> ControlFlow<Self::Break>;
72 fn visit_store(
74 &mut self,
75 node: &Node,
76 buffer: &Ident,
77 index: &Expr,
78 value: &Expr,
79 ) -> ControlFlow<Self::Break>;
80 fn visit_if(
82 &mut self,
83 node: &Node,
84 cond: &Expr,
85 then_nodes: &[Node],
86 otherwise: &[Node],
87 ) -> ControlFlow<Self::Break>;
88 fn visit_loop(
90 &mut self,
91 node: &Node,
92 var: &Ident,
93 from: &Expr,
94 to: &Expr,
95 body: &[Node],
96 ) -> ControlFlow<Self::Break>;
97 fn visit_indirect_dispatch(
99 &mut self,
100 node: &Node,
101 count_buffer: &Ident,
102 count_offset: u64,
103 ) -> ControlFlow<Self::Break>;
104 fn visit_async_load(
106 &mut self,
107 node: &Node,
108 source: &Ident,
109 destination: &Ident,
110 offset: &Expr,
111 size: &Expr,
112 tag: &Ident,
113 ) -> ControlFlow<Self::Break>;
114 fn visit_async_store(
116 &mut self,
117 node: &Node,
118 source: &Ident,
119 destination: &Ident,
120 offset: &Expr,
121 size: &Expr,
122 tag: &Ident,
123 ) -> ControlFlow<Self::Break>;
124 fn visit_async_wait(&mut self, node: &Node, tag: &Ident) -> ControlFlow<Self::Break>;
126 fn visit_trap(&mut self, node: &Node, address: &Expr, tag: &Ident) -> ControlFlow<Self::Break>;
128 fn visit_resume(&mut self, node: &Node, tag: &Ident) -> ControlFlow<Self::Break>;
130 fn visit_return(&mut self, node: &Node) -> ControlFlow<Self::Break>;
132 fn visit_barrier(&mut self, node: &Node) -> ControlFlow<Self::Break>;
134 fn visit_collective(&mut self, node: &Node) -> ControlFlow<Self::Break> {
136 let _ = node;
137 ControlFlow::Continue(())
138 }
139 fn visit_block(&mut self, node: &Node, body: &[Node]) -> ControlFlow<Self::Break>;
141 fn visit_region(
143 &mut self,
144 node: &Node,
145 generator: &Ident,
146 source_region: &Option<GeneratorRef>,
147 body: &[Node],
148 ) -> ControlFlow<Self::Break>;
149 fn visit_opaque_node(
151 &mut self,
152 node: &Node,
153 extension: &dyn NodeExtension,
154 ) -> ControlFlow<Self::Break>;
155
156 fn walk_children_default(&mut self, node: &Node, order: VisitOrder) -> ControlFlow<Self::Break>
158 where
159 Self: Sized,
160 {
161 walk_node_children_default(self, node, order)
162 }
163}
164
165pub fn visit_node<V: NodeVisitor>(visitor: &mut V, node: &Node) -> ControlFlow<V::Break> {
167 visit_node_preorder(visitor, node)
168}
169
170pub fn visit_node_preorder<V: NodeVisitor>(visitor: &mut V, node: &Node) -> ControlFlow<V::Break> {
172 let mut stack = SmallVec::<[&Node; 32]>::new();
173 stack.push(node);
174 while let Some(current) = stack.pop() {
175 dispatch_node(visitor, current)?;
176 match current {
177 Node::If {
178 then, otherwise, ..
179 } => {
180 for n in otherwise.iter().rev() {
181 stack.push(n);
182 }
183 for n in then.iter().rev() {
184 stack.push(n);
185 }
186 }
187 Node::Loop { body, .. } | Node::Block(body) => {
188 for n in body.iter().rev() {
189 stack.push(n);
190 }
191 }
192 Node::Region { body, .. } => {
193 for n in body.iter().rev() {
194 stack.push(n);
195 }
196 }
197 _ => {}
198 }
199 }
200 ControlFlow::Continue(())
201}
202
203pub fn visit_node_postorder<V: NodeVisitor>(visitor: &mut V, node: &Node) -> ControlFlow<V::Break> {
205 enum Task<'a> {
206 Visit(&'a Node),
207 Dispatch(&'a Node),
208 }
209 let mut stack = SmallVec::<[Task<'_>; 32]>::new();
210 stack.push(Task::Visit(node));
211 while let Some(task) = stack.pop() {
212 match task {
213 Task::Visit(n) => {
214 stack.push(Task::Dispatch(n));
215 match n {
216 Node::If {
217 then, otherwise, ..
218 } => {
219 for child in otherwise.iter().rev() {
220 stack.push(Task::Visit(child));
221 }
222 for child in then.iter().rev() {
223 stack.push(Task::Visit(child));
224 }
225 }
226 Node::Loop { body, .. } | Node::Block(body) => {
227 for child in body.iter().rev() {
228 stack.push(Task::Visit(child));
229 }
230 }
231 Node::Region { body, .. } => {
232 for child in body.iter().rev() {
233 stack.push(Task::Visit(child));
234 }
235 }
236 _ => {}
237 }
238 }
239 Task::Dispatch(n) => {
240 dispatch_node(visitor, n)?;
241 }
242 }
243 }
244 ControlFlow::Continue(())
245}
246
247pub fn walk_node_children_default<V: NodeVisitor>(
249 visitor: &mut V,
250 node: &Node,
251 order: VisitOrder,
252) -> ControlFlow<V::Break> {
253 match node {
254 Node::If {
255 then, otherwise, ..
256 } => {
257 for child in then {
258 visit_node_with_order(visitor, child, order)?;
259 }
260 for child in otherwise {
261 visit_node_with_order(visitor, child, order)?;
262 }
263 }
264 Node::Loop { body, .. } | Node::Block(body) => {
265 for child in body {
266 visit_node_with_order(visitor, child, order)?;
267 }
268 }
269 Node::Region { body, .. } => {
270 for child in body.iter() {
271 visit_node_with_order(visitor, child, order)?;
272 }
273 }
274 _ => {}
275 }
276 ControlFlow::Continue(())
277}
278
279fn visit_node_with_order<V: NodeVisitor>(
280 visitor: &mut V,
281 node: &Node,
282 order: VisitOrder,
283) -> ControlFlow<V::Break> {
284 match order {
285 VisitOrder::Preorder => visit_node_preorder(visitor, node),
286 VisitOrder::Postorder => visit_node_postorder(visitor, node),
287 }
288}
289
290pub(crate) fn dispatch_node<V: NodeVisitor>(visitor: &mut V, node: &Node) -> ControlFlow<V::Break> {
291 match node {
292 Node::Let { name, value } => visitor.visit_let(node, name, value),
293 Node::Assign { name, value } => visitor.visit_assign(node, name, value),
294 Node::Store {
295 buffer,
296 index,
297 value,
298 } => visitor.visit_store(node, buffer, index, value),
299 Node::If {
300 cond,
301 then,
302 otherwise,
303 } => visitor.visit_if(node, cond, then, otherwise),
304 Node::Loop {
305 var,
306 from,
307 to,
308 body,
309 } => visitor.visit_loop(node, var, from, to, body),
310 Node::IndirectDispatch {
311 count_buffer,
312 count_offset,
313 } => visitor.visit_indirect_dispatch(node, count_buffer, *count_offset),
314 Node::AsyncLoad {
315 source,
316 destination,
317 offset,
318 size,
319 tag,
320 } => visitor.visit_async_load(node, source, destination, offset, size, tag),
321 Node::AsyncStore {
322 source,
323 destination,
324 offset,
325 size,
326 tag,
327 } => visitor.visit_async_store(node, source, destination, offset, size, tag),
328 Node::AsyncWait { tag } => visitor.visit_async_wait(node, tag),
329 Node::Trap { address, tag } => visitor.visit_trap(node, address, tag),
330 Node::Resume { tag } => visitor.visit_resume(node, tag),
331 Node::AllReduce { .. }
332 | Node::AllGather { .. }
333 | Node::ReduceScatter { .. }
334 | Node::Broadcast { .. } => visitor.visit_collective(node),
335 Node::Return => visitor.visit_return(node),
336 Node::Barrier { .. } => visitor.visit_barrier(node),
337 Node::Block(body) => visitor.visit_block(node, body),
338 Node::Region {
339 generator,
340 source_region,
341 body,
342 } => visitor.visit_region(node, generator, source_region, body),
343 Node::Opaque(extension) => visitor.visit_opaque_node(node, extension.as_ref()),
344 }
345}