1use vyre::ir::{AtomicOp, BinOp, BufferAccess, BufferDecl, DataType, Expr, Program, UnOp};
10
11use smallvec::SmallVec;
12use vyre::Error;
13
14use crate::execution::expr_cast::cast_value;
15use crate::{atomics, oob, value::Value, workgroup::Invocation, workgroup::Memory};
16
17pub use crate::oob::Buffer;
19
20pub fn eval(
27 expr: &Expr,
28 invocation: &mut Invocation<'_>,
29 memory: &mut Memory,
30 program: &Program,
31) -> Result<Value, vyre::Error> {
32 eval_frame_oracle(expr, invocation, memory, program)
33}
34
35pub(crate) fn eval_frame_oracle(
42 expr: &Expr,
43 invocation: &mut Invocation<'_>,
44 memory: &mut Memory,
45 program: &Program,
46) -> Result<Value, vyre::Error> {
47 enum Frame<'a> {
48 Expr(&'a Expr),
49 BinOp(BinOp),
50 UnOp(&'a UnOp),
51 Select,
52 Cast(&'a DataType),
53 Fma,
54 Load {
55 buffer: &'a str,
56 },
57 AtomicIndex {
58 op: AtomicOp,
59 buffer: &'a str,
60 expected: Option<&'a Expr>,
61 value: &'a Expr,
62 },
63 AtomicExpected {
64 op: AtomicOp,
65 buffer: &'a str,
66 index: u32,
67 value: &'a Expr,
68 expected_expr: &'a Expr,
69 },
70 AtomicValue {
71 op: AtomicOp,
72 buffer: &'a str,
73 expected: Option<u32>,
74 index: u32,
75 },
76 }
77
78 let mut frames: SmallVec<[Frame<'_>; 32]> = SmallVec::new();
79 frames.push(Frame::Expr(expr));
80 let mut values: SmallVec<[Value; 32]> = SmallVec::new();
81
82 while let Some(frame) = frames.pop() {
83 match frame {
84 Frame::Expr(expr) => match expr {
85 Expr::LitU32(value) => values.push(Value::U32(*value)),
86 Expr::LitI32(value) => values.push(Value::I32(*value)),
87 Expr::LitF32(value) => {
88 values.push(Value::Float(f64::from(
89 crate::execution::typed_ops::canonical_f32(*value),
90 )));
91 }
92 Expr::LitBool(value) => values.push(Value::Bool(*value)),
93 Expr::Var(name) => values.push(eval_var(name, invocation)?),
94 Expr::BufLen { buffer } => values.push(eval_buf_len(buffer, memory, program)?),
95 Expr::InvocationId { axis } => values.push(eval_invocation_id(*axis, invocation)?),
96 Expr::WorkgroupId { axis } => values.push(eval_workgroup_id(*axis, invocation)?),
97 Expr::LocalId { axis } => values.push(eval_local_id(*axis, invocation)?),
98 Expr::Load { buffer, index } => {
99 frames.push(Frame::Load { buffer });
100 frames.push(Frame::Expr(index));
101 }
102 Expr::BinOp { op, left, right } => {
103 frames.push(Frame::BinOp(*op));
104 frames.push(Frame::Expr(right));
105 frames.push(Frame::Expr(left));
106 }
107 Expr::UnOp { op, operand } => {
108 frames.push(Frame::UnOp(op));
109 frames.push(Frame::Expr(operand));
110 }
111 Expr::Select {
112 cond,
113 true_val,
114 false_val,
115 } => {
116 frames.push(Frame::Select);
117 frames.push(Frame::Expr(false_val));
118 frames.push(Frame::Expr(true_val));
119 frames.push(Frame::Expr(cond));
120 }
121 Expr::Cast { target, value } => {
122 frames.push(Frame::Cast(target));
123 frames.push(Frame::Expr(value));
124 }
125 Expr::Fma { a, b, c } => {
126 frames.push(Frame::Fma);
127 frames.push(Frame::Expr(c));
128 frames.push(Frame::Expr(b));
129 frames.push(Frame::Expr(a));
130 }
131 Expr::Atomic {
132 op,
133 buffer,
134 index,
135 expected,
136 value,
137 ordering: _,
138 } => {
139 match (*op, expected.as_deref()) {
140 (AtomicOp::CompareExchange, None) => {
141 return Err(Error::interp(
142 "compare-exchange atomic is missing expected value. Fix: set Expr::Atomic.expected for AtomicOp::CompareExchange.",
143 ));
144 }
145 (AtomicOp::CompareExchange, Some(_)) => {}
146 (_, Some(_)) => {
147 return Err(Error::interp(
148 "non-compare-exchange atomic includes an expected value. Fix: use Expr::Atomic.expected only with AtomicOp::CompareExchange.",
149 ));
150 }
151 (_, None) => {}
152 }
153 frames.push(Frame::AtomicIndex {
154 op: *op,
155 buffer,
156 expected: expected.as_deref(),
157 value,
158 });
159 frames.push(Frame::Expr(index));
160 }
161 Expr::Call { op_id, args } => {
162 let val = crate::execution::call::eval_call(
163 expr as *const Expr,
164 op_id,
165 args,
166 invocation,
167 memory,
168 program,
169 )?;
170 values.push(val);
171 }
172 Expr::Opaque(extension) => {
173 return Err(Error::interp(format!(
174 "reference interpreter does not support opaque expression extension `{}`/`{}`. Fix: provide a reference evaluator for this ExprNode or lower it to core Expr variants before evaluation.",
175 extension.extension_kind(),
176 extension.debug_identity()
177 )));
178 }
179 _ => {
180 return Err(Error::interp(
181 "reference interpreter encountered an unknown expression variant. Fix: add explicit reference semantics for the new ExprNode before dispatch.",
182 ));
183 }
184 },
185 Frame::BinOp(op) => {
186 let right = values.pop().ok_or_else(|| {
187 Error::interp("binary op missing right operand. Fix: internal evaluator error.")
188 })?;
189 let left = values.pop().ok_or_else(|| {
190 Error::interp("binary op missing left operand. Fix: internal evaluator error.")
191 })?;
192 values.push(super::typed_ops::eval_binop(op, left, right)?);
193 }
194 Frame::UnOp(op) => {
195 let operand = values.pop().ok_or_else(|| {
196 Error::interp("unary op missing operand. Fix: internal evaluator error.")
197 })?;
198 values.push(super::typed_ops::eval_unop(op, operand)?);
199 }
200 Frame::Select => {
201 let false_val = values.pop().ok_or_else(|| {
202 Error::interp("select missing false branch. Fix: internal evaluator error.")
203 })?;
204 let true_val = values.pop().ok_or_else(|| {
205 Error::interp("select missing true branch. Fix: internal evaluator error.")
206 })?;
207 let cond = values
208 .pop()
209 .ok_or_else(|| {
210 Error::interp("select missing condition. Fix: internal evaluator error.")
211 })?
212 .truthy();
213 values.push(if cond { true_val } else { false_val });
214 }
215 Frame::Cast(target) => {
216 let value = values.pop().ok_or_else(|| {
217 Error::interp("cast missing value. Fix: internal evaluator error.")
218 })?;
219 values.push(cast_value(target, &value)?);
220 }
221 Frame::Fma => {
222 let c = values
223 .pop()
224 .ok_or_else(|| {
225 Error::interp("fma missing operand c. Fix: internal evaluator error.")
226 })?
227 .try_as_f32()
228 .ok_or_else(|| {
229 Error::interp(
230 "fma operand `c` is not a float. Fix: cast to f32 before fma.",
231 )
232 })?;
233 let b = values
234 .pop()
235 .ok_or_else(|| {
236 Error::interp("fma missing operand b. Fix: internal evaluator error.")
237 })?
238 .try_as_f32()
239 .ok_or_else(|| {
240 Error::interp(
241 "fma operand `b` is not a float. Fix: cast to f32 before fma.",
242 )
243 })?;
244 let a = values
245 .pop()
246 .ok_or_else(|| {
247 Error::interp("fma missing operand a. Fix: internal evaluator error.")
248 })?
249 .try_as_f32()
250 .ok_or_else(|| {
251 Error::interp(
252 "fma operand `a` is not a float. Fix: cast to f32 before fma.",
253 )
254 })?;
255 let a = crate::execution::typed_ops::canonical_f32(a);
256 let b = crate::execution::typed_ops::canonical_f32(b);
257 let c = crate::execution::typed_ops::canonical_f32(c);
258 values.push(Value::Float(f64::from(
259 crate::execution::typed_ops::canonical_f32(a.mul_add(b, c)),
260 )));
261 }
262 Frame::Load { buffer } => {
263 let value = values.pop().ok_or_else(|| {
264 Error::interp("load missing index. Fix: internal evaluator error.")
265 })?;
266 let idx = value.try_as_u32().ok_or_else(|| {
267 Error::interp(format!(
268 "load index {value:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
269 ))
270 })?;
271 values.push(oob::load(resolve_buffer(memory, program, buffer)?, idx));
272 }
273 Frame::AtomicIndex {
274 op,
275 buffer,
276 expected,
277 value,
278 } => {
279 let val = values.pop().ok_or_else(|| {
280 Error::interp("atomic missing index. Fix: internal evaluator error.")
281 })?;
282 let idx = val.try_as_u32().ok_or_else(|| {
283 Error::interp(format!(
284 "atomic index {val:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
285 ))
286 })?;
287 if let Some(expected_expr) = expected {
288 frames.push(Frame::AtomicExpected {
289 op,
290 buffer,
291 index: idx,
292 value,
293 expected_expr,
294 });
295 frames.push(Frame::Expr(expected_expr));
296 } else {
297 frames.push(Frame::AtomicValue {
298 op,
299 buffer,
300 expected: None,
301 index: idx,
302 });
303 frames.push(Frame::Expr(value));
304 }
305 }
306 Frame::AtomicExpected {
307 op,
308 buffer,
309 index,
310 value,
311 expected_expr,
312 } => {
313 let val = values.pop().ok_or_else(|| {
314 Error::interp(
315 "atomic compare-exchange missing expected value. Fix: internal evaluator error.",
316 )
317 })?;
318 let expected_val = val.try_as_u32().ok_or_else(|| {
319 Error::interp(format!(
320 "atomic expected value {expected_expr:?} cannot be represented as u32. Fix: use a scalar u32-compatible argument."
321 ))
322 })?;
323 frames.push(Frame::AtomicValue {
324 op,
325 buffer,
326 expected: Some(expected_val),
327 index,
328 });
329 frames.push(Frame::Expr(value));
330 }
331 Frame::AtomicValue {
332 op,
333 buffer,
334 expected,
335 index,
336 } => {
337 let val = values.pop().ok_or_else(|| {
338 Error::interp("atomic missing value. Fix: internal evaluator error.")
339 })?;
340 let value = val.try_as_u32().ok_or_else(|| {
341 Error::interp(
342 "atomic value cannot be represented as u32. Fix: use a scalar u32-compatible argument.",
343 )
344 })?;
345 let target = atomic_buffer_mut(memory, program, buffer)?;
346 let Some(old) = oob::atomic_load(target, index) else {
347 values.push(Value::U32(0));
348 continue;
349 };
350 let (old, new) = atomics::apply(op, old, expected, value)?;
351 oob::atomic_store(target, index, new);
352 values.push(Value::U32(old));
353 }
354 }
355 }
356
357 values.pop().ok_or_else(|| {
358 Error::interp("expression evaluation produced no value. Fix: internal evaluator error.")
359 })
360}
361
362pub fn buffer_mut<'a>(
369 memory: &'a mut Memory,
370 program: &Program,
371 name: &str,
372) -> Result<&'a mut Buffer, vyre::Error> {
373 let decl = buffer_decl(program, name)?;
374 match decl.access() {
375 BufferAccess::ReadWrite | BufferAccess::Workgroup => resolve_buffer_mut(memory, decl),
376 BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
377 "store target `{name}` is not writable. Fix: declare it ReadWrite or Workgroup."
378 ))),
379 _ => Err(Error::interp(format!(
380 "store target `{name}` uses an unsupported access mode. Fix: use a supported BufferAccess."
381 ))),
382 }
383}
384
385fn eval_var(name: &str, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
386 invocation.local(name).cloned().ok_or_else(|| {
387 Error::interp(format!(
388 "reference to undeclared variable `{name}`. Fix: add a Let before this use."
389 ))
390 })
391}
392
393fn eval_buf_len(buffer: &str, memory: &Memory, program: &Program) -> Result<Value, vyre::Error> {
394 Ok(Value::U32(resolve_buffer(memory, program, buffer)?.len()))
395}
396
397fn eval_invocation_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
398 axis_value(invocation.ids.global, axis)
399}
400
401fn eval_workgroup_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
402 axis_value(invocation.ids.workgroup, axis)
403}
404
405fn eval_local_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
406 axis_value(invocation.ids.local, axis)
407}
408
409fn resolve_buffer<'a>(
410 memory: &'a Memory,
411 program: &Program,
412 name: &str,
413) -> Result<&'a oob::Buffer, vyre::Error> {
414 let decl = buffer_decl(program, name)?;
415 if decl.access() == BufferAccess::Workgroup {
416 memory.workgroup.get(name)
417 } else {
418 memory.storage.get(name)
419 }
420 .ok_or_else(|| {
421 Error::interp(format!(
422 "missing buffer `{name}`. Fix: initialize all declared buffers."
423 ))
424 })
425}
426
427fn resolve_buffer_mut<'a>(
428 memory: &'a mut Memory,
429 decl: &BufferDecl,
430) -> Result<&'a mut oob::Buffer, vyre::Error> {
431 let name = decl.name();
432 if decl.access() == BufferAccess::Workgroup {
433 memory.workgroup.get_mut(name)
434 } else {
435 memory.storage.get_mut(name)
436 }
437 .ok_or_else(|| {
438 Error::interp(format!(
439 "missing buffer `{name}`. Fix: initialize all declared buffers."
440 ))
441 })
442}
443
444fn atomic_buffer_mut<'a>(
445 memory: &'a mut Memory,
446 program: &Program,
447 name: &str,
448) -> Result<&'a mut oob::Buffer, vyre::Error> {
449 let decl = buffer_decl(program, name)?;
450 match decl.access() {
451 BufferAccess::ReadWrite => resolve_buffer_mut(memory, decl),
452 BufferAccess::Workgroup => Err(Error::interp(format!(
453 "atomic target `{name}` is workgroup memory. Fix: atomics only support ReadWrite storage buffers."
454 ))),
455 BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
456 "atomic target `{name}` is not writable. Fix: atomics only support ReadWrite storage buffers."
457 ))),
458 _ => Err(Error::interp(format!(
459 "atomic target `{name}` uses an unsupported access mode. Fix: use a supported BufferAccess."
460 ))),
461 }
462}
463
464fn buffer_decl<'a>(program: &'a Program, name: &str) -> Result<&'a BufferDecl, vyre::Error> {
465 program.buffer(name).ok_or_else(|| {
466 Error::interp(format!(
467 "unknown buffer `{name}`. Fix: declare it in Program::buffers."
468 ))
469 })
470}
471
472fn axis_value(values: [u32; 3], axis: u8) -> Result<Value, vyre::Error> {
473 values
474 .get(axis as usize)
475 .copied()
476 .map(Value::U32)
477 .ok_or_else(|| {
478 Error::interp(format!(
479 "invocation/workgroup ID axis {axis} out of range. Fix: use 0, 1, or 2."
480 ))
481 })
482}
483
484#[cfg(test)]
485mod tests {
486
487 use proptest::prelude::*;
488 use vyre::ir::{Expr, Program};
489
490 use super::eval;
491 use crate::value::Value;
492 use crate::workgroup::{Invocation, InvocationIds, Memory};
493
494 fn empty_memory() -> Memory {
495 Memory {
496 storage: Default::default(),
497 workgroup: Default::default(),
498 }
499 }
500
501 proptest! {
502 #![proptest_config(ProptestConfig::with_cases(256))]
503
504 #[test]
505 fn prop_frame_evaluator_matches_recursive_contract(a in any::<u32>(), b in any::<u32>(), c in any::<u32>(), pick_left in any::<bool>()) {
506 let program = Program::wrapped(Vec::new(), [1, 1, 1], Vec::new());
507 let int_expr = Expr::select(
508 Expr::bool(pick_left),
509 Expr::add(Expr::u32(a), Expr::mul(Expr::u32(b), Expr::u32(c))),
510 Expr::sub(Expr::u32(a), Expr::u32(b)),
511 );
512 let float_expr = Expr::fma(
513 Expr::f32(((a & 0xffff) as f32) * 0.5),
514 Expr::f32(((b & 0xff) as f32) + 1.0),
515 Expr::f32(((c & 0xffff) as f32) * 0.25),
516 );
517
518 for expr in [&int_expr, &float_expr] {
519 let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
520 let mut memory = empty_memory();
521
522 let frame = eval(expr, &mut invocation, &mut memory, &program)
523 .expect("Fix: frame evaluator must evaluate generated expression");
524 let recursive = eval_recursive_contract(expr)
525 .expect("Fix: recursive contract must evaluate generated expression");
526 prop_assert_eq!(frame, recursive);
527 }
528 }
529 }
530
531 #[test]
532 fn deeply_nested_expression_uses_frame_stack_not_host_recursion() {
533 let program = Program::wrapped(Vec::new(), [1, 1, 1], Vec::new());
534 let mut expr = Expr::u32(0);
535 for _ in 0..4096 {
536 expr = Expr::add(expr, Expr::u32(1));
537 }
538
539 let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
540 let mut memory = empty_memory();
541 let value = eval(&expr, &mut invocation, &mut memory, &program).expect(
542 "Fix: frame evaluator must handle deep generated expressions without recursion",
543 );
544
545 assert_eq!(value, Value::U32(4096));
546 }
547
548 fn eval_recursive_contract(expr: &Expr) -> Result<Value, vyre::Error> {
549 match expr {
550 Expr::LitU32(value) => Ok(Value::U32(*value)),
551 Expr::LitI32(value) => Ok(Value::I32(*value)),
552 Expr::LitF32(value) => Ok(Value::Float(f64::from(
553 crate::execution::typed_ops::canonical_f32(*value),
554 ))),
555 Expr::LitBool(value) => Ok(Value::Bool(*value)),
556 Expr::BinOp { op, left, right } => {
557 let left = eval_recursive_contract(left)?;
558 let right = eval_recursive_contract(right)?;
559 crate::execution::typed_ops::eval_binop(*op, left, right)
560 }
561 Expr::Select {
562 cond,
563 true_val,
564 false_val,
565 } => {
566 if eval_recursive_contract(cond)?.truthy() {
567 eval_recursive_contract(true_val)
568 } else {
569 eval_recursive_contract(false_val)
570 }
571 }
572 Expr::Fma { a, b, c } => {
573 let a = eval_recursive_contract(a)?.try_as_f32().ok_or_else(|| {
574 vyre::Error::interp("fma operand `a` is not a float in recursive contract")
575 })?;
576 let b = eval_recursive_contract(b)?.try_as_f32().ok_or_else(|| {
577 vyre::Error::interp("fma operand `b` is not a float in recursive contract")
578 })?;
579 let c = eval_recursive_contract(c)?.try_as_f32().ok_or_else(|| {
580 vyre::Error::interp("fma operand `c` is not a float in recursive contract")
581 })?;
582 let a = crate::execution::typed_ops::canonical_f32(a);
583 let b = crate::execution::typed_ops::canonical_f32(b);
584 let c = crate::execution::typed_ops::canonical_f32(c);
585 Ok(Value::Float(f64::from(
586 crate::execution::typed_ops::canonical_f32(a.mul_add(b, c)),
587 )))
588 }
589 _ => Err(vyre::Error::interp(
590 "recursive test contract received an expression outside its generated subset",
591 )),
592 }
593 }
594}