Skip to main content

vyre_reference/dual_impls/
common.rs

1//! Shared byte helpers for canonical primitive evaluators.
2
3use std::{error::Error, fmt};
4
5use crate::workgroup::Memory;
6use vyre_primitives::CombineOp;
7
8/// Error returned by canonical primitive reference evaluation.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct EvalError {
11    message: String,
12}
13
14impl EvalError {
15    /// Build an actionable evaluation error.
16    #[must_use]
17    pub fn new(message: impl Into<String>) -> Self {
18        let message = message.into();
19        debug_assert!(message.contains("Fix:"));
20        Self { message }
21    }
22}
23
24impl fmt::Display for EvalError {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        f.write_str(&self.message)
27    }
28}
29
30impl Error for EvalError {}
31
32/// CPU reference evaluator for one canonical primitive.
33pub trait ReferenceEvaluator {
34    /// Evaluate the primitive over byte-backed memory payloads.
35    ///
36    /// # Errors
37    ///
38    /// Returns [`EvalError`] when the input arity or payload format violates
39    /// the primitive contract.
40    fn evaluate(&self, inputs: &[Memory]) -> Result<Memory, EvalError>;
41}
42
43pub(crate) fn one_input(inputs: &[Memory], id: &str) -> Result<Vec<u8>, EvalError> {
44    if inputs.len() != 1 {
45        return Err(EvalError::new(format!(
46            "primitive `{id}` expected 1 input memory, got {}. Fix: pass exactly one byte payload.",
47            inputs.len()
48        )));
49    }
50    Ok(inputs[0].bytes())
51}
52
53pub(crate) fn two_inputs(inputs: &[Memory], id: &str) -> Result<(Vec<u8>, Vec<u8>), EvalError> {
54    if inputs.len() != 2 {
55        return Err(EvalError::new(format!(
56            "primitive `{id}` expected 2 input memories, got {}. Fix: pass left and right byte payloads.",
57            inputs.len()
58        )));
59    }
60    Ok((inputs[0].bytes(), inputs[1].bytes()))
61}
62
63pub(crate) fn read_u32(bytes: impl AsRef<[u8]>, id: &str) -> Result<u32, EvalError> {
64    let bytes = bytes.as_ref();
65    if bytes.len() != 4 {
66        return Err(EvalError::new(format!(
67            "primitive `{id}` expected a 4-byte u32 payload, got {} bytes. Fix: encode scalar inputs as little-endian u32.",
68            bytes.len()
69        )));
70    }
71    Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
72}
73
74pub(crate) fn u32_words(bytes: impl AsRef<[u8]>, id: &str) -> Result<Vec<u32>, EvalError> {
75    let bytes = bytes.as_ref();
76    if bytes.len() % 4 != 0 {
77        return Err(EvalError::new(format!(
78            "primitive `{id}` expected u32-aligned bytes, got {} bytes. Fix: encode every element as little-endian u32.",
79            bytes.len()
80        )));
81    }
82    Ok(bytes
83        .chunks_exact(4)
84        .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
85        .collect())
86}
87
88pub(crate) fn write_u32s(values: impl IntoIterator<Item = u32>) -> Memory {
89    let mut bytes = Vec::new();
90    for value in values {
91        bytes.extend_from_slice(&value.to_le_bytes());
92    }
93    Memory::from_bytes(bytes)
94}
95
96pub(crate) fn scalar(value: u32) -> Memory {
97    Memory::from_bytes(value.to_le_bytes().to_vec())
98}
99
100pub(crate) fn unary_u32_scalar(
101    inputs: &[Memory],
102    id: &str,
103    op: impl FnOnce(u32) -> u32,
104) -> Result<Memory, EvalError> {
105    let input = one_input(inputs, id)?;
106    Ok(scalar(op(read_u32(input, id)?)))
107}
108
109pub(crate) fn binary_u32_scalar(
110    inputs: &[Memory],
111    id: &str,
112    op: impl FnOnce(u32, u32) -> u32,
113) -> Result<Memory, EvalError> {
114    let (left, right) = two_inputs(inputs, id)?;
115    Ok(scalar(op(read_u32(left, id)?, read_u32(right, id)?)))
116}
117
118pub(crate) fn binary_u32_predicate(
119    inputs: &[Memory],
120    id: &str,
121    op: impl FnOnce(u32, u32) -> bool,
122) -> Result<Memory, EvalError> {
123    binary_u32_scalar(inputs, id, |left, right| u32::from(op(left, right)))
124}
125
126pub(crate) fn combine(op: CombineOp, left: u32, right: u32) -> Result<u32, EvalError> {
127    Ok(match op {
128        CombineOp::Add => left.wrapping_add(right),
129        CombineOp::Mul => left.wrapping_mul(right),
130        CombineOp::BitAnd => left & right,
131        CombineOp::BitOr => left | right,
132        CombineOp::BitXor => left ^ right,
133        CombineOp::Min => left.min(right),
134        CombineOp::Max => left.max(right),
135        _ => {
136            return Err(EvalError::new(format!(
137                "primitive combiner does not support CombineOp variant {op:?}. Fix: register a reference evaluator for the new combiner before dispatch."
138            )));
139        }
140    })
141}
142
143pub(crate) fn checked_index(index: u32, len: usize, id: &str) -> Result<usize, EvalError> {
144    let index = usize::try_from(index).map_err(|_| {
145        EvalError::new(format!(
146            "primitive `{id}` index does not fit usize. Fix: keep index regions within platform addressable bounds."
147        ))
148    })?;
149    if index >= len {
150        Err(EvalError::new(format!(
151            "primitive `{id}` index {index} is outside source length {len}. Fix: validate index regions before dispatch."
152        )))
153    } else {
154        Ok(index)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn combine_known_variants_do_not_panic() {
164        let cases = [
165            (CombineOp::Add, 7, 5, 12),
166            (CombineOp::Mul, 7, 5, 35),
167            (CombineOp::BitAnd, 0b1100, 0b1010, 0b1000),
168            (CombineOp::BitOr, 0b1100, 0b1010, 0b1110),
169            (CombineOp::BitXor, 0b1100, 0b1010, 0b0110),
170            (CombineOp::Min, 7, 5, 5),
171            (CombineOp::Max, 7, 5, 7),
172        ];
173
174        for (op, left, right, expected) in cases {
175            assert_eq!(combine(op, left, right), Ok(expected), "{op:?}");
176        }
177    }
178
179    #[test]
180    fn scalar_helpers_preserve_contract_checks() {
181        let left = Memory::from_bytes(7u32.to_le_bytes().to_vec());
182        let right = Memory::from_bytes(5u32.to_le_bytes().to_vec());
183        let malformed = Memory::from_bytes(vec![1, 2, 3]);
184
185        assert_eq!(
186            binary_u32_scalar(
187                &[left.clone(), right.clone()],
188                "test_add",
189                u32::wrapping_add
190            )
191            .expect("Fix: valid binary scalar inputs must evaluate")
192            .bytes(),
193            12u32.to_le_bytes().to_vec()
194        );
195        assert_eq!(
196            binary_u32_predicate(&[left.clone(), right.clone()], "test_gt", |a, b| a > b)
197                .expect("Fix: valid binary predicate inputs must evaluate")
198                .bytes(),
199            1u32.to_le_bytes().to_vec()
200        );
201        assert_eq!(
202            unary_u32_scalar(std::slice::from_ref(&left), "test_not", |value| !value)
203                .expect("Fix: valid unary scalar input must evaluate")
204                .bytes(),
205            (!7u32).to_le_bytes().to_vec()
206        );
207
208        assert!(
209            binary_u32_scalar(std::slice::from_ref(&left), "test_add", u32::wrapping_add).is_err()
210        );
211        assert!(unary_u32_scalar(&[malformed], "test_not", |value| !value).is_err());
212    }
213}