1use vyre::ir::{AtomicOp, BinOp, BufferAccess, BufferDecl, DataType, Expr, Program, UnOp};
10
11use vyre::Error;
12
13use crate::{atomics, oob, value::Value, workgroup::Invocation, workgroup::Memory};
14
15pub use crate::oob::Buffer;
17
18const MAX_CALL_INPUT_BYTES: usize = 64 * 1024 * 1024;
19
20pub fn eval(
27 expr: &Expr,
28 invocation: &mut Invocation<'_>,
29 memory: &mut Memory,
30 program: &Program,
31) -> Result<Value, vyre::Error> {
32 match expr {
33 Expr::LitU32(value) => eval_lit_u32(*value),
34 Expr::LitI32(value) => eval_lit_i32(*value),
35 Expr::LitBool(value) => eval_lit_bool(*value),
36 Expr::Var(name) => eval_var(name, invocation),
37 Expr::Load { buffer, index } => eval_load(buffer, index, invocation, memory, program),
38 Expr::BufLen { buffer } => eval_buf_len(buffer, memory, program),
39 Expr::InvocationId { axis } => eval_invocation_id(*axis, invocation),
40 Expr::WorkgroupId { axis } => eval_workgroup_id(*axis, invocation),
41 Expr::LocalId { axis } => eval_local_id(*axis, invocation),
42 Expr::BinOp { op, left, right } => {
43 eval_binop(op.clone(), left, right, invocation, memory, program)
44 }
45 Expr::UnOp { op, operand } => eval_unop(op.clone(), operand, invocation, memory, program),
46 Expr::Call { op_id, args } => eval_call(op_id, args, invocation, memory, program),
47 Expr::Select {
48 cond,
49 true_val,
50 false_val,
51 } => eval_select(cond, true_val, false_val, invocation, memory, program),
52 Expr::Cast { target, value } => {
53 eval_cast(target.clone(), value, invocation, memory, program)
54 }
55 Expr::Atomic {
56 op,
57 buffer,
58 index,
59 expected,
60 value,
61 } => eval_atomic(
62 op.clone(),
63 buffer,
64 index,
65 expected.as_deref(),
66 value,
67 invocation,
68 memory,
69 program,
70 ),
71 _ => Err(Error::interp(format!(
72 "unsupported IR `unknown Expr variant: {expr:?}`. Fix: update vyre-reference for the new vyre::ir variant."
73 ))),
74 }
75}
76
77pub fn buffer_mut<'a>(
84 memory: &'a mut Memory,
85 program: &Program,
86 name: &str,
87) -> Result<&'a mut Buffer, vyre::Error> {
88 let decl = buffer_decl(program, name)?;
89 match decl.access() {
90 BufferAccess::ReadWrite | BufferAccess::Workgroup => resolve_buffer_mut(memory, decl),
91 BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
92 "store target `{name}` is not writable. Fix: declare it ReadWrite or Workgroup."
93 ))),
94 _ => Err(Error::interp(format!(
95 "store target `{name}` uses an unsupported access mode. Fix: upgrade vyre-conform or use a supported BufferAccess."
96 ))),
97 }
98}
99
100fn eval_lit_u32(value: u32) -> Result<Value, vyre::Error> {
101 Ok(Value::U32(value))
102}
103
104fn eval_lit_i32(value: i32) -> Result<Value, vyre::Error> {
105 Ok(Value::I32(value))
106}
107
108fn eval_lit_bool(value: bool) -> Result<Value, vyre::Error> {
109 Ok(Value::Bool(value))
110}
111
112fn eval_var(name: &str, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
113 invocation.locals.get(name).cloned().ok_or_else(|| {
114 Error::interp(format!(
115 "reference to undeclared variable `{name}`. Fix: add a Let before this use."
116 ))
117 })
118}
119
120fn eval_load(
121 buffer: &str,
122 index: &Expr,
123 invocation: &mut Invocation<'_>,
124 memory: &mut Memory,
125 program: &Program,
126) -> Result<Value, vyre::Error> {
127 let idx = eval_to_index(index, "load index", invocation, memory, program)?;
128 Ok(oob::load(resolve_buffer(memory, program, buffer)?, idx))
129}
130
131fn eval_buf_len(buffer: &str, memory: &Memory, program: &Program) -> Result<Value, vyre::Error> {
132 Ok(Value::U32(resolve_buffer(memory, program, buffer)?.len()))
133}
134
135fn eval_invocation_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
136 axis_value(invocation.ids.global, axis)
137}
138
139fn eval_workgroup_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
140 axis_value(invocation.ids.workgroup, axis)
141}
142
143fn eval_local_id(axis: u8, invocation: &Invocation<'_>) -> Result<Value, vyre::Error> {
144 axis_value(invocation.ids.local, axis)
145}
146
147fn eval_binop(
148 op: BinOp,
149 left: &Expr,
150 right: &Expr,
151 invocation: &mut Invocation<'_>,
152 memory: &mut Memory,
153 program: &Program,
154) -> Result<Value, vyre::Error> {
155 let left = eval(left, invocation, memory, program)?;
156 let right = eval(right, invocation, memory, program)?;
157 super::typed_ops::eval_binop(op, left, right)
158}
159
160fn eval_unop(
161 op: UnOp,
162 operand: &Expr,
163 invocation: &mut Invocation<'_>,
164 memory: &mut Memory,
165 program: &Program,
166) -> Result<Value, vyre::Error> {
167 let operand = eval(operand, invocation, memory, program)?;
168 super::typed_ops::eval_unop(op, operand)
169}
170
171fn eval_call(
172 op_id: &str,
173 args: &[Expr],
174 invocation: &mut Invocation<'_>,
175 memory: &mut Memory,
176 program: &Program,
177) -> Result<Value, vyre::Error> {
178 let spec = vyre::ops::registry::lookup(op_id).ok_or_else(|| Error::interp(format!(
179 "unsupported call `{op_id}`. Fix: register the op in core::ops::registry or inline the callee as IR."
180 )))?;
181 let expected = spec.inputs().len();
182 if args.len() != expected {
183 return Err(Error::interp(format!(
184 "call `{op_id}` received {} arguments but the primitive signature requires {expected}. Fix: pass exactly {expected} arguments.",
185 args.len()
186 )));
187 }
188 let mut input = Vec::new();
189 for (arg, declared_type) in args.iter().zip(spec.inputs()) {
190 let declared_width = declared_type.min_bytes();
191 let bytes = eval(arg, invocation, memory, program)?.to_bytes_width(declared_width);
192 let next_len = input
193 .len()
194 .checked_add(bytes.len())
195 .ok_or_else(|| Error::interp(format!(
196 "call `{op_id}` input byte size overflows usize. Fix: reduce the argument count or byte payload size."
197 )))?;
198 if next_len > MAX_CALL_INPUT_BYTES {
199 return Err(Error::interp(format!(
200 "call `{op_id}` requires {next_len} input bytes, exceeding the {MAX_CALL_INPUT_BYTES}-byte reference budget. Fix: reduce call input size."
201 )));
202 }
203 input.extend_from_slice(&bytes);
204 }
205 let mut output = Vec::new();
206 match spec.compose() {
207 vyre::ops::Compose::Composition(build) => {
208 crate::flat_cpu::run_flat(&build().with_entry_op_id(spec.id()), &input, &mut output)?;
209 }
210 vyre::ops::Compose::Intrinsic(intrinsic) => {
211 intrinsic.cpu_fn()(&input, &mut output);
212 }
213 other => {
214 return Err(Error::interp(format!(
215 "Fix: vyre-reference does not yet implement compose-kind `{other:?}` for op `{}`. Either implement the CPU path for this compose variant in vyre-reference/src/eval_expr.rs, or route the caller through a different op.",
216 spec.id()
217 )));
218 }
219 }
220 Ok(spec_output_value(
221 spec.outputs().first().cloned().unwrap_or(DataType::Bytes),
222 &output,
223 ))
224}
225
226fn eval_select(
227 cond: &Expr,
228 true_val: &Expr,
229 false_val: &Expr,
230 invocation: &mut Invocation<'_>,
231 memory: &mut Memory,
232 program: &Program,
233) -> Result<Value, vyre::Error> {
234 let cond = eval(cond, invocation, memory, program)?.truthy();
235 let true_val = eval(true_val, invocation, memory, program)?;
236 let false_val = eval(false_val, invocation, memory, program)?;
237 Ok(if cond { true_val } else { false_val })
238}
239
240fn eval_cast(
241 target: DataType,
242 value: &Expr,
243 invocation: &mut Invocation<'_>,
244 memory: &mut Memory,
245 program: &Program,
246) -> Result<Value, vyre::Error> {
247 let value = eval(value, invocation, memory, program)?;
248 cast_value(target, &value)
249}
250
251fn eval_atomic(
252 op: AtomicOp,
253 buffer: &str,
254 index: &Expr,
255 expected: Option<&Expr>,
256 value: &Expr,
257 invocation: &mut Invocation<'_>,
258 memory: &mut Memory,
259 program: &Program,
260) -> Result<Value, vyre::Error> {
261 match (op.clone(), expected) {
262 (AtomicOp::CompareExchange, None) => {
263 return Err(Error::interp(
264 "compare-exchange atomic is missing expected value. Fix: set Expr::Atomic.expected for AtomicOp::CompareExchange.",
265 ));
266 }
267 (AtomicOp::CompareExchange, Some(_)) => {}
268 (_, Some(_)) => {
269 return Err(Error::interp(
270 "non-compare-exchange atomic includes an expected value. Fix: use Expr::Atomic.expected only with AtomicOp::CompareExchange.",
271 ));
272 }
273 (_, None) => {}
274 }
275 let idx = eval_to_index(index, "atomic index", invocation, memory, program)?;
276 let expected = expected
277 .map(|expr| {
278 eval(expr, invocation, memory, program)?.try_as_u32().ok_or_else(|| {
279 Error::interp(format!(
280 "atomic expected value {expr:?} cannot be represented as u32. Fix: use a scalar u32-compatible argument."
281 ))
282 })
283 })
284 .transpose()?;
285 let value = eval(value, invocation, memory, program)?;
286 let value = value.try_as_u32().ok_or_else(|| {
287 Error::interp(
288 "atomic value cannot be represented as u32. Fix: use a scalar u32-compatible argument.",
289 )
290 })?;
291 let target = atomic_buffer_mut(memory, program, buffer)?;
292 let Some(old) = oob::atomic_load(target, idx) else {
293 return Ok(Value::U32(0));
294 };
295 let (old, new) = atomics::apply(op, old, expected, value)?;
296 oob::atomic_store(target, idx, new);
297 Ok(Value::U32(old))
298}
299
300fn eval_to_index(
301 index: &Expr,
302 context: &'static str,
303 invocation: &mut Invocation<'_>,
304 memory: &mut Memory,
305 program: &Program,
306) -> Result<u32, vyre::Error> {
307 let value = eval(index, invocation, memory, program)?;
308 value
309 .try_as_u32()
310 .ok_or_else(|| Error::interp(format!(
311 "{context} {value:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32.",
312 )))
313}
314
315fn resolve_buffer<'a>(
316 memory: &'a Memory,
317 program: &Program,
318 name: &str,
319) -> Result<&'a oob::Buffer, vyre::Error> {
320 let decl = buffer_decl(program, name)?;
321 if decl.access() == BufferAccess::Workgroup {
322 memory.workgroup.get(name)
323 } else {
324 memory.storage.get(name)
325 }
326 .ok_or_else(|| {
327 Error::interp(format!(
328 "missing buffer `{name}`. Fix: initialize all declared buffers."
329 ))
330 })
331}
332
333fn resolve_buffer_mut<'a>(
334 memory: &'a mut Memory,
335 decl: &BufferDecl,
336) -> Result<&'a mut oob::Buffer, vyre::Error> {
337 let name = decl.name();
338 if decl.access() == BufferAccess::Workgroup {
339 memory.workgroup.get_mut(name)
340 } else {
341 memory.storage.get_mut(name)
342 }
343 .ok_or_else(|| {
344 Error::interp(format!(
345 "missing buffer `{name}`. Fix: initialize all declared buffers."
346 ))
347 })
348}
349
350fn atomic_buffer_mut<'a>(
351 memory: &'a mut Memory,
352 program: &Program,
353 name: &str,
354) -> Result<&'a mut oob::Buffer, vyre::Error> {
355 let decl = buffer_decl(program, name)?;
356 match decl.access() {
357 BufferAccess::ReadWrite => resolve_buffer_mut(memory, decl),
358 BufferAccess::Workgroup => Err(Error::interp(format!(
359 "atomic target `{name}` is workgroup memory. Fix: atomics only support ReadWrite storage buffers."
360 ))),
361 BufferAccess::ReadOnly | BufferAccess::Uniform => Err(Error::interp(format!(
362 "atomic target `{name}` is not writable. Fix: atomics only support ReadWrite storage buffers."
363 ))),
364 _ => Err(Error::interp(format!(
365 "atomic target `{name}` uses an unsupported access mode. Fix: upgrade vyre-conform or use a supported BufferAccess."
366 ))),
367 }
368}
369
370fn buffer_decl<'a>(program: &'a Program, name: &str) -> Result<&'a BufferDecl, vyre::Error> {
371 program.buffer(name).ok_or_else(|| {
372 Error::interp(format!(
373 "unknown buffer `{name}`. Fix: declare it in Program::buffers."
374 ))
375 })
376}
377
378fn axis_value(values: [u32; 3], axis: u8) -> Result<Value, vyre::Error> {
379 values
380 .get(axis as usize)
381 .copied()
382 .map(Value::U32)
383 .ok_or_else(|| {
384 Error::interp(format!(
385 "invocation/workgroup ID axis {axis} out of range. Fix: use 0, 1, or 2."
386 ))
387 })
388}
389
390fn spec_output_value(ty: DataType, bytes: &[u8]) -> Value {
391 match ty {
392 DataType::U32 => Value::U32(read_u32_prefix(bytes)),
393 DataType::I32 => Value::I32(read_u32_prefix(bytes) as i32),
394 DataType::Bool => Value::Bool(read_u32_prefix(bytes) != 0),
395 DataType::U64 => Value::U64(read_u64_prefix(bytes)),
396 DataType::F32 => Value::Float(f32::from_bits(read_u32_prefix(bytes)) as f64),
397 DataType::Vec2U32 => Value::Bytes(read_fixed_prefix(bytes, 8)),
398 DataType::Vec4U32 => Value::Bytes(read_fixed_prefix(bytes, 16)),
399 DataType::Bytes => Value::Bytes(bytes.to_vec()),
400 _ => Value::Bytes(bytes.to_vec()),
401 }
402}
403
404fn read_fixed_prefix(bytes: &[u8], width: usize) -> Vec<u8> {
405 let mut fixed = vec![0u8; width];
406 let len = bytes.len().min(width);
407 fixed[..len].copy_from_slice(&bytes[..len]);
408 fixed
409}
410
411fn cast_value(target: DataType, value: &Value) -> Result<Value, vyre::Error> {
412 match target {
413 DataType::U32 => match value {
419 Value::I32(v) => Ok(Value::U32(*v as u32)),
420 _ => value
421 .try_as_u32()
422 .map(Value::U32)
423 .ok_or_else(|| invalid_cast(target, value)),
424 },
425 DataType::I32 => match value {
426 Value::I32(value) => Ok(Value::I32(*value)),
427 _ => value
428 .try_as_u32()
429 .map(|value| Value::I32(value as i32))
430 .ok_or_else(|| invalid_cast(target, value)),
431 },
432 DataType::U64 => value
436 .try_as_u64()
437 .map(Value::U64)
438 .ok_or_else(|| invalid_cast(target, value)),
439 DataType::Bool => Ok(Value::Bool(value.truthy())),
440 DataType::Bytes => Ok(Value::Bytes(value.to_bytes())),
441 DataType::Vec2U32 => Ok(Value::Bytes(widen_to_words(value, 2))),
446 DataType::Vec4U32 => Ok(Value::Bytes(widen_to_words(value, 4))),
447 _ => Ok(Value::Bytes(value.to_bytes())),
448 }
449}
450
451fn invalid_cast(target: DataType, value: &Value) -> Error {
452 Error::interp(format!(
453 "cast to {target:?} cannot represent {value:?} losslessly. Fix: cast from an in-range scalar value."
454 ))
455}
456
457fn widen_to_words(value: &Value, words: usize) -> Vec<u8> {
463 let target_bytes = words * 4;
464 let mut bytes = value.to_bytes();
465 if bytes.len() > target_bytes {
466 bytes.truncate(target_bytes);
467 } else if bytes.len() < target_bytes {
468 bytes.resize(target_bytes, 0);
469 }
470 bytes
471}
472
473use super::ops::{read_u32_prefix, read_u64_prefix};