Skip to main content

vyre_reference/dual_impls/
common.rs

1//! Shared byte helpers for canonical primitive evaluators.
2
3use std::{error::Error, fmt};
4
5use crate::workgroup::Memory;
6use vyre_primitives::CombineOp;
7
8/// Error returned by canonical primitive reference evaluation.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct EvalError {
11    message: String,
12}
13
14impl EvalError {
15    /// Build an actionable evaluation error.
16    #[must_use]
17    pub fn new(message: impl Into<String>) -> Self {
18        let message = message.into();
19        debug_assert!(message.contains("Fix:"));
20        Self { message }
21    }
22}
23
24impl fmt::Display for EvalError {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        f.write_str(&self.message)
27    }
28}
29
30impl Error for EvalError {}
31
32/// CPU reference evaluator for one canonical primitive.
33pub trait ReferenceEvaluator {
34    /// Evaluate the primitive over byte-backed memory payloads.
35    ///
36    /// # Errors
37    ///
38    /// Returns [`EvalError`] when the input arity or payload format violates
39    /// the primitive contract.
40    fn evaluate(&self, inputs: &[Memory]) -> Result<Memory, EvalError>;
41}
42
43pub(crate) fn one_input(inputs: &[Memory], id: &str) -> Result<Vec<u8>, EvalError> {
44    if inputs.len() != 1 {
45        return Err(EvalError::new(format!(
46            "primitive `{id}` expected 1 input memory, got {}. Fix: pass exactly one byte payload.",
47            inputs.len()
48        )));
49    }
50    Ok(inputs[0].bytes())
51}
52
53pub(crate) fn two_inputs(inputs: &[Memory], id: &str) -> Result<(Vec<u8>, Vec<u8>), EvalError> {
54    if inputs.len() != 2 {
55        return Err(EvalError::new(format!(
56            "primitive `{id}` expected 2 input memories, got {}. Fix: pass left and right byte payloads.",
57            inputs.len()
58        )));
59    }
60    Ok((inputs[0].bytes(), inputs[1].bytes()))
61}
62
63pub(crate) fn read_u32(bytes: impl AsRef<[u8]>, id: &str) -> Result<u32, EvalError> {
64    let bytes = bytes.as_ref();
65    if bytes.len() != 4 {
66        return Err(EvalError::new(format!(
67            "primitive `{id}` expected a 4-byte u32 payload, got {} bytes. Fix: encode scalar inputs as little-endian u32.",
68            bytes.len()
69        )));
70    }
71    Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
72}
73
74pub(crate) fn u32_words(bytes: impl AsRef<[u8]>, id: &str) -> Result<Vec<u32>, EvalError> {
75    let bytes = bytes.as_ref();
76    if bytes.len() % 4 != 0 {
77        return Err(EvalError::new(format!(
78            "primitive `{id}` expected u32-aligned bytes, got {} bytes. Fix: encode every element as little-endian u32.",
79            bytes.len()
80        )));
81    }
82    Ok(bytes
83        .chunks_exact(4)
84        .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
85        .collect())
86}
87
88pub(crate) fn write_u32s(values: impl IntoIterator<Item = u32>) -> Memory {
89    let mut bytes = Vec::new();
90    for value in values {
91        bytes.extend_from_slice(&value.to_le_bytes());
92    }
93    Memory::from_bytes(bytes)
94}
95
96pub(crate) fn scalar(value: u32) -> Memory {
97    Memory::from_bytes(value.to_le_bytes().to_vec())
98}
99
100pub(crate) fn combine(op: CombineOp, left: u32, right: u32) -> Result<u32, EvalError> {
101    Ok(match op {
102        CombineOp::Add => left.wrapping_add(right),
103        CombineOp::Mul => left.wrapping_mul(right),
104        CombineOp::BitAnd => left & right,
105        CombineOp::BitOr => left | right,
106        CombineOp::BitXor => left ^ right,
107        CombineOp::Min => left.min(right),
108        CombineOp::Max => left.max(right),
109        _ => {
110            return Err(EvalError::new(format!(
111                "primitive combiner does not support CombineOp variant {op:?}. Fix: register a reference evaluator for the new combiner before dispatch."
112            )));
113        }
114    })
115}
116
117pub(crate) fn checked_index(index: u32, len: usize, id: &str) -> Result<usize, EvalError> {
118    let index = usize::try_from(index).map_err(|_| {
119        EvalError::new(format!(
120            "primitive `{id}` index does not fit usize. Fix: keep index regions within platform addressable bounds."
121        ))
122    })?;
123    if index >= len {
124        Err(EvalError::new(format!(
125            "primitive `{id}` index {index} is outside source length {len}. Fix: validate index regions before dispatch."
126        )))
127    } else {
128        Ok(index)
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn combine_known_variants_do_not_panic() {
138        let cases = [
139            (CombineOp::Add, 7, 5, 12),
140            (CombineOp::Mul, 7, 5, 35),
141            (CombineOp::BitAnd, 0b1100, 0b1010, 0b1000),
142            (CombineOp::BitOr, 0b1100, 0b1010, 0b1110),
143            (CombineOp::BitXor, 0b1100, 0b1010, 0b0110),
144            (CombineOp::Min, 7, 5, 5),
145            (CombineOp::Max, 7, 5, 7),
146        ];
147
148        for (op, left, right, expected) in cases {
149            assert_eq!(combine(op, left, right), Ok(expected), "{op:?}");
150        }
151    }
152}