vyre_reference/workgroup/
mod.rs1use std::collections::{HashMap, HashSet};
9
10use vyre::ir::{BufferAccess, Node, Program};
11
12use vyre::Error;
13
14use crate::{oob::Buffer, value::Value};
15
16pub const MAX_WORKGROUP_BYTES: usize = 64 * 1024 * 1024;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub struct InvocationIds {
22 pub global: [u32; 3],
24 pub workgroup: [u32; 3],
26 pub local: [u32; 3],
28}
29
30impl InvocationIds {
31 pub const ZERO: Self = Self {
33 global: [0, 0, 0],
34 workgroup: [0, 0, 0],
35 local: [0, 0, 0],
36 };
37}
38
39#[derive(Debug)]
41pub struct Memory {
42 pub(crate) storage: HashMap<String, Buffer>,
43 pub(crate) workgroup: HashMap<String, Buffer>,
44}
45
46pub struct Invocation<'a> {
48 pub ids: InvocationIds,
50 pub(crate) locals: HashMap<String, Value>,
51 immutable: HashSet<String>,
52 scopes: Vec<Vec<String>>,
53 frames: Vec<Frame<'a>>,
54 pub returned: bool,
56 pub waiting_at_barrier: bool,
58 pub uniform_checks: Vec<(usize, bool)>,
60}
61
62#[non_exhaustive]
64pub enum Frame<'a> {
65 Nodes {
67 nodes: &'a [Node],
69 index: usize,
71 scoped: bool,
73 },
74 Loop {
76 var: &'a str,
78 next: u32,
80 to: u32,
82 body: &'a [Node],
84 },
85}
86
87impl<'a> Invocation<'a> {
88 pub fn new(ids: InvocationIds, entry: &'a [Node]) -> Self {
90 Self {
91 ids,
92 locals: HashMap::new(),
93 immutable: HashSet::new(),
94 scopes: vec![Vec::new()],
95 frames: vec![Frame::Nodes {
96 nodes: entry,
97 index: 0,
98 scoped: false,
99 }],
100 returned: false,
101 waiting_at_barrier: false,
102 uniform_checks: Vec::new(),
103 }
104 }
105
106 pub fn done(&self) -> bool {
108 self.returned || self.frames.is_empty()
109 }
110
111 pub fn push_scope(&mut self) {
120 self.scopes.push(Vec::new());
121 }
122
123 pub fn pop_scope(&mut self) {
132 if let Some(names) = self.scopes.pop() {
133 for name in names {
134 self.locals.remove(&name);
135 self.immutable.remove(&name);
136 }
137 }
138 }
139
140 pub fn bind(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
149 if self.locals.contains_key(name) {
150 return Err(Error::interp(format!(
151 "duplicate local binding `{name}`. Fix: choose a unique local name; shadowing is not allowed."
152 )));
153 }
154 self.locals.insert(name.to_string(), value);
155 if let Some(scope) = self.scopes.last_mut() {
156 scope.push(name.to_string());
157 }
158 Ok(())
159 }
160
161 pub fn bind_loop_var(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
170 self.bind(name, value)?;
171 self.immutable.insert(name.to_string());
172 Ok(())
173 }
174
175 pub fn assign(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
177 if self.immutable.contains(name) {
178 return Err(Error::interp(format!(
179 "assignment to loop variable `{name}`. Fix: loop variables are immutable."
180 )));
181 }
182 let Some(slot) = self.locals.get_mut(name) else {
183 return Err(Error::interp(format!(
184 "assignment to undeclared variable `{name}`. Fix: add a Let before assigning it."
185 )));
186 };
187 *slot = value;
188 Ok(())
189 }
190
191 pub(crate) fn frames_mut(&mut self) -> &mut Vec<Frame<'a>> {
192 &mut self.frames
193 }
194}
195
196pub(crate) fn create_invocations(
197 program: &Program,
198 workgroup: [u32; 3],
199) -> Result<Vec<Invocation<'_>>, vyre::Error> {
200 let global_dim = |wgid: u32, size: u32, local: u32| {
201 wgid
202 .checked_mul(size)
203 .and_then(|base| base.checked_add(local))
204 .ok_or_else(|| Error::interp(
205 "workgroup * dispatch dimensions overflow u32 global id. Fix: reduce workgroup id or workgroup size so each global_invocation_id component fits in u32.",
206 ))
207 };
208 let [sx, sy, sz] = program.workgroup_size();
209 let mut invocations = Vec::with_capacity((sx * sy * sz) as usize);
210 for z in 0..sz {
211 for y in 0..sy {
212 for x in 0..sx {
213 let local = [x, y, z];
214 let global = [
215 global_dim(workgroup[0], sx, x)?,
216 global_dim(workgroup[1], sy, y)?,
217 global_dim(workgroup[2], sz, z)?,
218 ];
219 invocations.push(Invocation::new(
220 InvocationIds {
221 global,
222 workgroup,
223 local,
224 },
225 program.entry(),
226 ));
227 }
228 }
229 }
230 Ok(invocations)
231}
232
233pub(crate) fn workgroup_memory(program: &Program) -> Result<HashMap<String, Buffer>, vyre::Error> {
234 let mut workgroup = HashMap::new();
235 let mut allocated = 0usize;
236 for decl in program
237 .buffers()
238 .iter()
239 .filter(|decl| decl.access() == BufferAccess::Workgroup)
240 {
241 let element_size = decl.element().min_bytes();
242 let len = (decl.count() as usize)
243 .checked_mul(element_size)
244 .ok_or_else(|| Error::interp(format!(
245 "workgroup buffer `{}` byte size overflows usize. Fix: reduce count or element size.",
246 decl.name()
247 )))?;
248 allocated = allocated
249 .checked_add(len)
250 .ok_or_else(|| Error::interp(
251 "total workgroup memory byte size overflows usize. Fix: reduce workgroup buffer declarations.",
252 ))?;
253 if allocated > MAX_WORKGROUP_BYTES {
254 return Err(Error::interp(format!(
255 "workgroup memory requires {allocated} bytes, exceeding the {MAX_WORKGROUP_BYTES}-byte reference budget. Fix: reduce workgroup buffer counts."
256 )));
257 }
258 workgroup.insert(
259 decl.name().to_string(),
260 Buffer {
261 bytes: vec![0; len],
262 element: decl.element(),
263 },
264 );
265 }
266 Ok(workgroup)
267}