Skip to main content

vyre_reference/execution/
expr.rs

1//! Expression evaluator that gives the parity engine a pure-Rust ground truth
2//! for every `Expr` variant.
3//!
4//! If a backend lowers `Expr::BinOp`, `Expr::Load`, or `Expr::Atomic` differently
5//! than this evaluator, the conform gate reports the exact divergence. This module
6//! exists so IR semantics are defined by Rust code, not by whatever a backend
7//! happens to emit.
8
9use vyre::ir::{AtomicOp, BinOp, BufferAccess, BufferDecl, DataType, Expr, Program, UnOp};
10
11use smallvec::SmallVec;
12use vyre::Error;
13
14use crate::execution::expr_cast::cast_value;
15use crate::{atomics, oob, value::Value, workgroup::Invocation, workgroup::Memory};
16
17/// Re-export the OOB-guarded buffer type used by storage operations.
18pub use crate::oob::Buffer;
19
20/// Evaluate an expression through the single-pass frame evaluator.
21///
22/// # Errors
23///
24/// Returns [`Error::Interp`] when expression lowering or flat execution
25/// fails. The recursive evaluator is retained only as a test oracle.
26pub fn eval(
27    expr: &Expr,
28    invocation: &mut Invocation<'_>,
29    memory: &mut Memory,
30    program: &Program,
31) -> Result<Value, vyre::Error> {
32    eval_frame_oracle(expr, invocation, memory, program)
33}
34
35/// Evaluate an expression for one invocation.
36///
37/// # Errors
38///
39/// Returns [`Error::Interp`] on operand type errors, malformed atomic or call
40/// expressions, unsupported variants, or float operands.
41pub(crate) fn eval_frame_oracle(
42    expr: &Expr,
43    invocation: &mut Invocation<'_>,
44    memory: &mut Memory,
45    program: &Program,
46) -> Result<Value, vyre::Error> {
47    enum Frame<'a> {
48        Expr(&'a Expr),
49        BinOp(BinOp),
50        UnOp(&'a UnOp),
51        Select,
52        Cast(&'a DataType),
53        Fma,
54        Load {
55            buffer: &'a str,
56        },
57        AtomicIndex {
58            op: AtomicOp,
59            buffer: &'a str,
60            expected: Option<&'a Expr>,
61            value: &'a Expr,
62        },
63        AtomicExpected {
64            op: AtomicOp,
65            buffer: &'a str,
66            index: u32,
67            value: &'a Expr,
68            expected_expr: &'a Expr,
69        },
70        AtomicValue {
71            op: AtomicOp,
72            buffer: &'a str,
73            expected: Option<u32>,
74            index: u32,
75        },
76    }
77
78    let mut frames: SmallVec<[Frame<'_>; 32]> = SmallVec::new();
79    frames.push(Frame::Expr(expr));
80    let mut values: SmallVec<[Value; 32]> = SmallVec::new();
81
82    while let Some(frame) = frames.pop() {
83        match frame {
84            Frame::Expr(expr) => match expr {
85                Expr::LitU32(value) => values.push(Value::U32(*value)),
86                Expr::LitI32(value) => values.push(Value::I32(*value)),
87                Expr::LitF32(value) => {
88                    values.push(Value::Float(f64::from(
89                        crate::execution::typed_ops::canonical_f32(*value),
90                    )));
91                }
92                Expr::LitBool(value) => values.push(Value::Bool(*value)),
93                Expr::Var(name) => values.push(eval_var(name, invocation)?),
94                Expr::BufLen { buffer } => values.push(eval_buf_len(buffer, memory, program)?),
95                Expr::InvocationId { axis } => values.push(eval_invocation_id(*axis, invocation)?),
96                Expr::WorkgroupId { axis } => values.push(eval_workgroup_id(*axis, invocation)?),
97                Expr::LocalId { axis } => values.push(eval_local_id(*axis, invocation)?),
98                Expr::Load { buffer, index } => {
99                    frames.push(Frame::Load { buffer });
100                    frames.push(Frame::Expr(index));
101                }
102                Expr::BinOp { op, left, right } => {
103                    frames.push(Frame::BinOp(*op));
104                    frames.push(Frame::Expr(right));
105                    frames.push(Frame::Expr(left));
106                }
107                Expr::UnOp { op, operand } => {
108                    frames.push(Frame::UnOp(op));
109                    frames.push(Frame::Expr(operand));
110                }
111                Expr::Select {
112                    cond,
113                    true_val,
114                    false_val,
115                } => {
116                    frames.push(Frame::Select);
117                    frames.push(Frame::Expr(false_val));
118                    frames.push(Frame::Expr(true_val));
119                    frames.push(Frame::Expr(cond));
120                }
121                Expr::Cast { target, value } => {
122                    frames.push(Frame::Cast(target));
123                    frames.push(Frame::Expr(value));
124                }
125                Expr::Fma { a, b, c } => {
126                    frames.push(Frame::Fma);
127                    frames.push(Frame::Expr(c));
128                    frames.push(Frame::Expr(b));
129                    frames.push(Frame::Expr(a));
130                }
131                Expr::Atomic {
132                    op,
133                    buffer,
134                    index,
135                    expected,
136                    value,
137                    ordering: _,
138                } => {
139                    match (*op, expected.as_deref()) {
140                        (AtomicOp::CompareExchange, None) => {
141                            return Err(Error::interp(
142                                "compare-exchange atomic is missing expected value. Fix: set Expr::Atomic.expected for AtomicOp::CompareExchange.",
143                            ));
144                        }
145                        (AtomicOp::CompareExchange, Some(_)) => {}
146                        (_, Some(_)) => {
147                            return Err(Error::interp(
148                                "non-compare-exchange atomic includes an expected value. Fix: use Expr::Atomic.expected only with AtomicOp::CompareExchange.",
149                            ));
150                        }
151                        (_, None) => {}
152                    }
153                    frames.push(Frame::AtomicIndex {
154                        op: *op,
155                        buffer,
156                        expected: expected.as_deref(),
157                        value,
158                    });
159                    frames.push(Frame::Expr(index));
160                }
161                Expr::Call { op_id, args } => {
162                    let val = crate::execution::call::eval_call(
163                        expr as *const Expr,
164                        op_id,
165                        args,
166                        invocation,
167                        memory,
168                        program,
169                    )?;
170                    values.push(val);
171                }
172                Expr::Opaque(extension) => {
173                    return Err(Error::interp(format!(
174                        "reference interpreter does not support opaque expression extension `{}`/`{}`. Fix: provide a reference evaluator for this ExprNode or lower it to core Expr variants before evaluation.",
175                        extension.extension_kind(),
176                        extension.debug_identity()
177                    )));
178                }
179                _ => {
180                    return Err(Error::interp(
181                        "reference interpreter encountered an unknown expression variant. Fix: add explicit reference semantics for the new ExprNode before dispatch.",
182                    ));
183                }
184            },
185            Frame::BinOp(op) => {
186                let right = values.pop().ok_or_else(|| {
187                    Error::interp("binary op missing right operand. Fix: internal evaluator error.")
188                })?;
189                let left = values.pop().ok_or_else(|| {
190                    Error::interp("binary op missing left operand. Fix: internal evaluator error.")
191                })?;
192                values.push(super::typed_ops::eval_binop(op, left, right)?);
193            }
194            Frame::UnOp(op) => {
195                let operand = values.pop().ok_or_else(|| {
196                    Error::interp("unary op missing operand. Fix: internal evaluator error.")
197                })?;
198                values.push(super::typed_ops::eval_unop(op, operand)?);
199            }
200            Frame::Select => {
201                let false_val = values.pop().ok_or_else(|| {
202                    Error::interp("select missing false branch. Fix: internal evaluator error.")
203                })?;
204                let true_val = values.pop().ok_or_else(|| {
205                    Error::interp("select missing true branch. Fix: internal evaluator error.")
206                })?;
207                let cond = values
208                    .pop()
209                    .ok_or_else(|| {
210                        Error::interp("select missing condition. Fix: internal evaluator error.")
211                    })?
212                    .truthy();
213                values.push(if cond { true_val } else { false_val });
214            }
215            Frame::Cast(target) => {
216                let value = values.pop().ok_or_else(|| {
217                    Error::interp("cast missing value. Fix: internal evaluator error.")
218                })?;
219                values.push(cast_value(target, &value)?);
220            }
221            Frame::Fma => {
222                let c = values
223                    .pop()
224                    .ok_or_else(|| {
225                        Error::interp("fma missing operand c. Fix: internal evaluator error.")
226                    })?
227                    .try_as_f32()
228                    .ok_or_else(|| {
229                        Error::interp(
230                            "fma operand `c` is not a float. Fix: cast to f32 before fma.",
231                        )
232                    })?;
233                let b = values
234                    .pop()
235                    .ok_or_else(|| {
236                        Error::interp("fma missing operand b. Fix: internal evaluator error.")
237                    })?
238                    .try_as_f32()
239                    .ok_or_else(|| {
240                        Error::interp(
241                            "fma operand `b` is not a float. Fix: cast to f32 before fma.",
242                        )
243                    })?;
244                let a = values
245                    .pop()
246                    .ok_or_else(|| {
247                        Error::interp("fma missing operand a. Fix: internal evaluator error.")
248                    })?
249                    .try_as_f32()
250                    .ok_or_else(|| {
251                        Error::interp(
252                            "fma operand `a` is not a float. Fix: cast to f32 before fma.",
253                        )
254                    })?;
255                let a = crate::execution::typed_ops::canonical_f32(a);
256                let b = crate::execution::typed_ops::canonical_f32(b);
257                let c = crate::execution::typed_ops::canonical_f32(c);
258                values.push(Value::Float(f64::from(
259                    crate::execution::typed_ops::canonical_f32(a.mul_add(b, c)),
260                )));
261            }
262            Frame::Load { buffer } => {
263                let value = values.pop().ok_or_else(|| {
264                    Error::interp("load missing index. Fix: internal evaluator error.")
265                })?;
266                let idx = value.try_as_u32().ok_or_else(|| {
267                    Error::interp(format!(
268                        "load index {value:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
269                    ))
270                })?;
271                values.push(oob::load(resolve_buffer(memory, program, buffer)?, idx));
272            }
273            Frame::AtomicIndex {
274                op,
275                buffer,
276                expected,
277                value,
278            } => {
279                let val = values.pop().ok_or_else(|| {
280                    Error::interp("atomic missing index. Fix: internal evaluator error.")
281                })?;
282                let idx = val.try_as_u32().ok_or_else(|| {
283                    Error::interp(format!(
284                        "atomic index {val:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
285                    ))
286                })?;
287                if let Some(expected_expr) = expected {
288                    frames.push(Frame::AtomicExpected {
289                        op,
290                        buffer,
291                        index: idx,
292                        value,
293                        expected_expr,
294                    });
295                    frames.push(Frame::Expr(expected_expr));
296                } else {
297                    frames.push(Frame::AtomicValue {
298                        op,
299                        buffer,
300                        expected: None,
301                        index: idx,
302                    });
303                    frames.push(Frame::Expr(value));
304                }
305            }
306            Frame::AtomicExpected {
307                op,
308                buffer,
309                index,
310                value,
311                expected_expr,
312            } => {
313                let val = values.pop().ok_or_else(|| {
314                    Error::interp(
315                        "atomic compare-exchange missing expected value. Fix: internal evaluator error.",
316                    )
317                })?;
318                let expected_val = val.try_as_u32().ok_or_else(|| {
319                    Error::interp(format!(
320                        "atomic expected value {expected_expr:?} cannot be represented as u32. Fix: use a scalar u32-compatible argument."
321                    ))
322                })?;
323                frames.push(Frame::AtomicValue {
324                    op,
325                    buffer,
326                    expected: Some(expected_val),
327                    index,
328                });
329                frames.push(Frame::Expr(value));
330            }
331            Frame::AtomicValue {
332                op,
333                buffer,
334                expected,
335                index,
336            } => {
337                let val = values.pop().ok_or_else(|| {
338                    Error::interp("atomic missing value. Fix: internal evaluator error.")
339                })?;
340                let value = val.try_as_u32().ok_or_else(|| {
341                    Error::interp(
342                        "atomic value cannot be represented as u32. Fix: use a scalar u32-compatible argument.",
343                    )
344                })?;
345                let target = atomic_buffer_mut(memory, program, buffer)?;
346                let Some(old) = oob::atomic_load(target, index) else {
347                    values.push(Value::U32(0));
348                    continue;
349                };
350                let (old, new) = atomics::apply(op, old, expected, value)?;
351                oob::atomic_store(target, index, new);
352                values.push(Value::U32(old));
353            }
354        }
355    }
356
357    values.pop().ok_or_else(|| {
358        Error::interp("expression evaluation produced no value. Fix: internal evaluator error.")
359    })
360}
361
362/// Return a mutable buffer only when the program declares it writable.
363///
364/// # Errors
365///
366/// Returns [`Error::Interp`] if the buffer is read-only, uniform,
367/// or does not exist in the program declaration.
368pub fn buffer_mut<'a>(
369    memory: &'a mut Memory,
370    program: &Program,
371    name: &str,
372) -> Result<&'a mut Buffer, vyre::Error> {
373    let decl = buffer_decl(program, name)?;
374    match decl.access() {
375        BufferAccess::ReadWrite | BufferAccess::Workgroup => resolve_buffer_mut(memory, decl),
376        BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
377            "store target `{name}` is not writable. Fix: declare it ReadWrite or Workgroup."
378        ))),
379        _ => Err(Error::interp(format!(
380            "store target `{name}` uses an unsupported access mode. Fix: use a supported BufferAccess."
381        ))),
382    }
383}
384
385fn eval_var(name: &str, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
386    invocation.local(name).cloned().ok_or_else(|| {
387        Error::interp(format!(
388            "reference to undeclared variable `{name}`. Fix: add a Let before this use."
389        ))
390    })
391}
392
393fn eval_buf_len(buffer: &str, memory: &Memory, program: &Program) -> Result<Value, vyre::Error> {
394    Ok(Value::U32(resolve_buffer(memory, program, buffer)?.len()))
395}
396
397fn eval_invocation_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
398    axis_value(invocation.ids.global, axis)
399}
400
401fn eval_workgroup_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
402    axis_value(invocation.ids.workgroup, axis)
403}
404
405fn eval_local_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
406    axis_value(invocation.ids.local, axis)
407}
408
409fn resolve_buffer<'a>(
410    memory: &'a Memory,
411    program: &Program,
412    name: &str,
413) -> Result<&'a oob::Buffer, vyre::Error> {
414    let decl = buffer_decl(program, name)?;
415    if decl.access() == BufferAccess::Workgroup {
416        memory.workgroup.get(name)
417    } else {
418        memory.storage.get(name)
419    }
420    .ok_or_else(|| {
421        Error::interp(format!(
422            "missing buffer `{name}`. Fix: initialize all declared buffers."
423        ))
424    })
425}
426
427fn resolve_buffer_mut<'a>(
428    memory: &'a mut Memory,
429    decl: &BufferDecl,
430) -> Result<&'a mut oob::Buffer, vyre::Error> {
431    let name = decl.name();
432    if decl.access() == BufferAccess::Workgroup {
433        memory.workgroup.get_mut(name)
434    } else {
435        memory.storage.get_mut(name)
436    }
437    .ok_or_else(|| {
438        Error::interp(format!(
439            "missing buffer `{name}`. Fix: initialize all declared buffers."
440        ))
441    })
442}
443
444fn atomic_buffer_mut<'a>(
445    memory: &'a mut Memory,
446    program: &Program,
447    name: &str,
448) -> Result<&'a mut oob::Buffer, vyre::Error> {
449    let decl = buffer_decl(program, name)?;
450    match decl.access() {
451        BufferAccess::ReadWrite => resolve_buffer_mut(memory, decl),
452        BufferAccess::Workgroup => Err(Error::interp(format!(
453            "atomic target `{name}` is workgroup memory. Fix: atomics only support ReadWrite storage buffers."
454        ))),
455        BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
456            "atomic target `{name}` is not writable. Fix: atomics only support ReadWrite storage buffers."
457        ))),
458        _ => Err(Error::interp(format!(
459            "atomic target `{name}` uses an unsupported access mode. Fix: use a supported BufferAccess."
460        ))),
461    }
462}
463
464fn buffer_decl<'a>(program: &'a Program, name: &str) -> Result<&'a BufferDecl, vyre::Error> {
465    program.buffer(name).ok_or_else(|| {
466        Error::interp(format!(
467            "unknown buffer `{name}`. Fix: declare it in Program::buffers."
468        ))
469    })
470}
471
472fn axis_value(values: [u32; 3], axis: u8) -> Result<Value, vyre::Error> {
473    values
474        .get(axis as usize)
475        .copied()
476        .map(Value::U32)
477        .ok_or_else(|| {
478            Error::interp(format!(
479                "invocation/workgroup ID axis {axis} out of range. Fix: use 0, 1, or 2."
480            ))
481        })
482}
483
484#[cfg(test)]
485mod tests {
486
487    use proptest::prelude::*;
488    use vyre::ir::{Expr, Program};
489
490    use super::eval;
491    use crate::value::Value;
492    use crate::workgroup::{Invocation, InvocationIds, Memory};
493
494    fn empty_memory() -> Memory {
495        Memory {
496            storage: Default::default(),
497            workgroup: Default::default(),
498        }
499    }
500
501    proptest! {
502        #![proptest_config(ProptestConfig::with_cases(256))]
503
504        #[test]
505        fn prop_frame_evaluator_matches_recursive_contract(a in any::<u32>(), b in any::<u32>(), c in any::<u32>(), pick_left in any::<bool>()) {
506            let program = Program::wrapped(Vec::new(), [1, 1, 1], Vec::new());
507            let int_expr = Expr::select(
508                Expr::bool(pick_left),
509                Expr::add(Expr::u32(a), Expr::mul(Expr::u32(b), Expr::u32(c))),
510                Expr::sub(Expr::u32(a), Expr::u32(b)),
511            );
512            let float_expr = Expr::fma(
513                Expr::f32(((a & 0xffff) as f32) * 0.5),
514                Expr::f32(((b & 0xff) as f32) + 1.0),
515                Expr::f32(((c & 0xffff) as f32) * 0.25),
516            );
517
518            for expr in [&int_expr, &float_expr] {
519                let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
520                let mut memory = empty_memory();
521
522                let frame = eval(expr, &mut invocation, &mut memory, &program)
523                    .expect("Fix: frame evaluator must evaluate generated expression");
524                let recursive = eval_recursive_contract(expr)
525                    .expect("Fix: recursive contract must evaluate generated expression");
526                prop_assert_eq!(frame, recursive);
527            }
528        }
529    }
530
531    #[test]
532    fn deeply_nested_expression_uses_frame_stack_not_host_recursion() {
533        let program = Program::wrapped(Vec::new(), [1, 1, 1], Vec::new());
534        let mut expr = Expr::u32(0);
535        for _ in 0..4096 {
536            expr = Expr::add(expr, Expr::u32(1));
537        }
538
539        let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
540        let mut memory = empty_memory();
541        let value = eval(&expr, &mut invocation, &mut memory, &program).expect(
542            "Fix: frame evaluator must handle deep generated expressions without recursion",
543        );
544
545        assert_eq!(value, Value::U32(4096));
546    }
547
548    fn eval_recursive_contract(expr: &Expr) -> Result<Value, vyre::Error> {
549        match expr {
550            Expr::LitU32(value) => Ok(Value::U32(*value)),
551            Expr::LitI32(value) => Ok(Value::I32(*value)),
552            Expr::LitF32(value) => Ok(Value::Float(f64::from(
553                crate::execution::typed_ops::canonical_f32(*value),
554            ))),
555            Expr::LitBool(value) => Ok(Value::Bool(*value)),
556            Expr::BinOp { op, left, right } => {
557                let left = eval_recursive_contract(left)?;
558                let right = eval_recursive_contract(right)?;
559                crate::execution::typed_ops::eval_binop(*op, left, right)
560            }
561            Expr::Select {
562                cond,
563                true_val,
564                false_val,
565            } => {
566                if eval_recursive_contract(cond)?.truthy() {
567                    eval_recursive_contract(true_val)
568                } else {
569                    eval_recursive_contract(false_val)
570                }
571            }
572            Expr::Fma { a, b, c } => {
573                let a = eval_recursive_contract(a)?.try_as_f32().ok_or_else(|| {
574                    vyre::Error::interp("fma operand `a` is not a float in recursive contract")
575                })?;
576                let b = eval_recursive_contract(b)?.try_as_f32().ok_or_else(|| {
577                    vyre::Error::interp("fma operand `b` is not a float in recursive contract")
578                })?;
579                let c = eval_recursive_contract(c)?.try_as_f32().ok_or_else(|| {
580                    vyre::Error::interp("fma operand `c` is not a float in recursive contract")
581                })?;
582                let a = crate::execution::typed_ops::canonical_f32(a);
583                let b = crate::execution::typed_ops::canonical_f32(b);
584                let c = crate::execution::typed_ops::canonical_f32(c);
585                Ok(Value::Float(f64::from(
586                    crate::execution::typed_ops::canonical_f32(a.mul_add(b, c)),
587                )))
588            }
589            _ => Err(vyre::Error::interp(
590                "recursive test contract received an expression outside its generated subset",
591            )),
592        }
593    }
594}