Skip to main content

vyre_foundation/visit/
expr.rs

1use crate::ir_inner::model::expr::{Expr, ExprNode, Ident};
2use crate::ir_inner::model::types::{AtomicOp, BinOp, DataType, UnOp};
3use crate::visit::VisitOrder;
4use smallvec::SmallVec;
5use std::ops::ControlFlow;
6
7/// Visitor over [`Expr`] trees.
8///
9/// Implementors must handle every core variant explicitly. This is
10/// intentional: `Expr` is `#[non_exhaustive]`, so a new variant must
11/// become a compile error in every visitor instead of silently
12/// disappearing behind a default body.
13///
14/// Traversal order is explicit:
15/// - [`visit_preorder`] visits the current expression before its children.
16/// - [`visit_postorder`] visits children before the current expression.
17///
18/// Visitors that want pass-through recursion can call
19/// [`ExprVisitor::walk_children_default`] from a variant method.
20pub trait ExprVisitor {
21    /// Break payload returned when traversal short-circuits.
22    type Break;
23
24    /// Integer literal (`u32`).
25    fn visit_lit_u32(&mut self, _expr: &Expr, _value: u32) -> ControlFlow<Self::Break> {
26        ControlFlow::Continue(())
27    }
28    /// Integer literal (`i32`).
29    fn visit_lit_i32(&mut self, _expr: &Expr, _value: i32) -> ControlFlow<Self::Break> {
30        ControlFlow::Continue(())
31    }
32    /// Float literal (`f32`).
33    fn visit_lit_f32(&mut self, _expr: &Expr, _value: f32) -> ControlFlow<Self::Break> {
34        ControlFlow::Continue(())
35    }
36    /// Bool literal.
37    fn visit_lit_bool(&mut self, _expr: &Expr, _value: bool) -> ControlFlow<Self::Break> {
38        ControlFlow::Continue(())
39    }
40    /// Variable reference.
41    fn visit_var(&mut self, _expr: &Expr, _name: &Ident) -> ControlFlow<Self::Break> {
42        ControlFlow::Continue(())
43    }
44    /// Buffer load (`buffer[index]`).
45    fn visit_load(
46        &mut self,
47        _expr: &Expr,
48        _buffer: &Ident,
49        _index: &Expr,
50    ) -> ControlFlow<Self::Break> {
51        ControlFlow::Continue(())
52    }
53    /// Buffer length.
54    fn visit_buf_len(&mut self, _expr: &Expr, _buffer: &Ident) -> ControlFlow<Self::Break> {
55        ControlFlow::Continue(())
56    }
57    /// Invocation id axis (`gid.{x,y,z}`).
58    fn visit_invocation_id(&mut self, _expr: &Expr, _axis: u32) -> ControlFlow<Self::Break> {
59        ControlFlow::Continue(())
60    }
61    /// Workgroup id axis.
62    fn visit_workgroup_id(&mut self, _expr: &Expr, _axis: u32) -> ControlFlow<Self::Break> {
63        ControlFlow::Continue(())
64    }
65    /// Local id axis within the workgroup.
66    fn visit_local_id(&mut self, _expr: &Expr, _axis: u32) -> ControlFlow<Self::Break> {
67        ControlFlow::Continue(())
68    }
69    /// Subgroup invocation id (lane index within subgroup).
70    fn visit_subgroup_local_id(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
71        ControlFlow::Continue(())
72    }
73    /// Subgroup size.
74    fn visit_subgroup_size(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
75        ControlFlow::Continue(())
76    }
77    /// Binary operation.
78    fn visit_bin_op(
79        &mut self,
80        _expr: &Expr,
81        _op: &BinOp,
82        _left: &Expr,
83        _right: &Expr,
84    ) -> ControlFlow<Self::Break> {
85        ControlFlow::Continue(())
86    }
87    /// Unary operation.
88    fn visit_un_op(
89        &mut self,
90        _expr: &Expr,
91        _op: &UnOp,
92        _operand: &Expr,
93    ) -> ControlFlow<Self::Break> {
94        ControlFlow::Continue(())
95    }
96    /// Function call.
97    fn visit_call(
98        &mut self,
99        _expr: &Expr,
100        _op_id: &str,
101        _args: &[Expr],
102    ) -> ControlFlow<Self::Break> {
103        ControlFlow::Continue(())
104    }
105    /// Sequence-valued extension hook.
106    ///
107    /// Core IR does not currently emit a dedicated `Expr::Sequence`
108    /// variant, but downstream visitor implementations must still opt in
109    /// explicitly so a sequence node cannot compile behind a silent
110    /// default body.
111    fn visit_sequence(&mut self, _parts: &[Expr]) -> ControlFlow<Self::Break> {
112        ControlFlow::Continue(())
113    }
114    /// Fused multiply-add (`a * b + c`).
115    fn visit_fma(
116        &mut self,
117        _expr: &Expr,
118        _a: &Expr,
119        _b: &Expr,
120        _c: &Expr,
121    ) -> ControlFlow<Self::Break> {
122        ControlFlow::Continue(())
123    }
124    /// Ternary `select(cond, true_val, false_val)`.
125    fn visit_select(
126        &mut self,
127        _expr: &Expr,
128        _cond: &Expr,
129        _true_val: &Expr,
130        _false_val: &Expr,
131    ) -> ControlFlow<Self::Break> {
132        ControlFlow::Continue(())
133    }
134    /// Numeric cast.
135    fn visit_cast(
136        &mut self,
137        _expr: &Expr,
138        _target: &DataType,
139        _value: &Expr,
140    ) -> ControlFlow<Self::Break> {
141        ControlFlow::Continue(())
142    }
143    /// Atomic operation on a shared buffer.
144    fn visit_atomic(
145        &mut self,
146        _expr: &Expr,
147        _op: &AtomicOp,
148        _buffer: &Ident,
149        _index: &Expr,
150        _expected: Option<&Expr>,
151        _value: &Expr,
152    ) -> ControlFlow<Self::Break> {
153        ControlFlow::Continue(())
154    }
155    /// Subgroup ballot.
156    fn visit_subgroup_ballot(&mut self, _expr: &Expr, _cond: &Expr) -> ControlFlow<Self::Break> {
157        ControlFlow::Continue(())
158    }
159    /// Subgroup shuffle.
160    fn visit_subgroup_shuffle(
161        &mut self,
162        _expr: &Expr,
163        _value: &Expr,
164        _lane: &Expr,
165    ) -> ControlFlow<Self::Break> {
166        ControlFlow::Continue(())
167    }
168    /// Subgroup add.
169    fn visit_subgroup_add(&mut self, _expr: &Expr, _value: &Expr) -> ControlFlow<Self::Break> {
170        ControlFlow::Continue(())
171    }
172    /// Downstream opaque expression extension.
173    fn visit_opaque_expr(
174        &mut self,
175        _expr: &Expr,
176        _extension: &dyn ExprNode,
177    ) -> ControlFlow<Self::Break> {
178        ControlFlow::Continue(())
179    }
180
181    /// Recursively walk this expression's children using the requested order.
182    fn walk_children_default(&mut self, expr: &Expr, order: VisitOrder) -> ControlFlow<Self::Break>
183    where
184        Self: Sized,
185    {
186        walk_expr_children_default(self, expr, order)
187    }
188}
189
190/// Visit an expression tree in pre-order.
191///
192/// This is the historical default entry point for expression traversal.
193pub fn visit_expr<V: ExprVisitor>(visitor: &mut V, expr: &Expr) -> ControlFlow<V::Break> {
194    visit_preorder(visitor, expr)
195}
196
197/// Visit an expression tree in pre-order.
198pub fn visit_preorder<V: ExprVisitor>(visitor: &mut V, expr: &Expr) -> ControlFlow<V::Break> {
199    let mut stack = SmallVec::<[&Expr; 32]>::new();
200    stack.push(expr);
201    while let Some(current) = stack.pop() {
202        dispatch_expr(visitor, current)?;
203        push_expr_children_reverse(&mut stack, current);
204    }
205    ControlFlow::Continue(())
206}
207
208/// Visit an expression tree in post-order.
209pub fn visit_postorder<V: ExprVisitor>(visitor: &mut V, expr: &Expr) -> ControlFlow<V::Break> {
210    let mut stack = SmallVec::<[ExprVisitTask<'_>; 32]>::new();
211    stack.push(ExprVisitTask::Visit(expr));
212    while let Some(task) = stack.pop() {
213        match task {
214            ExprVisitTask::Visit(current) => {
215                stack.push(ExprVisitTask::Dispatch(current));
216                push_expr_child_tasks_reverse(&mut stack, current);
217            }
218            ExprVisitTask::Dispatch(current) => dispatch_expr(visitor, current)?,
219        }
220    }
221    ControlFlow::Continue(())
222}
223
224/// Walk only the children of `expr`, leaving the current node to the caller.
225pub fn walk_expr_children_default<V: ExprVisitor>(
226    visitor: &mut V,
227    expr: &Expr,
228    order: VisitOrder,
229) -> ControlFlow<V::Break> {
230    match expr {
231        Expr::LitU32(_)
232        | Expr::LitI32(_)
233        | Expr::LitF32(_)
234        | Expr::LitBool(_)
235        | Expr::Var(_)
236        | Expr::BufLen { .. }
237        | Expr::InvocationId { .. }
238        | Expr::WorkgroupId { .. }
239        | Expr::LocalId { .. }
240        | Expr::SubgroupLocalId
241        | Expr::SubgroupSize
242        | Expr::Opaque(_) => ControlFlow::Continue(()),
243        Expr::Load { index, .. } | Expr::UnOp { operand: index, .. } => {
244            visit_with_order(visitor, index, order)
245        }
246        Expr::BinOp { left, right, .. } => {
247            visit_with_order(visitor, left, order)?;
248            visit_with_order(visitor, right, order)
249        }
250        Expr::Call { args, .. } => {
251            for arg in args {
252                visit_with_order(visitor, arg, order)?;
253            }
254            ControlFlow::Continue(())
255        }
256        Expr::Select {
257            cond,
258            true_val,
259            false_val,
260        } => {
261            visit_with_order(visitor, cond, order)?;
262            visit_with_order(visitor, true_val, order)?;
263            visit_with_order(visitor, false_val, order)
264        }
265        Expr::Cast { value, .. }
266        | Expr::SubgroupBallot { cond: value }
267        | Expr::SubgroupAdd { value } => visit_with_order(visitor, value, order),
268        Expr::Fma { a, b, c } => {
269            visit_with_order(visitor, a, order)?;
270            visit_with_order(visitor, b, order)?;
271            visit_with_order(visitor, c, order)
272        }
273        Expr::Atomic {
274            index,
275            expected,
276            value,
277            ..
278        } => {
279            visit_with_order(visitor, index, order)?;
280            if let Some(expected) = expected.as_deref() {
281                visit_with_order(visitor, expected, order)?;
282            }
283            visit_with_order(visitor, value, order)
284        }
285        Expr::SubgroupShuffle { value, lane } => {
286            visit_with_order(visitor, value, order)?;
287            visit_with_order(visitor, lane, order)
288        }
289    }
290}
291
292fn visit_with_order<V: ExprVisitor>(
293    visitor: &mut V,
294    expr: &Expr,
295    order: VisitOrder,
296) -> ControlFlow<V::Break> {
297    match order {
298        VisitOrder::Preorder => visit_preorder(visitor, expr),
299        VisitOrder::Postorder => visit_postorder(visitor, expr),
300    }
301}
302
303fn push_expr_children_reverse<'a>(stack: &mut SmallVec<[&'a Expr; 32]>, expr: &'a Expr) {
304    match expr {
305        Expr::LitU32(_)
306        | Expr::LitI32(_)
307        | Expr::LitF32(_)
308        | Expr::LitBool(_)
309        | Expr::Var(_)
310        | Expr::BufLen { .. }
311        | Expr::InvocationId { .. }
312        | Expr::WorkgroupId { .. }
313        | Expr::LocalId { .. }
314        | Expr::SubgroupLocalId
315        | Expr::SubgroupSize
316        | Expr::Opaque(_) => {}
317        Expr::Load { index, .. }
318        | Expr::UnOp { operand: index, .. }
319        | Expr::Cast { value: index, .. }
320        | Expr::SubgroupBallot { cond: index }
321        | Expr::SubgroupAdd { value: index } => stack.push(index),
322        Expr::BinOp { left, right, .. } => {
323            stack.push(right);
324            stack.push(left);
325        }
326        Expr::Call { args, .. } => {
327            for arg in args.iter().rev() {
328                stack.push(arg);
329            }
330        }
331        Expr::Fma { a, b, c } => {
332            stack.push(c);
333            stack.push(b);
334            stack.push(a);
335        }
336        Expr::Select {
337            cond,
338            true_val,
339            false_val,
340        } => {
341            stack.push(false_val);
342            stack.push(true_val);
343            stack.push(cond);
344        }
345        Expr::Atomic {
346            index,
347            expected,
348            value,
349            ..
350        } => {
351            stack.push(value);
352            if let Some(expected) = expected.as_deref() {
353                stack.push(expected);
354            }
355            stack.push(index);
356        }
357        Expr::SubgroupShuffle { value, lane } => {
358            stack.push(lane);
359            stack.push(value);
360        }
361    }
362}
363
364fn push_expr_child_tasks_reverse<'a>(
365    stack: &mut SmallVec<[ExprVisitTask<'a>; 32]>,
366    expr: &'a Expr,
367) {
368    match expr {
369        Expr::LitU32(_)
370        | Expr::LitI32(_)
371        | Expr::LitF32(_)
372        | Expr::LitBool(_)
373        | Expr::Var(_)
374        | Expr::BufLen { .. }
375        | Expr::InvocationId { .. }
376        | Expr::WorkgroupId { .. }
377        | Expr::LocalId { .. }
378        | Expr::SubgroupLocalId
379        | Expr::SubgroupSize
380        | Expr::Opaque(_) => {}
381        Expr::Load { index, .. }
382        | Expr::UnOp { operand: index, .. }
383        | Expr::Cast { value: index, .. }
384        | Expr::SubgroupBallot { cond: index }
385        | Expr::SubgroupAdd { value: index } => stack.push(ExprVisitTask::Visit(index)),
386        Expr::BinOp { left, right, .. } => {
387            stack.push(ExprVisitTask::Visit(right));
388            stack.push(ExprVisitTask::Visit(left));
389        }
390        Expr::Call { args, .. } => {
391            for arg in args.iter().rev() {
392                stack.push(ExprVisitTask::Visit(arg));
393            }
394        }
395        Expr::Fma { a, b, c } => {
396            stack.push(ExprVisitTask::Visit(c));
397            stack.push(ExprVisitTask::Visit(b));
398            stack.push(ExprVisitTask::Visit(a));
399        }
400        Expr::Select {
401            cond,
402            true_val,
403            false_val,
404        } => {
405            stack.push(ExprVisitTask::Visit(false_val));
406            stack.push(ExprVisitTask::Visit(true_val));
407            stack.push(ExprVisitTask::Visit(cond));
408        }
409        Expr::Atomic {
410            index,
411            expected,
412            value,
413            ..
414        } => {
415            stack.push(ExprVisitTask::Visit(value));
416            if let Some(expected) = expected.as_deref() {
417                stack.push(ExprVisitTask::Visit(expected));
418            }
419            stack.push(ExprVisitTask::Visit(index));
420        }
421        Expr::SubgroupShuffle { value, lane } => {
422            stack.push(ExprVisitTask::Visit(lane));
423            stack.push(ExprVisitTask::Visit(value));
424        }
425    }
426}
427
428enum ExprVisitTask<'a> {
429    Visit(&'a Expr),
430    Dispatch(&'a Expr),
431}
432
433fn dispatch_expr<V: ExprVisitor>(visitor: &mut V, expr: &Expr) -> ControlFlow<V::Break> {
434    match expr {
435        Expr::LitU32(value) => visitor.visit_lit_u32(expr, *value),
436        Expr::LitI32(value) => visitor.visit_lit_i32(expr, *value),
437        Expr::LitF32(value) => visitor.visit_lit_f32(expr, *value),
438        Expr::LitBool(value) => visitor.visit_lit_bool(expr, *value),
439        Expr::Var(name) => visitor.visit_var(expr, name),
440        Expr::Load { buffer, index } => visitor.visit_load(expr, buffer, index),
441        Expr::BufLen { buffer } => visitor.visit_buf_len(expr, buffer),
442        Expr::InvocationId { axis } => visitor.visit_invocation_id(expr, (*axis).into()),
443        Expr::WorkgroupId { axis } => visitor.visit_workgroup_id(expr, (*axis).into()),
444        Expr::LocalId { axis } => visitor.visit_local_id(expr, (*axis).into()),
445        Expr::BinOp { op, left, right } => visitor.visit_bin_op(expr, op, left, right),
446        Expr::UnOp { op, operand } => visitor.visit_un_op(expr, op, operand),
447        Expr::Call { op_id, args } => visitor.visit_call(expr, op_id, args),
448        Expr::Fma { a, b, c } => visitor.visit_fma(expr, a, b, c),
449        Expr::Select {
450            cond,
451            true_val,
452            false_val,
453        } => visitor.visit_select(expr, cond, true_val, false_val),
454        Expr::Cast { target, value } => visitor.visit_cast(expr, target, value),
455        Expr::Atomic {
456            op,
457            buffer,
458            index,
459            expected,
460            value,
461            ordering: _,
462        } => visitor.visit_atomic(expr, op, buffer, index, expected.as_deref(), value),
463        Expr::SubgroupBallot { cond } => visitor.visit_subgroup_ballot(expr, cond),
464        Expr::SubgroupShuffle { value, lane } => visitor.visit_subgroup_shuffle(expr, value, lane),
465        Expr::SubgroupAdd { value } => visitor.visit_subgroup_add(expr, value),
466        Expr::SubgroupLocalId => visitor.visit_subgroup_local_id(expr),
467        Expr::SubgroupSize => visitor.visit_subgroup_size(expr),
468        Expr::Opaque(extension) => visitor.visit_opaque_expr(expr, extension.as_ref()),
469    }
470}