Skip to main content

vyre_reference/workgroup/
mod.rs

1//! Workgroup simulation — the parity engine's model of invocation coordination.
2//!
3//! GPU backends must reproduce the exact barrier synchronization, shared-memory
4//! layout, and invocation-ID arithmetic that this module defines. The conform gate
5//! compares GPU dispatch output against this deterministic CPU simulation; any
6//! divergence in control flow uniformity or workgroup memory semantics is a bug.
7
8use std::collections::{HashMap, HashSet};
9
10use vyre::ir::{BufferAccess, Node, Program};
11
12use vyre::Error;
13
14use crate::{oob::Buffer, value::Value};
15
16/// Maximum per-workgroup shared memory the reference interpreter will allocate.
17pub const MAX_WORKGROUP_BYTES: usize = 64 * 1024 * 1024;
18
19/// Identity of one compute invocation.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub struct InvocationIds {
22    /// Global invocation id.
23    pub global: [u32; 3],
24    /// Workgroup id.
25    pub workgroup: [u32; 3],
26    /// Local invocation id.
27    pub local: [u32; 3],
28}
29
30impl InvocationIds {
31    /// Zero-valued invocation ids for examples and unit tests.
32    pub const ZERO: Self = Self {
33        global: [0, 0, 0],
34        workgroup: [0, 0, 0],
35        local: [0, 0, 0],
36    };
37}
38
39/// Shared execution memory for storage and current workgroup buffers.
40#[derive(Debug)]
41pub struct Memory {
42    pub(crate) storage: HashMap<String, Buffer>,
43    pub(crate) workgroup: HashMap<String, Buffer>,
44}
45
46/// One paused or running invocation.
47pub struct Invocation<'a> {
48    /// Builtin ids for this invocation.
49    pub ids: InvocationIds,
50    pub(crate) locals: HashMap<String, Value>,
51    immutable: HashSet<String>,
52    scopes: Vec<Vec<String>>,
53    frames: Vec<Frame<'a>>,
54    /// True after `return`.
55    pub returned: bool,
56    /// True when paused at a barrier.
57    pub waiting_at_barrier: bool,
58    /// Uniform-if observations for branches that contain a barrier.
59    pub uniform_checks: Vec<(usize, bool)>,
60}
61
62/// Interpreter continuation stack.
63#[non_exhaustive]
64pub enum Frame<'a> {
65    /// Sequence of nodes.
66    Nodes {
67        /// Nodes being executed.
68        nodes: &'a [Node],
69        /// Next node index.
70        index: usize,
71        /// Whether completion pops a lexical scope.
72        scoped: bool,
73    },
74    /// Bounded `u32` loop.
75    Loop {
76        /// Loop variable name.
77        var: &'a str,
78        /// Next induction value.
79        next: u32,
80        /// Exclusive upper bound.
81        to: u32,
82        /// Loop body.
83        body: &'a [Node],
84    },
85}
86
87impl<'a> Invocation<'a> {
88    /// Create an invocation at the start of the entry point.
89    pub fn new(ids: InvocationIds, entry: &'a [Node]) -> Self {
90        Self {
91            ids,
92            locals: HashMap::new(),
93            immutable: HashSet::new(),
94            scopes: vec![Vec::new()],
95            frames: vec![Frame::Nodes {
96                nodes: entry,
97                index: 0,
98                scoped: false,
99            }],
100            returned: false,
101            waiting_at_barrier: false,
102            uniform_checks: Vec::new(),
103        }
104    }
105
106    /// Return true when no further execution can occur.
107    pub fn done(&self) -> bool {
108        self.returned || self.frames.is_empty()
109    }
110
111    /// Push a lexical scope.
112    ///
113    ///
114    /// ```rust,no_run
115    /// use vyre_reference::workgroup::{Invocation, InvocationIds};
116    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
117    /// invocation.push_scope();
118    /// ```
119    pub fn push_scope(&mut self) {
120        self.scopes.push(Vec::new());
121    }
122
123    /// Pop a lexical scope and remove bindings declared in it.
124    ///
125    ///
126    /// ```rust,no_run
127    /// use vyre_reference::workgroup::{Invocation, InvocationIds};
128    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
129    /// invocation.pop_scope();
130    /// ```
131    pub fn pop_scope(&mut self) {
132        if let Some(names) = self.scopes.pop() {
133            for name in names {
134                self.locals.remove(&name);
135                self.immutable.remove(&name);
136            }
137        }
138    }
139
140    /// Bind a mutable local.
141    ///
142    ///
143    /// ```rust,no_run
144    /// use vyre_reference::{value::Value, workgroup::{Invocation, InvocationIds}};
145    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
146    /// invocation.bind("example", Value::U32(1)).unwrap();
147    /// ```
148    pub fn bind(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
149        if self.locals.contains_key(name) {
150            return Err(Error::interp(format!(
151                "duplicate local binding `{name}`. Fix: choose a unique local name; shadowing is not allowed."
152            )));
153        }
154        self.locals.insert(name.to_string(), value);
155        if let Some(scope) = self.scopes.last_mut() {
156            scope.push(name.to_string());
157        }
158        Ok(())
159    }
160
161    /// Bind an immutable loop variable.
162    ///
163    ///
164    /// ```rust,no_run
165    /// use vyre_reference::{value::Value, workgroup::{Invocation, InvocationIds}};
166    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
167    /// invocation.bind_loop_var("example", Value::U32(1)).unwrap();
168    /// ```
169    pub fn bind_loop_var(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
170        self.bind(name, value)?;
171        self.immutable.insert(name.to_string());
172        Ok(())
173    }
174
175    /// Assign an existing mutable local.
176    pub fn assign(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
177        if self.immutable.contains(name) {
178            return Err(Error::interp(format!(
179                "assignment to loop variable `{name}`. Fix: loop variables are immutable."
180            )));
181        }
182        let Some(slot) = self.locals.get_mut(name) else {
183            return Err(Error::interp(format!(
184                "assignment to undeclared variable `{name}`. Fix: add a Let before assigning it."
185            )));
186        };
187        *slot = value;
188        Ok(())
189    }
190
191    pub(crate) fn frames_mut(&mut self) -> &mut Vec<Frame<'a>> {
192        &mut self.frames
193    }
194}
195
196pub(crate) fn create_invocations(
197    program: &Program,
198    workgroup: [u32; 3],
199) -> Result<Vec<Invocation<'_>>, vyre::Error> {
200    let global_dim = |wgid: u32, size: u32, local: u32| {
201        wgid
202            .checked_mul(size)
203            .and_then(|base| base.checked_add(local))
204            .ok_or_else(|| Error::interp(
205                "workgroup * dispatch dimensions overflow u32 global id. Fix: reduce workgroup id or workgroup size so each global_invocation_id component fits in u32.",
206            ))
207    };
208    let [sx, sy, sz] = program.workgroup_size();
209    let mut invocations = Vec::with_capacity((sx * sy * sz) as usize);
210    for z in 0..sz {
211        for y in 0..sy {
212            for x in 0..sx {
213                let local = [x, y, z];
214                let global = [
215                    global_dim(workgroup[0], sx, x)?,
216                    global_dim(workgroup[1], sy, y)?,
217                    global_dim(workgroup[2], sz, z)?,
218                ];
219                invocations.push(Invocation::new(
220                    InvocationIds {
221                        global,
222                        workgroup,
223                        local,
224                    },
225                    program.entry(),
226                ));
227            }
228        }
229    }
230    Ok(invocations)
231}
232
233pub(crate) fn workgroup_memory(program: &Program) -> Result<HashMap<String, Buffer>, vyre::Error> {
234    let mut workgroup = HashMap::new();
235    let mut allocated = 0usize;
236    for decl in program
237        .buffers()
238        .iter()
239        .filter(|decl| decl.access() == BufferAccess::Workgroup)
240    {
241        let element_size = decl.element().min_bytes();
242        let len = (decl.count() as usize)
243            .checked_mul(element_size)
244            .ok_or_else(|| Error::interp(format!(
245                    "workgroup buffer `{}` byte size overflows usize. Fix: reduce count or element size.",
246                    decl.name()
247            )))?;
248        allocated = allocated
249            .checked_add(len)
250            .ok_or_else(|| Error::interp(
251                "total workgroup memory byte size overflows usize. Fix: reduce workgroup buffer declarations.",
252            ))?;
253        if allocated > MAX_WORKGROUP_BYTES {
254            return Err(Error::interp(format!(
255                "workgroup memory requires {allocated} bytes, exceeding the {MAX_WORKGROUP_BYTES}-byte reference budget. Fix: reduce workgroup buffer counts."
256            )));
257        }
258        workgroup.insert(
259            decl.name().to_string(),
260            Buffer {
261                bytes: vec![0; len],
262                element: decl.element(),
263            },
264        );
265    }
266    Ok(workgroup)
267}