Skip to main content

vyre_reference/
workgroup.rs

1//! Workgroup simulation  -  the parity engine's model of invocation coordination.
2//!
3//! GPU backends must reproduce the exact barrier synchronization, shared-memory
4//! layout, and invocation-ID arithmetic that this module defines. The conform gate
5//! compares GPU dispatch output against this deterministic CPU simulation; any
6//! divergence in control flow uniformity or workgroup memory semantics is a bug.
7
8use std::convert::Infallible;
9use std::ops::ControlFlow::{self, Continue};
10use std::sync::Arc;
11
12use rustc_hash::FxHashMap;
13use smallvec::SmallVec;
14#[cfg(test)]
15use vyre::ir::BufferAccess;
16use vyre::ir::{Expr, Node, Program};
17use vyre::visit::{visit_node_preorder, visit_preorder, ExprVisitor, NodeVisitor};
18use vyre::OpDef;
19use vyre_foundation::ir::model::expr::GeneratorRef;
20
21use vyre::Error;
22
23use crate::{oob::Buffer, value::Value};
24
25/// Maximum per-workgroup shared memory the reference interpreter will allocate.
26pub const MAX_WORKGROUP_BYTES: usize = 64 * 1024 * 1024;
27
28/// Small-N buffer lookup keyed by interned `Arc<str>` names.
29///
30/// Typical reference interpreter programs have ≤ 8 declared buffers. A
31/// linear scan over 8 entries is branch-predicted and hits L1 cache; hashing
32/// each access (as `HashMap<String, Buffer>` did) burned a SipHash-1-3 on
33/// every load/store in the inner interpreter loop. This struct preserves
34/// the public `get` / `get_mut` / `insert` shape consumers depend on while
35/// eliminating the per-lookup hash + heap traffic.
36#[derive(Debug, Default, Clone)]
37pub struct BufferMap {
38    entries: SmallVec<[(Arc<str>, Buffer); 8]>,
39}
40
41impl BufferMap {
42    /// Construct an empty map.
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            entries: SmallVec::new(),
47        }
48    }
49
50    /// Look up a buffer by name.
51    #[must_use]
52    pub fn get(&self, name: &str) -> Option<&Buffer> {
53        self.entries
54            .iter()
55            .find(|(key, _)| key.as_ref() == name)
56            .map(|(_, buffer)| buffer)
57    }
58
59    /// Look up a mutable buffer by name.
60    pub fn get_mut(&mut self, name: &str) -> Option<&mut Buffer> {
61        self.entries
62            .iter_mut()
63            .find(|(key, _)| key.as_ref() == name)
64            .map(|(_, buffer)| buffer)
65    }
66
67    /// Insert or overwrite a buffer. Returns the previous value when the
68    /// key already existed.
69    pub fn insert(&mut self, name: impl Into<Arc<str>>, buffer: Buffer) -> Option<Buffer> {
70        let name = name.into();
71        if let Some(entry) = self
72            .entries
73            .iter_mut()
74            .find(|(key, _)| key.as_ref() == name.as_ref())
75        {
76            return Some(std::mem::replace(&mut entry.1, buffer));
77        }
78        self.entries.push((name, buffer));
79        None
80    }
81
82    /// Iterate `(name, buffer)` pairs in insertion order.
83    pub fn iter(&self) -> impl Iterator<Item = (&str, &Buffer)> {
84        self.entries
85            .iter()
86            .map(|(name, buffer)| (name.as_ref(), buffer))
87    }
88
89    /// Move-iterate `(name, buffer)` pairs.
90    pub fn into_iter_pairs(self) -> impl Iterator<Item = (Arc<str>, Buffer)> {
91        self.entries.into_iter()
92    }
93
94    /// Number of entries.
95    #[must_use]
96    pub fn len(&self) -> usize {
97        self.entries.len()
98    }
99
100    /// True when empty.
101    #[must_use]
102    pub fn is_empty(&self) -> bool {
103        self.entries.is_empty()
104    }
105}
106
107/// Identity of one compute invocation.
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub struct InvocationIds {
110    /// Global invocation id.
111    pub global: [u32; 3],
112    /// Workgroup id.
113    pub workgroup: [u32; 3],
114    /// Local invocation id.
115    pub local: [u32; 3],
116}
117
118impl InvocationIds {
119    /// Zero-valued invocation ids for examples and unit tests.
120    pub const ZERO: Self = Self {
121        global: [0, 0, 0],
122        workgroup: [0, 0, 0],
123        local: [0, 0, 0],
124    };
125}
126
127/// Shared execution memory for storage and current workgroup buffers.
128#[derive(Debug, Default, Clone)]
129pub struct Memory {
130    pub(crate) storage: BufferMap,
131    pub(crate) workgroup: BufferMap,
132}
133
134impl Memory {
135    /// Create empty memory for test fixtures.
136    #[must_use]
137    pub fn empty() -> Self {
138        Self::default()
139    }
140
141    /// Add a storage buffer.
142    #[must_use]
143    pub fn with_storage(mut self, name: impl Into<Arc<str>>, buffer: Buffer) -> Self {
144        self.storage.insert(name, buffer);
145        self
146    }
147
148    /// Add a workgroup buffer.
149    #[must_use]
150    pub fn with_workgroup(mut self, name: impl Into<Arc<str>>, buffer: Buffer) -> Self {
151        self.workgroup.insert(name, buffer);
152        self
153    }
154
155    /// Build a single byte payload memory used by canonical primitive evaluators.
156    #[must_use]
157    pub fn from_bytes(bytes: Vec<u8>) -> Self {
158        let mut storage = BufferMap::new();
159        storage.insert("__value", Buffer::new(bytes, vyre::ir::DataType::Bytes));
160        Self {
161            storage,
162            workgroup: BufferMap::new(),
163        }
164    }
165
166    /// Return the byte payload for canonical primitive evaluators.
167    #[must_use]
168    pub fn bytes(&self) -> Vec<u8> {
169        self.storage.get("__value").map_or_else(Vec::new, |buffer| {
170            buffer
171                .bytes
172                .read()
173                .unwrap_or_else(|error| error.into_inner())
174                .clone()
175        })
176    }
177
178    /// Consume this memory and return the byte payload for canonical primitives.
179    #[must_use]
180    pub fn into_bytes(self) -> Vec<u8> {
181        self.storage
182            .into_iter_pairs()
183            .find_map(|(name, buffer)| {
184                (name.as_ref() == "__value").then(|| {
185                    std::sync::Arc::try_unwrap(buffer.bytes)
186                        .map(|rw| rw.into_inner().unwrap_or_else(|error| error.into_inner()))
187                        .unwrap_or_else(|a| {
188                            a.read().unwrap_or_else(|error| error.into_inner()).clone()
189                        })
190                })
191            })
192            .unwrap_or_default()
193    }
194}
195
196/// Shared slot layout for all locals in one program.
197#[derive(Debug, Default)]
198pub struct LocalSlots {
199    names: rustc_hash::FxHashMap<Arc<str>, usize>,
200    slot_names: Vec<Arc<str>>,
201}
202
203impl LocalSlots {
204    /// Build a slot layout from every binding site in a program.
205    #[must_use]
206    pub fn for_program(program: &Program) -> Self {
207        Self::for_nodes(program.entry())
208    }
209
210    /// Build a slot layout from a node slice.
211    #[must_use]
212    pub fn for_nodes(nodes: &[Node]) -> Self {
213        let mut slots = Self::default();
214        for node in nodes {
215            match visit_node_preorder(&mut slots, node) {
216                Continue(()) => {}
217                ControlFlow::Break(never) => match never {},
218            }
219        }
220        slots
221    }
222
223    fn slot(&self, name: &str) -> Option<usize> {
224        self.names.get(name).copied()
225    }
226
227    fn len(&self) -> usize {
228        self.slot_names.len()
229    }
230
231    fn intern(&mut self, name: &str) {
232        if self.names.contains_key(name) {
233            return;
234        }
235        let slot = self.slot_names.len();
236        let name: Arc<str> = Arc::from(name);
237        self.slot_names.push(Arc::clone(&name));
238        self.names.insert(name, slot);
239    }
240}
241
242impl ExprVisitor for LocalSlots {
243    type Break = Infallible;
244}
245
246impl NodeVisitor for LocalSlots {
247    type Break = Infallible;
248
249    fn visit_let(
250        &mut self,
251        _: &Node,
252        name: &vyre::ir::Ident,
253        value: &Expr,
254    ) -> ControlFlow<Self::Break> {
255        self.intern(name);
256        visit_preorder(self, value)
257    }
258
259    fn visit_assign(
260        &mut self,
261        _: &Node,
262        _: &vyre::ir::Ident,
263        value: &Expr,
264    ) -> ControlFlow<Self::Break> {
265        visit_preorder(self, value)
266    }
267
268    fn visit_store(
269        &mut self,
270        _: &Node,
271        _: &vyre::ir::Ident,
272        index: &Expr,
273        value: &Expr,
274    ) -> ControlFlow<Self::Break> {
275        visit_preorder(self, index)?;
276        visit_preorder(self, value)
277    }
278
279    fn visit_if(
280        &mut self,
281        _: &Node,
282        cond: &Expr,
283        _: &[Node],
284        _: &[Node],
285    ) -> ControlFlow<Self::Break> {
286        visit_preorder(self, cond)
287    }
288
289    fn visit_loop(
290        &mut self,
291        _: &Node,
292        var: &vyre::ir::Ident,
293        from: &Expr,
294        to: &Expr,
295        _: &[Node],
296    ) -> ControlFlow<Self::Break> {
297        self.intern(var);
298        visit_preorder(self, from)?;
299        visit_preorder(self, to)
300    }
301
302    fn visit_indirect_dispatch(
303        &mut self,
304        _: &Node,
305        _: &vyre::ir::Ident,
306        _: u64,
307    ) -> ControlFlow<Self::Break> {
308        Continue(())
309    }
310
311    fn visit_async_load(
312        &mut self,
313        _: &Node,
314        _: &vyre::ir::Ident,
315        _: &vyre::ir::Ident,
316        offset: &Expr,
317        size: &Expr,
318        _: &vyre::ir::Ident,
319    ) -> ControlFlow<Self::Break> {
320        visit_preorder(self, offset)?;
321        visit_preorder(self, size)
322    }
323
324    fn visit_async_store(
325        &mut self,
326        _: &Node,
327        _: &vyre::ir::Ident,
328        _: &vyre::ir::Ident,
329        offset: &Expr,
330        size: &Expr,
331        _: &vyre::ir::Ident,
332    ) -> ControlFlow<Self::Break> {
333        visit_preorder(self, offset)?;
334        visit_preorder(self, size)
335    }
336
337    fn visit_async_wait(&mut self, _: &Node, _: &vyre::ir::Ident) -> ControlFlow<Self::Break> {
338        Continue(())
339    }
340
341    fn visit_trap(
342        &mut self,
343        _: &Node,
344        address: &Expr,
345        _: &vyre::ir::Ident,
346    ) -> ControlFlow<Self::Break> {
347        visit_preorder(self, address)
348    }
349
350    fn visit_resume(&mut self, _: &Node, _: &vyre::ir::Ident) -> ControlFlow<Self::Break> {
351        Continue(())
352    }
353
354    fn visit_return(&mut self, _: &Node) -> ControlFlow<Self::Break> {
355        Continue(())
356    }
357
358    fn visit_barrier(&mut self, _: &Node) -> ControlFlow<Self::Break> {
359        Continue(())
360    }
361
362    fn visit_block(&mut self, _: &Node, _: &[Node]) -> ControlFlow<Self::Break> {
363        Continue(())
364    }
365
366    fn visit_region(
367        &mut self,
368        _: &Node,
369        _: &vyre::ir::Ident,
370        _: &Option<GeneratorRef>,
371        _: &[Node],
372    ) -> ControlFlow<Self::Break> {
373        Continue(())
374    }
375
376    fn visit_opaque_node(
377        &mut self,
378        _: &Node,
379        _: &dyn vyre::ir::NodeExtension,
380    ) -> ControlFlow<Self::Break> {
381        Continue(())
382    }
383}
384
385/// One paused or running invocation.
386pub struct Invocation<'a> {
387    /// Builtin ids for this invocation.
388    pub ids: InvocationIds,
389    slots: Arc<LocalSlots>,
390    locals: Vec<Option<Value>>,
391    immutable: Vec<bool>,
392    scopes: Vec<Vec<usize>>,
393    frames: Vec<Frame<'a>>,
394    /// True after `return`.
395    pub returned: bool,
396    /// True when paused at a barrier.
397    pub waiting_at_barrier: bool,
398    /// Uniform-if observations for branches that contain a barrier.
399    pub uniform_checks: Vec<(usize, bool)>,
400    /// Async transfers started by `AsyncLoad`/`AsyncStore` and pending
401    /// observation by `AsyncWait`.
402    pub(crate) pending_async: FxHashMap<Arc<str>, AsyncTransfer>,
403    pub(crate) op_cache: FxHashMap<*const Expr, ResolvedCall>,
404}
405
406#[derive(Debug, Clone, Copy)]
407pub(crate) struct ResolvedCall {
408    pub(crate) def: &'static OpDef,
409}
410
411/// Interpreter continuation stack.
412#[non_exhaustive]
413pub enum Frame<'a> {
414    /// Sequence of nodes.
415    Nodes {
416        /// Nodes being executed.
417        nodes: &'a [Node],
418        /// Next node index.
419        index: usize,
420        /// Whether completion pops a lexical scope.
421        scoped: bool,
422    },
423    /// Bounded `u32` loop.
424    Loop {
425        /// Loop variable name.
426        var: &'a str,
427        /// Next induction value.
428        next: u32,
429        /// Exclusive upper bound.
430        to: u32,
431        /// Loop body.
432        body: &'a [Node],
433    },
434}
435
436impl<'a> Invocation<'a> {
437    /// Create an invocation at the start of the entry point.
438    pub fn new(ids: InvocationIds, entry: &'a [Node]) -> Self {
439        Self::with_slots(ids, entry, Arc::new(LocalSlots::for_nodes(entry)))
440    }
441
442    pub(crate) fn with_slots(
443        ids: InvocationIds,
444        entry: &'a [Node],
445        slots: Arc<LocalSlots>,
446    ) -> Self {
447        let slot_count = slots.len();
448        Self {
449            ids,
450            slots,
451
452            locals: vec![None; slot_count],
453            immutable: vec![false; slot_count],
454            scopes: vec![Vec::new()],
455            frames: vec![Frame::Nodes {
456                nodes: entry,
457                index: 0,
458                scoped: false,
459            }],
460            returned: false,
461            waiting_at_barrier: false,
462            uniform_checks: Vec::new(),
463            pending_async: FxHashMap::default(),
464            op_cache: FxHashMap::default(),
465        }
466    }
467
468    /// Return true when no further execution can occur.
469    pub fn done(&self) -> bool {
470        self.returned || self.frames.is_empty()
471    }
472
473    /// Push a lexical scope.
474    ///
475    ///
476    /// ```rust,no_run
477    /// use vyre_reference::workgroup::{Invocation, InvocationIds};
478    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
479    /// invocation.push_scope();
480    /// ```
481    pub fn push_scope(&mut self) {
482        self.scopes.push(Vec::new());
483    }
484
485    /// Pop a lexical scope and remove bindings declared in it.
486    ///
487    ///
488    /// ```rust,no_run
489    /// use vyre_reference::workgroup::{Invocation, InvocationIds};
490    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
491    /// invocation.pop_scope();
492    /// ```
493    pub fn pop_scope(&mut self) {
494        if let Some(names) = self.scopes.pop() {
495            for slot in names {
496                self.locals[slot] = None;
497                self.immutable[slot] = false;
498            }
499        }
500    }
501
502    pub(crate) fn begin_async(&mut self, tag: &str, transfer: AsyncTransfer) -> Result<(), Error> {
503        let tag: Arc<str> = Arc::from(tag);
504        if self.pending_async.insert(tag.clone(), transfer).is_some() {
505            return Err(Error::interp(format!(
506                "async tag `{}` was started more than once before a matching wait. \
507                 Fix: reuse the tag only after AsyncWait completes.",
508                tag
509            )));
510        }
511        Ok(())
512    }
513
514    pub(crate) fn finish_async(&mut self, tag: &str) -> Result<AsyncTransfer, Error> {
515        self.pending_async.remove(tag).ok_or_else(|| Error::interp(format!(
516            "async wait for tag `{tag}` has no matching async load. Fix: emit AsyncLoad before AsyncWait."
517        )))
518    }
519
520    /// Look up an active local by name.
521    pub fn local(&self, name: &str) -> Option<&Value> {
522        self.slots
523            .slot(name)
524            .and_then(|slot| self.locals.get(slot))
525            .and_then(Option::as_ref)
526    }
527
528    /// Bind a mutable local.
529    ///
530    ///
531    /// ```rust,no_run
532    /// use vyre_reference::{value::Value, workgroup::{Invocation, InvocationIds}};
533    /// fn main() -> Result<(), vyre_foundation::Error> {
534    ///     let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
535    ///     invocation.bind("example", Value::U32(1))?;
536    ///     Ok(())
537    /// }
538    /// ```
539    pub fn bind(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
540        let slot = self.slots.slot(name).ok_or_else(|| {
541            Error::interp(format!(
542                "local binding `{name}` has no preassigned slot. Fix: rebuild the local slot layout from the full Program before interpretation."
543            ))
544        })?;
545        if self.locals[slot].is_some() {
546            return Err(Error::interp(format!(
547                "duplicate local binding `{name}`. Fix: choose a unique local name; shadowing is not allowed."
548            )));
549        }
550        self.locals[slot] = Some(value);
551        if let Some(scope) = self.scopes.last_mut() {
552            scope.push(slot);
553        }
554        Ok(())
555    }
556
557    /// Bind an immutable loop variable.
558    ///
559    ///
560    /// ```rust,no_run
561    /// use vyre_reference::{value::Value, workgroup::{Invocation, InvocationIds}};
562    /// fn main() -> Result<(), vyre_foundation::Error> {
563    ///     let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
564    ///     invocation.bind_loop_var("example", Value::U32(1))?;
565    ///     Ok(())
566    /// }
567    /// ```
568    pub fn bind_loop_var(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
569        self.bind(name, value)?;
570        let slot = self.slots.slot(name).ok_or_else(|| {
571            Error::interp(format!(
572                "local binding `{name}` disappeared after bind. Fix: keep local slot layout immutable during interpretation."
573            ))
574        })?;
575        self.immutable[slot] = true;
576        Ok(())
577    }
578
579    /// Assign an existing mutable local.
580    pub fn assign(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
581        let slot = self.slots.slot(name).ok_or_else(|| {
582            Error::interp(format!(
583                "assignment to undeclared variable `{name}`. Fix: add a Let before assigning it."
584            ))
585        })?;
586        if self.immutable[slot] {
587            return Err(Error::interp(format!(
588                "assignment to loop variable `{name}`. Fix: loop variables are immutable."
589            )));
590        }
591        let Some(local) = self.locals.get_mut(slot).and_then(Option::as_mut) else {
592            return Err(Error::interp(format!(
593                "assignment to undeclared variable `{name}`. Fix: add a Let before assigning it."
594            )));
595        };
596        *local = value;
597        Ok(())
598    }
599
600    pub(crate) fn frames_mut(&mut self) -> &mut Vec<Frame<'a>> {
601        &mut self.frames
602    }
603}
604
605/// Deferred byte-copy transfer for the workgroup reference scheduler.
606pub(crate) enum AsyncTransfer {
607    /// Copy `payload` into `destination` starting at byte offset `start`.
608    Copy {
609        destination: Arc<str>,
610        start: usize,
611        payload: Vec<u8>,
612    },
613}
614
615#[cfg(test)]
616#[allow(dead_code)]
617pub(crate) fn create_invocations(
618    program: &Program,
619    workgroup: [u32; 3],
620    slots: Arc<LocalSlots>,
621) -> Result<Vec<Invocation<'_>>, vyre::Error> {
622    let global_dim = |wgid: u32, size: u32, local: u32| {
623        wgid
624            .checked_mul(size)
625            .and_then(|base| base.checked_add(local))
626            .ok_or_else(|| Error::interp(
627                "workgroup * dispatch dimensions overflow u32 global id. Fix: reduce workgroup id or workgroup size so each global_invocation_id component fits in u32.",
628            ))
629    };
630    let [sx, sy, sz] = program.workgroup_size();
631    let invocation_count = sx
632        .checked_mul(sy)
633        .and_then(|count| count.checked_mul(sz))
634        .ok_or_else(|| {
635            Error::interp(
636                "workgroup invocation count overflows u32. Fix: reduce workgroup dimensions before reference execution.",
637            )
638        })?;
639    let mut invocations = Vec::with_capacity(usize::try_from(invocation_count).map_err(|_| {
640        Error::interp(
641            "workgroup invocation count exceeds host usize. Fix: reduce workgroup dimensions before reference execution.",
642        )
643    })?);
644    for z in 0..sz {
645        for y in 0..sy {
646            for x in 0..sx {
647                let local = [x, y, z];
648                let global = [
649                    global_dim(workgroup[0], sx, x)?,
650                    global_dim(workgroup[1], sy, y)?,
651                    global_dim(workgroup[2], sz, z)?,
652                ];
653                invocations.push(Invocation::with_slots(
654                    InvocationIds {
655                        global,
656                        workgroup,
657                        local,
658                    },
659                    program.entry(),
660                    Arc::clone(&slots),
661                ));
662            }
663        }
664    }
665    Ok(invocations)
666}
667
668#[cfg(test)]
669#[allow(dead_code)]
670pub(crate) fn workgroup_memory(program: &Program) -> Result<BufferMap, vyre::Error> {
671    let mut workgroup = BufferMap::new();
672    let mut allocated = 0usize;
673    for decl in program
674        .buffers()
675        .iter()
676        .filter(|decl| decl.access() == BufferAccess::Workgroup)
677    {
678        let element_size = decl.element().min_bytes();
679        let len = (decl.count() as usize)
680            .checked_mul(element_size)
681            .ok_or_else(|| Error::interp(format!(
682                    "workgroup buffer `{}` byte size overflows usize. Fix: reduce count or element size.",
683                    decl.name()
684            )))?;
685        allocated = allocated
686            .checked_add(len)
687            .ok_or_else(|| Error::interp(
688                "total workgroup memory byte size overflows usize. Fix: reduce workgroup buffer declarations.",
689            ))?;
690        if allocated > MAX_WORKGROUP_BYTES {
691            return Err(Error::interp(format!(
692                "workgroup memory requires {allocated} bytes, exceeding the {MAX_WORKGROUP_BYTES}-byte reference budget. Fix: reduce workgroup buffer counts."
693            )));
694        }
695        workgroup.insert(decl.name(), Buffer::new(vec![0; len], decl.element()));
696    }
697    Ok(workgroup)
698}