Skip to main content

vyre_reference/
interp.rs

1//! Top-level interpreter dispatch — the parity engine's executable specification.
2//!
3//! This module exists so every vyre IR program has one deterministic result
4//! defined by pure Rust code. The conform gate treats the output of `run` as
5//! the golden expected value and diffs it against the GPU backend's actual
6//! dispatch output. Any byte-level divergence is a certified bug in the backend.
7
8use std::collections::HashMap;
9
10use vyre::ir::{BufferAccess, BufferDecl, Program};
11
12use vyre::Error;
13
14use crate::{
15    eval_node,
16    oob::Buffer,
17    value::Value,
18    workgroup::{self, Invocation, Memory},
19};
20
21/// Execute a vyre IR program on the pure Rust reference interpreter.
22///
23/// `inputs` are matched to every non-workgroup buffer declaration in
24/// `Program::buffers` order. `ReadWrite` buffers consume an input value as
25/// their initial contents and are returned as raw `Value::Bytes` in declaration
26/// order after dispatch.
27///
28/// # Errors
29///
30/// Returns [`Error::Interp`] if IR validation fails, inputs are missing or
31/// excess, workgroup size is zero, names are unresolved, operand types do not
32/// match, unsupported IR is encountered, or float operations are requested
33/// before the integer-only interpreter grows full float support.
34pub fn run(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
35    let validation_errors = vyre::ir::validate(program);
36    if !validation_errors.is_empty() {
37        let messages = validation_errors
38            .into_iter()
39            .map(|error| error.message().to_string())
40            .collect::<Vec<_>>()
41            .join("; ");
42        return Err(Error::interp(format!(
43            "program failed IR validation: {messages}. Fix: repair the Program before invoking the reference interpreter."
44        )));
45    }
46
47    let Prepared {
48        storage,
49        output_names,
50        max_elements,
51    } = prepare_storage(program, inputs)?;
52    execute_dispatch(program, storage, output_names, max_elements)
53}
54
55struct Prepared {
56    storage: HashMap<String, Buffer>,
57    output_names: Vec<String>,
58    max_elements: u32,
59}
60
61fn prepare_storage(program: &Program, inputs: &[Value]) -> Result<Prepared, vyre::Error> {
62    let mut storage = HashMap::new();
63    let mut input_index = 0usize;
64    let mut output_names = Vec::new();
65    let mut max_elements = 1u32;
66
67    for decl in program.buffers() {
68        if decl.access() == BufferAccess::Workgroup {
69            continue;
70        }
71        let value = inputs
72            .get(input_index)
73            .ok_or_else(|| Error::interp(format!(
74                    "missing input for buffer `{}`. Fix: pass one Value for each non-workgroup buffer in Program::buffers order.",
75                    decl.name()
76            )))?;
77        input_index += 1;
78
79        let bytes = value.to_bytes();
80        max_elements = max_elements.max(element_count(decl, bytes.len())?);
81        if decl.access() == BufferAccess::ReadWrite {
82            output_names.push(decl.name().to_string());
83        }
84        storage.insert(
85            decl.name().to_string(),
86            Buffer {
87                bytes,
88                element: decl.element(),
89            },
90        );
91    }
92
93    if input_index != inputs.len() {
94        return Err(Error::interp(
95            "unused input values supplied. Fix: pass exactly one Value per non-workgroup buffer declaration.",
96        ));
97    }
98
99    Ok(Prepared {
100        storage,
101        output_names,
102        max_elements,
103    })
104}
105
106fn execute_dispatch(
107    program: &Program,
108    mut storage: HashMap<String, Buffer>,
109    output_names: Vec<String>,
110    max_elements: u32,
111) -> Result<Vec<Value>, vyre::Error> {
112    validate_workgroup_size(program)?;
113    let invocations_per_workgroup = invocations_per_workgroup(program);
114    let workgroup_count_x = max_elements.div_ceil(invocations_per_workgroup).max(1);
115
116    for wg_x in 0..workgroup_count_x {
117        run_workgroup(program, &mut storage, [wg_x, 0, 0])?;
118    }
119
120    output_names
121        .into_iter()
122        .map(|name| {
123            storage
124                .remove(&name)
125                .map(|buffer| Value::Bytes(buffer.bytes))
126                .ok_or_else(|| Error::interp(format!(
127                        "missing output buffer `{name}` after dispatch. Fix: keep buffer declarations unique."
128                )))
129        })
130        .collect()
131}
132
133fn validate_workgroup_size(program: &Program) -> Result<(), vyre::Error> {
134    if program.workgroup_size().contains(&0) {
135        return Err(Error::interp(
136            "workgroup size contains zero. Fix: all dimensions must be >= 1.",
137        ));
138    }
139    Ok(())
140}
141
142fn invocations_per_workgroup(program: &Program) -> u32 {
143    program
144        .workgroup_size()
145        .iter()
146        .copied()
147        .fold(1u32, u32::saturating_mul)
148        .max(1)
149}
150
151fn run_workgroup(
152    program: &Program,
153    storage: &mut HashMap<String, Buffer>,
154    workgroup_id: [u32; 3],
155) -> Result<(), vyre::Error> {
156    let mut memory = Memory {
157        storage: std::mem::take(storage),
158        workgroup: workgroup::workgroup_memory(program)?,
159    };
160    let mut invocations = workgroup::create_invocations(program, workgroup_id)?;
161    run_invocations(program, &mut memory, &mut invocations)?;
162    *storage = memory.storage;
163    Ok(())
164}
165
166fn run_invocations<'a>(
167    program: &'a Program,
168    memory: &mut Memory,
169    invocations: &mut [Invocation<'a>],
170) -> Result<(), vyre::Error> {
171    while invocations.iter().any(|invocation| !invocation.done()) {
172        let made_progress = step_round_robin(program, memory, invocations)?;
173        verify_uniform_control_flow(invocations)?;
174        if release_barrier_if_ready(invocations) {
175            continue;
176        }
177        if !made_progress && live_waiting_count(invocations) > 0 {
178            return Err(Error::interp(
179                "program violates uniform-control-flow rule: not every live invocation reached the same barrier. Fix: move Barrier to uniform control flow.",
180            ));
181        }
182    }
183    Ok(())
184}
185
186fn step_round_robin<'a>(
187    program: &'a Program,
188    memory: &mut Memory,
189    invocations: &mut [Invocation<'a>],
190) -> Result<bool, vyre::Error> {
191    let mut made_progress = false;
192    for invocation in invocations {
193        if invocation.done() || invocation.waiting_at_barrier {
194            continue;
195        }
196        eval_node::step(invocation, memory, program)?;
197        made_progress = true;
198    }
199    Ok(made_progress)
200}
201
202fn release_barrier_if_ready(invocations: &mut [Invocation<'_>]) -> bool {
203    let active = invocations
204        .iter()
205        .filter(|invocation| !invocation.done())
206        .count();
207    let waiting = live_waiting_count(invocations);
208    if active > 0 && active == waiting {
209        for invocation in invocations {
210            invocation.waiting_at_barrier = false;
211        }
212        true
213    } else {
214        false
215    }
216}
217
218fn live_waiting_count(invocations: &[Invocation<'_>]) -> usize {
219    invocations
220        .iter()
221        .filter(|invocation| !invocation.done() && invocation.waiting_at_barrier)
222        .count()
223}
224
225fn verify_uniform_control_flow(invocations: &[Invocation<'_>]) -> Result<(), vyre::Error> {
226    // Kimi audit finding #1: filter on `!done()` instead of
227    // `!returned`. A finished invocation that exited normally
228    // (`frames.is_empty()` but `returned == false`) still carries
229    // stale `uniform_checks` entries from its own past branches.
230    // Including them in the cross-invocation comparison produces
231    // false uniform-control-flow violations when a second invocation
232    // legitimately visits the same barrier with a different branch
233    // condition.
234    let mut observed: HashMap<usize, bool> = HashMap::new();
235    for invocation in invocations.iter().filter(|invocation| !invocation.done()) {
236        for (id, value) in &invocation.uniform_checks {
237            if let Some(previous) = observed.insert(*id, *value) {
238                if previous != *value {
239                    return Err(Error::interp(
240                        "program violates uniform-control-flow rule: Barrier appears inside an If whose condition differs across the workgroup. Fix: make the condition uniform or move Barrier outside the branch.",
241                    ));
242                }
243            }
244        }
245    }
246    Ok(())
247}
248
249fn element_count(decl: &BufferDecl, byte_len: usize) -> Result<u32, vyre::Error> {
250    let stride = decl.element().min_bytes();
251    if stride == 0 {
252        return u32::try_from(byte_len).map_err(|_| Error::interp(format!(
253                "buffer `{}` has {} bytes and cannot be indexed within u32 address space. Fix: shrink or split the invocation."
254                , decl.name(),
255                byte_len,
256        )));
257    }
258    let elements = byte_len / stride;
259    u32::try_from(elements).map_err(|_| Error::interp(format!(
260            "buffer `{}` has {} bytes for stride {} and overflows u32 elements. Fix: shrink declaration footprint or split work.",
261            decl.name(),
262            byte_len,
263            stride,
264    )))
265}