Skip to main content

vyre_reference/execution/
node.rs

1//! Statement executor that gives the parity engine a pure-Rust ground truth
2//! for every `Node` variant.
3//!
4//! This module simulates the exact control-flow, memory, and barrier behavior
5//! that a correct GPU backend must produce. Any divergence in `If`, `Loop`,
6//! `Barrier`, or `Store` semantics is caught by the conform gate as a concrete
7//! counterexample.
8
9use vyre::ir::{Expr, Node, Program};
10
11use crate::{
12    execution::expr as eval_expr,
13    oob,
14    workgroup::{AsyncTransfer, Frame, Invocation, Memory},
15};
16use vyre::Error;
17
18/// Execute one scheduling step for an invocation.
19///
20/// # Errors
21///
22/// Returns [`Error::Interp`] for uniform-control-flow violations,
23/// out-of-bounds stores, malformed loops, or expression evaluation failures.
24pub fn step<'a>(
25    invocation: &mut Invocation<'a>,
26    memory: &mut Memory,
27    program: &'a Program,
28) -> Result<(), vyre::Error> {
29    if invocation.done() || invocation.waiting_at_barrier {
30        return Ok(());
31    }
32
33    loop {
34        let Some(frame) = invocation.frames_mut().pop() else {
35            return Ok(());
36        };
37        match frame {
38            Frame::Nodes {
39                nodes,
40                index,
41                scoped,
42            } => {
43                if step_nodes_frame(invocation, memory, program, nodes, index, scoped)? {
44                    return Ok(());
45                }
46            }
47            Frame::Loop {
48                var,
49                next,
50                to,
51                body,
52            } => step_loop_frame(invocation, var, next, to, body)?,
53        }
54    }
55}
56
57fn step_nodes_frame<'a>(
58    invocation: &mut Invocation<'a>,
59    memory: &mut Memory,
60    program: &'a Program,
61    nodes: &'a [Node],
62    index: usize,
63    scoped: bool,
64) -> Result<bool, vyre::Error> {
65    if index >= nodes.len() {
66        if scoped {
67            invocation.pop_scope();
68        }
69        return Ok(false);
70    }
71
72    invocation.frames_mut().push(Frame::Nodes {
73        nodes,
74        index: index + 1,
75        scoped,
76    });
77    execute_node(&nodes[index], invocation, memory, program)?;
78    Ok(true)
79}
80
81fn step_loop_frame<'a>(
82    invocation: &mut Invocation<'a>,
83    var: &'a str,
84    next: u32,
85    to: u32,
86    body: &'a [Node],
87) -> Result<(), vyre::Error> {
88    if next >= to {
89        return Ok(());
90    }
91    invocation.frames_mut().push(Frame::Loop {
92        var,
93        next: next.wrapping_add(1),
94        to,
95        body,
96    });
97    invocation.push_scope();
98    invocation.bind_loop_var(var, crate::value::Value::U32(next))?;
99    invocation.frames_mut().push(Frame::Nodes {
100        nodes: body,
101        index: 0,
102        scoped: true,
103    });
104    Ok(())
105}
106
107fn execute_node<'a>(
108    node: &'a Node,
109    invocation: &mut Invocation<'a>,
110    memory: &mut Memory,
111    program: &'a Program,
112) -> Result<(), vyre::Error> {
113    match node {
114        Node::Let { name, value } => eval_let(name, value, invocation, memory, program),
115        Node::Assign { name, value } => eval_assign(name, value, invocation, memory, program),
116        Node::Store {
117            buffer,
118            index,
119            value,
120        } => eval_store(buffer, index, value, invocation, memory, program),
121        Node::If {
122            cond,
123            then,
124            otherwise,
125        } => eval_if(cond, then, otherwise, node, invocation, memory, program),
126        Node::Loop {
127            var,
128            from,
129            to,
130            body,
131        } => eval_loop(var, from, to, body, invocation, memory, program),
132        Node::Return => eval_return(invocation),
133        Node::Block(nodes) => eval_block(nodes, invocation),
134        Node::Barrier { .. } => eval_barrier(invocation),
135        Node::IndirectDispatch {
136            count_buffer,
137            count_offset,
138        } => eval_indirect_dispatch(count_buffer, *count_offset, memory, program),
139        Node::AsyncLoad {
140            source,
141            destination,
142            offset,
143            size,
144            tag,
145        } => eval_async_load(
146            AsyncLoadEval {
147                source,
148                destination,
149                offset,
150                size,
151                tag,
152            },
153            invocation,
154            memory,
155            program,
156        ),
157        Node::AsyncStore {
158            source,
159            destination,
160            offset,
161            size,
162            tag,
163        } => eval_async_store(
164            AsyncStoreEval {
165                source,
166                destination,
167                offset,
168                size,
169                tag,
170            },
171            invocation,
172            memory,
173            program,
174        ),
175        Node::AsyncWait { tag } => eval_async_wait(tag, invocation, memory, program),
176        Node::Trap { address, tag } => {
177            let address = eval_expr::eval(address, invocation, memory, program)?
178                .try_as_u32()
179                .ok_or_else(|| {
180                    Error::interp(format!(
181                        "reference trap `{tag}` address is not a u32. Fix: pass a scalar u32 trap address."
182                    ))
183                })?;
184            Err(vyre::Error::interp(format!(
185                "reference dispatch trapped: address={address}, tag=`{tag}`. Fix: handle the trap condition or route this Program through a backend/runtime with replay support."
186            )))
187        }
188        Node::Resume { tag } => Err(vyre::Error::interp(format!(
189            "reference dispatch reached Resume `{tag}` without a replay runtime. Fix: lower Resume through a runtime-owned replay path before reference execution."
190        ))),
191        Node::Region { body, .. } => eval_block(body, invocation),
192        Node::Opaque(extension) => Err(vyre::Error::interp(format!(
193            "reference interpreter does not support opaque node extension `{}`/`{}`. Fix: provide a reference evaluator for this NodeExtension or lower it to core Node variants before evaluation.",
194            extension.extension_kind(),
195            extension.debug_identity()
196        ))),
197        _ => Err(vyre::Error::interp(
198            "reference interpreter encountered an unknown Node variant. Fix: update vyre-reference before executing this IR.",
199        )),
200    }
201}
202
203fn eval_let(
204    name: &str,
205    value: &Expr,
206    invocation: &mut Invocation<'_>,
207    memory: &mut Memory,
208    program: &Program,
209) -> Result<(), vyre::Error> {
210    let value = eval_expr::eval(value, invocation, memory, program)?;
211    invocation.bind(name, value)
212}
213
214fn eval_assign(
215    name: &str,
216    value: &Expr,
217    invocation: &mut Invocation<'_>,
218    memory: &mut Memory,
219    program: &Program,
220) -> Result<(), vyre::Error> {
221    let value = eval_expr::eval(value, invocation, memory, program)?;
222    invocation.assign(name, value)
223}
224
225fn eval_store(
226    buffer: &str,
227    index: &Expr,
228    value: &Expr,
229    invocation: &mut Invocation<'_>,
230    memory: &mut Memory,
231    program: &Program,
232) -> Result<(), vyre::Error> {
233    let index = eval_expr::eval(index, invocation, memory, program)?;
234    let index = index
235        .try_as_u32()
236        .ok_or_else(|| Error::interp(format!(
237                "store index {index:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
238        )))?;
239    let value = eval_expr::eval(value, invocation, memory, program)?;
240    let target = eval_expr::buffer_mut(memory, program, buffer)?;
241    oob::store(target, index, &value);
242    Ok(())
243}
244
245fn eval_indirect_dispatch(
246    count_buffer: &str,
247    count_offset: u64,
248    memory: &Memory,
249    program: &Program,
250) -> Result<(), vyre::Error> {
251    if count_offset % 4 != 0 {
252        return Err(Error::interp(format!(
253            "indirect dispatch offset {count_offset} is not 4-byte aligned. Fix: use a u32-aligned dispatch tuple."
254        )));
255    }
256    let decl = program.buffer(count_buffer).ok_or_else(|| {
257        Error::interp(format!(
258            "indirect dispatch references unknown buffer `{count_buffer}`. Fix: declare the count buffer before execution."
259        ))
260    })?;
261    let buffer = if decl.access() == vyre::ir::BufferAccess::Workgroup {
262        memory.workgroup.get(count_buffer)
263    } else {
264        memory.storage.get(count_buffer)
265    }
266    .ok_or_else(|| {
267        Error::interp(format!(
268            "indirect dispatch buffer `{count_buffer}` is missing. Fix: initialize the count buffer before execution."
269        ))
270    })?;
271    let required_end = count_offset.checked_add(12).ok_or_else(|| {
272        Error::interp(
273            "indirect dispatch byte range overflowed u64. Fix: shrink the count offset."
274                .to_string(),
275        )
276    })?;
277    let byte_len = buffer
278        .bytes
279        .read()
280        .map_err(|_| {
281            Error::interp(format!(
282                "indirect dispatch buffer `{count_buffer}` lock is poisoned. Fix: rebuild the interpreter memory state before execution."
283            ))
284        })?
285        .len();
286    if u64::try_from(byte_len).unwrap_or(u64::MAX) < required_end {
287        return Err(Error::interp(format!(
288            "indirect dispatch buffer `{count_buffer}` is too short for a 3-word dispatch tuple at byte offset {count_offset}. Fix: provide 12 readable bytes starting at that offset."
289        )));
290    }
291    Ok(())
292}
293
294struct AsyncLoadEval<'a> {
295    source: &'a str,
296    destination: &'a str,
297    offset: &'a Expr,
298    size: &'a Expr,
299    tag: &'a str,
300}
301
302struct AsyncStoreEval<'a> {
303    source: &'a str,
304    destination: &'a str,
305    offset: &'a Expr,
306    size: &'a Expr,
307    tag: &'a str,
308}
309
310fn eval_async_load(
311    request: AsyncLoadEval<'_>,
312    invocation: &mut Invocation<'_>,
313    memory: &mut Memory,
314    program: &Program,
315) -> Result<(), vyre::Error> {
316    let start = eval_byte_count(
317        request.offset,
318        "async load source offset",
319        invocation,
320        memory,
321        program,
322    )?;
323    let byte_count = eval_byte_count(request.size, "async load size", invocation, memory, program)?;
324    let payload = read_bytes(memory, program, request.source, start, byte_count)?;
325    ensure_writable_buffer(memory, program, request.destination)?;
326    invocation.begin_async(
327        request.tag,
328        AsyncTransfer::Copy {
329            destination: request.destination.into(),
330            start: 0,
331            payload,
332        },
333    )
334}
335
336fn eval_async_store(
337    request: AsyncStoreEval<'_>,
338    invocation: &mut Invocation<'_>,
339    memory: &mut Memory,
340    program: &Program,
341) -> Result<(), vyre::Error> {
342    let start = eval_byte_count(
343        request.offset,
344        "async store destination offset",
345        invocation,
346        memory,
347        program,
348    )?;
349    let byte_count = eval_byte_count(
350        request.size,
351        "async store size",
352        invocation,
353        memory,
354        program,
355    )?;
356    let payload = read_bytes(memory, program, request.source, 0, byte_count)?;
357    ensure_writable_buffer(memory, program, request.destination)?;
358    invocation.begin_async(
359        request.tag,
360        AsyncTransfer::Copy {
361            destination: request.destination.into(),
362            start,
363            payload,
364        },
365    )
366}
367
368fn eval_async_wait(
369    tag: &str,
370    invocation: &mut Invocation<'_>,
371    memory: &mut Memory,
372    program: &Program,
373) -> Result<(), vyre::Error> {
374    apply_async_transfer(invocation.finish_async(tag)?, memory, program)
375}
376
377fn eval_byte_count(
378    expr: &Expr,
379    label: &str,
380    invocation: &mut Invocation<'_>,
381    memory: &mut Memory,
382    program: &Program,
383) -> Result<usize, Error> {
384    let value = eval_expr::eval(expr, invocation, memory, program)?;
385    usize::try_from(value.try_as_u64().ok_or_else(|| {
386        Error::interp(format!(
387            "{label} cannot be represented as u64. Fix: use an in-range non-negative byte count."
388        ))
389    })?)
390    .map_err(|_| {
391        Error::interp(format!(
392            "{label} exceeds host usize. Fix: reduce the async transfer span."
393        ))
394    })
395}
396
397fn read_bytes(
398    memory: &Memory,
399    program: &Program,
400    source: &str,
401    start: usize,
402    byte_count: usize,
403) -> Result<Vec<u8>, Error> {
404    let buffer = resolve_buffer(memory, program, source)?;
405    let bytes = buffer
406        .bytes
407        .read()
408        .unwrap_or_else(|error| error.into_inner());
409    let mut payload = vec![0; byte_count];
410    if start < bytes.len() {
411        let available = (bytes.len() - start).min(byte_count);
412        payload[..available].copy_from_slice(&bytes[start..start + available]);
413    }
414    Ok(payload)
415}
416
417fn ensure_writable_buffer(memory: &mut Memory, program: &Program, name: &str) -> Result<(), Error> {
418    eval_expr::buffer_mut(memory, program, name).map(|_| ())
419}
420
421fn apply_async_transfer(
422    transfer: AsyncTransfer,
423    memory: &mut Memory,
424    program: &Program,
425) -> Result<(), Error> {
426    match transfer {
427        AsyncTransfer::Copy {
428            destination,
429            start,
430            payload,
431        } => {
432            let buffer = eval_expr::buffer_mut(memory, program, &destination)?;
433            let mut bytes = buffer
434                .bytes
435                .write()
436                .unwrap_or_else(|error| error.into_inner());
437            if start >= bytes.len() {
438                return Ok(());
439            }
440            let write_len = payload.len().min(bytes.len() - start);
441            bytes[start..start + write_len].copy_from_slice(&payload[..write_len]);
442            Ok(())
443        }
444    }
445}
446
447fn resolve_buffer<'a>(
448    memory: &'a Memory,
449    program: &Program,
450    name: &str,
451) -> Result<&'a oob::Buffer, Error> {
452    let decl = program.buffer(name).ok_or_else(|| {
453        Error::interp(format!(
454            "missing buffer declaration `{name}`. Fix: declare every async transfer buffer."
455        ))
456    })?;
457    if decl.access() == vyre::ir::BufferAccess::Workgroup {
458        memory.workgroup.get(name)
459    } else {
460        memory.storage.get(name)
461    }
462    .ok_or_else(|| {
463        Error::interp(format!(
464            "missing buffer `{name}`. Fix: initialize every declared async transfer buffer."
465        ))
466    })
467}
468
469fn eval_if<'a>(
470    cond: &Expr,
471    then: &'a [Node],
472    otherwise: &'a [Node],
473    node: &Node,
474    invocation: &mut Invocation<'a>,
475    memory: &mut Memory,
476    program: &Program,
477) -> Result<(), vyre::Error> {
478    let cond_value = eval_expr::eval(cond, invocation, memory, program)?.truthy();
479    if contains_barrier(then) || contains_barrier(otherwise) {
480        invocation.uniform_checks.push((node_id(node), cond_value));
481    }
482    let branch = if cond_value { then } else { otherwise };
483    invocation.push_scope();
484    invocation.frames_mut().push(Frame::Nodes {
485        nodes: branch,
486        index: 0,
487        scoped: true,
488    });
489    Ok(())
490}
491
492fn eval_loop<'a>(
493    var: &'a str,
494    from: &Expr,
495    to: &Expr,
496    body: &'a [Node],
497    invocation: &mut Invocation<'a>,
498    memory: &mut Memory,
499    program: &Program,
500) -> Result<(), vyre::Error> {
501    let from_value = eval_expr::eval(from, invocation, memory, program)?;
502    let to_value = eval_expr::eval(to, invocation, memory, program)?;
503    let from = from_value.try_as_u32().ok_or_else(|| {
504        Error::interp(format!(
505                "loop lower bound {from_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
506        ))
507    })?;
508    let to = to_value.try_as_u32().ok_or_else(|| Error::interp(format!(
509            "loop upper bound {to_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
510    )))?;
511    invocation.frames_mut().push(Frame::Loop {
512        var,
513        next: from,
514        to,
515        body,
516    });
517    Ok(())
518}
519
520fn eval_return(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
521    invocation.frames_mut().clear();
522    invocation.returned = true;
523    Ok(())
524}
525
526fn eval_block<'a>(nodes: &'a [Node], invocation: &mut Invocation<'a>) -> Result<(), vyre::Error> {
527    invocation.push_scope();
528    invocation.frames_mut().push(Frame::Nodes {
529        nodes,
530        index: 0,
531        scoped: true,
532    });
533    Ok(())
534}
535
536fn eval_barrier(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
537    invocation.waiting_at_barrier = true;
538    Ok(())
539}
540
541/// Whether any statement in `nodes` may reach a [`Node::Barrier { ordering: vyre::memory_model::MemoryOrdering::SeqCst }`], scanning
542/// child statement lists recursively with an exhaustive [`Node`] match.
543fn contains_barrier(nodes: &[Node]) -> bool {
544    nodes.iter().any(node_contains_barrier)
545}
546
547fn node_contains_barrier(node: &Node) -> bool {
548    match node {
549        Node::Barrier { .. } => true,
550        Node::Let { .. }
551        | Node::Assign { .. }
552        | Node::Store { .. }
553        | Node::Return
554        | Node::IndirectDispatch { .. }
555        | Node::AsyncLoad { .. }
556        | Node::AsyncStore { .. }
557        | Node::AsyncWait { .. }
558        | Node::Trap { .. }
559        | Node::Resume { .. }
560        | Node::Opaque(_) => false,
561        Node::If {
562            then, otherwise, ..
563        } => contains_barrier(then) || contains_barrier(otherwise),
564        Node::Loop { body, .. } => contains_barrier(body),
565        Node::Block(body) => contains_barrier(body),
566        _ => false,
567    }
568}
569
570fn node_id(node: &Node) -> usize {
571    std::ptr::from_ref(node).addr()
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577    use crate::oob::Buffer;
578    use crate::workgroup::InvocationIds;
579    use vyre::ir::{BufferDecl, DataType};
580
581    fn run_program(program: &Program, memory: &mut Memory) -> Result<(), vyre::Error> {
582        let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
583        while !invocation.done() {
584            step(&mut invocation, memory, program)?;
585        }
586        Ok(())
587    }
588
589    fn bytes(memory: &Memory, name: &str) -> Vec<u8> {
590        memory
591            .storage
592            .get(name)
593            .expect("test buffer exists")
594            .bytes
595            .read()
596            .unwrap_or_else(|error| error.into_inner())
597            .clone()
598    }
599
600    #[test]
601    fn async_load_wait_copies_payload_into_destination() {
602        let program = Program::wrapped(
603            vec![
604                BufferDecl::read("src", 0, DataType::Bytes).with_count(8),
605                BufferDecl::output("dst", 1, DataType::Bytes).with_count(8),
606            ],
607            [1, 1, 1],
608            vec![
609                Node::async_load_ext("src", "dst", Expr::u32(2), Expr::u32(4), "copy"),
610                Node::AsyncWait { tag: "copy".into() },
611            ],
612        );
613        let mut memory = Memory::empty()
614            .with_storage(
615                "src",
616                Buffer::new(vec![10, 11, 12, 13, 14, 15, 16, 17], DataType::Bytes),
617            )
618            .with_storage("dst", Buffer::new(vec![0; 8], DataType::Bytes));
619
620        run_program(&program, &mut memory).unwrap();
621
622        assert_eq!(bytes(&memory, "dst"), vec![12, 13, 14, 15, 0, 0, 0, 0]);
623    }
624
625    #[test]
626    fn async_store_wait_copies_payload_at_destination_offset() {
627        let program = Program::wrapped(
628            vec![
629                BufferDecl::read("src", 0, DataType::Bytes).with_count(4),
630                BufferDecl::output("dst", 1, DataType::Bytes).with_count(8),
631            ],
632            [1, 1, 1],
633            vec![
634                Node::async_store("src", "dst", Expr::u32(3), Expr::u32(4), "store"),
635                Node::AsyncWait {
636                    tag: "store".into(),
637                },
638            ],
639        );
640        let mut memory = Memory::empty()
641            .with_storage("src", Buffer::new(vec![21, 22, 23, 24], DataType::Bytes))
642            .with_storage("dst", Buffer::new(vec![0; 8], DataType::Bytes));
643
644        run_program(&program, &mut memory).unwrap();
645
646        assert_eq!(bytes(&memory, "dst"), vec![0, 0, 0, 21, 22, 23, 24, 0]);
647    }
648}