vyre_macros/lib.rs
1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3//! Procedural macros for the [`vyre`](https://docs.rs/vyre) GPU compute IR
4//! compiler.
5//!
6//! This crate is compile-time only. Downstream users import from
7//! `vyre::optimizer::vyre_pass` rather than depending on this crate directly.
8//!
9//! The single macro is [`macro@vyre_pass`] — see that item for the full usage
10//! contract, argument shape, and a worked example. A high-level narrative
11//! lives in the crate [README](https://github.com/).
12
13mod ast_registry;
14mod define_op;
15
16use proc_macro::TokenStream;
17use quote::quote;
18use syn::parse::{Parse, ParseStream};
19use syn::spanned::Spanned;
20use syn::{
21 parse_macro_input, Attribute, Data, DeriveInput, ExprArray, Fields, ItemStruct, LitStr, Meta,
22 Token,
23};
24
25/// Function-like `define_op!` — single-site op registration via inventory.
26///
27/// See [`define_op`](define_op/index.html) for the full argument contract.
28#[proc_macro]
29pub fn define_op(item: TokenStream) -> TokenStream {
30 define_op::define_op_impl(item)
31}
32
33/// Generates the declarative IR AST core (Expr and Node enums)
34/// plus serialization and visitor traits.
35#[proc_macro]
36pub fn vyre_ast_registry(item: TokenStream) -> TokenStream {
37 ast_registry::vyre_ast_registry_impl(item)
38}
39
40/// A generic marker attribute used exclusively to instruct `vyre_ast_registry!`
41/// to skip generating a builder method for a specific struct field.
42#[proc_macro_attribute]
43pub fn skip_builder(_attr: TokenStream, item: TokenStream) -> TokenStream {
44 item
45}
46
47struct PassArgs {
48 name: LitStr,
49 requires: Vec<LitStr>,
50 invalidates: Vec<LitStr>,
51}
52
53impl Parse for PassArgs {
54 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
55 let mut name = None;
56 let mut requires = Vec::new();
57 let mut invalidates = Vec::new();
58
59 while !input.is_empty() {
60 let key: syn::Ident = input.parse()?;
61 input.parse::<Token![=]>()?;
62 match key.to_string().as_str() {
63 "name" => name = Some(input.parse()?),
64 "requires" => requires = parse_string_array(input)?,
65 "invalidates" => invalidates = parse_string_array(input)?,
66 _ => {
67 return Err(syn::Error::new(
68 key.span(),
69 "unsupported vyre_pass argument. Fix: use name, requires, or invalidates.",
70 ));
71 }
72 }
73 if input.peek(Token![,]) {
74 input.parse::<Token![,]>()?;
75 }
76 }
77
78 Ok(Self {
79 name: name.ok_or_else(|| input.error("missing pass name. Fix: add name = \"...\"."))?,
80 requires,
81 invalidates,
82 })
83 }
84}
85
86fn parse_string_array(input: ParseStream<'_>) -> syn::Result<Vec<LitStr>> {
87 let array: ExprArray = input.parse()?;
88 array
89 .elems
90 .into_iter()
91 .map(|expr| match expr {
92 syn::Expr::Lit(lit) => match lit.lit {
93 syn::Lit::Str(value) => Ok(value),
94 other => Err(syn::Error::new_spanned(
95 other,
96 "pass metadata arrays accept only string literals. Fix: use [\"analysis_name\"].",
97 )),
98 },
99 other => Err(syn::Error::new_spanned(
100 other,
101 "pass metadata arrays accept only string literals. Fix: use [\"analysis_name\"].",
102 )),
103 })
104 .collect()
105}
106
107/// Register a unit struct as a `vyre::optimizer::ProgramPass`.
108///
109/// Expands to (a) a full `ProgramPass` trait impl that forwards to your inherent
110/// `analyze` / `transform` / `fingerprint` methods and (b) an
111/// `inventory::submit!` that adds the pass to the global registry so
112/// `vyre::optimize()` picks it up automatically.
113///
114/// # Arguments
115///
116/// | Argument | Type | Meaning |
117/// |----------------|-------------|---------------------------------------------------------------------|
118/// | `name` | string lit | Stable pass name used in diagnostics / ordering. |
119/// | `requires` | `[&str]` | Pass names that must fire before this one. |
120/// | `invalidates` | `[&str]` | Analyses invalidated when this pass rewrites the program. |
121///
122/// # Required inherent methods on the annotated type
123///
124/// ```ignore
125/// fn analyze(program: &Program) -> PassAnalysis;
126/// fn transform(program: Program) -> PassResult;
127/// fn fingerprint(program: &Program) -> u64;
128/// ```
129///
130/// # Example
131///
132/// ```ignore
133/// use vyre::optimizer::{vyre_pass, PassAnalysis, PassResult, fingerprint_program};
134/// use vyre::ir::Program;
135///
136/// #[vyre_pass(name = "fold_zero_add", requires = [], invalidates = [])]
137/// pub struct FoldZeroAdd;
138///
139/// impl FoldZeroAdd {
140/// fn analyze(_program: &Program) -> PassAnalysis { PassAnalysis::RUN }
141/// fn transform(program: Program) -> PassResult {
142/// // ... real rewrite ...
143/// PassResult::from_programs(&program.clone(), program)
144/// }
145/// fn fingerprint(program: &Program) -> u64 { fingerprint_program(program) }
146/// }
147/// ```
148///
149/// After expansion, `vyre::optimize(p)` will pick up `FoldZeroAdd` through
150/// the `inventory::collect!(ProgramPassRegistration)` entry emitted by the macro.
151/// No manual registration needed.
152#[proc_macro_attribute]
153pub fn vyre_pass(args: TokenStream, item: TokenStream) -> TokenStream {
154 let args = parse_macro_input!(args as PassArgs);
155 let item = parse_macro_input!(item as ItemStruct);
156 let ident = &item.ident;
157 let name = args.name;
158 let requires = args.requires;
159 let invalidates = args.invalidates;
160
161 quote! {
162 #item
163
164 impl ::vyre::optimizer::private::Sealed for #ident {}
165
166 impl ::vyre::optimizer::ProgramPass for #ident {
167 #[inline]
168 fn metadata(&self) -> ::vyre::optimizer::PassMetadata {
169 ::vyre::optimizer::PassMetadata {
170 name: #name,
171 requires: &[#(#requires),*],
172 invalidates: &[#(#invalidates),*],
173 }
174 }
175
176 #[inline]
177 fn analyze(&self, program: &::vyre::ir::Program) -> ::vyre::optimizer::PassAnalysis {
178 Self::analyze(program)
179 }
180
181 #[inline]
182 fn transform(
183 &self,
184 program: ::vyre::ir::Program,
185 ) -> ::vyre::optimizer::PassResult {
186 Self::transform(program)
187 }
188
189 #[inline]
190 fn fingerprint(&self, program: &::vyre::ir::Program) -> u64 {
191 Self::fingerprint(program)
192 }
193 }
194
195 ::inventory::submit! {
196 ::vyre::optimizer::ProgramPassRegistration {
197 metadata: ::vyre::optimizer::PassMetadata {
198 name: #name,
199 requires: &[#(#requires),*],
200 invalidates: &[#(#invalidates),*],
201 },
202 factory: || ::std::boxed::Box::new(#ident),
203 }
204 }
205 }
206 .into()
207}
208
209/// Derive `vyre::AlgebraicLawProvider` from a `#[vyre(laws = [...])]` attribute.
210///
211/// Attach the derive to a unit struct (or any struct) that represents an op
212/// type. List its algebraic laws in the attribute; the macro emits the trait
213/// impl plus a `const LAWS: &[AlgebraicLaw]` associated item.
214///
215/// # Example
216///
217/// ```ignore
218/// use vyre_macros::AlgebraicLaws;
219///
220/// #[derive(AlgebraicLaws)]
221/// #[vyre(laws = [Commutative, Associative, "Identity { element: 0 }"])]
222/// pub struct Xor;
223/// ```
224///
225/// Expands to:
226///
227/// ```ignore
228/// impl Xor {
229/// pub const LAWS: &'static [::vyre::ops::AlgebraicLaw] = &[
230/// ::vyre::ops::AlgebraicLaw::Commutative,
231/// ::vyre::ops::AlgebraicLaw::Associative,
232/// ::vyre::ops::AlgebraicLaw::Identity { element: 0 },
233/// ];
234/// }
235/// impl ::vyre::ops::AlgebraicLawProvider for Xor {
236/// fn laws() -> &'static [::vyre::ops::AlgebraicLaw] { Self::LAWS }
237/// }
238/// ```
239#[proc_macro_derive(AlgebraicLaws, attributes(vyre))]
240pub fn derive_algebraic_laws(item: TokenStream) -> TokenStream {
241 let input = parse_macro_input!(item as DeriveInput);
242 let ident = &input.ident;
243 let laws = match extract_laws_attribute(&input.attrs) {
244 Ok(v) => v,
245 Err(e) => return e.to_compile_error().into(),
246 };
247
248 // Parse each law string as an AlgebraicLaw variant expression.
249 let law_exprs = laws.iter().map(|lit| {
250 let src = lit.value();
251 let trimmed = src.trim();
252 let path: syn::Expr = match syn::parse_str(&format!("::vyre::ops::AlgebraicLaw::{trimmed}"))
253 {
254 Ok(e) => e,
255 Err(err) => {
256 return syn::Error::new_spanned(
257 lit,
258 format!("failed to parse AlgebraicLaw variant `{trimmed}`: {err}"),
259 )
260 .to_compile_error();
261 }
262 };
263 quote! { #path }
264 });
265
266 // ensure the input type is a struct/enum we can attach impls to
267 match &input.data {
268 Data::Struct(_) | Data::Enum(_) => {}
269 Data::Union(_) => {
270 return syn::Error::new_spanned(
271 ident,
272 "#[derive(AlgebraicLaws)] does not support unions.",
273 )
274 .to_compile_error()
275 .into();
276 }
277 }
278
279 let law_exprs_vec: Vec<_> = law_exprs.collect();
280
281 quote! {
282 impl #ident {
283 /// Algebraic laws declared on this op type.
284 pub const LAWS: &'static [::vyre::ops::AlgebraicLaw] = &[
285 #(#law_exprs_vec),*
286 ];
287 }
288
289 impl ::vyre::ops::AlgebraicLawProvider for #ident {
290 fn laws() -> &'static [::vyre::ops::AlgebraicLaw] {
291 Self::LAWS
292 }
293 }
294 }
295 .into()
296}
297
298fn extract_laws_attribute(attrs: &[Attribute]) -> syn::Result<Vec<LitStr>> {
299 for attr in attrs {
300 if !attr.path().is_ident("vyre") {
301 continue;
302 }
303 let mut laws: Option<Vec<LitStr>> = None;
304 attr.parse_nested_meta(|meta| {
305 if meta.path.is_ident("laws") {
306 let value = meta.value()?;
307 // Accept both [Commutative, Identity{element:0}] bracketed
308 // identifier lists and [ "Commutative", "Identity{element:0}" ]
309 // string-literal arrays.
310 let lookahead = value.lookahead1();
311 if lookahead.peek(syn::token::Bracket) {
312 let content;
313 syn::bracketed!(content in value);
314 let mut collected = Vec::new();
315 while !content.is_empty() {
316 if content.peek(LitStr) {
317 let lit: LitStr = content.parse()?;
318 collected.push(lit);
319 } else {
320 // parse as raw token stream up to the next comma
321 let expr: syn::Expr = content.parse()?;
322 let rendered = quote! { #expr }.to_string();
323 collected.push(LitStr::new(&rendered, expr.span()));
324 }
325 if content.peek(Token![,]) {
326 content.parse::<Token![,]>()?;
327 }
328 }
329 laws = Some(collected);
330 Ok(())
331 } else {
332 Err(meta.error("expected `laws = [..]`"))
333 }
334 } else {
335 Err(meta.error("unknown vyre() argument; expected `laws = [..]`"))
336 }
337 })?;
338 if let Some(l) = laws {
339 return Ok(l);
340 }
341 }
342 Ok(Vec::new())
343}
344
345// Keep unused imports alive (silence the compiler's unused warnings; `Fields`
346// and `Meta` are referenced through docs/future use, and removing them here
347// risks churn during the open-IR migration).
348#[allow(dead_code)]
349fn _keep_imports_alive(_: Fields, _: Meta) {}