vyre_foundation/dispatch/
dialect_lookup.rs1use crate::ir_inner::model::program::Program;
14use lasso::ThreadedRodeo;
15use std::sync::{Arc, OnceLock};
16use vyre_spec::{AlgebraicLaw, CpuFn};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct InternedOpId(pub u32);
21
22fn get_interner() -> &'static ThreadedRodeo {
23 static INTERNER: OnceLock<ThreadedRodeo> = OnceLock::new();
24 INTERNER.get_or_init(ThreadedRodeo::new)
25}
26
27#[must_use]
29pub fn intern_string(s: &str) -> InternedOpId {
30 let interner = get_interner();
31 let key = interner.get_or_intern(s);
32 InternedOpId(key.into_inner().get())
33}
34
35pub type ReferenceKind = CpuFn;
37
38#[derive(Default, Debug, Clone)]
40pub struct LoweringCtx<'a> {
41 pub unused: std::marker::PhantomData<&'a ()>,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct TextModule {
48 pub asm: String,
50 pub version: u32,
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct NativeModule {
57 pub ast: Vec<u8>,
59 pub entry: String,
61}
62
63pub type PrimaryTextBuilder = fn(&LoweringCtx<'_>) -> Result<(), String>;
65pub type PrimaryBinaryBuilder = fn(&LoweringCtx<'_>) -> Vec<u32>;
67pub type SecondaryTextBuilder = fn(&LoweringCtx<'_>) -> TextModule;
69pub type NativeModuleBuilder = fn(&LoweringCtx<'_>) -> NativeModule;
71pub type ExtensionLoweringFn =
84 fn(&LoweringCtx<'_>) -> Result<std::vec::Vec<u8>, std::string::String>;
85
86#[derive(Clone)]
99pub struct LoweringTable {
100 pub cpu_ref: ReferenceKind,
102 pub primary_text: Option<PrimaryTextBuilder>,
104 pub primary_binary: Option<PrimaryBinaryBuilder>,
106 pub secondary_text: Option<SecondaryTextBuilder>,
108 pub native_module: Option<NativeModuleBuilder>,
110 pub extensions: rustc_hash::FxHashMap<&'static str, ExtensionLoweringFn>,
116}
117
118impl Default for LoweringTable {
119 fn default() -> Self {
120 Self::empty()
121 }
122}
123
124impl LoweringTable {
125 #[must_use]
127 pub fn new(cpu_ref: ReferenceKind) -> Self {
128 Self {
129 cpu_ref,
130 primary_text: None,
131 primary_binary: None,
132 secondary_text: None,
133 native_module: None,
134 extensions: rustc_hash::FxHashMap::default(),
135 }
136 }
137
138 #[must_use]
140 pub fn empty() -> Self {
141 Self {
142 cpu_ref: crate::cpu_op::structured_intrinsic_cpu,
143 primary_text: None,
144 primary_binary: None,
145 secondary_text: None,
146 native_module: None,
147 extensions: rustc_hash::FxHashMap::default(),
148 }
149 }
150
151 #[must_use]
155 pub fn with_extension(
156 mut self,
157 backend_id: &'static str,
158 builder: ExtensionLoweringFn,
159 ) -> Self {
160 self.extensions.insert(backend_id, builder);
161 self
162 }
163
164 #[must_use]
166 pub fn extension(&self, backend_id: &str) -> Option<ExtensionLoweringFn> {
167 self.extensions.get(backend_id).copied()
168 }
169}
170
171impl std::fmt::Debug for LoweringTable {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("LoweringTable")
174 .field("cpu_ref", &"<fn>")
175 .field("primary_text", &self.primary_text.map(|_| "<fn>"))
176 .field("primary_binary", &self.primary_binary.map(|_| "<fn>"))
177 .field("secondary_text", &self.secondary_text.map(|_| "<fn>"))
178 .field("native_module", &self.native_module.map(|_| "<fn>"))
179 .field(
180 "extensions",
181 &self
182 .extensions
183 .keys()
184 .copied()
185 .collect::<std::vec::Vec<_>>(),
186 )
187 .finish()
188 }
189}
190
191#[derive(Debug, Clone, PartialEq, Eq)]
193#[non_exhaustive]
194pub enum AttrType {
195 U32,
197 I32,
199 F32,
201 Bool,
203 Bytes,
205 String,
207 Enum(&'static [&'static str]),
209 Unknown,
211}
212
213#[derive(Debug, Clone, PartialEq, Eq)]
215pub struct AttrSchema {
216 pub name: &'static str,
218 pub ty: AttrType,
220 pub default: Option<&'static str>,
222}
223
224#[derive(Debug, Clone, PartialEq, Eq)]
226pub struct TypedParam {
227 pub name: &'static str,
229 pub ty: &'static str,
231}
232
233#[derive(Debug, Clone, PartialEq, Eq)]
235pub struct Signature {
236 pub inputs: &'static [TypedParam],
238 pub outputs: &'static [TypedParam],
240 pub attrs: &'static [AttrSchema],
242 pub bytes_extraction: bool,
244}
245
246impl Signature {
247 #[must_use]
249 pub const fn bytes_extractor(
250 inputs: &'static [TypedParam],
251 outputs: &'static [TypedParam],
252 attrs: &'static [AttrSchema],
253 ) -> Self {
254 Self {
255 inputs,
256 outputs,
257 attrs,
258 bytes_extraction: true,
259 }
260 }
261}
262
263#[derive(Debug, Clone, Copy, PartialEq, Eq)]
265pub enum Category {
266 Composite,
268 Extension,
270 Intrinsic,
272}
273
274#[derive(Debug, Clone)]
276pub struct OpDef {
277 pub id: &'static str,
279 pub dialect: &'static str,
281 pub category: Category,
283 pub signature: Signature,
285 pub lowerings: LoweringTable,
287 pub laws: &'static [AlgebraicLaw],
289 pub compose: Option<fn() -> Program>,
291}
292
293impl OpDef {
294 #[must_use]
296 pub const fn id(&self) -> &'static str {
297 self.id
298 }
299
300 #[must_use]
302 pub fn program(&self) -> Option<Program> {
303 self.compose
304 .map(|compose| compose().with_entry_op_id(self.id))
305 }
306}
307
308impl Default for OpDef {
309 fn default() -> Self {
310 Self {
311 id: "",
312 dialect: "",
313 category: Category::Intrinsic,
314 signature: Signature {
315 inputs: &[],
316 outputs: &[],
317 attrs: &[],
318 bytes_extraction: false,
319 },
320 lowerings: LoweringTable::empty(),
321 laws: &[],
322 compose: None,
323 }
324 }
325}
326
327#[doc(hidden)]
328pub mod private {
329 pub trait Sealed {}
330}
331
332pub trait DialectLookup: private::Sealed + Send + Sync {
334 fn provider_id(&self) -> &'static str;
342
343 fn intern_op(&self, name: &str) -> InternedOpId;
345
346 fn lookup(&self, id: InternedOpId) -> Option<&'static OpDef>;
348}
349
350static DIALECT_LOOKUP: OnceLock<Arc<dyn DialectLookup>> = OnceLock::new();
351
352pub fn install_dialect_lookup(lookup: Arc<dyn DialectLookup>) -> Result<(), String> {
369 match DIALECT_LOOKUP.get() {
370 Some(existing) => {
371 let existing_id = existing.provider_id();
372 let incoming_id = lookup.provider_id();
373 ensure_same_provider(existing_id, incoming_id)?;
374 }
375 None => {
376 if let Err(lookup) = DIALECT_LOOKUP.set(lookup) {
377 let Some(existing) = DIALECT_LOOKUP.get() else {
381 return Err(
382 "dialect lookup install lost the value after OnceLock::set failed. Fix: report this impossible OnceLock state."
383 .to_string(),
384 );
385 };
386 let existing_id = existing.provider_id();
387 let incoming_id = lookup.provider_id();
388 ensure_same_provider(existing_id, incoming_id)?;
389 }
390 }
391 }
392 Ok(())
393}
394
395fn ensure_same_provider(existing_id: &str, incoming_id: &str) -> Result<(), String> {
396 if existing_id == incoming_id {
397 Ok(())
398 } else {
399 Err(format!(
400 "dialect lookup already installed by provider `{existing_id}`; second installer `{incoming_id}` reports a different id. Fix: pick one provider for the process or reuse the first provider's id. Silent replacement is refused because two divergent lookups would mis-resolve op ids at runtime."
401 ))
402 }
403}
404
405#[must_use]
407pub fn dialect_lookup() -> Option<&'static dyn DialectLookup> {
408 DIALECT_LOOKUP.get().map(Arc::as_ref)
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn intern_string_is_deterministic() {
417 let a = intern_string("test::op::add");
418 let b = intern_string("test::op::add");
419 assert_eq!(a, b);
420 }
421
422 #[test]
423 fn intern_string_distinct_for_different_ops() {
424 let a = intern_string("test::op::add");
425 let b = intern_string("test::op::mul");
426 assert_ne!(a, b);
427 }
428
429 #[test]
430 fn lowering_table_empty_has_no_native_builders() {
431 let table = LoweringTable::empty();
432 assert!(table.primary_text.is_none());
433 assert!(table.primary_binary.is_none());
434 assert!(table.secondary_text.is_none());
435 assert!(table.native_module.is_none());
436 assert!(table.extensions.is_empty());
437 }
438
439 #[test]
440 fn lowering_table_extension_lookup() {
441 fn dummy_builder(_: &LoweringCtx<'_>) -> Result<Vec<u8>, String> {
442 Ok(vec![1, 2, 3])
443 }
444 let table = LoweringTable::empty().with_extension("my-extension", dummy_builder);
445 assert!(table.extension("my-extension").is_some());
446 assert!(table.extension("nonexistent").is_none());
447 }
448
449 #[test]
450 fn opdef_default_has_empty_id() {
451 let def = OpDef::default();
452 assert_eq!(def.id(), "");
453 assert!(def.program().is_none());
454 }
455
456 #[test]
457 fn signature_bytes_extractor_sets_flag() {
458 let sig = Signature::bytes_extractor(&[], &[], &[]);
459 assert!(sig.bytes_extraction);
460 }
461
462 #[test]
463 fn secondary_text_module_equality() {
464 let a = TextModule {
465 asm: ".version 7.0".into(),
466 version: 70,
467 };
468 let b = TextModule {
469 asm: ".version 7.0".into(),
470 version: 70,
471 };
472 assert_eq!(a, b);
473 }
474
475 #[test]
476 fn native_module_module_equality() {
477 let a = NativeModule {
478 ast: vec![1, 2, 3],
479 entry: "main".into(),
480 };
481 let b = NativeModule {
482 ast: vec![1, 2, 3],
483 entry: "main".into(),
484 };
485 assert_eq!(a, b);
486 }
487
488 #[test]
489 fn category_debug() {
490 assert_eq!(format!("{:?}", Category::Composite), "Composite");
491 assert_eq!(format!("{:?}", Category::Extension), "Extension");
492 assert_eq!(format!("{:?}", Category::Intrinsic), "Intrinsic");
493 }
494}