Skip to main content

vyre_spec/
data_type.rs

1//! Frozen IR data-type tags shared by signatures, validators, and wire metadata.
2// TAG RESERVATIONS: U32=0x01, I32=0x02, U64=0x03, Vec2U32=0x04,
3// Vec4U32=0x05, Bool=0x06, Bytes=0x07, Array=0x08, F16=0x09,
4// BF16=0x0A, F32=0x0B, F64=0x0C, Tensor=0x0D, U8=0x0E, U16=0x0F,
5// I8=0x10, I16=0x11, I64=0x12, Handle=0x13, Vec=0x14,
6// TensorShaped=0x15, SparseCsr=0x16, SparseCoo=0x17, SparseBsr=0x18,
7// F8E4M3=0x19, F8E5M2=0x1A, I4=0x1B, FP4=0x1C, NF4=0x1D,
8// DeviceMesh=0x1E, 0x1F..=0x7F reserved, Opaque=0x80.
9
10use core::fmt;
11
12use crate::extension::ExtensionDataTypeId;
13
14/// Stable handle type id for backend-owned GPU resources.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
16pub struct TypeId(pub u32);
17
18impl TypeId {
19    /// Return the raw stable handle type id.
20    #[must_use]
21    pub const fn as_u32(self) -> u32 {
22        self.0
23    }
24}
25
26/// Canonical data types supported by the vyre IR frozen data contract.
27///
28/// Integer-first by design. GPU floating-point is nondeterministic across
29/// vendors through different rounding, fused multiply-add, and subnormal
30/// handling. Integer arithmetic is deterministic everywhere. F32 is supported
31/// for primitives that require it, with conformance validated per-backend.
32/// `vyre::ir::DataType` re-exports this same type; conformance metadata should
33/// use this canonical contract path. Example: `DataType::Vec4U32` records a
34/// four-word lane value and has a minimum byte width of 16.
35#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
36#[non_exhaustive]
37pub enum DataType {
38    /// Unsigned 8-bit integer.
39    U8,
40    /// Unsigned 16-bit integer.
41    U16,
42    /// Unsigned 32-bit integer. The fundamental GPU word.
43    U32,
44    /// Signed 8-bit integer.
45    I8,
46    /// Signed 16-bit integer.
47    I16,
48    /// Signed 32-bit integer.
49    I32,
50    /// Signed 64-bit integer.
51    I64,
52    /// Unsigned 64-bit integer, emulated as `vec2<u32>` with low and high words.
53    U64,
54    /// Two-component `u32` vector.
55    Vec2U32,
56    /// Four-component `u32` vector.
57    Vec4U32,
58    /// Boolean value stored as a GPU word.
59    Bool,
60    /// Variable-length byte buffer.
61    Bytes,
62    /// Fixed-element-size array.
63    ///
64    /// Each element is `element_size` bytes. The total byte count is
65    /// `N * element_size` where N is encoded by the value.
66    Array {
67        /// Byte size of each element.
68        element_size: usize,
69    },
70    /// Strict IEEE 754 binary16 floating-point.
71    F16,
72    /// Strict bfloat16 floating-point.
73    BF16,
74    /// IEEE 754 binary32 floating-point.
75    F32,
76    /// Strict IEEE 754 binary64 floating-point.
77    F64,
78    /// Multi-dimensional tensor value.
79    Tensor,
80    /// Opaque backend resource handle.
81    Handle(TypeId),
82    /// Generic fixed-lane vector.
83    Vec {
84        /// Lane element type.
85        element: Box<Self>,
86        /// Lane count.
87        count: u8,
88    },
89    /// Tensor with explicit element type and rank-limited shape.
90    TensorShaped {
91        /// Tensor element type.
92        element: Box<Self>,
93        /// Tensor dimensions. Four dimensions stay inline.
94        shape: smallvec::SmallVec<[u32; 4]>,
95    },
96    /// Sparse-CSR tensor: compressed sparse row layout. Element type
97    /// lives in the dense values buffer; structure (indptr + `col_idx`)
98    /// is laid out separately by the consumer per the documented CSR
99    /// contract. Size depends on nnz; conservative sentinel applies.
100    ///
101    /// Wire encoding: tag `0x16` followed by the element type tag.
102    SparseCsr {
103        /// Element type of the dense values buffer.
104        element: Box<Self>,
105    },
106    /// Sparse-COO tensor: coordinate-list layout with (row, col, val)
107    /// triples. Simpler than CSR but less cache-friendly; lowering
108    /// passes typically convert COO → CSR before dispatch.
109    ///
110    /// Wire encoding: tag `0x17` followed by the element type tag.
111    SparseCoo {
112        /// Element type of each triple's value.
113        element: Box<Self>,
114    },
115    /// Sparse-BSR tensor: block-sparse rows with fixed block size.
116    /// Favored by quantized LLM weight matrices (50%+ sparsity at
117    /// block-granularity retains line-rate GEMM).
118    ///
119    /// Wire encoding: tag `0x18` followed by `block_rows u32`,
120    /// `block_cols u32`, then the element type tag.
121    SparseBsr {
122        /// Element type.
123        element: Box<Self>,
124        /// Block height in elements.
125        block_rows: u32,
126        /// Block width in elements.
127        block_cols: u32,
128    },
129    /// 8-bit float (E4M3 format, per FP8 spec) for quantized inference.
130    F8E4M3,
131    /// 8-bit float (E5M2 format, per FP8 spec) — wider range than E4M3.
132    F8E5M2,
133    /// 4-bit signed integer for aggressive LLM weight quantization.
134    I4,
135    /// 4-bit float for LLM-class inference.
136    FP4,
137    /// 4-bit "normal-float" (per `QLoRA` paper) for LLM weight compression.
138    NF4,
139    /// Device-mesh handle — topology identifier consumed by
140    /// collective ops (`all_reduce`, `all_gather`, `reduce_scatter`,
141    /// broadcast). Shape is informational; actual topology is
142    /// resolved through the backend's mesh registry.
143    DeviceMesh {
144        /// Device count along each mesh axis. 1-D = pure ring/tree;
145        /// 2-D = torus; higher-D = hypercube.
146        axes: smallvec::SmallVec<[u32; 3]>,
147    },
148    /// Extension-declared data type.
149    ///
150    /// The `ExtensionDataTypeId` is stable across process runs and
151    /// resolves to a `&'static dyn ExtensionDataType` via
152    /// `vyre::dialect::extension::resolve_data_type` (in vyre-core).
153    /// Wire encoding of Opaque is `0x80 ++ u32 extension_id` — see
154    /// `docs/wire-format.md` §Extensions.
155    ///
156    /// The builtin const methods on `DataType` (`min_bytes`, `max_bytes`,
157    /// `size_bytes`, `is_float_family`) return conservative sentinels for
158    /// Opaque because the real values live behind the trait and are not
159    /// known at compile time. Consumers that need the actual values
160    /// should resolve the trait via the vyre-core registry.
161    Opaque(ExtensionDataTypeId),
162}
163
164#[allow(clippy::match_same_arms)]
165impl DataType {
166    /// Minimum byte count to represent one value of this type.
167    #[must_use]
168    pub const fn min_bytes(&self) -> usize {
169        match self {
170            Self::U16 | Self::I16 | Self::F16 | Self::BF16 => 2,
171            Self::Bool | Self::U32 | Self::I32 | Self::F32 | Self::Handle(_) => 4,
172            Self::I64 | Self::U64 | Self::Vec2U32 | Self::F64 => 8,
173            Self::Vec4U32 => 16,
174            Self::Vec { element, count } => element.min_bytes() * (*count as usize),
175            Self::Bytes | Self::Array { .. } | Self::Tensor | Self::TensorShaped { .. } => 0,
176            // Quantized / compressed scalar families. F8/F4 = 1 byte rounded up;
177            // I4 / NF4 = 1 byte rounded up (two values share a byte in practice,
178            // but the conservative minimum is one byte per logical value).
179            Self::U8
180            | Self::I8
181            | Self::F8E4M3
182            | Self::F8E5M2
183            | Self::I4
184            | Self::FP4
185            | Self::NF4 => 1,
186            // Sparse layouts + device-mesh handles are unbounded at the
187            // spec level; runtime asks the extension for a concrete size.
188            Self::SparseCsr { .. } | Self::SparseCoo { .. } | Self::SparseBsr { .. } => 0,
189            Self::DeviceMesh { .. } => 0,
190            // Opaque: conservative sentinel. Real value via ExtensionDataType::min_bytes.
191            Self::Opaque(_) => 0,
192        }
193    }
194
195    /// Maximum byte count for one value of this type.
196    ///
197    /// Returns `None` for truly unbounded types; currently all variants
198    /// have a hard ceiling. Fixed-width types return `Some(min_bytes())`.
199    #[must_use]
200    pub const fn max_bytes(&self) -> Option<usize> {
201        match self {
202            Self::U8 | Self::I8 => Some(1),
203            Self::U16 | Self::I16 | Self::F16 | Self::BF16 => Some(2),
204            Self::U32 | Self::I32 | Self::Bool => Some(4),
205            Self::I64 | Self::U64 | Self::Vec2U32 | Self::F64 => Some(8),
206            Self::Vec4U32 => Some(16),
207            Self::F32 => Some(4),
208            Self::Handle(_) => Some(4),
209            Self::Vec { element, count } => match element.max_bytes() {
210                Some(bytes) => Some(bytes * (*count as usize)),
211                None => None,
212            },
213            Self::Bytes => Some(64 * 1024 * 1024),
214            Self::Array { .. } | Self::Tensor => Some(256 * 1024 * 1024),
215            Self::TensorShaped { .. } => None,
216            Self::F8E4M3 | Self::F8E5M2 => Some(1),
217            Self::I4 | Self::FP4 | Self::NF4 => Some(1),
218            Self::SparseCsr { .. } | Self::SparseCoo { .. } | Self::SparseBsr { .. } => None,
219            Self::DeviceMesh { .. } => Some(4),
220            // Opaque: unbounded at the spec level. Real ceiling via ExtensionDataType::max_bytes.
221            Self::Opaque(_) => None,
222        }
223    }
224
225    /// Element size for array-typed outputs, or `None` for scalar types.
226    #[must_use]
227    pub const fn element_size(&self) -> Option<usize> {
228        match self {
229            Self::Array { element_size } => Some(*element_size),
230            Self::Vec { element, .. }
231            | Self::TensorShaped { element, .. }
232            | Self::SparseCsr { element }
233            | Self::SparseCoo { element }
234            | Self::SparseBsr { element, .. } => element.size_bytes(),
235            Self::Opaque(_) => None,
236            _ => None,
237        }
238    }
239
240    /// Fixed scalar element size in bytes, or `None` for variable-size types.
241    ///
242    /// Scalar types return their natural width (`U32` → `Some(4)`, `Vec4U32` →
243    /// `Some(16)`). `Bytes` returns `Some(1)` because each element is one byte.
244    /// `Array` returns `Some(element_size)`. `Tensor` returns `None` because it
245    /// has no fixed per-element size.
246    #[must_use]
247    pub const fn size_bytes(&self) -> Option<usize> {
248        match self {
249            Self::U8 | Self::I8 => Some(1),
250            Self::U16 | Self::I16 | Self::F16 | Self::BF16 => Some(2),
251            Self::Bool | Self::U32 | Self::I32 | Self::F32 => Some(4),
252            Self::I64 | Self::U64 | Self::Vec2U32 | Self::F64 => Some(8),
253            Self::Vec4U32 => Some(16),
254            Self::Handle(_) => Some(4),
255            Self::Bytes => Some(1),
256            Self::Array { element_size } => Some(*element_size),
257            Self::Vec { element, count } => match element.size_bytes() {
258                Some(bytes) => Some(bytes * (*count as usize)),
259                None => None,
260            },
261            Self::Tensor | Self::TensorShaped { .. } => None,
262            Self::F8E4M3 | Self::F8E5M2 => Some(1),
263            Self::I4 | Self::FP4 | Self::NF4 => Some(1),
264            Self::SparseCsr { .. } | Self::SparseCoo { .. } | Self::SparseBsr { .. } => None,
265            Self::DeviceMesh { .. } => Some(4),
266            // Opaque: real size via ExtensionDataType::size_bytes (runtime).
267            Self::Opaque(_) => None,
268        }
269    }
270
271    /// Whether this type belongs to the strict floating-point conformance family.
272    #[must_use]
273    pub const fn is_float_family(&self) -> bool {
274        match self {
275            Self::F16 | Self::BF16 | Self::F32 | Self::F64 => true,
276            Self::F8E4M3 | Self::F8E5M2 | Self::FP4 | Self::NF4 => true,
277            Self::Vec { element, .. }
278            | Self::TensorShaped { element, .. }
279            | Self::SparseCsr { element }
280            | Self::SparseCoo { element }
281            | Self::SparseBsr { element, .. } => element.is_float_family(),
282            _ => false,
283        }
284    }
285}
286
287impl fmt::Display for DataType {
288    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289        match self {
290            Self::U8 => f.write_str("u8"),
291            Self::U16 => f.write_str("u16"),
292            Self::U32 => f.write_str("u32"),
293            Self::I8 => f.write_str("i8"),
294            Self::I16 => f.write_str("i16"),
295            Self::I32 => f.write_str("i32"),
296            Self::I64 => f.write_str("i64"),
297            Self::U64 => f.write_str("u64"),
298            Self::Vec2U32 => f.write_str("vec2<u32>"),
299            Self::Vec4U32 => f.write_str("vec4<u32>"),
300            Self::Bool => f.write_str("bool"),
301            Self::Bytes => f.write_str("bytes"),
302            Self::Array { element_size } => write!(f, "array<{element_size}B>"),
303            Self::F16 => f.write_str("f16"),
304            Self::BF16 => f.write_str("bf16"),
305            Self::F32 => f.write_str("f32"),
306            Self::F64 => f.write_str("f64"),
307            Self::Tensor => f.write_str("tensor"),
308            Self::Handle(id) => write!(f, "handle<{:#010x}>", id.as_u32()),
309            Self::Vec { element, count } => write!(f, "vec<{element};{count}>"),
310            Self::TensorShaped { element, shape } => {
311                write!(f, "tensor<{element};")?;
312                for (idx, dim) in shape.iter().enumerate() {
313                    if idx > 0 {
314                        f.write_str("x")?;
315                    }
316                    write!(f, "{dim}")?;
317                }
318                f.write_str(">")
319            }
320            Self::Opaque(id) => write!(f, "opaque<{:#010x}>", id.as_u32()),
321            Self::F8E4M3 => f.write_str("f8e4m3"),
322            Self::F8E5M2 => f.write_str("f8e5m2"),
323            Self::I4 => f.write_str("i4"),
324            Self::FP4 => f.write_str("fp4"),
325            Self::NF4 => f.write_str("nf4"),
326            Self::SparseCsr { element } => write!(f, "sparse_csr<{element}>"),
327            Self::SparseCoo { element } => write!(f, "sparse_coo<{element}>"),
328            Self::SparseBsr {
329                element,
330                block_rows,
331                block_cols,
332            } => write!(f, "sparse_bsr<{element};{block_rows}x{block_cols}>"),
333            Self::DeviceMesh { axes } => {
334                f.write_str("device_mesh<")?;
335                for (i, a) in axes.iter().enumerate() {
336                    if i > 0 {
337                        f.write_str("x")?;
338                    }
339                    write!(f, "{a}")?;
340                }
341                f.write_str(">")
342            }
343        }
344    }
345}