Skip to main content

vyre_reference/
eval_node.rs

1//! Statement executor that gives the parity engine a pure-Rust ground truth
2//! for every `Node` variant.
3//!
4//! This module simulates the exact control-flow, memory, and barrier behavior
5//! that a correct GPU backend must produce. Any divergence in `If`, `Loop`,
6//! `Barrier`, or `Store` semantics is caught by the conform gate as a concrete
7//! counterexample.
8
9use vyre::ir::{Expr, Node, Program};
10
11use crate::{
12    eval_expr, oob,
13    workgroup::{Frame, Invocation, Memory},
14};
15use vyre::Error;
16
17/// Execute one scheduling step for an invocation.
18///
19/// # Errors
20///
21/// Returns [`Error::Interp`] for uniform-control-flow violations,
22/// out-of-bounds stores, malformed loops, or expression evaluation failures.
23pub fn step<'a>(
24    invocation: &mut Invocation<'a>,
25    memory: &mut Memory,
26    program: &'a Program,
27) -> Result<(), vyre::Error> {
28    if invocation.done() || invocation.waiting_at_barrier {
29        return Ok(());
30    }
31
32    loop {
33        let Some(frame) = invocation.frames_mut().pop() else {
34            return Ok(());
35        };
36        match frame {
37            Frame::Nodes {
38                nodes,
39                index,
40                scoped,
41            } => {
42                if step_nodes_frame(invocation, memory, program, nodes, index, scoped)? {
43                    return Ok(());
44                }
45            }
46            Frame::Loop {
47                var,
48                next,
49                to,
50                body,
51            } => step_loop_frame(invocation, var, next, to, body)?,
52        }
53    }
54}
55
56fn step_nodes_frame<'a>(
57    invocation: &mut Invocation<'a>,
58    memory: &mut Memory,
59    program: &'a Program,
60    nodes: &'a [Node],
61    index: usize,
62    scoped: bool,
63) -> Result<bool, vyre::Error> {
64    if index >= nodes.len() {
65        if scoped {
66            invocation.pop_scope();
67        }
68        return Ok(false);
69    }
70
71    invocation.frames_mut().push(Frame::Nodes {
72        nodes,
73        index: index + 1,
74        scoped,
75    });
76    execute_node(&nodes[index], invocation, memory, program)?;
77    Ok(true)
78}
79
80fn step_loop_frame<'a>(
81    invocation: &mut Invocation<'a>,
82    var: &'a str,
83    next: u32,
84    to: u32,
85    body: &'a [Node],
86) -> Result<(), vyre::Error> {
87    if next >= to {
88        return Ok(());
89    }
90    invocation.frames_mut().push(Frame::Loop {
91        var,
92        next: next.wrapping_add(1),
93        to,
94        body,
95    });
96    invocation.push_scope();
97    invocation.bind_loop_var(var, crate::value::Value::U32(next))?;
98    invocation.frames_mut().push(Frame::Nodes {
99        nodes: body,
100        index: 0,
101        scoped: true,
102    });
103    Ok(())
104}
105
106fn execute_node<'a>(
107    node: &'a Node,
108    invocation: &mut Invocation<'a>,
109    memory: &mut Memory,
110    program: &'a Program,
111) -> Result<(), vyre::Error> {
112    match node {
113        Node::Let { name, value } => eval_let(name, value, invocation, memory, program),
114        Node::Assign { name, value } => eval_assign(name, value, invocation, memory, program),
115        Node::Store {
116            buffer,
117            index,
118            value,
119        } => eval_store(buffer, index, value, invocation, memory, program),
120        Node::If {
121            cond,
122            then,
123            otherwise,
124        } => eval_if(cond, then, otherwise, node, invocation, memory, program),
125        Node::Loop {
126            var,
127            from,
128            to,
129            body,
130        } => eval_loop(var, from, to, body, invocation, memory, program),
131        Node::Return => eval_return(invocation),
132        Node::Block(nodes) => eval_block(nodes, invocation),
133        Node::Barrier => eval_barrier(invocation),
134    }
135}
136
137fn eval_let(
138    name: &str,
139    value: &Expr,
140    invocation: &mut Invocation<'_>,
141    memory: &mut Memory,
142    program: &Program,
143) -> Result<(), vyre::Error> {
144    let value = eval_expr::eval(value, invocation, memory, program)?;
145    invocation.bind(name, value)
146}
147
148fn eval_assign(
149    name: &str,
150    value: &Expr,
151    invocation: &mut Invocation<'_>,
152    memory: &mut Memory,
153    program: &Program,
154) -> Result<(), vyre::Error> {
155    let value = eval_expr::eval(value, invocation, memory, program)?;
156    invocation.assign(name, value)
157}
158
159fn eval_store(
160    buffer: &str,
161    index: &Expr,
162    value: &Expr,
163    invocation: &mut Invocation<'_>,
164    memory: &mut Memory,
165    program: &Program,
166) -> Result<(), vyre::Error> {
167    let index = eval_expr::eval(index, invocation, memory, program)?;
168    let index = index
169        .try_as_u32()
170        .ok_or_else(|| Error::interp(format!(
171                "store index {index:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
172        )))?;
173    let value = eval_expr::eval(value, invocation, memory, program)?;
174    let target = eval_expr::buffer_mut(memory, program, buffer)?;
175    oob::store(target, index, &value);
176    Ok(())
177}
178
179fn eval_if<'a>(
180    cond: &Expr,
181    then: &'a [Node],
182    otherwise: &'a [Node],
183    node: &Node,
184    invocation: &mut Invocation<'a>,
185    memory: &mut Memory,
186    program: &Program,
187) -> Result<(), vyre::Error> {
188    let cond_value = eval_expr::eval(cond, invocation, memory, program)?.truthy();
189    if contains_barrier(then) || contains_barrier(otherwise) {
190        invocation.uniform_checks.push((node_id(node), cond_value));
191    }
192    let branch = if cond_value { then } else { otherwise };
193    invocation.push_scope();
194    invocation.frames_mut().push(Frame::Nodes {
195        nodes: branch,
196        index: 0,
197        scoped: true,
198    });
199    Ok(())
200}
201
202fn eval_loop<'a>(
203    var: &'a str,
204    from: &Expr,
205    to: &Expr,
206    body: &'a [Node],
207    invocation: &mut Invocation<'a>,
208    memory: &mut Memory,
209    program: &Program,
210) -> Result<(), vyre::Error> {
211    let from_value = eval_expr::eval(from, invocation, memory, program)?;
212    let to_value = eval_expr::eval(to, invocation, memory, program)?;
213    let from = from_value.try_as_u32().ok_or_else(|| {
214        Error::interp(format!(
215                "loop lower bound {from_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
216        ))
217    })?;
218    let to = to_value.try_as_u32().ok_or_else(|| Error::interp(format!(
219            "loop upper bound {to_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
220    )))?;
221    invocation.frames_mut().push(Frame::Loop {
222        var,
223        next: from,
224        to,
225        body,
226    });
227    Ok(())
228}
229
230fn eval_return(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
231    invocation.frames_mut().clear();
232    invocation.returned = true;
233    Ok(())
234}
235
236fn eval_block<'a>(nodes: &'a [Node], invocation: &mut Invocation<'a>) -> Result<(), vyre::Error> {
237    invocation.push_scope();
238    invocation.frames_mut().push(Frame::Nodes {
239        nodes,
240        index: 0,
241        scoped: true,
242    });
243    Ok(())
244}
245
246fn eval_barrier(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
247    invocation.waiting_at_barrier = true;
248    Ok(())
249}
250
251/// Whether any statement in `nodes` may reach a [`Node::Barrier`], scanning
252/// child statement lists recursively with an exhaustive [`Node`] match.
253fn contains_barrier(nodes: &[Node]) -> bool {
254    nodes.iter().any(node_contains_barrier)
255}
256
257fn node_contains_barrier(node: &Node) -> bool {
258    match node {
259        Node::Barrier => true,
260        Node::Let { .. } | Node::Assign { .. } | Node::Store { .. } | Node::Return => false,
261        Node::If {
262            then, otherwise, ..
263        } => contains_barrier(then) || contains_barrier(otherwise),
264        Node::Loop { body, .. } => contains_barrier(body),
265        Node::Block(body) => contains_barrier(body),
266    }
267}
268
269fn node_id(node: &Node) -> usize {
270    std::ptr::from_ref(node).addr()
271}