vyre_spec/extension.rs
1//! Extension contracts for open IR.
2//!
3//! Downstream crates ship new `Expr`, `Node`, `DataType`, `BinOp`, `UnOp`,
4//! `AtomicOp`, `TernaryOp`, and `RuleCondition` variants by implementing the
5//! traits in this module and registering an id with the vyre-core inventory
6//! layer.
7//!
8//! `vyre-spec` is intentionally data-only and carries no dependency on
9//! `inventory`. The trait signatures below describe the stable contract;
10//! actual registration + resolution lives in `vyre::dialect::extension`
11//! (see the vyre-core crate).
12//!
13//! Every extension id occupies the range `0x8000_0000..=0xFFFF_FFFF` — the
14//! high bit of the wire tag distinguishes extension ids from the frozen
15//! core tag space `0x00..=0x7F`. The `ExtensionDataTypeId::from_name`
16//! constructor folds a stable crate-name hash into the reserved range so
17//! two independently-authored extensions collide only on deliberate
18//! name-clashes.
19
20use core::fmt::Debug;
21
22/// Stable u32 id for an extension variant.
23///
24/// Extension ids are generated deterministically from a stable name via
25/// [`ExtensionDataTypeId::from_name`]. A crate that never changes its
26/// extension name keeps the same id across versions, which is the
27/// wire-format contract: a `Program` encoded by v1.0 of an extension
28/// decodes identically in v1.1 so long as the name is stable.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
30pub struct ExtensionDataTypeId(pub u32);
31
32impl ExtensionDataTypeId {
33 /// Reserved range: every extension id has its high bit set.
34 ///
35 /// Core IR discriminants occupy `0x00..=0x7F`; extensions occupy
36 /// `0x80..=0xFFFF_FFFF`. Wire decoders test the high byte to route
37 /// decoding between the two.
38 pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
39
40 /// Construct an id from a stable extension name.
41 ///
42 /// The id is derived deterministically: callers that pass the same
43 /// `name` always get the same id. Wire encoders serialize this id
44 /// directly; decoders on a machine with the same extension crate
45 /// linked resolve it back to the original trait vtable.
46 ///
47 /// The implementation hashes `name` with FNV-1a and folds the
48 /// 32-bit result into the extension range by setting the high
49 /// bit. Two collision-free names produce two distinct ids with
50 /// overwhelming probability.
51 #[must_use]
52 pub const fn from_name(name: &str) -> Self {
53 Self(fnv1a_with_high_bit(name))
54 }
55
56 /// Return the raw id.
57 #[must_use]
58 pub const fn as_u32(self) -> u32 {
59 self.0
60 }
61
62 /// Is this a reserved extension id (high bit set)?
63 #[must_use]
64 pub const fn is_extension(self) -> bool {
65 (self.0 & Self::EXTENSION_RANGE_MASK) != 0
66 }
67}
68
69/// The contract for an extension-declared `DataType`.
70///
71/// An implementer describes the runtime shape of a non-core data type:
72/// how many bytes it occupies, whether it participates in the float
73/// conformance family, and how it should be displayed.
74///
75/// vyre-core walks a link-time inventory of `ExtensionDataTypeRegistration`
76/// entries to resolve a `DataType::Opaque(id)` back to the trait vtable.
77/// The resolver caches `&'static dyn ExtensionDataType` so downstream
78/// consumers never re-consult the registry on the hot path.
79pub trait ExtensionDataType: Send + Sync + Debug + 'static {
80 /// Stable id for this data type.
81 fn id(&self) -> ExtensionDataTypeId;
82 /// Human-readable name for display / debug.
83 fn display_name(&self) -> &'static str;
84 /// Minimum byte count to represent one value of this type.
85 fn min_bytes(&self) -> usize;
86 /// Maximum byte count for one value of this type; `None` when unbounded.
87 fn max_bytes(&self) -> Option<usize>;
88 /// Fixed element size in bytes, or `None` for variable-size types.
89 fn size_bytes(&self) -> Option<usize>;
90 /// Whether this type belongs to the IEEE-754 float conformance family.
91 fn is_float_family(&self) -> bool {
92 false
93 }
94 /// Whether values can be safely memcpy'd between host and device.
95 fn is_host_shareable(&self) -> bool {
96 true
97 }
98}
99
100/// Runtime contract for an extension-declared binary operator.
101///
102/// Vyre-core's resolver caches `&'static dyn ExtensionBinOp` pointers keyed
103/// by [`ExtensionBinOpId`]; downstream evaluators / lowerings call through
104/// this trait without re-consulting the registry on the hot path.
105pub trait ExtensionBinOp: Send + Sync + Debug + 'static {
106 /// Stable id of this binary operator.
107 fn id(&self) -> ExtensionBinOpId;
108 /// Human-readable name for display / debug.
109 fn display_name(&self) -> &'static str;
110 /// Evaluate on the reference (CPU) backend.
111 ///
112 /// Returning `None` means "this backend does not support the op"; the
113 /// caller surfaces a typed error. Extensions implementing backends
114 /// other than reference supply their own lowering via the backend
115 /// registry.
116 fn eval_u32(&self, _a: u32, _b: u32) -> Option<u32> {
117 None
118 }
119}
120
121/// Runtime contract for an extension-declared unary operator.
122pub trait ExtensionUnOp: Send + Sync + Debug + 'static {
123 /// Stable id of this unary operator.
124 fn id(&self) -> ExtensionUnOpId;
125 /// Human-readable name for display / debug.
126 fn display_name(&self) -> &'static str;
127 /// Evaluate on the reference (CPU) backend. `None` = unsupported.
128 fn eval_u32(&self, _a: u32) -> Option<u32> {
129 None
130 }
131}
132
133/// Runtime contract for an extension-declared atomic operator.
134pub trait ExtensionAtomicOp: Send + Sync + Debug + 'static {
135 /// Stable id of this atomic operator.
136 fn id(&self) -> ExtensionAtomicOpId;
137 /// Human-readable name for display / debug.
138 fn display_name(&self) -> &'static str;
139}
140
141/// Runtime contract for an extension-declared ternary operator.
142pub trait ExtensionTernaryOp: Send + Sync + Debug + 'static {
143 /// Stable id of this ternary operator.
144 fn id(&self) -> ExtensionTernaryOpId;
145 /// Human-readable name for display / debug.
146 fn display_name(&self) -> &'static str;
147}
148
149/// Stable u32 id for an extension binary operator.
150///
151/// Identical discipline to [`ExtensionDataTypeId`]: stable across process
152/// runs, high bit set, generated by FNV-1a of the extension name.
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
154pub struct ExtensionBinOpId(pub u32);
155
156impl ExtensionBinOpId {
157 /// Reserved range mask (see [`ExtensionDataTypeId::EXTENSION_RANGE_MASK`]).
158 pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
159
160 /// Construct from a stable extension name.
161 #[must_use]
162 pub const fn from_name(name: &str) -> Self {
163 Self(fnv1a_with_high_bit(name))
164 }
165
166 /// Raw id.
167 #[must_use]
168 pub const fn as_u32(self) -> u32 {
169 self.0
170 }
171}
172
173/// Stable u32 id for an extension unary operator.
174#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
175pub struct ExtensionUnOpId(pub u32);
176
177impl ExtensionUnOpId {
178 /// Reserved range mask.
179 pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
180
181 /// Construct from a stable extension name.
182 #[must_use]
183 pub const fn from_name(name: &str) -> Self {
184 Self(fnv1a_with_high_bit(name))
185 }
186
187 /// Raw id.
188 #[must_use]
189 pub const fn as_u32(self) -> u32 {
190 self.0
191 }
192}
193
194/// Stable u32 id for an extension atomic operator.
195#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
196pub struct ExtensionAtomicOpId(pub u32);
197
198impl ExtensionAtomicOpId {
199 /// Reserved range mask.
200 pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
201
202 /// Construct from a stable extension name.
203 #[must_use]
204 pub const fn from_name(name: &str) -> Self {
205 Self(fnv1a_with_high_bit(name))
206 }
207
208 /// Raw id.
209 #[must_use]
210 pub const fn as_u32(self) -> u32 {
211 self.0
212 }
213}
214
215/// Stable u32 id for an extension ternary operator.
216#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
217pub struct ExtensionTernaryOpId(pub u32);
218
219impl ExtensionTernaryOpId {
220 /// Reserved range mask.
221 pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
222
223 /// Construct from a stable extension name.
224 #[must_use]
225 pub const fn from_name(name: &str) -> Self {
226 Self(fnv1a_with_high_bit(name))
227 }
228
229 /// Raw id.
230 #[must_use]
231 pub const fn as_u32(self) -> u32 {
232 self.0
233 }
234}
235
236/// Stable u32 id for an extension rule condition.
237#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
238pub struct ExtensionRuleConditionId(pub u32);
239
240impl ExtensionRuleConditionId {
241 /// Reserved range mask.
242 pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
243
244 /// Construct from a stable extension name.
245 #[must_use]
246 pub const fn from_name(name: &str) -> Self {
247 Self(fnv1a_with_high_bit(name))
248 }
249
250 /// Raw id.
251 #[must_use]
252 pub const fn as_u32(self) -> u32 {
253 self.0
254 }
255}
256
257/// FNV-1a 32-bit hash folded into the extension range (high bit set).
258///
259/// Shared helper backing every `ExtensionXxxId::from_name`. Kept private
260/// so callers don't construct raw ids that bypass the high-bit invariant.
261#[must_use]
262const fn fnv1a_with_high_bit(name: &str) -> u32 {
263 let mut hash: u32 = 0x811c_9dc5;
264 let bytes = name.as_bytes();
265 let mut i = 0;
266 while i < bytes.len() {
267 hash ^= bytes[i] as u32;
268 hash = hash.wrapping_mul(0x0100_0193);
269 i += 1;
270 }
271 hash | 0x8000_0000
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn id_from_name_is_deterministic() {
280 assert_eq!(
281 ExtensionDataTypeId::from_name("tensor.gather"),
282 ExtensionDataTypeId::from_name("tensor.gather"),
283 );
284 }
285
286 #[test]
287 fn id_from_different_names_differ() {
288 let a = ExtensionDataTypeId::from_name("tensor.gather");
289 let b = ExtensionDataTypeId::from_name("tensor.scatter");
290 assert_ne!(a, b);
291 }
292
293 #[test]
294 fn every_id_is_in_extension_range() {
295 let id = ExtensionDataTypeId::from_name("anything");
296 assert!(id.is_extension(), "{:#010x} missing high bit", id.as_u32());
297 assert!(id.as_u32() & ExtensionDataTypeId::EXTENSION_RANGE_MASK != 0);
298 }
299}