vyre_spec/data_type.rs
1//! Frozen IR data-type tags shared by signatures, validators, and wire metadata.
2// TAG RESERVATIONS: U32=0x01, I32=0x02, U64=0x03, Vec2U32=0x04,
3// Vec4U32=0x05, Bool=0x06, Bytes=0x07, Array=0x08, F16=0x09,
4// BF16=0x0A, F32=0x0B, F64=0x0C, Tensor=0x0D, U8=0x0E, U16=0x0F,
5// I8=0x10, I16=0x11, I64=0x12, Handle=0x13, Vec=0x14,
6// TensorShaped=0x15, SparseCsr=0x16, SparseCoo=0x17, SparseBsr=0x18,
7// F8E4M3=0x19, F8E5M2=0x1A, I4=0x1B, FP4=0x1C, NF4=0x1D,
8// DeviceMesh=0x1E, Quantized=0x1F, 0x20..=0x7F reserved, Opaque=0x80.
9
10use crate::extension::ExtensionDataTypeId;
11
12mod display;
13mod layout;
14mod validation;
15
16/// Stable handle type id for backend-owned GPU resources.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
18pub struct TypeId(pub u32);
19
20impl TypeId {
21 /// Return the raw stable handle type id.
22 #[must_use]
23 pub const fn as_u32(self) -> u32 {
24 self.0
25 }
26}
27
28/// Scale metadata layout for a quantized tensor or vector.
29#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
30pub enum QuantizationScale {
31 /// One scale value for the whole buffer.
32 PerTensor,
33 /// One scale value per slice along `axis`.
34 PerChannel {
35 /// Tensor axis carrying independent scale values.
36 axis: u32,
37 },
38 /// One scale value per contiguous group.
39 PerGroup {
40 /// Number of logical elements per quantization group.
41 group_size: u32,
42 },
43}
44
45/// Zero-point metadata layout for affine quantization.
46#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
47pub enum QuantizationZeroPoint {
48 /// Symmetric quantization; zero point is implicitly zero.
49 Absent,
50 /// One zero point for the whole buffer.
51 PerTensor,
52 /// One zero point per slice along `axis`.
53 PerChannel {
54 /// Tensor axis carrying independent zero-point values.
55 axis: u32,
56 },
57 /// One zero point per contiguous quantization group.
58 PerGroup {
59 /// Number of logical elements per quantization group.
60 group_size: u32,
61 },
62}
63
64/// Canonical data types supported by the vyre IR frozen data contract.
65///
66/// Integer-first by design. GPU floating-point is nondeterministic across
67/// vendors through different rounding, fused multiply-add, and subnormal
68/// handling. Integer arithmetic is deterministic everywhere. F32 is supported
69/// for primitives that require it, with conformance validated per-backend.
70/// `vyre::ir::DataType` re-exports this same type; conformance metadata should
71/// use this canonical contract path. Example: `DataType::Vec4U32` records a
72/// four-word lane value and has a minimum byte width of 16.
73#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
74#[non_exhaustive]
75pub enum DataType {
76 /// Unsigned 8-bit integer.
77 U8,
78 /// Unsigned 16-bit integer.
79 U16,
80 /// Unsigned 32-bit integer. The fundamental GPU word.
81 U32,
82 /// Signed 8-bit integer.
83 I8,
84 /// Signed 16-bit integer.
85 I16,
86 /// Signed 32-bit integer.
87 I32,
88 /// Signed 64-bit integer.
89 I64,
90 /// Unsigned 64-bit integer, emulated as `vec2<u32>` with low and high words.
91 U64,
92 /// Two-component `u32` vector.
93 Vec2U32,
94 /// Four-component `u32` vector.
95 Vec4U32,
96 /// Boolean value stored as a GPU word.
97 Bool,
98 /// Variable-length byte buffer.
99 Bytes,
100 /// Fixed-element-size array.
101 ///
102 /// Each element is `element_size` bytes. The total byte count is
103 /// `N * element_size` where N is encoded by the value.
104 Array {
105 /// Byte size of each element.
106 element_size: usize,
107 },
108 /// Strict IEEE 754 binary16 floating-point.
109 F16,
110 /// Strict bfloat16 floating-point.
111 BF16,
112 /// IEEE 754 binary32 floating-point.
113 F32,
114 /// Strict IEEE 754 binary64 floating-point.
115 F64,
116 /// Multi-dimensional tensor value.
117 Tensor,
118 /// Opaque backend resource handle.
119 Handle(TypeId),
120 /// Generic fixed-lane vector.
121 Vec {
122 /// Lane element type.
123 element: Box<Self>,
124 /// Lane count.
125 count: u8,
126 },
127 /// Tensor with explicit element type and rank-limited shape.
128 TensorShaped {
129 /// Tensor element type.
130 element: Box<Self>,
131 /// Tensor dimensions. Four dimensions stay inline.
132 shape: smallvec::SmallVec<[u32; 4]>,
133 },
134 /// Sparse-CSR tensor: compressed sparse row layout. Element type
135 /// lives in the dense values buffer; structure (indptr + `col_idx`)
136 /// is laid out separately by the consumer per the documented CSR
137 /// contract. Size depends on nnz; conservative sentinel applies.
138 ///
139 /// Wire encoding: tag `0x16` followed by the element type tag.
140 SparseCsr {
141 /// Element type of the dense values buffer.
142 element: Box<Self>,
143 },
144 /// Sparse-COO tensor: coordinate-list layout with (row, col, val)
145 /// triples. Simpler than CSR but less cache-friendly; lowering
146 /// passes typically convert COO → CSR before dispatch.
147 ///
148 /// Wire encoding: tag `0x17` followed by the element type tag.
149 SparseCoo {
150 /// Element type of each triple's value.
151 element: Box<Self>,
152 },
153 /// Sparse-BSR tensor: block-sparse rows with fixed block size.
154 /// Favored by quantized LLM weight matrices (50%+ sparsity at
155 /// block-granularity retains line-rate GEMM).
156 ///
157 /// Wire encoding: tag `0x18` followed by `block_rows u32`,
158 /// `block_cols u32`, then the element type tag.
159 SparseBsr {
160 /// Element type.
161 element: Box<Self>,
162 /// Block height in elements.
163 block_rows: u32,
164 /// Block width in elements.
165 block_cols: u32,
166 },
167 /// 8-bit float (E4M3 format, per FP8 spec) for quantized inference.
168 F8E4M3,
169 /// 8-bit float (E5M2 format, per FP8 spec) - wider range than E4M3.
170 F8E5M2,
171 /// 4-bit signed integer for aggressive LLM weight quantization.
172 I4,
173 /// 4-bit float for LLM-class inference.
174 FP4,
175 /// 4-bit "normal-float" (per `QLoRA` paper) for LLM weight compression.
176 NF4,
177 /// Device-mesh handle - topology identifier consumed by
178 /// collective ops (`all_reduce`, `all_gather`, `reduce_scatter`,
179 /// broadcast). Shape is informational; actual topology is
180 /// resolved through the backend's mesh registry.
181 DeviceMesh {
182 /// Device count along each mesh axis. 1-D = pure ring/tree;
183 /// 2-D = torus; higher-D = hypercube.
184 axes: smallvec::SmallVec<[u32; 3]>,
185 },
186 /// First-class quantized value domain.
187 ///
188 /// `storage` is the physical packed element family (`I4`, `I8`, `U8`,
189 /// `F8E4M3`, `NF4`, etc.). `scale` and `zero_point` describe the
190 /// sidecar buffers needed to dequantize, operate, and optionally requantize
191 /// without losing the stable IR type. This closes RFC-0003 at the spec
192 /// layer; concrete ops still choose whether to lower to tensor-core MMA,
193 /// scalar dequantize-op-requantize, or a backend-specific packed path.
194 Quantized {
195 /// Physical storage element type.
196 storage: Box<Self>,
197 /// Scale sidecar layout.
198 scale: QuantizationScale,
199 /// Optional zero-point sidecar layout.
200 zero_point: QuantizationZeroPoint,
201 },
202 /// Extension-declared data type.
203 ///
204 /// The `ExtensionDataTypeId` is stable across process runs and
205 /// resolves to a `&'static dyn ExtensionDataType` via
206 /// `vyre::dialect::extension::resolve_data_type` (in vyre-core).
207 /// Wire encoding of Opaque is `0x80 ++ u32 extension_id` - see
208 /// `docs/wire-format.md` §Extensions.
209 ///
210 /// The builtin const methods on `DataType` (`min_bytes`, `max_bytes`,
211 /// `size_bytes`, `is_float_family`) return conservative sentinels for
212 /// Opaque because the real values live behind the trait and are not
213 /// known at compile time. Consumers that need the actual values
214 /// should resolve the trait via the vyre-core registry.
215 Opaque(ExtensionDataTypeId),
216}
217
218#[allow(clippy::match_same_arms)]
219impl DataType {
220 /// Frozen builtin wire tag for this data type.
221 ///
222 /// Returns `None` for extension-declared opaque types because their wire
223 /// representation is the high-bit extension id, not a core builtin tag.
224 #[must_use]
225 pub const fn builtin_wire_tag(&self) -> Option<u8> {
226 match self {
227 Self::U32 => Some(0x01),
228 Self::I32 => Some(0x02),
229 Self::U64 => Some(0x03),
230 Self::Vec2U32 => Some(0x04),
231 Self::Vec4U32 => Some(0x05),
232 Self::Bool => Some(0x06),
233 Self::Bytes => Some(0x07),
234 Self::Array { .. } => Some(0x08),
235 Self::F16 => Some(0x09),
236 Self::BF16 => Some(0x0A),
237 Self::F32 => Some(0x0B),
238 Self::F64 => Some(0x0C),
239 Self::Tensor => Some(0x0D),
240 Self::U8 => Some(0x0E),
241 Self::U16 => Some(0x0F),
242 Self::I8 => Some(0x10),
243 Self::I16 => Some(0x11),
244 Self::I64 => Some(0x12),
245 Self::Handle(_) => Some(0x13),
246 Self::Vec { .. } => Some(0x14),
247 Self::TensorShaped { .. } => Some(0x15),
248 Self::SparseCsr { .. } => Some(0x16),
249 Self::SparseCoo { .. } => Some(0x17),
250 Self::SparseBsr { .. } => Some(0x18),
251 Self::F8E4M3 => Some(0x19),
252 Self::F8E5M2 => Some(0x1A),
253 Self::I4 => Some(0x1B),
254 Self::FP4 => Some(0x1C),
255 Self::NF4 => Some(0x1D),
256 Self::DeviceMesh { .. } => Some(0x1E),
257 Self::Quantized { .. } => Some(0x1F),
258 Self::Opaque(_) => None,
259 }
260 }
261
262 /// Whether this type belongs to the strict floating-point conformance family.
263 #[must_use]
264 pub const fn is_float_family(&self) -> bool {
265 match self {
266 Self::F16 | Self::BF16 | Self::F32 | Self::F64 => true,
267 Self::F8E4M3 | Self::F8E5M2 | Self::FP4 | Self::NF4 => true,
268 Self::Vec { element, .. }
269 | Self::TensorShaped { element, .. }
270 | Self::SparseCsr { element }
271 | Self::SparseCoo { element }
272 | Self::SparseBsr { element, .. } => element.is_float_family(),
273 Self::Quantized { .. } => false,
274 _ => false,
275 }
276 }
277
278 /// Whether this type carries first-class quantization sidecar metadata.
279 #[must_use]
280 pub const fn is_quantized(&self) -> bool {
281 match self {
282 Self::Quantized { .. } => true,
283 Self::Vec { element, .. }
284 | Self::TensorShaped { element, .. }
285 | Self::SparseCsr { element }
286 | Self::SparseCoo { element }
287 | Self::SparseBsr { element, .. } => element.is_quantized(),
288 _ => false,
289 }
290 }
291
292 /// Whether this type is valid as the storage field of `DataType::Quantized`.
293 #[must_use]
294 pub const fn is_quantized_storage(&self) -> bool {
295 matches!(
296 self,
297 Self::I4
298 | Self::I8
299 | Self::I16
300 | Self::U8
301 | Self::U16
302 | Self::F8E4M3
303 | Self::F8E5M2
304 | Self::FP4
305 | Self::NF4
306 )
307 }
308}