Skip to main content

vyre_spec/
data_type.rs

1//! Frozen IR data-type tags shared by signatures, validators, and wire metadata.
2
3use core::fmt;
4
5/// Canonical data types supported by the vyre IR frozen data contract.
6///
7/// Integer-first by design. GPU floating-point is nondeterministic across
8/// vendors through different rounding, fused multiply-add, and subnormal
9/// handling. Integer arithmetic is deterministic everywhere. F32 is supported
10/// for primitives that require it, with conformance validated per-backend.
11/// `vyre::ir::DataType` re-exports this same type; conformance metadata should
12/// use this canonical contract path. Example: `DataType::Vec4U32` records a
13/// four-word lane value and has a minimum byte width of 16.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
15pub enum DataType {
16    /// Unsigned 32-bit integer. The fundamental GPU word.
17    U32,
18    /// Signed 32-bit integer.
19    I32,
20    /// Unsigned 64-bit integer, emulated as `vec2<u32>` with low and high words.
21    U64,
22    /// Two-component `u32` vector.
23    Vec2U32,
24    /// Four-component `u32` vector.
25    Vec4U32,
26    /// Boolean value stored as a GPU word.
27    Bool,
28    /// Variable-length byte buffer.
29    Bytes,
30    /// Fixed-element-size array.
31    ///
32    /// Each element is `element_size` bytes. The total byte count is
33    /// `N * element_size` where N is encoded by the value.
34    Array {
35        /// Byte size of each element.
36        element_size: usize,
37    },
38    /// Strict IEEE 754 binary16 floating-point.
39    F16,
40    /// Strict bfloat16 floating-point.
41    BF16,
42    /// IEEE 754 binary32 floating-point.
43    F32,
44    /// Strict IEEE 754 binary64 floating-point.
45    F64,
46    /// Multi-dimensional tensor value.
47    Tensor,
48}
49
50impl DataType {
51    /// Minimum byte count to represent one value of this type.
52    #[must_use]
53    pub const fn min_bytes(&self) -> usize {
54        match self {
55            Self::Bool | Self::U32 | Self::I32 | Self::F32 => 4,
56            Self::U64 | Self::Vec2U32 => 8,
57            Self::Vec4U32 => 16,
58            Self::F16 | Self::BF16 => 2,
59            Self::F64 => 8,
60            Self::Bytes | Self::Array { .. } | Self::Tensor => 0,
61        }
62    }
63
64    /// Maximum byte count for one value of this type.
65    ///
66    /// Returns `None` for truly unbounded types; currently all variants
67    /// have a hard ceiling. Fixed-width types return `Some(min_bytes())`.
68    #[must_use]
69    pub const fn max_bytes(&self) -> Option<usize> {
70        match self {
71            Self::U32 | Self::I32 | Self::Bool => Some(4),
72            Self::U64 | Self::Vec2U32 => Some(8),
73            Self::Vec4U32 => Some(16),
74            Self::F16 | Self::BF16 => Some(2),
75            Self::F32 => Some(4),
76            Self::F64 => Some(8),
77            Self::Bytes => Some(64 * 1024 * 1024),
78            Self::Array { .. } | Self::Tensor => Some(256 * 1024 * 1024),
79        }
80    }
81
82    /// Element size for array-typed outputs, or `None` for scalar types.
83    #[must_use]
84    pub const fn element_size(&self) -> Option<usize> {
85        match self {
86            Self::Array { element_size } => Some(*element_size),
87            _ => None,
88        }
89    }
90
91    /// Whether this type belongs to the strict floating-point conformance family.
92    #[must_use]
93    pub const fn is_float_family(&self) -> bool {
94        matches!(self, Self::F16 | Self::BF16 | Self::F32 | Self::F64)
95    }
96}
97
98impl fmt::Display for DataType {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        match self {
101            Self::U32 => f.write_str("u32"),
102            Self::I32 => f.write_str("i32"),
103            Self::U64 => f.write_str("u64"),
104            Self::Vec2U32 => f.write_str("vec2<u32>"),
105            Self::Vec4U32 => f.write_str("vec4<u32>"),
106            Self::Bool => f.write_str("bool"),
107            Self::Bytes => f.write_str("bytes"),
108            Self::Array { element_size } => write!(f, "array<{element_size}B>"),
109            Self::F16 => f.write_str("f16"),
110            Self::BF16 => f.write_str("bf16"),
111            Self::F32 => f.write_str("f32"),
112            Self::F64 => f.write_str("f64"),
113            Self::Tensor => f.write_str("tensor"),
114        }
115    }
116}