vyre_reference/execution/
sequential.rs1use vyre::ir::Program;
15
16use crate::workgroup::{InvocationIds, MAX_WORKGROUP_BYTES};
17
18#[derive(Debug, Clone, Copy)]
26pub struct SequentialWorkgroup {
27 pub size: [u32; 3],
29}
30
31impl SequentialWorkgroup {
32 #[must_use]
34 pub fn for_program(program: &Program) -> Self {
35 Self {
36 size: program.workgroup_size(),
37 }
38 }
39
40 #[must_use]
42 pub fn invocation_count(&self) -> u32 {
43 self.size[0]
44 .saturating_mul(self.size[1])
45 .saturating_mul(self.size[2])
46 }
47
48 pub fn invocations(&self, workgroup_id: [u32; 3]) -> impl Iterator<Item = InvocationIds> {
50 let [sx, sy, sz] = self.size;
51 let wg = workgroup_id;
52 (0..sz).flat_map(move |lz| {
53 (0..sy).flat_map(move |ly| {
54 (0..sx).map(move |lx| InvocationIds {
55 global: [
56 wg[0].saturating_mul(sx).saturating_add(lx),
57 wg[1].saturating_mul(sy).saturating_add(ly),
58 wg[2].saturating_mul(sz).saturating_add(lz),
59 ],
60 workgroup: wg,
61 local: [lx, ly, lz],
62 })
63 })
64 })
65 }
66}
67
68pub const MAX_SHARED_BYTES: usize = MAX_WORKGROUP_BYTES;
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75 use vyre::ir::{BufferDecl, DataType, Node, Program};
76
77 fn trivial_program(size: [u32; 3]) -> Program {
78 Program::wrapped(
79 vec![BufferDecl::output("out", 0, DataType::U32)],
80 size,
81 vec![Node::let_bind("idx", vyre::ir::Expr::gid_x())],
82 )
83 }
84
85 #[test]
86 fn invocation_count_is_product() {
87 let wg = SequentialWorkgroup::for_program(&trivial_program([4, 2, 1]));
88 assert_eq!(wg.invocation_count(), 8);
89 }
90
91 #[test]
92 fn invocation_order_is_canonical() {
93 let wg = SequentialWorkgroup { size: [2, 2, 1] };
94 let ids: Vec<_> = wg.invocations([0, 0, 0]).map(|i| i.local).collect();
95 assert_eq!(ids, vec![[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]);
96 }
97
98 #[test]
99 fn invocation_globals_offset_by_workgroup() {
100 let wg = SequentialWorkgroup { size: [2, 1, 1] };
101 let ids: Vec<_> = wg.invocations([3, 0, 0]).map(|i| i.global).collect();
102 assert_eq!(ids, vec![[6, 0, 0], [7, 0, 0]]);
103 }
104}