Skip to main content

vyre_driver/program_walks/
indirect.rs

1//! Indirect-dispatch discovery over backend-neutral IR.
2
3use std::ops::ControlFlow::{self, Continue};
4
5use vyre_foundation::ir::model::expr::GeneratorRef;
6use vyre_foundation::ir::{Expr, Ident, Node, Program};
7use vyre_foundation::visit::{visit_node_preorder, NodeVisitor};
8
9use crate::backend::BackendError;
10
11/// Command-buffer indirect dispatch source.
12#[derive(Clone, Debug, Eq, PartialEq)]
13pub struct IndirectDispatch {
14    /// Buffer containing the indirect x/y/z workgroup tuple.
15    pub count_buffer: Ident,
16    /// Byte offset of the tuple in the buffer.
17    pub count_offset: u64,
18}
19
20/// Locates the single [`Node::IndirectDispatch`] in a program, if any.
21///
22/// # Errors
23///
24/// Returns when the program is inconsistent (e.g. multiple indirect
25/// sources, or a misaligned offset).
26pub fn find_indirect_dispatch(program: &Program) -> Result<Option<IndirectDispatch>, BackendError> {
27    if !program.has_indirect_dispatch() {
28        return Ok(None);
29    }
30    let mut found = None;
31    let mut collector = IndirectDispatchCollector { found: &mut found };
32    for node in program.entry() {
33        if let ControlFlow::Break(err) = visit_node_preorder(&mut collector, node) {
34            return Err(err);
35        }
36    }
37    Ok(found)
38}
39
40struct IndirectDispatchCollector<'a> {
41    found: &'a mut Option<IndirectDispatch>,
42}
43
44impl NodeVisitor for IndirectDispatchCollector<'_> {
45    type Break = BackendError;
46
47    fn visit_let(&mut self, _: &Node, _: &Ident, _: &Expr) -> ControlFlow<Self::Break> {
48        Continue(())
49    }
50
51    fn visit_assign(&mut self, _: &Node, _: &Ident, _: &Expr) -> ControlFlow<Self::Break> {
52        Continue(())
53    }
54
55    fn visit_store(&mut self, _: &Node, _: &Ident, _: &Expr, _: &Expr) -> ControlFlow<Self::Break> {
56        Continue(())
57    }
58
59    fn visit_if(&mut self, _: &Node, _: &Expr, _: &[Node], _: &[Node]) -> ControlFlow<Self::Break> {
60        Continue(())
61    }
62
63    fn visit_loop(
64        &mut self,
65        _: &Node,
66        _: &Ident,
67        _: &Expr,
68        _: &Expr,
69        _: &[Node],
70    ) -> ControlFlow<Self::Break> {
71        Continue(())
72    }
73
74    fn visit_indirect_dispatch(
75        &mut self,
76        _: &Node,
77        count_buffer: &Ident,
78        count_offset: u64,
79    ) -> ControlFlow<Self::Break> {
80        if count_offset % 4 != 0 {
81            return ControlFlow::Break(BackendError::new(format!(
82                "indirect dispatch offset {count_offset} is not 4-byte aligned. Fix: use a u32-aligned dispatch tuple."
83            )));
84        }
85        let next = IndirectDispatch {
86            count_buffer: count_buffer.clone(),
87            count_offset,
88        };
89        if self.found.replace(next).is_some() {
90            return ControlFlow::Break(BackendError::new(
91                "program declares more than one indirect dispatch source. Fix: keep exactly one Node::IndirectDispatch per Program.",
92            ));
93        }
94        Continue(())
95    }
96
97    fn visit_async_load(
98        &mut self,
99        _: &Node,
100        _: &Ident,
101        _: &Ident,
102        _: &Expr,
103        _: &Expr,
104        _: &Ident,
105    ) -> ControlFlow<Self::Break> {
106        Continue(())
107    }
108
109    fn visit_async_store(
110        &mut self,
111        _: &Node,
112        _: &Ident,
113        _: &Ident,
114        _: &Expr,
115        _: &Expr,
116        _: &Ident,
117    ) -> ControlFlow<Self::Break> {
118        Continue(())
119    }
120
121    fn visit_async_wait(&mut self, _: &Node, _: &Ident) -> ControlFlow<Self::Break> {
122        Continue(())
123    }
124
125    fn visit_trap(&mut self, _: &Node, _: &Expr, _: &Ident) -> ControlFlow<Self::Break> {
126        Continue(())
127    }
128
129    fn visit_resume(&mut self, _: &Node, _: &Ident) -> ControlFlow<Self::Break> {
130        Continue(())
131    }
132
133    fn visit_return(&mut self, _: &Node) -> ControlFlow<Self::Break> {
134        Continue(())
135    }
136
137    fn visit_barrier(&mut self, _: &Node) -> ControlFlow<Self::Break> {
138        Continue(())
139    }
140
141    fn visit_block(&mut self, _: &Node, _: &[Node]) -> ControlFlow<Self::Break> {
142        Continue(())
143    }
144
145    fn visit_region(
146        &mut self,
147        _: &Node,
148        _: &Ident,
149        _: &Option<GeneratorRef>,
150        _: &[Node],
151    ) -> ControlFlow<Self::Break> {
152        Continue(())
153    }
154
155    fn visit_opaque_node(
156        &mut self,
157        _: &Node,
158        _: &dyn vyre_foundation::ir::NodeExtension,
159    ) -> ControlFlow<Self::Break> {
160        Continue(())
161    }
162}