1use crate::ir_inner::model::expr::{Expr, ExprNode, Ident};
2use crate::ir_inner::model::types::{AtomicOp, BinOp, DataType, UnOp};
3use crate::visit::VisitOrder;
4use smallvec::SmallVec;
5use std::ops::ControlFlow;
6
7pub trait ExprVisitor {
21 type Break;
23
24 fn visit_lit_u32(&mut self, _expr: &Expr, _value: u32) -> ControlFlow<Self::Break> {
26 ControlFlow::Continue(())
27 }
28 fn visit_lit_i32(&mut self, _expr: &Expr, _value: i32) -> ControlFlow<Self::Break> {
30 ControlFlow::Continue(())
31 }
32 fn visit_lit_f32(&mut self, _expr: &Expr, _value: f32) -> ControlFlow<Self::Break> {
34 ControlFlow::Continue(())
35 }
36 fn visit_lit_bool(&mut self, _expr: &Expr, _value: bool) -> ControlFlow<Self::Break> {
38 ControlFlow::Continue(())
39 }
40 fn visit_var(&mut self, _expr: &Expr, _name: &Ident) -> ControlFlow<Self::Break> {
42 ControlFlow::Continue(())
43 }
44 fn visit_load(
46 &mut self,
47 _expr: &Expr,
48 _buffer: &Ident,
49 _index: &Expr,
50 ) -> ControlFlow<Self::Break> {
51 ControlFlow::Continue(())
52 }
53 fn visit_buf_len(&mut self, _expr: &Expr, _buffer: &Ident) -> ControlFlow<Self::Break> {
55 ControlFlow::Continue(())
56 }
57 fn visit_invocation_id(&mut self, _expr: &Expr, _axis: u32) -> ControlFlow<Self::Break> {
59 ControlFlow::Continue(())
60 }
61 fn visit_workgroup_id(&mut self, _expr: &Expr, _axis: u32) -> ControlFlow<Self::Break> {
63 ControlFlow::Continue(())
64 }
65 fn visit_local_id(&mut self, _expr: &Expr, _axis: u32) -> ControlFlow<Self::Break> {
67 ControlFlow::Continue(())
68 }
69 fn visit_subgroup_local_id(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
71 ControlFlow::Continue(())
72 }
73 fn visit_subgroup_size(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
75 ControlFlow::Continue(())
76 }
77 fn visit_bin_op(
79 &mut self,
80 _expr: &Expr,
81 _op: &BinOp,
82 _left: &Expr,
83 _right: &Expr,
84 ) -> ControlFlow<Self::Break> {
85 ControlFlow::Continue(())
86 }
87 fn visit_un_op(
89 &mut self,
90 _expr: &Expr,
91 _op: &UnOp,
92 _operand: &Expr,
93 ) -> ControlFlow<Self::Break> {
94 ControlFlow::Continue(())
95 }
96 fn visit_call(
98 &mut self,
99 _expr: &Expr,
100 _op_id: &str,
101 _args: &[Expr],
102 ) -> ControlFlow<Self::Break> {
103 ControlFlow::Continue(())
104 }
105 fn visit_sequence(&mut self, _parts: &[Expr]) -> ControlFlow<Self::Break> {
112 ControlFlow::Continue(())
113 }
114 fn visit_fma(
116 &mut self,
117 _expr: &Expr,
118 _a: &Expr,
119 _b: &Expr,
120 _c: &Expr,
121 ) -> ControlFlow<Self::Break> {
122 ControlFlow::Continue(())
123 }
124 fn visit_select(
126 &mut self,
127 _expr: &Expr,
128 _cond: &Expr,
129 _true_val: &Expr,
130 _false_val: &Expr,
131 ) -> ControlFlow<Self::Break> {
132 ControlFlow::Continue(())
133 }
134 fn visit_cast(
136 &mut self,
137 _expr: &Expr,
138 _target: &DataType,
139 _value: &Expr,
140 ) -> ControlFlow<Self::Break> {
141 ControlFlow::Continue(())
142 }
143 fn visit_atomic(
145 &mut self,
146 _expr: &Expr,
147 _op: &AtomicOp,
148 _buffer: &Ident,
149 _index: &Expr,
150 _expected: Option<&Expr>,
151 _value: &Expr,
152 ) -> ControlFlow<Self::Break> {
153 ControlFlow::Continue(())
154 }
155 fn visit_subgroup_ballot(&mut self, _expr: &Expr, _cond: &Expr) -> ControlFlow<Self::Break> {
157 ControlFlow::Continue(())
158 }
159 fn visit_subgroup_shuffle(
161 &mut self,
162 _expr: &Expr,
163 _value: &Expr,
164 _lane: &Expr,
165 ) -> ControlFlow<Self::Break> {
166 ControlFlow::Continue(())
167 }
168 fn visit_subgroup_add(&mut self, _expr: &Expr, _value: &Expr) -> ControlFlow<Self::Break> {
170 ControlFlow::Continue(())
171 }
172 fn visit_opaque_expr(
174 &mut self,
175 _expr: &Expr,
176 _extension: &dyn ExprNode,
177 ) -> ControlFlow<Self::Break> {
178 ControlFlow::Continue(())
179 }
180
181 fn walk_children_default(&mut self, expr: &Expr, order: VisitOrder) -> ControlFlow<Self::Break>
183 where
184 Self: Sized,
185 {
186 walk_expr_children_default(self, expr, order)
187 }
188}
189
190pub fn visit_expr<V: ExprVisitor>(visitor: &mut V, expr: &Expr) -> ControlFlow<V::Break> {
194 visit_preorder(visitor, expr)
195}
196
197pub fn visit_preorder<V: ExprVisitor>(visitor: &mut V, expr: &Expr) -> ControlFlow<V::Break> {
199 let mut stack = SmallVec::<[&Expr; 32]>::new();
200 stack.push(expr);
201 while let Some(current) = stack.pop() {
202 dispatch_expr(visitor, current)?;
203 push_expr_children_reverse(&mut stack, current);
204 }
205 ControlFlow::Continue(())
206}
207
208pub fn visit_postorder<V: ExprVisitor>(visitor: &mut V, expr: &Expr) -> ControlFlow<V::Break> {
210 let mut stack = SmallVec::<[ExprVisitTask<'_>; 32]>::new();
211 stack.push(ExprVisitTask::Visit(expr));
212 while let Some(task) = stack.pop() {
213 match task {
214 ExprVisitTask::Visit(current) => {
215 stack.push(ExprVisitTask::Dispatch(current));
216 push_expr_child_tasks_reverse(&mut stack, current);
217 }
218 ExprVisitTask::Dispatch(current) => dispatch_expr(visitor, current)?,
219 }
220 }
221 ControlFlow::Continue(())
222}
223
224pub fn walk_expr_children_default<V: ExprVisitor>(
226 visitor: &mut V,
227 expr: &Expr,
228 order: VisitOrder,
229) -> ControlFlow<V::Break> {
230 match expr {
231 Expr::LitU32(_)
232 | Expr::LitI32(_)
233 | Expr::LitF32(_)
234 | Expr::LitBool(_)
235 | Expr::Var(_)
236 | Expr::BufLen { .. }
237 | Expr::InvocationId { .. }
238 | Expr::WorkgroupId { .. }
239 | Expr::LocalId { .. }
240 | Expr::SubgroupLocalId
241 | Expr::SubgroupSize
242 | Expr::Opaque(_) => ControlFlow::Continue(()),
243 Expr::Load { index, .. } | Expr::UnOp { operand: index, .. } => {
244 visit_with_order(visitor, index, order)
245 }
246 Expr::BinOp { left, right, .. } => {
247 visit_with_order(visitor, left, order)?;
248 visit_with_order(visitor, right, order)
249 }
250 Expr::Call { args, .. } => {
251 for arg in args {
252 visit_with_order(visitor, arg, order)?;
253 }
254 ControlFlow::Continue(())
255 }
256 Expr::Select {
257 cond,
258 true_val,
259 false_val,
260 } => {
261 visit_with_order(visitor, cond, order)?;
262 visit_with_order(visitor, true_val, order)?;
263 visit_with_order(visitor, false_val, order)
264 }
265 Expr::Cast { value, .. }
266 | Expr::SubgroupBallot { cond: value }
267 | Expr::SubgroupAdd { value } => visit_with_order(visitor, value, order),
268 Expr::Fma { a, b, c } => {
269 visit_with_order(visitor, a, order)?;
270 visit_with_order(visitor, b, order)?;
271 visit_with_order(visitor, c, order)
272 }
273 Expr::Atomic {
274 index,
275 expected,
276 value,
277 ..
278 } => {
279 visit_with_order(visitor, index, order)?;
280 if let Some(expected) = expected.as_deref() {
281 visit_with_order(visitor, expected, order)?;
282 }
283 visit_with_order(visitor, value, order)
284 }
285 Expr::SubgroupShuffle { value, lane } => {
286 visit_with_order(visitor, value, order)?;
287 visit_with_order(visitor, lane, order)
288 }
289 }
290}
291
292fn visit_with_order<V: ExprVisitor>(
293 visitor: &mut V,
294 expr: &Expr,
295 order: VisitOrder,
296) -> ControlFlow<V::Break> {
297 match order {
298 VisitOrder::Preorder => visit_preorder(visitor, expr),
299 VisitOrder::Postorder => visit_postorder(visitor, expr),
300 }
301}
302
303fn push_expr_children_reverse<'a>(stack: &mut SmallVec<[&'a Expr; 32]>, expr: &'a Expr) {
304 match expr {
305 Expr::LitU32(_)
306 | Expr::LitI32(_)
307 | Expr::LitF32(_)
308 | Expr::LitBool(_)
309 | Expr::Var(_)
310 | Expr::BufLen { .. }
311 | Expr::InvocationId { .. }
312 | Expr::WorkgroupId { .. }
313 | Expr::LocalId { .. }
314 | Expr::SubgroupLocalId
315 | Expr::SubgroupSize
316 | Expr::Opaque(_) => {}
317 Expr::Load { index, .. }
318 | Expr::UnOp { operand: index, .. }
319 | Expr::Cast { value: index, .. }
320 | Expr::SubgroupBallot { cond: index }
321 | Expr::SubgroupAdd { value: index } => stack.push(index),
322 Expr::BinOp { left, right, .. } => {
323 stack.push(right);
324 stack.push(left);
325 }
326 Expr::Call { args, .. } => {
327 for arg in args.iter().rev() {
328 stack.push(arg);
329 }
330 }
331 Expr::Fma { a, b, c } => {
332 stack.push(c);
333 stack.push(b);
334 stack.push(a);
335 }
336 Expr::Select {
337 cond,
338 true_val,
339 false_val,
340 } => {
341 stack.push(false_val);
342 stack.push(true_val);
343 stack.push(cond);
344 }
345 Expr::Atomic {
346 index,
347 expected,
348 value,
349 ..
350 } => {
351 stack.push(value);
352 if let Some(expected) = expected.as_deref() {
353 stack.push(expected);
354 }
355 stack.push(index);
356 }
357 Expr::SubgroupShuffle { value, lane } => {
358 stack.push(lane);
359 stack.push(value);
360 }
361 }
362}
363
364fn push_expr_child_tasks_reverse<'a>(
365 stack: &mut SmallVec<[ExprVisitTask<'a>; 32]>,
366 expr: &'a Expr,
367) {
368 match expr {
369 Expr::LitU32(_)
370 | Expr::LitI32(_)
371 | Expr::LitF32(_)
372 | Expr::LitBool(_)
373 | Expr::Var(_)
374 | Expr::BufLen { .. }
375 | Expr::InvocationId { .. }
376 | Expr::WorkgroupId { .. }
377 | Expr::LocalId { .. }
378 | Expr::SubgroupLocalId
379 | Expr::SubgroupSize
380 | Expr::Opaque(_) => {}
381 Expr::Load { index, .. }
382 | Expr::UnOp { operand: index, .. }
383 | Expr::Cast { value: index, .. }
384 | Expr::SubgroupBallot { cond: index }
385 | Expr::SubgroupAdd { value: index } => stack.push(ExprVisitTask::Visit(index)),
386 Expr::BinOp { left, right, .. } => {
387 stack.push(ExprVisitTask::Visit(right));
388 stack.push(ExprVisitTask::Visit(left));
389 }
390 Expr::Call { args, .. } => {
391 for arg in args.iter().rev() {
392 stack.push(ExprVisitTask::Visit(arg));
393 }
394 }
395 Expr::Fma { a, b, c } => {
396 stack.push(ExprVisitTask::Visit(c));
397 stack.push(ExprVisitTask::Visit(b));
398 stack.push(ExprVisitTask::Visit(a));
399 }
400 Expr::Select {
401 cond,
402 true_val,
403 false_val,
404 } => {
405 stack.push(ExprVisitTask::Visit(false_val));
406 stack.push(ExprVisitTask::Visit(true_val));
407 stack.push(ExprVisitTask::Visit(cond));
408 }
409 Expr::Atomic {
410 index,
411 expected,
412 value,
413 ..
414 } => {
415 stack.push(ExprVisitTask::Visit(value));
416 if let Some(expected) = expected.as_deref() {
417 stack.push(ExprVisitTask::Visit(expected));
418 }
419 stack.push(ExprVisitTask::Visit(index));
420 }
421 Expr::SubgroupShuffle { value, lane } => {
422 stack.push(ExprVisitTask::Visit(lane));
423 stack.push(ExprVisitTask::Visit(value));
424 }
425 }
426}
427
428enum ExprVisitTask<'a> {
429 Visit(&'a Expr),
430 Dispatch(&'a Expr),
431}
432
433fn dispatch_expr<V: ExprVisitor>(visitor: &mut V, expr: &Expr) -> ControlFlow<V::Break> {
434 match expr {
435 Expr::LitU32(value) => visitor.visit_lit_u32(expr, *value),
436 Expr::LitI32(value) => visitor.visit_lit_i32(expr, *value),
437 Expr::LitF32(value) => visitor.visit_lit_f32(expr, *value),
438 Expr::LitBool(value) => visitor.visit_lit_bool(expr, *value),
439 Expr::Var(name) => visitor.visit_var(expr, name),
440 Expr::Load { buffer, index } => visitor.visit_load(expr, buffer, index),
441 Expr::BufLen { buffer } => visitor.visit_buf_len(expr, buffer),
442 Expr::InvocationId { axis } => visitor.visit_invocation_id(expr, (*axis).into()),
443 Expr::WorkgroupId { axis } => visitor.visit_workgroup_id(expr, (*axis).into()),
444 Expr::LocalId { axis } => visitor.visit_local_id(expr, (*axis).into()),
445 Expr::BinOp { op, left, right } => visitor.visit_bin_op(expr, op, left, right),
446 Expr::UnOp { op, operand } => visitor.visit_un_op(expr, op, operand),
447 Expr::Call { op_id, args } => visitor.visit_call(expr, op_id, args),
448 Expr::Fma { a, b, c } => visitor.visit_fma(expr, a, b, c),
449 Expr::Select {
450 cond,
451 true_val,
452 false_val,
453 } => visitor.visit_select(expr, cond, true_val, false_val),
454 Expr::Cast { target, value } => visitor.visit_cast(expr, target, value),
455 Expr::Atomic {
456 op,
457 buffer,
458 index,
459 expected,
460 value,
461 ordering: _,
462 } => visitor.visit_atomic(expr, op, buffer, index, expected.as_deref(), value),
463 Expr::SubgroupBallot { cond } => visitor.visit_subgroup_ballot(expr, cond),
464 Expr::SubgroupShuffle { value, lane } => visitor.visit_subgroup_shuffle(expr, value, lane),
465 Expr::SubgroupAdd { value } => visitor.visit_subgroup_add(expr, value),
466 Expr::SubgroupLocalId => visitor.visit_subgroup_local_id(expr),
467 Expr::SubgroupSize => visitor.visit_subgroup_size(expr),
468 Expr::Opaque(extension) => visitor.visit_opaque_expr(expr, extension.as_ref()),
469 }
470}