Skip to main content

scirs2_symbolic_macros/
lib.rs

1//! # scirs2-symbolic-macros
2//!
3//! Proc-macro DSL helpers for the `scirs2-symbolic` crate.
4//!
5//! ## Macros
6//!
7//! * [`eml_pattern!`] — construct a `Pattern` from a concise DSL expression.
8//! * [`eml_template!`] — same syntax, different label (marks the *right-hand side*
9//!   of a rewrite rule to aid readability).
10//!
11//! Both macros emit fully-qualified `scirs2_symbolic::cas::pattern` construction
12//! code so no `use` imports are required at the call site.
13//!
14//! ## Mini-DSL reference
15//!
16//! ```text
17//! ?0, ?1, ?2          → PatVar(0), PatVar(1), PatVar(2)
18//! var(0), var(1)      → PatGroundVar(0), PatGroundVar(1)
19//! const(f)            → PatConst(f)          (f may be a float or integer literal)
20//! int(n)              → PatConstInt(n)        (n is a u32 integer literal)
21//! add(A, B)           → PatOp2(BinaryKind::Add, A, B)
22//! sub(A, B)           → PatOp2(BinaryKind::Sub, A, B)
23//! mul(A, B)           → PatOp2(BinaryKind::Mul, A, B)
24//! div(A, B)           → PatOp2(BinaryKind::Div, A, B)
25//! pow(A, B)           → PatOp2(BinaryKind::Pow, A, B)
26//! neg(A)              → PatOp1(UnaryKind::Neg, A)
27//! sin(A) … tanh(A)    → PatOp1(UnaryKind::Sin …)
28//! exp(A), ln(A)       → PatOp1(UnaryKind::Exp / Ln)
29//! sqrt(A), abs(A)     → PatOp1(UnaryKind::Sqrt / Abs)
30//! arcsin … arctanh    → PatOp1(UnaryKind::Arcsin …)
31//! ```
32
33use proc_macro::TokenStream;
34use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
35use quote::quote;
36use syn::{
37    ext::IdentExt,
38    parse::{Parse, ParseStream, Result as SynResult},
39    parse_macro_input, LitFloat, LitInt, Token,
40};
41
42// ---------------------------------------------------------------------------
43// Module paths — single source of truth for the emitted fully-qualified names.
44// ---------------------------------------------------------------------------
45macro_rules! pat_path {
46    () => {
47        quote!(::scirs2_symbolic::cas::pattern)
48    };
49}
50
51// ---------------------------------------------------------------------------
52// PatternExpr — the parsed DSL node
53// ---------------------------------------------------------------------------
54
55/// A single DSL expression that emits `Pattern` construction code.
56struct PatternExpr(TokenStream2);
57
58impl Parse for PatternExpr {
59    fn parse(input: ParseStream<'_>) -> SynResult<Self> {
60        parse_pattern_expr(input).map(PatternExpr)
61    }
62}
63
64/// Core recursive parser.  Uses `IdentExt::parse_any` so that Rust keywords
65/// such as `const` are parsed as identifier-like tokens rather than causing
66/// a hard parse error.
67fn parse_pattern_expr(input: ParseStream<'_>) -> SynResult<TokenStream2> {
68    // Peek: `?` → wildcard capture variable
69    if input.peek(Token![?]) {
70        return parse_patvar(input);
71    }
72
73    // Everything else starts with an identifier (possibly a keyword such as `const`).
74    let name_ident = syn::Ident::parse_any(input)?;
75    let name = name_ident.to_string();
76    let span = name_ident.span();
77
78    match name.as_str() {
79        // ----------------------------------------------------------------
80        // const(f) → PatConst(f as f64)
81        // ----------------------------------------------------------------
82        "const" => {
83            let content;
84            syn::parenthesized!(content in input);
85            let ts = parse_const_arg(&content, span)?;
86            Ok(ts)
87        }
88
89        // ----------------------------------------------------------------
90        // int(n) → PatConstInt(n as u32)
91        // ----------------------------------------------------------------
92        "int" => {
93            let content;
94            syn::parenthesized!(content in input);
95            let lit: LitInt = content.parse()?;
96            let n: u32 = lit
97                .base10_parse()
98                .map_err(|e| syn::Error::new(lit.span(), format!("expected u32 for int(): {e}")))?;
99            let lit_u32 = Literal::u32_suffixed(n);
100            let p = pat_path!();
101            Ok(quote! { #p::Pattern::PatConstInt(#lit_u32) })
102        }
103
104        // ----------------------------------------------------------------
105        // var(n) → PatGroundVar(n as usize)
106        // ----------------------------------------------------------------
107        "var" => {
108            let content;
109            syn::parenthesized!(content in input);
110            let lit: LitInt = content.parse()?;
111            let n: usize = lit.base10_parse().map_err(|e| {
112                syn::Error::new(lit.span(), format!("expected usize for var(): {e}"))
113            })?;
114            let lit_usize = Literal::usize_suffixed(n);
115            let p = pat_path!();
116            Ok(quote! { #p::Pattern::PatGroundVar(#lit_usize) })
117        }
118
119        // ----------------------------------------------------------------
120        // Binary operators: add, sub, mul, div, pow
121        // ----------------------------------------------------------------
122        "add" | "sub" | "mul" | "div" | "pow" => {
123            let kind_ts = binary_kind_tokens(&name, span)?;
124            let content;
125            syn::parenthesized!(content in input);
126            let left = parse_pattern_expr(&content)?;
127            content.parse::<Token![,]>()?;
128            let right = parse_pattern_expr(&content)?;
129            let p = pat_path!();
130            Ok(quote! {
131                #p::Pattern::PatOp2(
132                    #kind_ts,
133                    ::std::boxed::Box::new(#left),
134                    ::std::boxed::Box::new(#right),
135                )
136            })
137        }
138
139        // ----------------------------------------------------------------
140        // Unary operators: neg, sin, cos, tan, exp, ln, sqrt, abs,
141        //                  sinh, cosh, tanh, arcsin, arccos, arctan,
142        //                  arcsinh, arccosh, arctanh
143        // ----------------------------------------------------------------
144        "neg" | "sin" | "cos" | "tan" | "exp" | "ln" | "sqrt" | "abs" | "sinh" | "cosh"
145        | "tanh" | "arcsin" | "arccos" | "arctan" | "arcsinh" | "arccosh" | "arctanh" => {
146            let kind_ts = unary_kind_tokens(&name, span)?;
147            let content;
148            syn::parenthesized!(content in input);
149            let child = parse_pattern_expr(&content)?;
150            let p = pat_path!();
151            Ok(quote! {
152                #p::Pattern::PatOp1(
153                    #kind_ts,
154                    ::std::boxed::Box::new(#child),
155                )
156            })
157        }
158
159        other => Err(syn::Error::new(
160            span,
161            format!(
162                "unknown pattern operator `{other}`; expected one of: \
163                 const, int, var, add, sub, mul, div, pow, neg, sin, cos, tan, exp, ln, \
164                 sqrt, abs, sinh, cosh, tanh, arcsin, arccos, arctan, arcsinh, arccosh, arctanh"
165            ),
166        )),
167    }
168}
169
170// ---------------------------------------------------------------------------
171// `?N` → PatVar(N)
172// ---------------------------------------------------------------------------
173
174fn parse_patvar(input: ParseStream<'_>) -> SynResult<TokenStream2> {
175    let q_span = input.span();
176    input.parse::<Token![?]>()?;
177    let lit: LitInt = input
178        .parse()
179        .map_err(|_| syn::Error::new(q_span, "expected an integer after `?`, e.g. `?0`, `?1`"))?;
180    let n: u32 = lit
181        .base10_parse()
182        .map_err(|e| syn::Error::new(lit.span(), format!("wildcard index must be u32: {e}")))?;
183    let lit_u32 = Literal::u32_suffixed(n);
184    let p = pat_path!();
185    Ok(quote! { #p::Pattern::PatVar(#lit_u32) })
186}
187
188// ---------------------------------------------------------------------------
189// `const(f)` — handles both float literals (2.0) and integer literals (0)
190// ---------------------------------------------------------------------------
191
192fn parse_const_arg(content: &syn::parse::ParseBuffer<'_>, span: Span) -> SynResult<TokenStream2> {
193    let p = pat_path!();
194
195    // Try float literal first.
196    if content.peek(LitFloat) {
197        let lit: LitFloat = content.parse()?;
198        let v: f64 = lit
199            .base10_parse()
200            .map_err(|e| syn::Error::new(lit.span(), format!("expected f64 float literal: {e}")))?;
201        let lit_f64 = Literal::f64_suffixed(v);
202        return Ok(quote! { #p::Pattern::PatConst(#lit_f64) });
203    }
204
205    // Handle negative sign before a literal.
206    if content.peek(Token![-]) {
207        content.parse::<Token![-]>()?;
208        if content.peek(LitFloat) {
209            let lit: LitFloat = content.parse()?;
210            let v: f64 = lit.base10_parse().map_err(|e| {
211                syn::Error::new(lit.span(), format!("expected f64 float literal: {e}"))
212            })?;
213            let neg_lit = Literal::f64_suffixed(-v);
214            return Ok(quote! { #p::Pattern::PatConst(#neg_lit) });
215        }
216        let lit: LitInt = content.parse()?;
217        let v: u64 = lit
218            .base10_parse()
219            .map_err(|e| syn::Error::new(lit.span(), format!("expected integer literal: {e}")))?;
220        let neg_lit = Literal::f64_suffixed(-(v as f64));
221        return Ok(quote! { #p::Pattern::PatConst(#neg_lit) });
222    }
223
224    // Integer literal — convert to f64.
225    if content.peek(LitInt) {
226        let lit: LitInt = content.parse()?;
227        let v: u64 = lit.base10_parse().map_err(|e| {
228            syn::Error::new(
229                lit.span(),
230                format!("expected integer literal for const(): {e}"),
231            )
232        })?;
233        let lit_f64 = Literal::f64_suffixed(v as f64);
234        return Ok(quote! { #p::Pattern::PatConst(#lit_f64) });
235    }
236
237    Err(syn::Error::new(
238        span,
239        "expected a numeric literal for const() (e.g. `const(0)`, `const(2.0)`)",
240    ))
241}
242
243// ---------------------------------------------------------------------------
244// Helpers that emit BinaryKind / UnaryKind tokens
245// ---------------------------------------------------------------------------
246
247fn binary_kind_tokens(name: &str, span: Span) -> SynResult<TokenStream2> {
248    let p = pat_path!();
249    match name {
250        "add" => Ok(quote! { #p::BinaryKind::Add }),
251        "sub" => Ok(quote! { #p::BinaryKind::Sub }),
252        "mul" => Ok(quote! { #p::BinaryKind::Mul }),
253        "div" => Ok(quote! { #p::BinaryKind::Div }),
254        "pow" => Ok(quote! { #p::BinaryKind::Pow }),
255        _ => Err(syn::Error::new(
256            span,
257            format!("unknown binary operator: {name}"),
258        )),
259    }
260}
261
262fn unary_kind_tokens(name: &str, span: Span) -> SynResult<TokenStream2> {
263    let p = pat_path!();
264    match name {
265        "neg" => Ok(quote! { #p::UnaryKind::Neg }),
266        "exp" => Ok(quote! { #p::UnaryKind::Exp }),
267        "ln" => Ok(quote! { #p::UnaryKind::Ln }),
268        "sin" => Ok(quote! { #p::UnaryKind::Sin }),
269        "cos" => Ok(quote! { #p::UnaryKind::Cos }),
270        "tan" => Ok(quote! { #p::UnaryKind::Tan }),
271        "sinh" => Ok(quote! { #p::UnaryKind::Sinh }),
272        "cosh" => Ok(quote! { #p::UnaryKind::Cosh }),
273        "tanh" => Ok(quote! { #p::UnaryKind::Tanh }),
274        "arcsin" => Ok(quote! { #p::UnaryKind::Arcsin }),
275        "arccos" => Ok(quote! { #p::UnaryKind::Arccos }),
276        "arctan" => Ok(quote! { #p::UnaryKind::Arctan }),
277        "arcsinh" => Ok(quote! { #p::UnaryKind::Arcsinh }),
278        "arccosh" => Ok(quote! { #p::UnaryKind::Arccosh }),
279        "arctanh" => Ok(quote! { #p::UnaryKind::Arctanh }),
280        "sqrt" => Ok(quote! { #p::UnaryKind::Sqrt }),
281        "abs" => Ok(quote! { #p::UnaryKind::Abs }),
282        _ => Err(syn::Error::new(
283            span,
284            format!("unknown unary operator: {name}"),
285        )),
286    }
287}
288
289// ---------------------------------------------------------------------------
290// Public proc-macros
291// ---------------------------------------------------------------------------
292
293/// Construct a `scirs2_symbolic::cas::pattern::Pattern` from the EML mini-DSL.
294///
295/// # Syntax
296///
297/// ```text
298/// eml_pattern!( <expr> )
299/// ```
300///
301/// where `<expr>` is one of:
302///
303/// | Token | Expansion |
304/// |-------|-----------|
305/// | `?N` | `Pattern::PatVar(N)` |
306/// | `var(N)` | `Pattern::PatGroundVar(N)` |
307/// | `const(f)` | `Pattern::PatConst(f)` |
308/// | `int(n)` | `Pattern::PatConstInt(n)` |
309/// | `add(A,B)` `sub(A,B)` `mul(A,B)` `div(A,B)` `pow(A,B)` | `Pattern::PatOp2(…, A, B)` |
310/// | `neg(A)` `sin(A)` `cos(A)` … | `Pattern::PatOp1(…, A)` |
311///
312/// # Example
313///
314/// ```rust,ignore
315/// use scirs2_symbolic::eml_pattern;
316///
317/// let pat = eml_pattern!(add(?0, const(0)));
318/// ```
319#[proc_macro]
320pub fn eml_pattern(input: TokenStream) -> TokenStream {
321    let expr = parse_macro_input!(input as PatternExpr);
322    expr.0.into()
323}
324
325/// Construct a `scirs2_symbolic::cas::pattern::Pattern` from the EML mini-DSL.
326///
327/// Identical to [`eml_pattern!`] — the different name labels the *right-hand side*
328/// (template / replacement) of a rewrite rule for readability.
329///
330/// # Example
331///
332/// ```rust,ignore
333/// use scirs2_symbolic::{eml_pattern, eml_template};
334///
335/// let lhs = eml_pattern!(add(?0, ?1));
336/// let rhs = eml_template!(add(?1, ?0));  // commutativity rewrite
337/// ```
338#[proc_macro]
339pub fn eml_template(input: TokenStream) -> TokenStream {
340    let expr = parse_macro_input!(input as PatternExpr);
341    expr.0.into()
342}