1use vyre::ir::{Expr, Node, Program};
10
11use crate::{
12 eval_expr, oob,
13 workgroup::{Frame, Invocation, Memory},
14};
15use vyre::Error;
16
17pub fn step<'a>(
24 invocation: &mut Invocation<'a>,
25 memory: &mut Memory,
26 program: &'a Program,
27) -> Result<(), vyre::Error> {
28 if invocation.done() || invocation.waiting_at_barrier {
29 return Ok(());
30 }
31
32 loop {
33 let Some(frame) = invocation.frames_mut().pop() else {
34 return Ok(());
35 };
36 match frame {
37 Frame::Nodes {
38 nodes,
39 index,
40 scoped,
41 } => {
42 if step_nodes_frame(invocation, memory, program, nodes, index, scoped)? {
43 return Ok(());
44 }
45 }
46 Frame::Loop {
47 var,
48 next,
49 to,
50 body,
51 } => step_loop_frame(invocation, var, next, to, body)?,
52 }
53 }
54}
55
56fn step_nodes_frame<'a>(
57 invocation: &mut Invocation<'a>,
58 memory: &mut Memory,
59 program: &'a Program,
60 nodes: &'a [Node],
61 index: usize,
62 scoped: bool,
63) -> Result<bool, vyre::Error> {
64 if index >= nodes.len() {
65 if scoped {
66 invocation.pop_scope();
67 }
68 return Ok(false);
69 }
70
71 invocation.frames_mut().push(Frame::Nodes {
72 nodes,
73 index: index + 1,
74 scoped,
75 });
76 execute_node(&nodes[index], invocation, memory, program)?;
77 Ok(true)
78}
79
80fn step_loop_frame<'a>(
81 invocation: &mut Invocation<'a>,
82 var: &'a str,
83 next: u32,
84 to: u32,
85 body: &'a [Node],
86) -> Result<(), vyre::Error> {
87 if next >= to {
88 return Ok(());
89 }
90 invocation.frames_mut().push(Frame::Loop {
91 var,
92 next: next.wrapping_add(1),
93 to,
94 body,
95 });
96 invocation.push_scope();
97 invocation.bind_loop_var(var, crate::value::Value::U32(next))?;
98 invocation.frames_mut().push(Frame::Nodes {
99 nodes: body,
100 index: 0,
101 scoped: true,
102 });
103 Ok(())
104}
105
106fn execute_node<'a>(
107 node: &'a Node,
108 invocation: &mut Invocation<'a>,
109 memory: &mut Memory,
110 program: &'a Program,
111) -> Result<(), vyre::Error> {
112 match node {
113 Node::Let { name, value } => eval_let(name, value, invocation, memory, program),
114 Node::Assign { name, value } => eval_assign(name, value, invocation, memory, program),
115 Node::Store {
116 buffer,
117 index,
118 value,
119 } => eval_store(buffer, index, value, invocation, memory, program),
120 Node::If {
121 cond,
122 then,
123 otherwise,
124 } => eval_if(cond, then, otherwise, node, invocation, memory, program),
125 Node::Loop {
126 var,
127 from,
128 to,
129 body,
130 } => eval_loop(var, from, to, body, invocation, memory, program),
131 Node::Return => eval_return(invocation),
132 Node::Block(nodes) => eval_block(nodes, invocation),
133 Node::Barrier => eval_barrier(invocation),
134 }
135}
136
137fn eval_let(
138 name: &str,
139 value: &Expr,
140 invocation: &mut Invocation<'_>,
141 memory: &mut Memory,
142 program: &Program,
143) -> Result<(), vyre::Error> {
144 let value = eval_expr::eval(value, invocation, memory, program)?;
145 invocation.bind(name, value)
146}
147
148fn eval_assign(
149 name: &str,
150 value: &Expr,
151 invocation: &mut Invocation<'_>,
152 memory: &mut Memory,
153 program: &Program,
154) -> Result<(), vyre::Error> {
155 let value = eval_expr::eval(value, invocation, memory, program)?;
156 invocation.assign(name, value)
157}
158
159fn eval_store(
160 buffer: &str,
161 index: &Expr,
162 value: &Expr,
163 invocation: &mut Invocation<'_>,
164 memory: &mut Memory,
165 program: &Program,
166) -> Result<(), vyre::Error> {
167 let index = eval_expr::eval(index, invocation, memory, program)?;
168 let index = index
169 .try_as_u32()
170 .ok_or_else(|| Error::interp(format!(
171 "store index {index:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
172 )))?;
173 let value = eval_expr::eval(value, invocation, memory, program)?;
174 let target = eval_expr::buffer_mut(memory, program, buffer)?;
175 oob::store(target, index, &value);
176 Ok(())
177}
178
179fn eval_if<'a>(
180 cond: &Expr,
181 then: &'a [Node],
182 otherwise: &'a [Node],
183 node: &Node,
184 invocation: &mut Invocation<'a>,
185 memory: &mut Memory,
186 program: &Program,
187) -> Result<(), vyre::Error> {
188 let cond_value = eval_expr::eval(cond, invocation, memory, program)?.truthy();
189 if contains_barrier(then) || contains_barrier(otherwise) {
190 invocation.uniform_checks.push((node_id(node), cond_value));
191 }
192 let branch = if cond_value { then } else { otherwise };
193 invocation.push_scope();
194 invocation.frames_mut().push(Frame::Nodes {
195 nodes: branch,
196 index: 0,
197 scoped: true,
198 });
199 Ok(())
200}
201
202fn eval_loop<'a>(
203 var: &'a str,
204 from: &Expr,
205 to: &Expr,
206 body: &'a [Node],
207 invocation: &mut Invocation<'a>,
208 memory: &mut Memory,
209 program: &Program,
210) -> Result<(), vyre::Error> {
211 let from_value = eval_expr::eval(from, invocation, memory, program)?;
212 let to_value = eval_expr::eval(to, invocation, memory, program)?;
213 let from = from_value.try_as_u32().ok_or_else(|| {
214 Error::interp(format!(
215 "loop lower bound {from_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
216 ))
217 })?;
218 let to = to_value.try_as_u32().ok_or_else(|| Error::interp(format!(
219 "loop upper bound {to_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
220 )))?;
221 invocation.frames_mut().push(Frame::Loop {
222 var,
223 next: from,
224 to,
225 body,
226 });
227 Ok(())
228}
229
230fn eval_return(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
231 invocation.frames_mut().clear();
232 invocation.returned = true;
233 Ok(())
234}
235
236fn eval_block<'a>(nodes: &'a [Node], invocation: &mut Invocation<'a>) -> Result<(), vyre::Error> {
237 invocation.push_scope();
238 invocation.frames_mut().push(Frame::Nodes {
239 nodes,
240 index: 0,
241 scoped: true,
242 });
243 Ok(())
244}
245
246fn eval_barrier(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
247 invocation.waiting_at_barrier = true;
248 Ok(())
249}
250
251fn contains_barrier(nodes: &[Node]) -> bool {
254 nodes.iter().any(node_contains_barrier)
255}
256
257fn node_contains_barrier(node: &Node) -> bool {
258 match node {
259 Node::Barrier => true,
260 Node::Let { .. } | Node::Assign { .. } | Node::Store { .. } | Node::Return => false,
261 Node::If {
262 then, otherwise, ..
263 } => contains_barrier(then) || contains_barrier(otherwise),
264 Node::Loop { body, .. } => contains_barrier(body),
265 Node::Block(body) => contains_barrier(body),
266 }
267}
268
269fn node_id(node: &Node) -> usize {
270 std::ptr::from_ref(node).addr()
271}