1use vyre::ir::{Expr, Node, Program};
10
11use crate::{
12 execution::expr as eval_expr,
13 execution::node_tree::{contains_barrier, node_id},
14 oob,
15 workgroup::{AsyncTransfer, Frame, Invocation, Memory},
16};
17use vyre::Error;
18
19pub fn step<'a>(
26 invocation: &mut Invocation<'a>,
27 memory: &mut Memory,
28 program: &'a Program,
29) -> Result<(), vyre::Error> {
30 if invocation.done() || invocation.waiting_at_barrier {
31 return Ok(());
32 }
33
34 loop {
35 let Some(frame) = invocation.frames_mut().pop() else {
36 return Ok(());
37 };
38 match frame {
39 Frame::Nodes {
40 nodes,
41 index,
42 scoped,
43 } => {
44 if step_nodes_frame(invocation, memory, program, nodes, index, scoped)? {
45 return Ok(());
46 }
47 }
48 Frame::Loop {
49 var,
50 next,
51 to,
52 body,
53 } => step_loop_frame(invocation, var, next, to, body)?,
54 }
55 }
56}
57
58fn step_nodes_frame<'a>(
59 invocation: &mut Invocation<'a>,
60 memory: &mut Memory,
61 program: &'a Program,
62 nodes: &'a [Node],
63 index: usize,
64 scoped: bool,
65) -> Result<bool, vyre::Error> {
66 if index >= nodes.len() {
67 if scoped {
68 invocation.pop_scope();
69 }
70 return Ok(false);
71 }
72
73 invocation.frames_mut().push(Frame::Nodes {
74 nodes,
75 index: index + 1,
76 scoped,
77 });
78 execute_node(&nodes[index], invocation, memory, program)?;
79 Ok(true)
80}
81
82fn step_loop_frame<'a>(
83 invocation: &mut Invocation<'a>,
84 var: &'a str,
85 next: u32,
86 to: u32,
87 body: &'a [Node],
88) -> Result<(), vyre::Error> {
89 if next >= to {
90 return Ok(());
91 }
92 invocation.frames_mut().push(Frame::Loop {
93 var,
94 next: next.wrapping_add(1),
95 to,
96 body,
97 });
98 invocation.push_scope();
99 invocation.bind_loop_var(var, crate::value::Value::U32(next))?;
100 invocation.frames_mut().push(Frame::Nodes {
101 nodes: body,
102 index: 0,
103 scoped: true,
104 });
105 Ok(())
106}
107
108fn execute_node<'a>(
109 node: &'a Node,
110 invocation: &mut Invocation<'a>,
111 memory: &mut Memory,
112 program: &'a Program,
113) -> Result<(), vyre::Error> {
114 match node {
115 Node::Let { name, value } => eval_let(name, value, invocation, memory, program),
116 Node::Assign { name, value } => eval_assign(name, value, invocation, memory, program),
117 Node::Store {
118 buffer,
119 index,
120 value,
121 } => eval_store(buffer, index, value, invocation, memory, program),
122 Node::If {
123 cond,
124 then,
125 otherwise,
126 } => eval_if(cond, then, otherwise, node, invocation, memory, program),
127 Node::Loop {
128 var,
129 from,
130 to,
131 body,
132 } => eval_loop(var, from, to, body, invocation, memory, program),
133 Node::Return => eval_return(invocation),
134 Node::Block(nodes) => eval_block(nodes, invocation),
135 Node::Barrier { .. } => eval_barrier(invocation),
136 Node::IndirectDispatch {
137 count_buffer,
138 count_offset,
139 } => eval_indirect_dispatch(count_buffer, *count_offset, memory, program),
140 Node::AsyncLoad {
141 source,
142 destination,
143 offset,
144 size,
145 tag,
146 } => eval_async_load(
147 AsyncLoadEval {
148 source,
149 destination,
150 offset,
151 size,
152 tag,
153 },
154 invocation,
155 memory,
156 program,
157 ),
158 Node::AsyncStore {
159 source,
160 destination,
161 offset,
162 size,
163 tag,
164 } => eval_async_store(
165 AsyncStoreEval {
166 source,
167 destination,
168 offset,
169 size,
170 tag,
171 },
172 invocation,
173 memory,
174 program,
175 ),
176 Node::AsyncWait { tag } => eval_async_wait(tag, invocation, memory, program),
177 Node::Trap { address, tag } => {
178 let address = eval_expr::eval(address, invocation, memory, program)?
179 .try_as_u32()
180 .ok_or_else(|| {
181 Error::interp(format!(
182 "reference trap `{tag}` address is not a u32. Fix: pass a scalar u32 trap address."
183 ))
184 })?;
185 Err(vyre::Error::interp(format!(
186 "reference dispatch trapped: address={address}, tag=`{tag}`. Fix: handle the trap condition or route this Program through a backend/runtime with replay support."
187 )))
188 }
189 Node::Resume { tag } => Err(vyre::Error::interp(format!(
190 "reference dispatch reached Resume `{tag}` without a replay runtime. Fix: lower Resume through a runtime-owned replay path before reference execution."
191 ))),
192 Node::AllReduce { buffer, group, .. } => Err(vyre::Error::interp(format!(
193 "reference dispatch reached AllReduce on buffer `{buffer}` for group {}. Fix: run this Program on a distributed backend with collective support or lower the single-rank collective before reference execution.",
194 group.as_u32()
195 ))),
196 Node::AllGather {
197 input,
198 output,
199 group,
200 } => Err(vyre::Error::interp(format!(
201 "reference dispatch reached AllGather `{input}` -> `{output}` for group {}. Fix: run this Program on a distributed backend with collective support or lower the single-rank collective before reference execution.",
202 group.as_u32()
203 ))),
204 Node::ReduceScatter {
205 input,
206 output,
207 group,
208 ..
209 } => Err(vyre::Error::interp(format!(
210 "reference dispatch reached ReduceScatter `{input}` -> `{output}` for group {}. Fix: run this Program on a distributed backend with collective support or lower the single-rank collective before reference execution.",
211 group.as_u32()
212 ))),
213 Node::Broadcast {
214 buffer,
215 root,
216 group,
217 } => Err(vyre::Error::interp(format!(
218 "reference dispatch reached Broadcast on buffer `{buffer}` from root {root} for group {}. Fix: run this Program on a distributed backend with collective support or lower the single-rank collective before reference execution.",
219 group.as_u32()
220 ))),
221 Node::Region { body, .. } => eval_block(body, invocation),
222 Node::Opaque(extension) => Err(vyre::Error::interp(format!(
223 "reference interpreter does not support opaque node extension `{}`/`{}`. Fix: provide a reference evaluator for this NodeExtension or lower it to core Node variants before evaluation.",
224 extension.extension_kind(),
225 extension.debug_identity()
226 ))),
227 _ => Err(vyre::Error::interp(
228 "reference interpreter encountered an unknown Node variant. Fix: update vyre-reference before executing this IR.",
229 )),
230 }
231}
232
233fn eval_let(
234 name: &str,
235 value: &Expr,
236 invocation: &mut Invocation<'_>,
237 memory: &mut Memory,
238 program: &Program,
239) -> Result<(), vyre::Error> {
240 let value = eval_expr::eval(value, invocation, memory, program)?;
241 invocation.bind(name, value)
242}
243
244fn eval_assign(
245 name: &str,
246 value: &Expr,
247 invocation: &mut Invocation<'_>,
248 memory: &mut Memory,
249 program: &Program,
250) -> Result<(), vyre::Error> {
251 let value = eval_expr::eval(value, invocation, memory, program)?;
252 invocation.assign(name, value)
253}
254
255fn eval_store(
256 buffer: &str,
257 index: &Expr,
258 value: &Expr,
259 invocation: &mut Invocation<'_>,
260 memory: &mut Memory,
261 program: &Program,
262) -> Result<(), vyre::Error> {
263 let index = eval_expr::eval(index, invocation, memory, program)?;
264 let index = index
265 .try_as_u32()
266 .ok_or_else(|| Error::interp(format!(
267 "store index {index:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
268 )))?;
269 let value = eval_expr::eval(value, invocation, memory, program)?;
270 let target = eval_expr::buffer_mut(memory, program, buffer)?;
271 oob::store(target, index, &value);
272 Ok(())
273}
274
275fn eval_indirect_dispatch(
276 count_buffer: &str,
277 count_offset: u64,
278 memory: &Memory,
279 program: &Program,
280) -> Result<(), vyre::Error> {
281 if count_offset % 4 != 0 {
282 return Err(Error::interp(format!(
283 "indirect dispatch offset {count_offset} is not 4-byte aligned. Fix: use a u32-aligned dispatch tuple."
284 )));
285 }
286 let decl = program.buffer(count_buffer).ok_or_else(|| {
287 Error::interp(format!(
288 "indirect dispatch references unknown buffer `{count_buffer}`. Fix: declare the count buffer before execution."
289 ))
290 })?;
291 let buffer = if decl.access() == vyre::ir::BufferAccess::Workgroup {
292 memory.workgroup.get(count_buffer)
293 } else {
294 memory.storage.get(count_buffer)
295 }
296 .ok_or_else(|| {
297 Error::interp(format!(
298 "indirect dispatch buffer `{count_buffer}` is missing. Fix: initialize the count buffer before execution."
299 ))
300 })?;
301 let required_end = count_offset.checked_add(12).ok_or_else(|| {
302 Error::interp(
303 "indirect dispatch byte range overflowed u64. Fix: shrink the count offset."
304 .to_string(),
305 )
306 })?;
307 let byte_len = buffer
308 .bytes
309 .read()
310 .map_err(|_| {
311 Error::interp(format!(
312 "indirect dispatch buffer `{count_buffer}` lock is poisoned. Fix: rebuild the interpreter memory state before execution."
313 ))
314 })?
315 .len();
316 if u64::try_from(byte_len).unwrap_or(u64::MAX) < required_end {
317 return Err(Error::interp(format!(
318 "indirect dispatch buffer `{count_buffer}` is too short for a 3-word dispatch tuple at byte offset {count_offset}. Fix: provide 12 readable bytes starting at that offset."
319 )));
320 }
321 Err(Error::interp(format!(
322 "Node::IndirectDispatch cannot execute in the sequential reference interpreter because dynamic indirect dispatch requires runtime queue scheduling. Fix: run this program on a backend/runtime that supports indirect dispatch or lower `{count_buffer}` at byte offset {count_offset} to a static workgroup grid before reference execution."
323 )))
324}
325
326struct AsyncLoadEval<'a> {
327 source: &'a str,
328 destination: &'a str,
329 offset: &'a Expr,
330 size: &'a Expr,
331 tag: &'a str,
332}
333
334struct AsyncStoreEval<'a> {
335 source: &'a str,
336 destination: &'a str,
337 offset: &'a Expr,
338 size: &'a Expr,
339 tag: &'a str,
340}
341
342fn eval_async_load(
343 request: AsyncLoadEval<'_>,
344 invocation: &mut Invocation<'_>,
345 memory: &mut Memory,
346 program: &Program,
347) -> Result<(), vyre::Error> {
348 let start = eval_byte_count(
349 request.offset,
350 "async load source offset",
351 invocation,
352 memory,
353 program,
354 )?;
355 let byte_count = eval_byte_count(request.size, "async load size", invocation, memory, program)?;
356 let payload = read_bytes(memory, program, request.source, start, byte_count)?;
357 ensure_writable_buffer(memory, program, request.destination)?;
358 invocation.begin_async(
359 request.tag,
360 AsyncTransfer::Copy {
361 destination: request.destination.into(),
362 start: 0,
363 payload,
364 },
365 )
366}
367
368fn eval_async_store(
369 request: AsyncStoreEval<'_>,
370 invocation: &mut Invocation<'_>,
371 memory: &mut Memory,
372 program: &Program,
373) -> Result<(), vyre::Error> {
374 let start = eval_byte_count(
375 request.offset,
376 "async store destination offset",
377 invocation,
378 memory,
379 program,
380 )?;
381 let byte_count = eval_byte_count(
382 request.size,
383 "async store size",
384 invocation,
385 memory,
386 program,
387 )?;
388 let payload = read_bytes(memory, program, request.source, 0, byte_count)?;
389 ensure_writable_buffer(memory, program, request.destination)?;
390 invocation.begin_async(
391 request.tag,
392 AsyncTransfer::Copy {
393 destination: request.destination.into(),
394 start,
395 payload,
396 },
397 )
398}
399
400fn eval_async_wait(
401 tag: &str,
402 invocation: &mut Invocation<'_>,
403 memory: &mut Memory,
404 program: &Program,
405) -> Result<(), vyre::Error> {
406 apply_async_transfer(invocation.finish_async(tag)?, memory, program)
407}
408
409fn eval_byte_count(
410 expr: &Expr,
411 label: &str,
412 invocation: &mut Invocation<'_>,
413 memory: &mut Memory,
414 program: &Program,
415) -> Result<usize, Error> {
416 let value = eval_expr::eval(expr, invocation, memory, program)?;
417 usize::try_from(value.try_as_u64().ok_or_else(|| {
418 Error::interp(format!(
419 "{label} cannot be represented as u64. Fix: use an in-range non-negative byte count."
420 ))
421 })?)
422 .map_err(|_| {
423 Error::interp(format!(
424 "{label} exceeds host usize. Fix: reduce the async transfer span."
425 ))
426 })
427}
428
429fn read_bytes(
430 memory: &Memory,
431 program: &Program,
432 source: &str,
433 start: usize,
434 byte_count: usize,
435) -> Result<Vec<u8>, Error> {
436 let buffer = resolve_buffer(memory, program, source)?;
437 let bytes = buffer
438 .bytes
439 .read()
440 .unwrap_or_else(|error| error.into_inner());
441 let mut payload = vec![0; byte_count];
442 if start < bytes.len() {
443 let available = (bytes.len() - start).min(byte_count);
444 payload[..available].copy_from_slice(&bytes[start..start + available]);
445 }
446 Ok(payload)
447}
448
449fn ensure_writable_buffer(memory: &mut Memory, program: &Program, name: &str) -> Result<(), Error> {
450 eval_expr::buffer_mut(memory, program, name).map(|_| ())
451}
452
453fn apply_async_transfer(
454 transfer: AsyncTransfer,
455 memory: &mut Memory,
456 program: &Program,
457) -> Result<(), Error> {
458 match transfer {
459 AsyncTransfer::Copy {
460 destination,
461 start,
462 payload,
463 } => {
464 let buffer = eval_expr::buffer_mut(memory, program, &destination)?;
465 let mut bytes = buffer
466 .bytes
467 .write()
468 .unwrap_or_else(|error| error.into_inner());
469 if start >= bytes.len() {
470 return Ok(());
471 }
472 let write_len = payload.len().min(bytes.len() - start);
473 bytes[start..start + write_len].copy_from_slice(&payload[..write_len]);
474 Ok(())
475 }
476 }
477}
478
479fn resolve_buffer<'a>(
480 memory: &'a Memory,
481 program: &Program,
482 name: &str,
483) -> Result<&'a oob::Buffer, Error> {
484 let decl = program.buffer(name).ok_or_else(|| {
485 Error::interp(format!(
486 "missing buffer declaration `{name}`. Fix: declare every async transfer buffer."
487 ))
488 })?;
489 if decl.access() == vyre::ir::BufferAccess::Workgroup {
490 memory.workgroup.get(name)
491 } else {
492 memory.storage.get(name)
493 }
494 .ok_or_else(|| {
495 Error::interp(format!(
496 "missing buffer `{name}`. Fix: initialize every declared async transfer buffer."
497 ))
498 })
499}
500
501fn eval_if<'a>(
502 cond: &Expr,
503 then: &'a [Node],
504 otherwise: &'a [Node],
505 node: &Node,
506 invocation: &mut Invocation<'a>,
507 memory: &mut Memory,
508 program: &Program,
509) -> Result<(), vyre::Error> {
510 let cond_value = eval_expr::eval(cond, invocation, memory, program)?.truthy();
511 if contains_barrier(then) || contains_barrier(otherwise) {
512 invocation.uniform_checks.push((node_id(node), cond_value));
513 }
514 let branch = if cond_value { then } else { otherwise };
515 invocation.push_scope();
516 invocation.frames_mut().push(Frame::Nodes {
517 nodes: branch,
518 index: 0,
519 scoped: true,
520 });
521 Ok(())
522}
523
524fn eval_loop<'a>(
525 var: &'a str,
526 from: &Expr,
527 to: &Expr,
528 body: &'a [Node],
529 invocation: &mut Invocation<'a>,
530 memory: &mut Memory,
531 program: &Program,
532) -> Result<(), vyre::Error> {
533 let from_value = eval_expr::eval(from, invocation, memory, program)?;
534 let to_value = eval_expr::eval(to, invocation, memory, program)?;
535 let from = from_value.try_as_u32().ok_or_else(|| {
536 Error::interp(format!(
537 "loop lower bound {from_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
538 ))
539 })?;
540 let to = to_value.try_as_u32().ok_or_else(|| Error::interp(format!(
541 "loop upper bound {to_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
542 )))?;
543 invocation.frames_mut().push(Frame::Loop {
544 var,
545 next: from,
546 to,
547 body,
548 });
549 Ok(())
550}
551
552fn eval_return(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
553 invocation.frames_mut().clear();
554 invocation.returned = true;
555 Ok(())
556}
557
558fn eval_block<'a>(nodes: &'a [Node], invocation: &mut Invocation<'a>) -> Result<(), vyre::Error> {
559 invocation.push_scope();
560 invocation.frames_mut().push(Frame::Nodes {
561 nodes,
562 index: 0,
563 scoped: true,
564 });
565 Ok(())
566}
567
568fn eval_barrier(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
569 invocation.waiting_at_barrier = true;
570 Ok(())
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576 use crate::oob::Buffer;
577 use crate::workgroup::InvocationIds;
578 use vyre::ir::{BufferDecl, DataType};
579
580 fn run_program(program: &Program, memory: &mut Memory) -> Result<(), vyre::Error> {
581 let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
582 while !invocation.done() {
583 step(&mut invocation, memory, program)?;
584 }
585 Ok(())
586 }
587
588 fn bytes(memory: &Memory, name: &str) -> Vec<u8> {
589 memory
590 .storage
591 .get(name)
592 .expect("Fix: test buffer exists")
593 .bytes
594 .read()
595 .unwrap_or_else(|error| error.into_inner())
596 .clone()
597 }
598
599 #[test]
600 fn async_load_wait_copies_payload_into_destination() {
601 let program = Program::wrapped(
602 vec![
603 BufferDecl::read("src", 0, DataType::Bytes).with_count(8),
604 BufferDecl::output("dst", 1, DataType::Bytes).with_count(8),
605 ],
606 [1, 1, 1],
607 vec![
608 Node::async_load_ext("src", "dst", Expr::u32(2), Expr::u32(4), "copy"),
609 Node::AsyncWait { tag: "copy".into() },
610 ],
611 );
612 let mut memory = Memory::empty()
613 .with_storage(
614 "src",
615 Buffer::new(vec![10, 11, 12, 13, 14, 15, 16, 17], DataType::Bytes),
616 )
617 .with_storage("dst", Buffer::new(vec![0; 8], DataType::Bytes));
618
619 run_program(&program, &mut memory).unwrap();
620
621 assert_eq!(bytes(&memory, "dst"), vec![12, 13, 14, 15, 0, 0, 0, 0]);
622 }
623
624 #[test]
625 fn async_store_wait_copies_payload_at_destination_offset() {
626 let program = Program::wrapped(
627 vec![
628 BufferDecl::read("src", 0, DataType::Bytes).with_count(4),
629 BufferDecl::output("dst", 1, DataType::Bytes).with_count(8),
630 ],
631 [1, 1, 1],
632 vec![
633 Node::async_store("src", "dst", Expr::u32(3), Expr::u32(4), "store"),
634 Node::AsyncWait {
635 tag: "store".into(),
636 },
637 ],
638 );
639 let mut memory = Memory::empty()
640 .with_storage("src", Buffer::new(vec![21, 22, 23, 24], DataType::Bytes))
641 .with_storage("dst", Buffer::new(vec![0; 8], DataType::Bytes));
642
643 run_program(&program, &mut memory).unwrap();
644
645 assert_eq!(bytes(&memory, "dst"), vec![0, 0, 0, 21, 22, 23, 24, 0]);
646 }
647}