1use vyre::ir::{AtomicOp, BinOp, BufferAccess, BufferDecl, DataType, Expr, Program, UnOp};
10
11use smallvec::SmallVec;
12use vyre::Error;
13
14use crate::execution::expr_cast::cast_value;
15use crate::{atomics, oob, value::Value, workgroup::Invocation, workgroup::Memory};
16
17pub use crate::oob::Buffer;
19
20pub fn eval(
27 expr: &Expr,
28 invocation: &mut Invocation<'_>,
29 memory: &mut Memory,
30 program: &Program,
31) -> Result<Value, vyre::Error> {
32 eval_frame_oracle(expr, invocation, memory, program)
33}
34
35pub(crate) fn eval_frame_oracle(
42 expr: &Expr,
43 invocation: &mut Invocation<'_>,
44 memory: &mut Memory,
45 program: &Program,
46) -> Result<Value, vyre::Error> {
47 enum Frame<'a> {
48 Expr(&'a Expr),
49 BinOp(BinOp),
50 UnOp(&'a UnOp),
51 Select,
52 Cast(&'a DataType),
53 Fma,
54 Load {
55 buffer: &'a str,
56 },
57 AtomicIndex {
58 op: AtomicOp,
59 buffer: &'a str,
60 expected: Option<&'a Expr>,
61 value: &'a Expr,
62 },
63 AtomicExpected {
64 op: AtomicOp,
65 buffer: &'a str,
66 index: u32,
67 value: &'a Expr,
68 expected_expr: &'a Expr,
69 },
70 AtomicValue {
71 op: AtomicOp,
72 buffer: &'a str,
73 expected: Option<u32>,
74 index: u32,
75 },
76 }
77
78 let mut frames: SmallVec<[Frame<'_>; 32]> = SmallVec::new();
79 frames.push(Frame::Expr(expr));
80 let mut values: SmallVec<[Value; 32]> = SmallVec::new();
81
82 while let Some(frame) = frames.pop() {
83 match frame {
84 Frame::Expr(expr) => match expr {
85 Expr::LitU32(value) => values.push(Value::U32(*value)),
86 Expr::LitI32(value) => values.push(Value::I32(*value)),
87 Expr::LitF32(value) => {
88 values.push(Value::Float(f64::from(
89 crate::execution::typed_ops::canonical_f32(*value),
90 )));
91 }
92 Expr::LitBool(value) => values.push(Value::Bool(*value)),
93 Expr::Var(name) => values.push(eval_var(name, invocation)?),
94 Expr::BufLen { buffer } => values.push(eval_buf_len(buffer, memory, program)?),
95 Expr::InvocationId { axis } => values.push(eval_invocation_id(*axis, invocation)?),
96 Expr::WorkgroupId { axis } => values.push(eval_workgroup_id(*axis, invocation)?),
97 Expr::LocalId { axis } => values.push(eval_local_id(*axis, invocation)?),
98 Expr::Load { buffer, index } => {
99 frames.push(Frame::Load { buffer });
100 frames.push(Frame::Expr(index));
101 }
102 Expr::BinOp { op, left, right } => {
103 frames.push(Frame::BinOp(*op));
104 frames.push(Frame::Expr(right));
105 frames.push(Frame::Expr(left));
106 }
107 Expr::UnOp { op, operand } => {
108 frames.push(Frame::UnOp(op));
109 frames.push(Frame::Expr(operand));
110 }
111 Expr::Select {
112 cond,
113 true_val,
114 false_val,
115 } => {
116 frames.push(Frame::Select);
117 frames.push(Frame::Expr(false_val));
118 frames.push(Frame::Expr(true_val));
119 frames.push(Frame::Expr(cond));
120 }
121 Expr::Cast { target, value } => {
122 frames.push(Frame::Cast(target));
123 frames.push(Frame::Expr(value));
124 }
125 Expr::Fma { a, b, c } => {
126 frames.push(Frame::Fma);
127 frames.push(Frame::Expr(c));
128 frames.push(Frame::Expr(b));
129 frames.push(Frame::Expr(a));
130 }
131 Expr::Atomic {
132 op,
133 buffer,
134 index,
135 expected,
136 value,
137 ordering: _,
138 } => {
139 match (*op, expected.as_deref()) {
140 (AtomicOp::CompareExchange, None) => {
141 return Err(Error::interp(
142 "compare-exchange atomic is missing expected value. Fix: set Expr::Atomic.expected for AtomicOp::CompareExchange.",
143 ));
144 }
145 (AtomicOp::CompareExchange, Some(_)) => {}
146 (_, Some(_)) => {
147 return Err(Error::interp(
148 "non-compare-exchange atomic includes an expected value. Fix: use Expr::Atomic.expected only with AtomicOp::CompareExchange.",
149 ));
150 }
151 (_, None) => {}
152 }
153 frames.push(Frame::AtomicIndex {
154 op: *op,
155 buffer,
156 expected: expected.as_deref(),
157 value,
158 });
159 frames.push(Frame::Expr(index));
160 }
161 Expr::Call { op_id, args } => {
162 let val = crate::execution::call::eval_call(
163 expr as *const Expr,
164 op_id,
165 args,
166 invocation,
167 memory,
168 program,
169 )?;
170 values.push(val);
171 }
172 Expr::Opaque(extension) => {
173 return Err(Error::interp(format!(
174 "reference interpreter does not support opaque expression extension `{}`/`{}`. Fix: provide a reference evaluator for this ExprNode or lower it to core Expr variants before evaluation.",
175 extension.extension_kind(),
176 extension.debug_identity()
177 )));
178 }
179 _ => {
180 return Err(Error::interp(
181 "reference interpreter encountered an unknown expression variant. Fix: add explicit reference semantics for the new ExprNode before dispatch.",
182 ));
183 }
184 },
185 Frame::BinOp(op) => {
186 let right = values.pop().ok_or_else(|| {
187 Error::interp("binary op missing right operand. Fix: internal evaluator error.")
188 })?;
189 let left = values.pop().ok_or_else(|| {
190 Error::interp("binary op missing left operand. Fix: internal evaluator error.")
191 })?;
192 values.push(super::typed_ops::eval_binop(op, left, right)?);
193 }
194 Frame::UnOp(op) => {
195 let operand = values.pop().ok_or_else(|| {
196 Error::interp("unary op missing operand. Fix: internal evaluator error.")
197 })?;
198 values.push(super::typed_ops::eval_unop(op, operand)?);
199 }
200 Frame::Select => {
201 let false_val = values.pop().ok_or_else(|| {
202 Error::interp("select missing false branch. Fix: internal evaluator error.")
203 })?;
204 let true_val = values.pop().ok_or_else(|| {
205 Error::interp("select missing true branch. Fix: internal evaluator error.")
206 })?;
207 let cond = values
208 .pop()
209 .ok_or_else(|| {
210 Error::interp("select missing condition. Fix: internal evaluator error.")
211 })?
212 .truthy();
213 values.push(if cond { true_val } else { false_val });
214 }
215 Frame::Cast(target) => {
216 let value = values.pop().ok_or_else(|| {
217 Error::interp("cast missing value. Fix: internal evaluator error.")
218 })?;
219 values.push(cast_value(target, &value)?);
220 }
221 Frame::Fma => {
222 let c = values
223 .pop()
224 .ok_or_else(|| {
225 Error::interp("fma missing operand c. Fix: internal evaluator error.")
226 })?
227 .try_as_f32()
228 .ok_or_else(|| {
229 Error::interp(
230 "fma operand `c` is not a float. Fix: cast to f32 before fma.",
231 )
232 })?;
233 let b = values
234 .pop()
235 .ok_or_else(|| {
236 Error::interp("fma missing operand b. Fix: internal evaluator error.")
237 })?
238 .try_as_f32()
239 .ok_or_else(|| {
240 Error::interp(
241 "fma operand `b` is not a float. Fix: cast to f32 before fma.",
242 )
243 })?;
244 let a = values
245 .pop()
246 .ok_or_else(|| {
247 Error::interp("fma missing operand a. Fix: internal evaluator error.")
248 })?
249 .try_as_f32()
250 .ok_or_else(|| {
251 Error::interp(
252 "fma operand `a` is not a float. Fix: cast to f32 before fma.",
253 )
254 })?;
255 let a = crate::execution::typed_ops::canonical_f32(a);
256 let b = crate::execution::typed_ops::canonical_f32(b);
257 let c = crate::execution::typed_ops::canonical_f32(c);
258 values.push(Value::Float(f64::from(
259 crate::execution::typed_ops::canonical_f32(a.mul_add(b, c)),
260 )));
261 }
262 Frame::Load { buffer } => {
263 let value = values.pop().ok_or_else(|| {
264 Error::interp("load missing index. Fix: internal evaluator error.")
265 })?;
266 let idx = value.try_as_u32().ok_or_else(|| {
267 Error::interp(format!(
268 "load index {value:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
269 ))
270 })?;
271 values.push(oob::load(resolve_buffer(memory, program, buffer)?, idx));
272 }
273 Frame::AtomicIndex {
274 op,
275 buffer,
276 expected,
277 value,
278 } => {
279 let val = values.pop().ok_or_else(|| {
280 Error::interp("atomic missing index. Fix: internal evaluator error.")
281 })?;
282 let idx = val.try_as_u32().ok_or_else(|| {
283 Error::interp(format!(
284 "atomic index {val:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
285 ))
286 })?;
287 if let Some(expected_expr) = expected {
288 frames.push(Frame::AtomicExpected {
289 op,
290 buffer,
291 index: idx,
292 value,
293 expected_expr,
294 });
295 frames.push(Frame::Expr(expected_expr));
296 } else {
297 frames.push(Frame::AtomicValue {
298 op,
299 buffer,
300 expected: None,
301 index: idx,
302 });
303 frames.push(Frame::Expr(value));
304 }
305 }
306 Frame::AtomicExpected {
307 op,
308 buffer,
309 index,
310 value,
311 expected_expr,
312 } => {
313 let val = values.pop().ok_or_else(|| {
314 Error::interp(
315 "atomic compare-exchange missing expected value. Fix: internal evaluator error.",
316 )
317 })?;
318 let expected_val = val.try_as_u32().ok_or_else(|| {
319 Error::interp(format!(
320 "atomic expected value {expected_expr:?} cannot be represented as u32. Fix: use a scalar u32-compatible argument."
321 ))
322 })?;
323 frames.push(Frame::AtomicValue {
324 op,
325 buffer,
326 expected: Some(expected_val),
327 index,
328 });
329 frames.push(Frame::Expr(value));
330 }
331 Frame::AtomicValue {
332 op,
333 buffer,
334 expected,
335 index,
336 } => {
337 let val = values.pop().ok_or_else(|| {
338 Error::interp("atomic missing value. Fix: internal evaluator error.")
339 })?;
340 let value = val.try_as_u32().ok_or_else(|| {
341 Error::interp(
342 "atomic value cannot be represented as u32. Fix: use a scalar u32-compatible argument.",
343 )
344 })?;
345 let target = atomic_buffer_mut(memory, program, buffer)?;
346 let Some(old) = oob::atomic_load(target, index) else {
347 values.push(Value::U32(0));
348 continue;
349 };
350 let (old, new) = atomics::apply(op, old, expected, value)?;
351 oob::atomic_store(target, index, new);
352 values.push(Value::U32(old));
353 }
354 }
355 }
356
357 values.pop().ok_or_else(|| {
358 Error::interp("expression evaluation produced no value. Fix: internal evaluator error.")
359 })
360}
361
362pub fn buffer_mut<'a>(
369 memory: &'a mut Memory,
370 program: &Program,
371 name: &str,
372) -> Result<&'a mut Buffer, vyre::Error> {
373 let decl = buffer_decl(program, name)?;
374 match decl.access() {
375 BufferAccess::ReadWrite | BufferAccess::WriteOnly | BufferAccess::Workgroup => {
376 resolve_buffer_mut(memory, decl)
377 }
378 BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
379 "store target `{name}` is not writable. Fix: declare it ReadWrite, WriteOnly, or Workgroup."
380 ))),
381 _ => Err(Error::interp(format!(
382 "store target `{name}` uses an unsupported access mode. Fix: use a supported BufferAccess."
383 ))),
384 }
385}
386
387fn eval_var(name: &str, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
388 invocation.local(name).cloned().ok_or_else(|| {
389 Error::interp(format!(
390 "reference to undeclared variable `{name}`. Fix: add a Let before this use."
391 ))
392 })
393}
394
395fn eval_buf_len(buffer: &str, memory: &Memory, program: &Program) -> Result<Value, vyre::Error> {
396 Ok(Value::U32(resolve_buffer(memory, program, buffer)?.len()))
397}
398
399fn eval_invocation_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
400 axis_value(invocation.ids.global, axis)
401}
402
403fn eval_workgroup_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
404 axis_value(invocation.ids.workgroup, axis)
405}
406
407fn eval_local_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
408 axis_value(invocation.ids.local, axis)
409}
410
411fn resolve_buffer<'a>(
412 memory: &'a Memory,
413 program: &Program,
414 name: &str,
415) -> Result<&'a oob::Buffer, vyre::Error> {
416 let decl = buffer_decl(program, name)?;
417 if decl.access() == BufferAccess::Workgroup {
418 memory.workgroup.get(name)
419 } else {
420 memory.storage.get(name)
421 }
422 .ok_or_else(|| {
423 Error::interp(format!(
424 "missing buffer `{name}`. Fix: initialize all declared buffers."
425 ))
426 })
427}
428
429fn resolve_buffer_mut<'a>(
430 memory: &'a mut Memory,
431 decl: &BufferDecl,
432) -> Result<&'a mut oob::Buffer, vyre::Error> {
433 let name = decl.name();
434 if decl.access() == BufferAccess::Workgroup {
435 memory.workgroup.get_mut(name)
436 } else {
437 memory.storage.get_mut(name)
438 }
439 .ok_or_else(|| {
440 Error::interp(format!(
441 "missing buffer `{name}`. Fix: initialize all declared buffers."
442 ))
443 })
444}
445
446fn atomic_buffer_mut<'a>(
447 memory: &'a mut Memory,
448 program: &Program,
449 name: &str,
450) -> Result<&'a mut oob::Buffer, vyre::Error> {
451 let decl = buffer_decl(program, name)?;
452 match decl.access() {
453 BufferAccess::ReadWrite => resolve_buffer_mut(memory, decl),
454 BufferAccess::Workgroup => Err(Error::interp(format!(
455 "atomic target `{name}` is workgroup memory. Fix: atomics only support ReadWrite storage buffers."
456 ))),
457 BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
458 "atomic target `{name}` is not writable. Fix: atomics only support ReadWrite storage buffers."
459 ))),
460 _ => Err(Error::interp(format!(
461 "atomic target `{name}` uses an unsupported access mode. Fix: use a supported BufferAccess."
462 ))),
463 }
464}
465
466fn buffer_decl<'a>(program: &'a Program, name: &str) -> Result<&'a BufferDecl, vyre::Error> {
467 program.buffer(name).ok_or_else(|| {
468 Error::interp(format!(
469 "unknown buffer `{name}`. Fix: declare it in Program::buffers."
470 ))
471 })
472}
473
474fn axis_value(values: [u32; 3], axis: u8) -> Result<Value, vyre::Error> {
475 values
476 .get(axis as usize)
477 .copied()
478 .map(Value::U32)
479 .ok_or_else(|| {
480 Error::interp(format!(
481 "invocation/workgroup ID axis {axis} out of range. Fix: use 0, 1, or 2."
482 ))
483 })
484}
485
486#[cfg(test)]
487mod tests {
488
489 use proptest::prelude::*;
490 use vyre::ir::{Expr, Program};
491
492 use super::eval;
493 use crate::value::Value;
494 use crate::workgroup::{Invocation, InvocationIds, Memory};
495
496 fn empty_memory() -> Memory {
497 Memory {
498 storage: Default::default(),
499 workgroup: Default::default(),
500 }
501 }
502
503 proptest! {
504 #![proptest_config(ProptestConfig::with_cases(256))]
505
506 #[test]
507 fn prop_frame_evaluator_matches_recursive_contract(a in any::<u32>(), b in any::<u32>(), c in any::<u32>(), pick_left in any::<bool>()) {
508 let program = Program::wrapped(Vec::new(), [1, 1, 1], Vec::new());
509 let int_expr = Expr::select(
510 Expr::bool(pick_left),
511 Expr::add(Expr::u32(a), Expr::mul(Expr::u32(b), Expr::u32(c))),
512 Expr::sub(Expr::u32(a), Expr::u32(b)),
513 );
514 let float_expr = Expr::fma(
515 Expr::f32(((a & 0xffff) as f32) * 0.5),
516 Expr::f32(((b & 0xff) as f32) + 1.0),
517 Expr::f32(((c & 0xffff) as f32) * 0.25),
518 );
519
520 for expr in [&int_expr, &float_expr] {
521 let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
522 let mut memory = empty_memory();
523
524 let frame = eval(expr, &mut invocation, &mut memory, &program)
525 .expect("Fix: frame evaluator must evaluate generated expression");
526 let recursive = eval_recursive_contract(expr)
527 .expect("Fix: recursive contract must evaluate generated expression");
528 prop_assert_eq!(frame, recursive);
529 }
530 }
531 }
532
533 #[test]
534 fn deeply_nested_expression_uses_frame_stack_not_host_recursion() {
535 let program = Program::wrapped(Vec::new(), [1, 1, 1], Vec::new());
536 let mut expr = Expr::u32(0);
537 for _ in 0..4096 {
538 expr = Expr::add(expr, Expr::u32(1));
539 }
540
541 let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
542 let mut memory = empty_memory();
543 let value = eval(&expr, &mut invocation, &mut memory, &program).expect(
544 "Fix: frame evaluator must handle deep generated expressions without recursion",
545 );
546
547 assert_eq!(value, Value::U32(4096));
548 }
549
550 fn eval_recursive_contract(expr: &Expr) -> Result<Value, vyre::Error> {
551 match expr {
552 Expr::LitU32(value) => Ok(Value::U32(*value)),
553 Expr::LitI32(value) => Ok(Value::I32(*value)),
554 Expr::LitF32(value) => Ok(Value::Float(f64::from(
555 crate::execution::typed_ops::canonical_f32(*value),
556 ))),
557 Expr::LitBool(value) => Ok(Value::Bool(*value)),
558 Expr::BinOp { op, left, right } => {
559 let left = eval_recursive_contract(left)?;
560 let right = eval_recursive_contract(right)?;
561 crate::execution::typed_ops::eval_binop(*op, left, right)
562 }
563 Expr::Select {
564 cond,
565 true_val,
566 false_val,
567 } => {
568 if eval_recursive_contract(cond)?.truthy() {
569 eval_recursive_contract(true_val)
570 } else {
571 eval_recursive_contract(false_val)
572 }
573 }
574 Expr::Fma { a, b, c } => {
575 let a = eval_recursive_contract(a)?.try_as_f32().ok_or_else(|| {
576 vyre::Error::interp("fma operand `a` is not a float in recursive contract")
577 })?;
578 let b = eval_recursive_contract(b)?.try_as_f32().ok_or_else(|| {
579 vyre::Error::interp("fma operand `b` is not a float in recursive contract")
580 })?;
581 let c = eval_recursive_contract(c)?.try_as_f32().ok_or_else(|| {
582 vyre::Error::interp("fma operand `c` is not a float in recursive contract")
583 })?;
584 let a = crate::execution::typed_ops::canonical_f32(a);
585 let b = crate::execution::typed_ops::canonical_f32(b);
586 let c = crate::execution::typed_ops::canonical_f32(c);
587 Ok(Value::Float(f64::from(
588 crate::execution::typed_ops::canonical_f32(a.mul_add(b, c)),
589 )))
590 }
591 _ => Err(vyre::Error::interp(
592 "recursive test contract received an expression outside its generated subset",
593 )),
594 }
595 }
596}