Skip to main content

vyre_reference/execution/
sequential.rs

1//! Sequential CPU execution model for workgroup parity.
2//!
3//! Replaces the "invocation scheduler" abstraction with an obvious
4//! sequential semantic: run invocation 0, then 1, then 2... inside a
5//! workgroup. At each barrier, the outer driver re-runs every invocation
6//! from the barrier checkpoint so shared-memory effects from earlier
7//! invocations are visible to later ones. This matches backend
8//! spec-compliant semantics under the parity contract.
9//!
10//! Exposes a single entry point `run_sequential_workgroup` which the
11//! conform runner uses as its CPU oracle; [`crate::workgroup`] owns
12//! per-invocation state and memory types shared by the execution tree.
13
14use vyre::ir::Program;
15
16use crate::workgroup::{InvocationIds, MAX_WORKGROUP_BYTES};
17
18/// Driver for sequential per-invocation execution inside a workgroup.
19///
20/// Does the simplest possible thing: iterate invocations 0..N in order,
21/// each time honoring any shared-memory writes made by prior invocations.
22/// When the underlying program contains a barrier, the caller replays the
23/// full sweep from the barrier point. [`crate::workgroup`] provides the
24/// invocation and memory types used by this driver.
25#[derive(Debug, Clone, Copy)]
26pub struct SequentialWorkgroup {
27    /// Workgroup size in x/y/z.
28    pub size: [u32; 3],
29}
30
31impl SequentialWorkgroup {
32    /// Construct a driver for the program's declared workgroup size.
33    #[must_use]
34    pub fn for_program(program: &Program) -> Self {
35        Self {
36            size: program.workgroup_size(),
37        }
38    }
39
40    /// Total number of invocations in one workgroup.
41    #[must_use]
42    pub fn invocation_count(&self) -> u32 {
43        self.size[0]
44            .saturating_mul(self.size[1])
45            .saturating_mul(self.size[2])
46    }
47
48    /// Yield the invocation ids in canonical order (z-major, y-major, x-minor).
49    pub fn invocations(&self, workgroup_id: [u32; 3]) -> impl Iterator<Item = InvocationIds> {
50        let [sx, sy, sz] = self.size;
51        let wg = workgroup_id;
52        (0..sz).flat_map(move |lz| {
53            (0..sy).flat_map(move |ly| {
54                (0..sx).map(move |lx| InvocationIds {
55                    global: [
56                        wg[0].saturating_mul(sx).saturating_add(lx),
57                        wg[1].saturating_mul(sy).saturating_add(ly),
58                        wg[2].saturating_mul(sz).saturating_add(lz),
59                    ],
60                    workgroup: wg,
61                    local: [lx, ly, lz],
62                })
63            })
64        })
65    }
66}
67
68/// Maximum shared-memory allocation exported for test convenience so the
69/// sequential driver and workgroup memory model agree on bounds.
70pub const MAX_SHARED_BYTES: usize = MAX_WORKGROUP_BYTES;
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use vyre::ir::{BufferDecl, DataType, Node, Program};
76
77    fn trivial_program(size: [u32; 3]) -> Program {
78        Program::wrapped(
79            vec![BufferDecl::output("out", 0, DataType::U32)],
80            size,
81            vec![Node::let_bind("idx", vyre::ir::Expr::gid_x())],
82        )
83    }
84
85    #[test]
86    fn invocation_count_is_product() {
87        let wg = SequentialWorkgroup::for_program(&trivial_program([4, 2, 1]));
88        assert_eq!(wg.invocation_count(), 8);
89    }
90
91    #[test]
92    fn invocation_order_is_canonical() {
93        let wg = SequentialWorkgroup { size: [2, 2, 1] };
94        let ids: Vec<_> = wg.invocations([0, 0, 0]).map(|i| i.local).collect();
95        assert_eq!(ids, vec![[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]);
96    }
97
98    #[test]
99    fn invocation_globals_offset_by_workgroup() {
100        let wg = SequentialWorkgroup { size: [2, 1, 1] };
101        let ids: Vec<_> = wg.invocations([3, 0, 0]).map(|i| i.global).collect();
102        assert_eq!(ids, vec![[6, 0, 0], [7, 0, 0]]);
103    }
104}