Skip to main content

vyre_spec/
extension.rs

1//! Extension contracts for open IR.
2//!
3//! Downstream crates ship new `Expr`, `Node`, `DataType`, `BinOp`, `UnOp`,
4//! `AtomicOp`, `TernaryOp`, and `RuleCondition` variants by implementing the
5//! traits in this module and registering an id with the vyre-core inventory
6//! layer.
7//!
8//! `vyre-spec` is intentionally data-only and carries no dependency on
9//! `inventory`. The trait signatures below describe the stable contract;
10//! actual registration + resolution lives in `vyre::dialect::extension`
11//! (see the vyre-core crate).
12//!
13//! Every extension id occupies the range `0x8000_0000..=0xFFFF_FFFF` — the
14//! high bit of the wire tag distinguishes extension ids from the frozen
15//! core tag space `0x00..=0x7F`. The `ExtensionDataTypeId::from_name`
16//! constructor folds a stable crate-name hash into the reserved range so
17//! two independently-authored extensions collide only on deliberate
18//! name-clashes.
19
20use core::fmt::Debug;
21
22/// Stable u32 id for an extension variant.
23///
24/// Extension ids are generated deterministically from a stable name via
25/// [`ExtensionDataTypeId::from_name`]. A crate that never changes its
26/// extension name keeps the same id across versions, which is the
27/// wire-format contract: a `Program` encoded by v1.0 of an extension
28/// decodes identically in v1.1 so long as the name is stable.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
30pub struct ExtensionDataTypeId(pub u32);
31
32impl ExtensionDataTypeId {
33    /// Reserved range: every extension id has its high bit set.
34    ///
35    /// Core IR discriminants occupy `0x00..=0x7F`; extensions occupy
36    /// `0x80..=0xFFFF_FFFF`. Wire decoders test the high byte to route
37    /// decoding between the two.
38    pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
39
40    /// Construct an id from a stable extension name.
41    ///
42    /// The id is derived deterministically: callers that pass the same
43    /// `name` always get the same id. Wire encoders serialize this id
44    /// directly; decoders on a machine with the same extension crate
45    /// linked resolve it back to the original trait vtable.
46    ///
47    /// The implementation hashes `name` with FNV-1a and folds the
48    /// 32-bit result into the extension range by setting the high
49    /// bit. Two collision-free names produce two distinct ids with
50    /// overwhelming probability.
51    #[must_use]
52    pub const fn from_name(name: &str) -> Self {
53        Self(fnv1a_with_high_bit(name))
54    }
55
56    /// Return the raw id.
57    #[must_use]
58    pub const fn as_u32(self) -> u32 {
59        self.0
60    }
61
62    /// Is this a reserved extension id (high bit set)?
63    #[must_use]
64    pub const fn is_extension(self) -> bool {
65        (self.0 & Self::EXTENSION_RANGE_MASK) != 0
66    }
67}
68
69/// The contract for an extension-declared `DataType`.
70///
71/// An implementer describes the runtime shape of a non-core data type:
72/// how many bytes it occupies, whether it participates in the float
73/// conformance family, and how it should be displayed.
74///
75/// vyre-core walks a link-time inventory of `ExtensionDataTypeRegistration`
76/// entries to resolve a `DataType::Opaque(id)` back to the trait vtable.
77/// The resolver caches `&'static dyn ExtensionDataType` so downstream
78/// consumers never re-consult the registry on the hot path.
79pub trait ExtensionDataType: Send + Sync + Debug + 'static {
80    /// Stable id for this data type.
81    fn id(&self) -> ExtensionDataTypeId;
82    /// Human-readable name for display / debug.
83    fn display_name(&self) -> &'static str;
84    /// Minimum byte count to represent one value of this type.
85    fn min_bytes(&self) -> usize;
86    /// Maximum byte count for one value of this type; `None` when unbounded.
87    fn max_bytes(&self) -> Option<usize>;
88    /// Fixed element size in bytes, or `None` for variable-size types.
89    fn size_bytes(&self) -> Option<usize>;
90    /// Whether this type belongs to the IEEE-754 float conformance family.
91    fn is_float_family(&self) -> bool {
92        false
93    }
94    /// Whether values can be safely memcpy'd between host and device.
95    fn is_host_shareable(&self) -> bool {
96        true
97    }
98}
99
100/// Runtime contract for an extension-declared binary operator.
101///
102/// Vyre-core's resolver caches `&'static dyn ExtensionBinOp` pointers keyed
103/// by [`ExtensionBinOpId`]; downstream evaluators / lowerings call through
104/// this trait without re-consulting the registry on the hot path.
105pub trait ExtensionBinOp: Send + Sync + Debug + 'static {
106    /// Stable id of this binary operator.
107    fn id(&self) -> ExtensionBinOpId;
108    /// Human-readable name for display / debug.
109    fn display_name(&self) -> &'static str;
110    /// Evaluate on the reference (CPU) backend.
111    ///
112    /// Returning `None` means "this backend does not support the op"; the
113    /// caller surfaces a typed error. Extensions implementing backends
114    /// other than reference supply their own lowering via the backend
115    /// registry.
116    fn eval_u32(&self, _a: u32, _b: u32) -> Option<u32> {
117        None
118    }
119}
120
121/// Runtime contract for an extension-declared unary operator.
122pub trait ExtensionUnOp: Send + Sync + Debug + 'static {
123    /// Stable id of this unary operator.
124    fn id(&self) -> ExtensionUnOpId;
125    /// Human-readable name for display / debug.
126    fn display_name(&self) -> &'static str;
127    /// Evaluate on the reference (CPU) backend. `None` = unsupported.
128    fn eval_u32(&self, _a: u32) -> Option<u32> {
129        None
130    }
131}
132
133/// Runtime contract for an extension-declared atomic operator.
134pub trait ExtensionAtomicOp: Send + Sync + Debug + 'static {
135    /// Stable id of this atomic operator.
136    fn id(&self) -> ExtensionAtomicOpId;
137    /// Human-readable name for display / debug.
138    fn display_name(&self) -> &'static str;
139}
140
141/// Runtime contract for an extension-declared ternary operator.
142pub trait ExtensionTernaryOp: Send + Sync + Debug + 'static {
143    /// Stable id of this ternary operator.
144    fn id(&self) -> ExtensionTernaryOpId;
145    /// Human-readable name for display / debug.
146    fn display_name(&self) -> &'static str;
147}
148
149/// Stable u32 id for an extension binary operator.
150///
151/// Identical discipline to [`ExtensionDataTypeId`]: stable across process
152/// runs, high bit set, generated by FNV-1a of the extension name.
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
154pub struct ExtensionBinOpId(pub u32);
155
156impl ExtensionBinOpId {
157    /// Reserved range mask (see [`ExtensionDataTypeId::EXTENSION_RANGE_MASK`]).
158    pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
159
160    /// Construct from a stable extension name.
161    #[must_use]
162    pub const fn from_name(name: &str) -> Self {
163        Self(fnv1a_with_high_bit(name))
164    }
165
166    /// Raw id.
167    #[must_use]
168    pub const fn as_u32(self) -> u32 {
169        self.0
170    }
171}
172
173/// Stable u32 id for an extension unary operator.
174#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
175pub struct ExtensionUnOpId(pub u32);
176
177impl ExtensionUnOpId {
178    /// Reserved range mask.
179    pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
180
181    /// Construct from a stable extension name.
182    #[must_use]
183    pub const fn from_name(name: &str) -> Self {
184        Self(fnv1a_with_high_bit(name))
185    }
186
187    /// Raw id.
188    #[must_use]
189    pub const fn as_u32(self) -> u32 {
190        self.0
191    }
192}
193
194/// Stable u32 id for an extension atomic operator.
195#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
196pub struct ExtensionAtomicOpId(pub u32);
197
198impl ExtensionAtomicOpId {
199    /// Reserved range mask.
200    pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
201
202    /// Construct from a stable extension name.
203    #[must_use]
204    pub const fn from_name(name: &str) -> Self {
205        Self(fnv1a_with_high_bit(name))
206    }
207
208    /// Raw id.
209    #[must_use]
210    pub const fn as_u32(self) -> u32 {
211        self.0
212    }
213}
214
215/// Stable u32 id for an extension ternary operator.
216#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
217pub struct ExtensionTernaryOpId(pub u32);
218
219impl ExtensionTernaryOpId {
220    /// Reserved range mask.
221    pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
222
223    /// Construct from a stable extension name.
224    #[must_use]
225    pub const fn from_name(name: &str) -> Self {
226        Self(fnv1a_with_high_bit(name))
227    }
228
229    /// Raw id.
230    #[must_use]
231    pub const fn as_u32(self) -> u32 {
232        self.0
233    }
234}
235
236/// Stable u32 id for an extension rule condition.
237#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
238pub struct ExtensionRuleConditionId(pub u32);
239
240impl ExtensionRuleConditionId {
241    /// Reserved range mask.
242    pub const EXTENSION_RANGE_MASK: u32 = 0x8000_0000;
243
244    /// Construct from a stable extension name.
245    #[must_use]
246    pub const fn from_name(name: &str) -> Self {
247        Self(fnv1a_with_high_bit(name))
248    }
249
250    /// Raw id.
251    #[must_use]
252    pub const fn as_u32(self) -> u32 {
253        self.0
254    }
255}
256
257/// FNV-1a 32-bit hash folded into the extension range (high bit set).
258///
259/// Shared helper backing every `ExtensionXxxId::from_name`. Kept private
260/// so callers don't construct raw ids that bypass the high-bit invariant.
261#[must_use]
262const fn fnv1a_with_high_bit(name: &str) -> u32 {
263    let mut hash: u32 = 0x811c_9dc5;
264    let bytes = name.as_bytes();
265    let mut i = 0;
266    while i < bytes.len() {
267        hash ^= bytes[i] as u32;
268        hash = hash.wrapping_mul(0x0100_0193);
269        i += 1;
270    }
271    hash | 0x8000_0000
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn id_from_name_is_deterministic() {
280        assert_eq!(
281            ExtensionDataTypeId::from_name("tensor.gather"),
282            ExtensionDataTypeId::from_name("tensor.gather"),
283        );
284    }
285
286    #[test]
287    fn id_from_different_names_differ() {
288        let a = ExtensionDataTypeId::from_name("tensor.gather");
289        let b = ExtensionDataTypeId::from_name("tensor.scatter");
290        assert_ne!(a, b);
291    }
292
293    #[test]
294    fn every_id_is_in_extension_range() {
295        let id = ExtensionDataTypeId::from_name("anything");
296        assert!(id.is_extension(), "{:#010x} missing high bit", id.as_u32());
297        assert!(id.as_u32() & ExtensionDataTypeId::EXTENSION_RANGE_MASK != 0);
298    }
299}