vyre_reference/dual_impls/
common.rs1use std::{error::Error, fmt};
4
5use crate::workgroup::Memory;
6use vyre_primitives::CombineOp;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct EvalError {
11 message: String,
12}
13
14impl EvalError {
15 #[must_use]
17 pub fn new(message: impl Into<String>) -> Self {
18 let message = message.into();
19 debug_assert!(message.contains("Fix:"));
20 Self { message }
21 }
22}
23
24impl fmt::Display for EvalError {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 f.write_str(&self.message)
27 }
28}
29
30impl Error for EvalError {}
31
32pub trait ReferenceEvaluator {
34 fn evaluate(&self, inputs: &[Memory]) -> Result<Memory, EvalError>;
41}
42
43pub(crate) fn one_input(inputs: &[Memory], id: &str) -> Result<Vec<u8>, EvalError> {
44 if inputs.len() != 1 {
45 return Err(EvalError::new(format!(
46 "primitive `{id}` expected 1 input memory, got {}. Fix: pass exactly one byte payload.",
47 inputs.len()
48 )));
49 }
50 Ok(inputs[0].bytes())
51}
52
53pub(crate) fn two_inputs(inputs: &[Memory], id: &str) -> Result<(Vec<u8>, Vec<u8>), EvalError> {
54 if inputs.len() != 2 {
55 return Err(EvalError::new(format!(
56 "primitive `{id}` expected 2 input memories, got {}. Fix: pass left and right byte payloads.",
57 inputs.len()
58 )));
59 }
60 Ok((inputs[0].bytes(), inputs[1].bytes()))
61}
62
63pub(crate) fn read_u32(bytes: impl AsRef<[u8]>, id: &str) -> Result<u32, EvalError> {
64 let bytes = bytes.as_ref();
65 if bytes.len() != 4 {
66 return Err(EvalError::new(format!(
67 "primitive `{id}` expected a 4-byte u32 payload, got {} bytes. Fix: encode scalar inputs as little-endian u32.",
68 bytes.len()
69 )));
70 }
71 Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
72}
73
74pub(crate) fn u32_words(bytes: impl AsRef<[u8]>, id: &str) -> Result<Vec<u32>, EvalError> {
75 let bytes = bytes.as_ref();
76 if bytes.len() % 4 != 0 {
77 return Err(EvalError::new(format!(
78 "primitive `{id}` expected u32-aligned bytes, got {} bytes. Fix: encode every element as little-endian u32.",
79 bytes.len()
80 )));
81 }
82 Ok(bytes
83 .chunks_exact(4)
84 .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
85 .collect())
86}
87
88pub(crate) fn write_u32s(values: impl IntoIterator<Item = u32>) -> Memory {
89 let mut bytes = Vec::new();
90 for value in values {
91 bytes.extend_from_slice(&value.to_le_bytes());
92 }
93 Memory::from_bytes(bytes)
94}
95
96pub(crate) fn scalar(value: u32) -> Memory {
97 Memory::from_bytes(value.to_le_bytes().to_vec())
98}
99
100pub(crate) fn unary_u32_scalar(
101 inputs: &[Memory],
102 id: &str,
103 op: impl FnOnce(u32) -> u32,
104) -> Result<Memory, EvalError> {
105 let input = one_input(inputs, id)?;
106 Ok(scalar(op(read_u32(input, id)?)))
107}
108
109pub(crate) fn binary_u32_scalar(
110 inputs: &[Memory],
111 id: &str,
112 op: impl FnOnce(u32, u32) -> u32,
113) -> Result<Memory, EvalError> {
114 let (left, right) = two_inputs(inputs, id)?;
115 Ok(scalar(op(read_u32(left, id)?, read_u32(right, id)?)))
116}
117
118pub(crate) fn binary_u32_predicate(
119 inputs: &[Memory],
120 id: &str,
121 op: impl FnOnce(u32, u32) -> bool,
122) -> Result<Memory, EvalError> {
123 binary_u32_scalar(inputs, id, |left, right| u32::from(op(left, right)))
124}
125
126pub(crate) fn combine(op: CombineOp, left: u32, right: u32) -> Result<u32, EvalError> {
127 Ok(match op {
128 CombineOp::Add => left.wrapping_add(right),
129 CombineOp::Mul => left.wrapping_mul(right),
130 CombineOp::BitAnd => left & right,
131 CombineOp::BitOr => left | right,
132 CombineOp::BitXor => left ^ right,
133 CombineOp::Min => left.min(right),
134 CombineOp::Max => left.max(right),
135 _ => {
136 return Err(EvalError::new(format!(
137 "primitive combiner does not support CombineOp variant {op:?}. Fix: register a reference evaluator for the new combiner before dispatch."
138 )));
139 }
140 })
141}
142
143pub(crate) fn checked_index(index: u32, len: usize, id: &str) -> Result<usize, EvalError> {
144 let index = usize::try_from(index).map_err(|_| {
145 EvalError::new(format!(
146 "primitive `{id}` index does not fit usize. Fix: keep index regions within platform addressable bounds."
147 ))
148 })?;
149 if index >= len {
150 Err(EvalError::new(format!(
151 "primitive `{id}` index {index} is outside source length {len}. Fix: validate index regions before dispatch."
152 )))
153 } else {
154 Ok(index)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn combine_known_variants_do_not_panic() {
164 let cases = [
165 (CombineOp::Add, 7, 5, 12),
166 (CombineOp::Mul, 7, 5, 35),
167 (CombineOp::BitAnd, 0b1100, 0b1010, 0b1000),
168 (CombineOp::BitOr, 0b1100, 0b1010, 0b1110),
169 (CombineOp::BitXor, 0b1100, 0b1010, 0b0110),
170 (CombineOp::Min, 7, 5, 5),
171 (CombineOp::Max, 7, 5, 7),
172 ];
173
174 for (op, left, right, expected) in cases {
175 assert_eq!(combine(op, left, right), Ok(expected), "{op:?}");
176 }
177 }
178
179 #[test]
180 fn scalar_helpers_preserve_contract_checks() {
181 let left = Memory::from_bytes(7u32.to_le_bytes().to_vec());
182 let right = Memory::from_bytes(5u32.to_le_bytes().to_vec());
183 let malformed = Memory::from_bytes(vec![1, 2, 3]);
184
185 assert_eq!(
186 binary_u32_scalar(
187 &[left.clone(), right.clone()],
188 "test_add",
189 u32::wrapping_add
190 )
191 .expect("Fix: valid binary scalar inputs must evaluate")
192 .bytes(),
193 12u32.to_le_bytes().to_vec()
194 );
195 assert_eq!(
196 binary_u32_predicate(&[left.clone(), right.clone()], "test_gt", |a, b| a > b)
197 .expect("Fix: valid binary predicate inputs must evaluate")
198 .bytes(),
199 1u32.to_le_bytes().to_vec()
200 );
201 assert_eq!(
202 unary_u32_scalar(std::slice::from_ref(&left), "test_not", |value| !value)
203 .expect("Fix: valid unary scalar input must evaluate")
204 .bytes(),
205 (!7u32).to_le_bytes().to_vec()
206 );
207
208 assert!(
209 binary_u32_scalar(std::slice::from_ref(&left), "test_add", u32::wrapping_add).is_err()
210 );
211 assert!(unary_u32_scalar(&[malformed], "test_not", |value| !value).is_err());
212 }
213}