Skip to main content

vyre_reference/execution/
expr.rs

1//! Expression evaluator that gives the parity engine a pure-Rust ground truth
2//! for every `Expr` variant.
3//!
4//! If a backend lowers `Expr::BinOp`, `Expr::Load`, or `Expr::Atomic` differently
5//! than this evaluator, the conform gate reports the exact divergence. This module
6//! exists so IR semantics are defined by Rust code, not by whatever a backend
7//! happens to emit.
8
9use vyre::ir::{AtomicOp, BinOp, BufferAccess, BufferDecl, DataType, Expr, Program, UnOp};
10
11use smallvec::SmallVec;
12use vyre::Error;
13
14use crate::execution::expr_cast::cast_value;
15use crate::{atomics, oob, value::Value, workgroup::Invocation, workgroup::Memory};
16
17/// Re-export the OOB-guarded buffer type used by storage operations.
18pub use crate::oob::Buffer;
19
20/// Evaluate an expression through the single-pass frame evaluator.
21///
22/// # Errors
23///
24/// Returns [`Error::Interp`] when expression lowering or flat execution
25/// fails. The recursive evaluator is retained only as a test oracle.
26pub fn eval(
27    expr: &Expr,
28    invocation: &mut Invocation<'_>,
29    memory: &mut Memory,
30    program: &Program,
31) -> Result<Value, vyre::Error> {
32    eval_frame_oracle(expr, invocation, memory, program)
33}
34
35/// Evaluate an expression for one invocation.
36///
37/// # Errors
38///
39/// Returns [`Error::Interp`] on operand type errors, malformed atomic or call
40/// expressions, unsupported variants, or float operands.
41pub(crate) fn eval_frame_oracle(
42    expr: &Expr,
43    invocation: &mut Invocation<'_>,
44    memory: &mut Memory,
45    program: &Program,
46) -> Result<Value, vyre::Error> {
47    enum Frame<'a> {
48        Expr(&'a Expr),
49        BinOp(BinOp),
50        UnOp(&'a UnOp),
51        Select,
52        Cast(&'a DataType),
53        Fma,
54        Load {
55            buffer: &'a str,
56        },
57        AtomicIndex {
58            op: AtomicOp,
59            buffer: &'a str,
60            expected: Option<&'a Expr>,
61            value: &'a Expr,
62        },
63        AtomicExpected {
64            op: AtomicOp,
65            buffer: &'a str,
66            index: u32,
67            value: &'a Expr,
68            expected_expr: &'a Expr,
69        },
70        AtomicValue {
71            op: AtomicOp,
72            buffer: &'a str,
73            expected: Option<u32>,
74            index: u32,
75        },
76    }
77
78    let mut frames: SmallVec<[Frame<'_>; 32]> = SmallVec::new();
79    frames.push(Frame::Expr(expr));
80    let mut values: SmallVec<[Value; 32]> = SmallVec::new();
81
82    while let Some(frame) = frames.pop() {
83        match frame {
84            Frame::Expr(expr) => match expr {
85                Expr::LitU32(value) => values.push(Value::U32(*value)),
86                Expr::LitI32(value) => values.push(Value::I32(*value)),
87                Expr::LitF32(value) => {
88                    values.push(Value::Float(f64::from(
89                        crate::execution::typed_ops::canonical_f32(*value),
90                    )));
91                }
92                Expr::LitBool(value) => values.push(Value::Bool(*value)),
93                Expr::Var(name) => values.push(eval_var(name, invocation)?),
94                Expr::BufLen { buffer } => values.push(eval_buf_len(buffer, memory, program)?),
95                Expr::InvocationId { axis } => values.push(eval_invocation_id(*axis, invocation)?),
96                Expr::WorkgroupId { axis } => values.push(eval_workgroup_id(*axis, invocation)?),
97                Expr::LocalId { axis } => values.push(eval_local_id(*axis, invocation)?),
98                Expr::Load { buffer, index } => {
99                    frames.push(Frame::Load { buffer });
100                    frames.push(Frame::Expr(index));
101                }
102                Expr::BinOp { op, left, right } => {
103                    frames.push(Frame::BinOp(*op));
104                    frames.push(Frame::Expr(right));
105                    frames.push(Frame::Expr(left));
106                }
107                Expr::UnOp { op, operand } => {
108                    frames.push(Frame::UnOp(op));
109                    frames.push(Frame::Expr(operand));
110                }
111                Expr::Select {
112                    cond,
113                    true_val,
114                    false_val,
115                } => {
116                    frames.push(Frame::Select);
117                    frames.push(Frame::Expr(false_val));
118                    frames.push(Frame::Expr(true_val));
119                    frames.push(Frame::Expr(cond));
120                }
121                Expr::Cast { target, value } => {
122                    frames.push(Frame::Cast(target));
123                    frames.push(Frame::Expr(value));
124                }
125                Expr::Fma { a, b, c } => {
126                    frames.push(Frame::Fma);
127                    frames.push(Frame::Expr(c));
128                    frames.push(Frame::Expr(b));
129                    frames.push(Frame::Expr(a));
130                }
131                Expr::Atomic {
132                    op,
133                    buffer,
134                    index,
135                    expected,
136                    value,
137                    ordering: _,
138                } => {
139                    match (*op, expected.as_deref()) {
140                        (AtomicOp::CompareExchange, None) => {
141                            return Err(Error::interp(
142                                "compare-exchange atomic is missing expected value. Fix: set Expr::Atomic.expected for AtomicOp::CompareExchange.",
143                            ));
144                        }
145                        (AtomicOp::CompareExchange, Some(_)) => {}
146                        (_, Some(_)) => {
147                            return Err(Error::interp(
148                                "non-compare-exchange atomic includes an expected value. Fix: use Expr::Atomic.expected only with AtomicOp::CompareExchange.",
149                            ));
150                        }
151                        (_, None) => {}
152                    }
153                    frames.push(Frame::AtomicIndex {
154                        op: *op,
155                        buffer,
156                        expected: expected.as_deref(),
157                        value,
158                    });
159                    frames.push(Frame::Expr(index));
160                }
161                Expr::Call { op_id, args } => {
162                    let val = crate::execution::call::eval_call(
163                        expr as *const Expr,
164                        op_id,
165                        args,
166                        invocation,
167                        memory,
168                        program,
169                    )?;
170                    values.push(val);
171                }
172                Expr::Opaque(extension) => {
173                    return Err(Error::interp(format!(
174                        "reference interpreter does not support opaque expression extension `{}`/`{}`. Fix: provide a reference evaluator for this ExprNode or lower it to core Expr variants before evaluation.",
175                        extension.extension_kind(),
176                        extension.debug_identity()
177                    )));
178                }
179                _ => {
180                    return Err(Error::interp(
181                        "reference interpreter encountered an unknown expression variant. Fix: add explicit reference semantics for the new ExprNode before dispatch.",
182                    ));
183                }
184            },
185            Frame::BinOp(op) => {
186                let right = values.pop().ok_or_else(|| {
187                    Error::interp("binary op missing right operand. Fix: internal evaluator error.")
188                })?;
189                let left = values.pop().ok_or_else(|| {
190                    Error::interp("binary op missing left operand. Fix: internal evaluator error.")
191                })?;
192                values.push(super::typed_ops::eval_binop(op, left, right)?);
193            }
194            Frame::UnOp(op) => {
195                let operand = values.pop().ok_or_else(|| {
196                    Error::interp("unary op missing operand. Fix: internal evaluator error.")
197                })?;
198                values.push(super::typed_ops::eval_unop(op, operand)?);
199            }
200            Frame::Select => {
201                let false_val = values.pop().ok_or_else(|| {
202                    Error::interp("select missing false branch. Fix: internal evaluator error.")
203                })?;
204                let true_val = values.pop().ok_or_else(|| {
205                    Error::interp("select missing true branch. Fix: internal evaluator error.")
206                })?;
207                let cond = values
208                    .pop()
209                    .ok_or_else(|| {
210                        Error::interp("select missing condition. Fix: internal evaluator error.")
211                    })?
212                    .truthy();
213                values.push(if cond { true_val } else { false_val });
214            }
215            Frame::Cast(target) => {
216                let value = values.pop().ok_or_else(|| {
217                    Error::interp("cast missing value. Fix: internal evaluator error.")
218                })?;
219                values.push(cast_value(target, &value)?);
220            }
221            Frame::Fma => {
222                let c = values
223                    .pop()
224                    .ok_or_else(|| {
225                        Error::interp("fma missing operand c. Fix: internal evaluator error.")
226                    })?
227                    .try_as_f32()
228                    .ok_or_else(|| {
229                        Error::interp(
230                            "fma operand `c` is not a float. Fix: cast to f32 before fma.",
231                        )
232                    })?;
233                let b = values
234                    .pop()
235                    .ok_or_else(|| {
236                        Error::interp("fma missing operand b. Fix: internal evaluator error.")
237                    })?
238                    .try_as_f32()
239                    .ok_or_else(|| {
240                        Error::interp(
241                            "fma operand `b` is not a float. Fix: cast to f32 before fma.",
242                        )
243                    })?;
244                let a = values
245                    .pop()
246                    .ok_or_else(|| {
247                        Error::interp("fma missing operand a. Fix: internal evaluator error.")
248                    })?
249                    .try_as_f32()
250                    .ok_or_else(|| {
251                        Error::interp(
252                            "fma operand `a` is not a float. Fix: cast to f32 before fma.",
253                        )
254                    })?;
255                let a = crate::execution::typed_ops::canonical_f32(a);
256                let b = crate::execution::typed_ops::canonical_f32(b);
257                let c = crate::execution::typed_ops::canonical_f32(c);
258                values.push(Value::Float(f64::from(
259                    crate::execution::typed_ops::canonical_f32(a.mul_add(b, c)),
260                )));
261            }
262            Frame::Load { buffer } => {
263                let value = values.pop().ok_or_else(|| {
264                    Error::interp("load missing index. Fix: internal evaluator error.")
265                })?;
266                let idx = value.try_as_u32().ok_or_else(|| {
267                    Error::interp(format!(
268                        "load index {value:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
269                    ))
270                })?;
271                values.push(oob::load(resolve_buffer(memory, program, buffer)?, idx));
272            }
273            Frame::AtomicIndex {
274                op,
275                buffer,
276                expected,
277                value,
278            } => {
279                let val = values.pop().ok_or_else(|| {
280                    Error::interp("atomic missing index. Fix: internal evaluator error.")
281                })?;
282                let idx = val.try_as_u32().ok_or_else(|| {
283                    Error::interp(format!(
284                        "atomic index {val:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
285                    ))
286                })?;
287                if let Some(expected_expr) = expected {
288                    frames.push(Frame::AtomicExpected {
289                        op,
290                        buffer,
291                        index: idx,
292                        value,
293                        expected_expr,
294                    });
295                    frames.push(Frame::Expr(expected_expr));
296                } else {
297                    frames.push(Frame::AtomicValue {
298                        op,
299                        buffer,
300                        expected: None,
301                        index: idx,
302                    });
303                    frames.push(Frame::Expr(value));
304                }
305            }
306            Frame::AtomicExpected {
307                op,
308                buffer,
309                index,
310                value,
311                expected_expr,
312            } => {
313                let val = values.pop().ok_or_else(|| {
314                    Error::interp(
315                        "atomic compare-exchange missing expected value. Fix: internal evaluator error.",
316                    )
317                })?;
318                let expected_val = val.try_as_u32().ok_or_else(|| {
319                    Error::interp(format!(
320                        "atomic expected value {expected_expr:?} cannot be represented as u32. Fix: use a scalar u32-compatible argument."
321                    ))
322                })?;
323                frames.push(Frame::AtomicValue {
324                    op,
325                    buffer,
326                    expected: Some(expected_val),
327                    index,
328                });
329                frames.push(Frame::Expr(value));
330            }
331            Frame::AtomicValue {
332                op,
333                buffer,
334                expected,
335                index,
336            } => {
337                let val = values.pop().ok_or_else(|| {
338                    Error::interp("atomic missing value. Fix: internal evaluator error.")
339                })?;
340                let value = val.try_as_u32().ok_or_else(|| {
341                    Error::interp(
342                        "atomic value cannot be represented as u32. Fix: use a scalar u32-compatible argument.",
343                    )
344                })?;
345                let target = atomic_buffer_mut(memory, program, buffer)?;
346                let Some(old) = oob::atomic_load(target, index) else {
347                    values.push(Value::U32(0));
348                    continue;
349                };
350                let (old, new) = atomics::apply(op, old, expected, value)?;
351                oob::atomic_store(target, index, new);
352                values.push(Value::U32(old));
353            }
354        }
355    }
356
357    values.pop().ok_or_else(|| {
358        Error::interp("expression evaluation produced no value. Fix: internal evaluator error.")
359    })
360}
361
362/// Return a mutable buffer only when the program declares it writable.
363///
364/// # Errors
365///
366/// Returns [`Error::Interp`] if the buffer is read-only, uniform,
367/// or does not exist in the program declaration.
368pub fn buffer_mut<'a>(
369    memory: &'a mut Memory,
370    program: &Program,
371    name: &str,
372) -> Result<&'a mut Buffer, vyre::Error> {
373    let decl = buffer_decl(program, name)?;
374    match decl.access() {
375        BufferAccess::ReadWrite | BufferAccess::WriteOnly | BufferAccess::Workgroup => {
376            resolve_buffer_mut(memory, decl)
377        }
378        BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
379            "store target `{name}` is not writable. Fix: declare it ReadWrite, WriteOnly, or Workgroup."
380        ))),
381        _ => Err(Error::interp(format!(
382            "store target `{name}` uses an unsupported access mode. Fix: use a supported BufferAccess."
383        ))),
384    }
385}
386
387fn eval_var(name: &str, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
388    invocation.local(name).cloned().ok_or_else(|| {
389        Error::interp(format!(
390            "reference to undeclared variable `{name}`. Fix: add a Let before this use."
391        ))
392    })
393}
394
395fn eval_buf_len(buffer: &str, memory: &Memory, program: &Program) -> Result<Value, vyre::Error> {
396    Ok(Value::U32(resolve_buffer(memory, program, buffer)?.len()))
397}
398
399fn eval_invocation_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
400    axis_value(invocation.ids.global, axis)
401}
402
403fn eval_workgroup_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
404    axis_value(invocation.ids.workgroup, axis)
405}
406
407fn eval_local_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
408    axis_value(invocation.ids.local, axis)
409}
410
411fn resolve_buffer<'a>(
412    memory: &'a Memory,
413    program: &Program,
414    name: &str,
415) -> Result<&'a oob::Buffer, vyre::Error> {
416    let decl = buffer_decl(program, name)?;
417    if decl.access() == BufferAccess::Workgroup {
418        memory.workgroup.get(name)
419    } else {
420        memory.storage.get(name)
421    }
422    .ok_or_else(|| {
423        Error::interp(format!(
424            "missing buffer `{name}`. Fix: initialize all declared buffers."
425        ))
426    })
427}
428
429fn resolve_buffer_mut<'a>(
430    memory: &'a mut Memory,
431    decl: &BufferDecl,
432) -> Result<&'a mut oob::Buffer, vyre::Error> {
433    let name = decl.name();
434    if decl.access() == BufferAccess::Workgroup {
435        memory.workgroup.get_mut(name)
436    } else {
437        memory.storage.get_mut(name)
438    }
439    .ok_or_else(|| {
440        Error::interp(format!(
441            "missing buffer `{name}`. Fix: initialize all declared buffers."
442        ))
443    })
444}
445
446fn atomic_buffer_mut<'a>(
447    memory: &'a mut Memory,
448    program: &Program,
449    name: &str,
450) -> Result<&'a mut oob::Buffer, vyre::Error> {
451    let decl = buffer_decl(program, name)?;
452    match decl.access() {
453        BufferAccess::ReadWrite => resolve_buffer_mut(memory, decl),
454        BufferAccess::Workgroup => Err(Error::interp(format!(
455            "atomic target `{name}` is workgroup memory. Fix: atomics only support ReadWrite storage buffers."
456        ))),
457        BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
458            "atomic target `{name}` is not writable. Fix: atomics only support ReadWrite storage buffers."
459        ))),
460        _ => Err(Error::interp(format!(
461            "atomic target `{name}` uses an unsupported access mode. Fix: use a supported BufferAccess."
462        ))),
463    }
464}
465
466fn buffer_decl<'a>(program: &'a Program, name: &str) -> Result<&'a BufferDecl, vyre::Error> {
467    program.buffer(name).ok_or_else(|| {
468        Error::interp(format!(
469            "unknown buffer `{name}`. Fix: declare it in Program::buffers."
470        ))
471    })
472}
473
474fn axis_value(values: [u32; 3], axis: u8) -> Result<Value, vyre::Error> {
475    values
476        .get(axis as usize)
477        .copied()
478        .map(Value::U32)
479        .ok_or_else(|| {
480            Error::interp(format!(
481                "invocation/workgroup ID axis {axis} out of range. Fix: use 0, 1, or 2."
482            ))
483        })
484}
485
486#[cfg(test)]
487mod tests {
488
489    use proptest::prelude::*;
490    use vyre::ir::{Expr, Program};
491
492    use super::eval;
493    use crate::value::Value;
494    use crate::workgroup::{Invocation, InvocationIds, Memory};
495
496    fn empty_memory() -> Memory {
497        Memory {
498            storage: Default::default(),
499            workgroup: Default::default(),
500        }
501    }
502
503    proptest! {
504        #![proptest_config(ProptestConfig::with_cases(256))]
505
506        #[test]
507        fn prop_frame_evaluator_matches_recursive_contract(a in any::<u32>(), b in any::<u32>(), c in any::<u32>(), pick_left in any::<bool>()) {
508            let program = Program::wrapped(Vec::new(), [1, 1, 1], Vec::new());
509            let int_expr = Expr::select(
510                Expr::bool(pick_left),
511                Expr::add(Expr::u32(a), Expr::mul(Expr::u32(b), Expr::u32(c))),
512                Expr::sub(Expr::u32(a), Expr::u32(b)),
513            );
514            let float_expr = Expr::fma(
515                Expr::f32(((a & 0xffff) as f32) * 0.5),
516                Expr::f32(((b & 0xff) as f32) + 1.0),
517                Expr::f32(((c & 0xffff) as f32) * 0.25),
518            );
519
520            for expr in [&int_expr, &float_expr] {
521                let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
522                let mut memory = empty_memory();
523
524                let frame = eval(expr, &mut invocation, &mut memory, &program)
525                    .expect("Fix: frame evaluator must evaluate generated expression");
526                let recursive = eval_recursive_contract(expr)
527                    .expect("Fix: recursive contract must evaluate generated expression");
528                prop_assert_eq!(frame, recursive);
529            }
530        }
531    }
532
533    #[test]
534    fn deeply_nested_expression_uses_frame_stack_not_host_recursion() {
535        let program = Program::wrapped(Vec::new(), [1, 1, 1], Vec::new());
536        let mut expr = Expr::u32(0);
537        for _ in 0..4096 {
538            expr = Expr::add(expr, Expr::u32(1));
539        }
540
541        let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
542        let mut memory = empty_memory();
543        let value = eval(&expr, &mut invocation, &mut memory, &program).expect(
544            "Fix: frame evaluator must handle deep generated expressions without recursion",
545        );
546
547        assert_eq!(value, Value::U32(4096));
548    }
549
550    fn eval_recursive_contract(expr: &Expr) -> Result<Value, vyre::Error> {
551        match expr {
552            Expr::LitU32(value) => Ok(Value::U32(*value)),
553            Expr::LitI32(value) => Ok(Value::I32(*value)),
554            Expr::LitF32(value) => Ok(Value::Float(f64::from(
555                crate::execution::typed_ops::canonical_f32(*value),
556            ))),
557            Expr::LitBool(value) => Ok(Value::Bool(*value)),
558            Expr::BinOp { op, left, right } => {
559                let left = eval_recursive_contract(left)?;
560                let right = eval_recursive_contract(right)?;
561                crate::execution::typed_ops::eval_binop(*op, left, right)
562            }
563            Expr::Select {
564                cond,
565                true_val,
566                false_val,
567            } => {
568                if eval_recursive_contract(cond)?.truthy() {
569                    eval_recursive_contract(true_val)
570                } else {
571                    eval_recursive_contract(false_val)
572                }
573            }
574            Expr::Fma { a, b, c } => {
575                let a = eval_recursive_contract(a)?.try_as_f32().ok_or_else(|| {
576                    vyre::Error::interp("fma operand `a` is not a float in recursive contract")
577                })?;
578                let b = eval_recursive_contract(b)?.try_as_f32().ok_or_else(|| {
579                    vyre::Error::interp("fma operand `b` is not a float in recursive contract")
580                })?;
581                let c = eval_recursive_contract(c)?.try_as_f32().ok_or_else(|| {
582                    vyre::Error::interp("fma operand `c` is not a float in recursive contract")
583                })?;
584                let a = crate::execution::typed_ops::canonical_f32(a);
585                let b = crate::execution::typed_ops::canonical_f32(b);
586                let c = crate::execution::typed_ops::canonical_f32(c);
587                Ok(Value::Float(f64::from(
588                    crate::execution::typed_ops::canonical_f32(a.mul_add(b, c)),
589                )))
590            }
591            _ => Err(vyre::Error::interp(
592                "recursive test contract received an expression outside its generated subset",
593            )),
594        }
595    }
596}