vyre_reference/dual_impls/
common.rs1use std::{error::Error, fmt};
4
5use crate::workgroup::Memory;
6use vyre_primitives::CombineOp;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct EvalError {
11 message: String,
12}
13
14impl EvalError {
15 #[must_use]
17 pub fn new(message: impl Into<String>) -> Self {
18 let message = message.into();
19 debug_assert!(message.contains("Fix:"));
20 Self { message }
21 }
22}
23
24impl fmt::Display for EvalError {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 f.write_str(&self.message)
27 }
28}
29
30impl Error for EvalError {}
31
32pub trait ReferenceEvaluator {
34 fn evaluate(&self, inputs: &[Memory]) -> Result<Memory, EvalError>;
41}
42
43pub(crate) fn one_input(inputs: &[Memory], id: &str) -> Result<Vec<u8>, EvalError> {
44 if inputs.len() != 1 {
45 return Err(EvalError::new(format!(
46 "primitive `{id}` expected 1 input memory, got {}. Fix: pass exactly one byte payload.",
47 inputs.len()
48 )));
49 }
50 Ok(inputs[0].bytes())
51}
52
53pub(crate) fn two_inputs(inputs: &[Memory], id: &str) -> Result<(Vec<u8>, Vec<u8>), EvalError> {
54 if inputs.len() != 2 {
55 return Err(EvalError::new(format!(
56 "primitive `{id}` expected 2 input memories, got {}. Fix: pass left and right byte payloads.",
57 inputs.len()
58 )));
59 }
60 Ok((inputs[0].bytes(), inputs[1].bytes()))
61}
62
63pub(crate) fn read_u32(bytes: impl AsRef<[u8]>, id: &str) -> Result<u32, EvalError> {
64 let bytes = bytes.as_ref();
65 if bytes.len() != 4 {
66 return Err(EvalError::new(format!(
67 "primitive `{id}` expected a 4-byte u32 payload, got {} bytes. Fix: encode scalar inputs as little-endian u32.",
68 bytes.len()
69 )));
70 }
71 Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
72}
73
74pub(crate) fn u32_words(bytes: impl AsRef<[u8]>, id: &str) -> Result<Vec<u32>, EvalError> {
75 let bytes = bytes.as_ref();
76 if bytes.len() % 4 != 0 {
77 return Err(EvalError::new(format!(
78 "primitive `{id}` expected u32-aligned bytes, got {} bytes. Fix: encode every element as little-endian u32.",
79 bytes.len()
80 )));
81 }
82 Ok(bytes
83 .chunks_exact(4)
84 .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
85 .collect())
86}
87
88pub(crate) fn write_u32s(values: impl IntoIterator<Item = u32>) -> Memory {
89 let mut bytes = Vec::new();
90 for value in values {
91 bytes.extend_from_slice(&value.to_le_bytes());
92 }
93 Memory::from_bytes(bytes)
94}
95
96pub(crate) fn scalar(value: u32) -> Memory {
97 Memory::from_bytes(value.to_le_bytes().to_vec())
98}
99
100pub(crate) fn combine(op: CombineOp, left: u32, right: u32) -> Result<u32, EvalError> {
101 Ok(match op {
102 CombineOp::Add => left.wrapping_add(right),
103 CombineOp::Mul => left.wrapping_mul(right),
104 CombineOp::BitAnd => left & right,
105 CombineOp::BitOr => left | right,
106 CombineOp::BitXor => left ^ right,
107 CombineOp::Min => left.min(right),
108 CombineOp::Max => left.max(right),
109 _ => {
110 return Err(EvalError::new(format!(
111 "primitive combiner does not support CombineOp variant {op:?}. Fix: register a reference evaluator for the new combiner before dispatch."
112 )));
113 }
114 })
115}
116
117pub(crate) fn checked_index(index: u32, len: usize, id: &str) -> Result<usize, EvalError> {
118 let index = usize::try_from(index).map_err(|_| {
119 EvalError::new(format!(
120 "primitive `{id}` index does not fit usize. Fix: keep index regions within platform addressable bounds."
121 ))
122 })?;
123 if index >= len {
124 Err(EvalError::new(format!(
125 "primitive `{id}` index {index} is outside source length {len}. Fix: validate index regions before dispatch."
126 )))
127 } else {
128 Ok(index)
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn combine_known_variants_do_not_panic() {
138 let cases = [
139 (CombineOp::Add, 7, 5, 12),
140 (CombineOp::Mul, 7, 5, 35),
141 (CombineOp::BitAnd, 0b1100, 0b1010, 0b1000),
142 (CombineOp::BitOr, 0b1100, 0b1010, 0b1110),
143 (CombineOp::BitXor, 0b1100, 0b1010, 0b0110),
144 (CombineOp::Min, 7, 5, 5),
145 (CombineOp::Max, 7, 5, 7),
146 ];
147
148 for (op, left, right, expected) in cases {
149 assert_eq!(combine(op, left, right), Ok(expected), "{op:?}");
150 }
151 }
152}