Skip to main content

vyre_reference/
value.rs

1//! Runtime values accepted and returned by the core reference interpreter.
2
3use std::sync::Arc;
4
5/// A concrete value passed into or returned from the reference interpreter.
6#[non_exhaustive]
7#[derive(Debug, Clone)]
8pub enum Value {
9    /// Unsigned 32-bit integer.
10    U32(u32),
11    /// Signed 32-bit integer.
12    I32(i32),
13    /// Unsigned 64-bit integer.
14    U64(u64),
15    /// Boolean value.
16    Bool(bool),
17    /// Raw little-endian storage bytes.
18    Bytes(Arc<[u8]>),
19    /// Floating-point value represented with stable host bits.
20    Float(f64),
21    /// Fixed-size array of values.
22    Array(Vec<Value>),
23}
24
25impl PartialEq for Value {
26    fn eq(&self, other: &Self) -> bool {
27        match (self, other) {
28            (Self::U32(a), Self::U32(b)) => a == b,
29            (Self::I32(a), Self::I32(b)) => a == b,
30            (Self::U64(a), Self::U64(b)) => a == b,
31            (Self::Bool(a), Self::Bool(b)) => a == b,
32            (Self::Bytes(a), Self::Bytes(b)) => a == b,
33            (Self::Float(a), Self::Float(b)) => a.to_bits() == b.to_bits(),
34            (Self::Array(a), Self::Array(b)) => a == b,
35            _ => false,
36        }
37    }
38}
39
40impl Eq for Value {}
41
42impl Value {
43    /// Interpret the value using the IR truth convention.
44    #[must_use]
45    pub fn truthy(&self) -> bool {
46        match self {
47            Self::Array(values) => !values.is_empty(),
48            Self::Float(value) => *value != 0.0,
49            _ => self.try_as_u32().unwrap_or(1) != 0,
50        }
51    }
52
53    /// Return this value as little-endian bytes for buffer initialization.
54    #[must_use]
55    pub fn to_bytes(&self) -> Vec<u8> {
56        match self {
57            Self::U32(value) => value.to_le_bytes().to_vec(),
58            Self::I32(value) => value.to_le_bytes().to_vec(),
59            Self::U64(value) => value.to_le_bytes().to_vec(),
60            Self::Bool(value) => u32::from(*value).to_le_bytes().to_vec(),
61            Self::Bytes(bytes) => bytes.to_vec(),
62            Self::Float(value) => value.to_le_bytes().to_vec(),
63            Self::Array(values) => values.iter().flat_map(Self::to_bytes).collect(),
64        }
65    }
66
67    /// Return this value encoded at the declared input width.
68    #[must_use]
69    pub fn to_bytes_width(&self, declared_width: usize) -> Vec<u8> {
70        let mut bytes = self.to_bytes();
71        if declared_width == 0 {
72            return bytes;
73        }
74        bytes.resize(declared_width, 0);
75        bytes.truncate(declared_width);
76        bytes
77    }
78
79    /// Write this value into an existing fixed-width byte slot.
80    ///
81    /// For non-empty targets, this is equivalent to
82    /// `target.copy_from_slice(&self.to_bytes_width(target.len()))` but avoids
83    /// allocating a temporary vector on store-heavy reference paths. Empty
84    /// targets are a no-op because they cannot carry the variable-width
85    /// `to_bytes_width(0)` payload.
86    pub fn write_bytes_width_into(&self, target: &mut [u8]) {
87        target.fill(0);
88        let mut cursor = 0usize;
89        self.copy_raw_bytes_prefix(target, &mut cursor);
90    }
91
92    /// Append this value encoded at the declared input width without
93    /// allocating a temporary byte vector for the caller.
94    ///
95    /// # Errors
96    ///
97    /// Returns an error if the destination length would overflow.
98    pub fn extend_bytes_width(
99        &self,
100        declared_width: usize,
101        out: &mut Vec<u8>,
102    ) -> Result<(), vyre::Error> {
103        let start_len = out.len();
104        let fixed_next_len = if declared_width == 0 {
105            None
106        } else {
107            Some(start_len.checked_add(declared_width).ok_or_else(|| {
108                vyre::Error::interp(
109                    "encoded value byte size overflows usize. Fix: reduce the argument count or byte payload size.",
110                )
111            })?)
112        };
113        match self {
114            Self::U32(value) => extend_fixed_width(&value.to_le_bytes(), declared_width, out),
115            Self::I32(value) => extend_fixed_width(&value.to_le_bytes(), declared_width, out),
116            Self::U64(value) => extend_fixed_width(&value.to_le_bytes(), declared_width, out),
117            Self::Bool(value) => {
118                extend_fixed_width(&u32::from(*value).to_le_bytes(), declared_width, out);
119            }
120            Self::Bytes(bytes) => extend_fixed_width(bytes, declared_width, out),
121            Self::Float(value) => extend_fixed_width(&value.to_le_bytes(), declared_width, out),
122            Self::Array(values) => {
123                for value in values {
124                    value.extend_bytes_width(0, out)?;
125                }
126                if let Some(next_len) = fixed_next_len {
127                    out.truncate(start_len + declared_width.min(out.len() - start_len));
128                    out.resize(next_len, 0);
129                }
130            }
131        }
132        if let Some(next_len) = fixed_next_len {
133            debug_assert_eq!(out.len(), next_len);
134        }
135        Ok(())
136    }
137
138    fn copy_raw_bytes_prefix(&self, target: &mut [u8], cursor: &mut usize) {
139        match self {
140            Self::U32(value) => copy_bytes_prefix(&value.to_le_bytes(), target, cursor),
141            Self::I32(value) => copy_bytes_prefix(&value.to_le_bytes(), target, cursor),
142            Self::U64(value) => copy_bytes_prefix(&value.to_le_bytes(), target, cursor),
143            Self::Bool(value) => {
144                copy_bytes_prefix(&u32::from(*value).to_le_bytes(), target, cursor);
145            }
146            Self::Bytes(bytes) => copy_bytes_prefix(bytes, target, cursor),
147            Self::Float(value) => copy_bytes_prefix(&value.to_le_bytes(), target, cursor),
148            Self::Array(values) => {
149                for value in values {
150                    if *cursor >= target.len() {
151                        break;
152                    }
153                    value.copy_raw_bytes_prefix(target, cursor);
154                }
155            }
156        }
157    }
158
159    /// Try to interpret the value as the IR's scalar `u32` word.
160    #[must_use]
161    pub fn try_as_u32(&self) -> Option<u32> {
162        match self {
163            Self::U32(value) => Some(*value),
164            Self::I32(value) => u32::try_from(*value).ok(),
165            Self::U64(value) => u32::try_from(*value).ok(),
166            Self::Bool(value) => Some(u32::from(*value)),
167            Self::Bytes(bytes) => (bytes.len() <= 4).then(|| read_u32_prefix(bytes)),
168            Self::Float(value) => f64_to_u32(*value),
169            Self::Array(_) => None,
170        }
171    }
172
173    /// Interpret the value as the IR's scalar `u32` word.
174    #[must_use]
175    pub fn as_u32(&self) -> u32 {
176        self.try_as_u32().unwrap_or(0)
177    }
178
179    /// Try to interpret the value as a full `u64`.
180    #[must_use]
181    pub fn try_as_u64(&self) -> Option<u64> {
182        match self {
183            Self::U32(value) => Some(u64::from(*value)),
184            Self::I32(value) => u64::try_from(*value).ok(),
185            Self::U64(value) => Some(*value),
186            Self::Bool(value) => Some(u64::from(*value)),
187            Self::Bytes(bytes) => (bytes.len() <= 8).then(|| read_u64_prefix(bytes)),
188            Self::Float(value) => f64_to_u64(*value),
189            Self::Array(_) => None,
190        }
191    }
192
193    /// Interpret the value as a full `u64`.
194    #[must_use]
195    pub fn as_u64(&self) -> u64 {
196        self.try_as_u64().unwrap_or(0)
197    }
198
199    /// Try to interpret the value as an `f32`.
200    #[must_use]
201    pub fn try_as_f32(&self) -> Option<f32> {
202        match self {
203            Self::Float(value) => Some(*value as f32),
204            Self::U32(value) => Some(f32::from_bits(*value)),
205            _ => None,
206        }
207    }
208
209    /// Return the full value payload as little-endian bytes.
210    #[must_use]
211    pub fn wide_bytes(&self) -> Vec<u8> {
212        self.to_bytes()
213    }
214
215    /// Create a zero value for the given data type.
216    #[must_use]
217    pub fn zero_for(ty: vyre::ir::DataType) -> Self {
218        Self::try_zero_for(ty).unwrap_or_else(|| Self::Bytes(Arc::from([])))
219    }
220
221    /// Try to create a zero value for the given data type.
222    #[must_use]
223    pub fn try_zero_for(ty: vyre::ir::DataType) -> Option<Self> {
224        match ty {
225            vyre::ir::DataType::U32 => Some(Self::U32(0)),
226            vyre::ir::DataType::I32 => Some(Self::I32(0)),
227            vyre::ir::DataType::U64 => Some(Self::U64(0)),
228            vyre::ir::DataType::Bool => Some(Self::Bool(false)),
229            vyre::ir::DataType::Bytes => Some(Self::Bytes(Arc::from([]))),
230            vyre::ir::DataType::F32 => Some(Self::Float(0.0)),
231            vyre::ir::DataType::F64 => Some(Self::Float(0.0)),
232            vyre::ir::DataType::Vec2U32 => Some(Self::Bytes(Arc::from(vec![0; 8]))),
233            vyre::ir::DataType::Vec4U32 => Some(Self::Bytes(Arc::from(vec![0; 16]))),
234            _ => {
235                fixed_scalar_storage_width(&ty).map(|width| Self::Bytes(Arc::from(vec![0; width])))
236            }
237        }
238    }
239
240    /// Create a value from element bytes for the given data type.
241    ///
242    /// # Errors
243    ///
244    /// Returns an error if the byte slice is too short for the declared type.
245    pub fn from_element_bytes(ty: vyre::ir::DataType, bytes: &[u8]) -> Result<Self, String> {
246        match ty {
247            vyre::ir::DataType::U32 => {
248                if bytes.len() < 4 {
249                    return Err("u32 requires 4 bytes".to_string());
250                }
251                Ok(Self::U32(u32::from_le_bytes([
252                    bytes[0], bytes[1], bytes[2], bytes[3],
253                ])))
254            }
255            vyre::ir::DataType::I32 => {
256                if bytes.len() < 4 {
257                    return Err("i32 requires 4 bytes".to_string());
258                }
259                Ok(Self::I32(i32::from_le_bytes([
260                    bytes[0], bytes[1], bytes[2], bytes[3],
261                ])))
262            }
263            vyre::ir::DataType::U64 => {
264                if bytes.len() < 8 {
265                    return Err("u64 requires 8 bytes".to_string());
266                }
267                Ok(Self::U64(u64::from_le_bytes([
268                    bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
269                ])))
270            }
271            vyre::ir::DataType::Bool => {
272                if bytes.len() < 4 {
273                    return Err("bool requires 4 bytes".to_string());
274                }
275                Ok(Self::Bool(
276                    u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) != 0,
277                ))
278            }
279            vyre::ir::DataType::Vec2U32 => {
280                if bytes.len() < 8 {
281                    return Err("vec2u32 requires 8 bytes".to_string());
282                }
283                Ok(Self::Bytes(Arc::from(&bytes[..8])))
284            }
285            vyre::ir::DataType::Vec4U32 => {
286                if bytes.len() < 16 {
287                    return Err("vec4u32 requires 16 bytes".to_string());
288                }
289                Ok(Self::Bytes(Arc::from(&bytes[..16])))
290            }
291            vyre::ir::DataType::F32 => {
292                if bytes.len() < 4 {
293                    return Err("f32 requires 4 bytes".to_string());
294                }
295                let value = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
296                Ok(Self::Float(f64::from(
297                    crate::execution::typed_ops::canonical_f32(value),
298                )))
299            }
300            vyre::ir::DataType::F64 => {
301                if bytes.len() < 8 {
302                    return Err("f64 requires 8 bytes".to_string());
303                }
304                Ok(Self::Float(f64::from_le_bytes([
305                    bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
306                ])))
307            }
308            vyre::ir::DataType::Bytes => Ok(Self::Bytes(Arc::from(bytes))),
309            _ => match fixed_scalar_storage_width(&ty) {
310                Some(width) => {
311                    if bytes.len() < width {
312                        return Err(format!("{ty} requires {width} bytes"));
313                    }
314                    Ok(Self::Bytes(Arc::from(&bytes[..width])))
315                }
316                None => Ok(Self::Bytes(Arc::from(bytes))),
317            },
318        }
319    }
320}
321
322fn fixed_scalar_storage_width(ty: &vyre::ir::DataType) -> Option<usize> {
323    match ty {
324        vyre::ir::DataType::U8
325        | vyre::ir::DataType::I8
326        | vyre::ir::DataType::F8E4M3
327        | vyre::ir::DataType::F8E5M2
328        | vyre::ir::DataType::I4
329        | vyre::ir::DataType::FP4
330        | vyre::ir::DataType::NF4 => Some(1),
331        vyre::ir::DataType::U16
332        | vyre::ir::DataType::I16
333        | vyre::ir::DataType::F16
334        | vyre::ir::DataType::BF16 => Some(2),
335        vyre::ir::DataType::Handle(_) | vyre::ir::DataType::DeviceMesh { .. } => Some(4),
336        vyre::ir::DataType::I64 => Some(8),
337        vyre::ir::DataType::Array { element_size } => Some(*element_size),
338        vyre::ir::DataType::Vec { element, count } => fixed_scalar_storage_width(element)
339            .and_then(|width| width.checked_mul(usize::from(*count))),
340        vyre::ir::DataType::TensorShaped { element, shape } => {
341            let element_width = fixed_scalar_storage_width(element)?;
342            shape
343                .iter()
344                .try_fold(element_width, |width, &dim| width.checked_mul(dim as usize))
345        }
346        vyre::ir::DataType::Quantized { storage, .. } => fixed_scalar_storage_width(storage),
347        _ => None,
348    }
349}
350
351fn extend_fixed_width(bytes: &[u8], declared_width: usize, out: &mut Vec<u8>) {
352    if declared_width == 0 {
353        out.extend_from_slice(bytes);
354        return;
355    }
356    let copied = bytes.len().min(declared_width);
357    out.extend_from_slice(&bytes[..copied]);
358    out.resize(out.len() + (declared_width - copied), 0);
359}
360
361fn copy_bytes_prefix(bytes: &[u8], target: &mut [u8], cursor: &mut usize) {
362    if *cursor >= target.len() {
363        return;
364    }
365    let len = (target.len() - *cursor).min(bytes.len());
366    target[*cursor..*cursor + len].copy_from_slice(&bytes[..len]);
367    *cursor += len;
368}
369
370fn f64_to_u32(value: f64) -> Option<u32> {
371    (value.is_finite() && value >= 0.0 && value <= f64::from(u32::MAX)).then(|| value as u32)
372}
373
374fn f64_to_u64(value: f64) -> Option<u64> {
375    const U64_EXCLUSIVE_MAX_AS_F64: f64 = 18_446_744_073_709_551_616.0;
376    (value.is_finite() && value >= 0.0 && value < U64_EXCLUSIVE_MAX_AS_F64).then(|| value as u64)
377}
378
379impl From<Vec<u8>> for Value {
380    fn from(bytes: Vec<u8>) -> Self {
381        Self::Bytes(Arc::from(bytes))
382    }
383}
384
385impl From<&[u8]> for Value {
386    fn from(bytes: &[u8]) -> Self {
387        Self::Bytes(Arc::from(bytes))
388    }
389}
390
391fn read_u32_prefix(bytes: &[u8]) -> u32 {
392    let mut padded = [0u8; 4];
393    let len = bytes.len().min(4);
394    padded[..len].copy_from_slice(&bytes[..len]);
395    u32::from_le_bytes(padded)
396}
397
398fn read_u64_prefix(bytes: &[u8]) -> u64 {
399    let mut padded = [0u8; 8];
400    let len = bytes.len().min(8);
401    padded[..len].copy_from_slice(&bytes[..len]);
402    u64::from_le_bytes(padded)
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use proptest::prelude::*;
409
410    #[test]
411    fn neg_zero_truthiness_is_false() {
412        assert!(!Value::Float(-0.0).truthy());
413    }
414
415    #[test]
416    fn pos_zero_truthiness_is_false() {
417        assert!(!Value::Float(0.0).truthy());
418    }
419
420    #[test]
421    fn nonzero_float_truthiness_is_true() {
422        assert!(Value::Float(1.0).truthy());
423        assert!(Value::Float(-1.0).truthy());
424        assert!(Value::Float(f64::INFINITY).truthy());
425        assert!(Value::Float(f64::NEG_INFINITY).truthy());
426    }
427
428    #[test]
429    fn f32_element_decode_canonicalizes_subnormal_and_nan_payload_bits() {
430        let positive_subnormal =
431            Value::from_element_bytes(vyre::ir::DataType::F32, &1u32.to_le_bytes())
432                .expect("Fix: replace expect with fallible API or document caller precondition; panic only on programmer error - f32 positive subnormal decode must succeed");
433        assert_eq!(
434            positive_subnormal.try_as_f32().unwrap().to_bits(),
435            0x0000_0000
436        );
437
438        let negative_subnormal =
439            Value::from_element_bytes(vyre::ir::DataType::F32, &0x8000_0001u32.to_le_bytes())
440                .expect("Fix: replace expect with fallible API or document caller precondition; panic only on programmer error - f32 negative subnormal decode must succeed");
441        assert_eq!(
442            negative_subnormal.try_as_f32().unwrap().to_bits(),
443            0x8000_0000
444        );
445
446        let payload_nan =
447            Value::from_element_bytes(vyre::ir::DataType::F32, &0x7fa0_0001u32.to_le_bytes())
448                .expect("Fix: replace expect with fallible API or document caller precondition; panic only on programmer error - f32 payload NaN decode must succeed");
449        assert_eq!(payload_nan.try_as_f32().unwrap().to_bits(), 0x7fc0_0000);
450    }
451
452    proptest! {
453        #[test]
454        fn neg_zero_select_branches_to_false(
455            positive_sign in proptest::bool::ANY,
456        ) {
457            let zero = if positive_sign { 0.0_f64 } else { -0.0_f64 };
458            prop_assert!(!Value::Float(zero).truthy(),
459                "Value::Float({zero}).truthy() must be false to match backend bool(0.0)/bool(-0.0) semantics");
460        }
461    }
462}