scirs2_symbolic_macros/lib.rs
1//! # scirs2-symbolic-macros
2//!
3//! Proc-macro DSL helpers for the `scirs2-symbolic` crate.
4//!
5//! ## Macros
6//!
7//! * [`eml_pattern!`] — construct a `Pattern` from a concise DSL expression.
8//! * [`eml_template!`] — same syntax, different label (marks the *right-hand side*
9//! of a rewrite rule to aid readability).
10//!
11//! Both macros emit fully-qualified `scirs2_symbolic::cas::pattern` construction
12//! code so no `use` imports are required at the call site.
13//!
14//! ## Mini-DSL reference
15//!
16//! ```text
17//! ?0, ?1, ?2 → PatVar(0), PatVar(1), PatVar(2)
18//! var(0), var(1) → PatGroundVar(0), PatGroundVar(1)
19//! const(f) → PatConst(f) (f may be a float or integer literal)
20//! int(n) → PatConstInt(n) (n is a u32 integer literal)
21//! add(A, B) → PatOp2(BinaryKind::Add, A, B)
22//! sub(A, B) → PatOp2(BinaryKind::Sub, A, B)
23//! mul(A, B) → PatOp2(BinaryKind::Mul, A, B)
24//! div(A, B) → PatOp2(BinaryKind::Div, A, B)
25//! pow(A, B) → PatOp2(BinaryKind::Pow, A, B)
26//! neg(A) → PatOp1(UnaryKind::Neg, A)
27//! sin(A) … tanh(A) → PatOp1(UnaryKind::Sin …)
28//! exp(A), ln(A) → PatOp1(UnaryKind::Exp / Ln)
29//! sqrt(A), abs(A) → PatOp1(UnaryKind::Sqrt / Abs)
30//! arcsin … arctanh → PatOp1(UnaryKind::Arcsin …)
31//! ```
32
33use proc_macro::TokenStream;
34use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
35use quote::quote;
36use syn::{
37 ext::IdentExt,
38 parse::{Parse, ParseStream, Result as SynResult},
39 parse_macro_input, LitFloat, LitInt, Token,
40};
41
42// ---------------------------------------------------------------------------
43// Module paths — single source of truth for the emitted fully-qualified names.
44// ---------------------------------------------------------------------------
45macro_rules! pat_path {
46 () => {
47 quote!(::scirs2_symbolic::cas::pattern)
48 };
49}
50
51// ---------------------------------------------------------------------------
52// PatternExpr — the parsed DSL node
53// ---------------------------------------------------------------------------
54
55/// A single DSL expression that emits `Pattern` construction code.
56struct PatternExpr(TokenStream2);
57
58impl Parse for PatternExpr {
59 fn parse(input: ParseStream<'_>) -> SynResult<Self> {
60 parse_pattern_expr(input).map(PatternExpr)
61 }
62}
63
64/// Core recursive parser. Uses `IdentExt::parse_any` so that Rust keywords
65/// such as `const` are parsed as identifier-like tokens rather than causing
66/// a hard parse error.
67fn parse_pattern_expr(input: ParseStream<'_>) -> SynResult<TokenStream2> {
68 // Peek: `?` → wildcard capture variable
69 if input.peek(Token![?]) {
70 return parse_patvar(input);
71 }
72
73 // Everything else starts with an identifier (possibly a keyword such as `const`).
74 let name_ident = syn::Ident::parse_any(input)?;
75 let name = name_ident.to_string();
76 let span = name_ident.span();
77
78 match name.as_str() {
79 // ----------------------------------------------------------------
80 // const(f) → PatConst(f as f64)
81 // ----------------------------------------------------------------
82 "const" => {
83 let content;
84 syn::parenthesized!(content in input);
85 let ts = parse_const_arg(&content, span)?;
86 Ok(ts)
87 }
88
89 // ----------------------------------------------------------------
90 // int(n) → PatConstInt(n as u32)
91 // ----------------------------------------------------------------
92 "int" => {
93 let content;
94 syn::parenthesized!(content in input);
95 let lit: LitInt = content.parse()?;
96 let n: u32 = lit
97 .base10_parse()
98 .map_err(|e| syn::Error::new(lit.span(), format!("expected u32 for int(): {e}")))?;
99 let lit_u32 = Literal::u32_suffixed(n);
100 let p = pat_path!();
101 Ok(quote! { #p::Pattern::PatConstInt(#lit_u32) })
102 }
103
104 // ----------------------------------------------------------------
105 // var(n) → PatGroundVar(n as usize)
106 // ----------------------------------------------------------------
107 "var" => {
108 let content;
109 syn::parenthesized!(content in input);
110 let lit: LitInt = content.parse()?;
111 let n: usize = lit.base10_parse().map_err(|e| {
112 syn::Error::new(lit.span(), format!("expected usize for var(): {e}"))
113 })?;
114 let lit_usize = Literal::usize_suffixed(n);
115 let p = pat_path!();
116 Ok(quote! { #p::Pattern::PatGroundVar(#lit_usize) })
117 }
118
119 // ----------------------------------------------------------------
120 // Binary operators: add, sub, mul, div, pow
121 // ----------------------------------------------------------------
122 "add" | "sub" | "mul" | "div" | "pow" => {
123 let kind_ts = binary_kind_tokens(&name, span)?;
124 let content;
125 syn::parenthesized!(content in input);
126 let left = parse_pattern_expr(&content)?;
127 content.parse::<Token![,]>()?;
128 let right = parse_pattern_expr(&content)?;
129 let p = pat_path!();
130 Ok(quote! {
131 #p::Pattern::PatOp2(
132 #kind_ts,
133 ::std::boxed::Box::new(#left),
134 ::std::boxed::Box::new(#right),
135 )
136 })
137 }
138
139 // ----------------------------------------------------------------
140 // Unary operators: neg, sin, cos, tan, exp, ln, sqrt, abs,
141 // sinh, cosh, tanh, arcsin, arccos, arctan,
142 // arcsinh, arccosh, arctanh
143 // ----------------------------------------------------------------
144 "neg" | "sin" | "cos" | "tan" | "exp" | "ln" | "sqrt" | "abs" | "sinh" | "cosh"
145 | "tanh" | "arcsin" | "arccos" | "arctan" | "arcsinh" | "arccosh" | "arctanh" => {
146 let kind_ts = unary_kind_tokens(&name, span)?;
147 let content;
148 syn::parenthesized!(content in input);
149 let child = parse_pattern_expr(&content)?;
150 let p = pat_path!();
151 Ok(quote! {
152 #p::Pattern::PatOp1(
153 #kind_ts,
154 ::std::boxed::Box::new(#child),
155 )
156 })
157 }
158
159 other => Err(syn::Error::new(
160 span,
161 format!(
162 "unknown pattern operator `{other}`; expected one of: \
163 const, int, var, add, sub, mul, div, pow, neg, sin, cos, tan, exp, ln, \
164 sqrt, abs, sinh, cosh, tanh, arcsin, arccos, arctan, arcsinh, arccosh, arctanh"
165 ),
166 )),
167 }
168}
169
170// ---------------------------------------------------------------------------
171// `?N` → PatVar(N)
172// ---------------------------------------------------------------------------
173
174fn parse_patvar(input: ParseStream<'_>) -> SynResult<TokenStream2> {
175 let q_span = input.span();
176 input.parse::<Token![?]>()?;
177 let lit: LitInt = input
178 .parse()
179 .map_err(|_| syn::Error::new(q_span, "expected an integer after `?`, e.g. `?0`, `?1`"))?;
180 let n: u32 = lit
181 .base10_parse()
182 .map_err(|e| syn::Error::new(lit.span(), format!("wildcard index must be u32: {e}")))?;
183 let lit_u32 = Literal::u32_suffixed(n);
184 let p = pat_path!();
185 Ok(quote! { #p::Pattern::PatVar(#lit_u32) })
186}
187
188// ---------------------------------------------------------------------------
189// `const(f)` — handles both float literals (2.0) and integer literals (0)
190// ---------------------------------------------------------------------------
191
192fn parse_const_arg(content: &syn::parse::ParseBuffer<'_>, span: Span) -> SynResult<TokenStream2> {
193 let p = pat_path!();
194
195 // Try float literal first.
196 if content.peek(LitFloat) {
197 let lit: LitFloat = content.parse()?;
198 let v: f64 = lit
199 .base10_parse()
200 .map_err(|e| syn::Error::new(lit.span(), format!("expected f64 float literal: {e}")))?;
201 let lit_f64 = Literal::f64_suffixed(v);
202 return Ok(quote! { #p::Pattern::PatConst(#lit_f64) });
203 }
204
205 // Handle negative sign before a literal.
206 if content.peek(Token![-]) {
207 content.parse::<Token![-]>()?;
208 if content.peek(LitFloat) {
209 let lit: LitFloat = content.parse()?;
210 let v: f64 = lit.base10_parse().map_err(|e| {
211 syn::Error::new(lit.span(), format!("expected f64 float literal: {e}"))
212 })?;
213 let neg_lit = Literal::f64_suffixed(-v);
214 return Ok(quote! { #p::Pattern::PatConst(#neg_lit) });
215 }
216 let lit: LitInt = content.parse()?;
217 let v: u64 = lit
218 .base10_parse()
219 .map_err(|e| syn::Error::new(lit.span(), format!("expected integer literal: {e}")))?;
220 let neg_lit = Literal::f64_suffixed(-(v as f64));
221 return Ok(quote! { #p::Pattern::PatConst(#neg_lit) });
222 }
223
224 // Integer literal — convert to f64.
225 if content.peek(LitInt) {
226 let lit: LitInt = content.parse()?;
227 let v: u64 = lit.base10_parse().map_err(|e| {
228 syn::Error::new(
229 lit.span(),
230 format!("expected integer literal for const(): {e}"),
231 )
232 })?;
233 let lit_f64 = Literal::f64_suffixed(v as f64);
234 return Ok(quote! { #p::Pattern::PatConst(#lit_f64) });
235 }
236
237 Err(syn::Error::new(
238 span,
239 "expected a numeric literal for const() (e.g. `const(0)`, `const(2.0)`)",
240 ))
241}
242
243// ---------------------------------------------------------------------------
244// Helpers that emit BinaryKind / UnaryKind tokens
245// ---------------------------------------------------------------------------
246
247fn binary_kind_tokens(name: &str, span: Span) -> SynResult<TokenStream2> {
248 let p = pat_path!();
249 match name {
250 "add" => Ok(quote! { #p::BinaryKind::Add }),
251 "sub" => Ok(quote! { #p::BinaryKind::Sub }),
252 "mul" => Ok(quote! { #p::BinaryKind::Mul }),
253 "div" => Ok(quote! { #p::BinaryKind::Div }),
254 "pow" => Ok(quote! { #p::BinaryKind::Pow }),
255 _ => Err(syn::Error::new(
256 span,
257 format!("unknown binary operator: {name}"),
258 )),
259 }
260}
261
262fn unary_kind_tokens(name: &str, span: Span) -> SynResult<TokenStream2> {
263 let p = pat_path!();
264 match name {
265 "neg" => Ok(quote! { #p::UnaryKind::Neg }),
266 "exp" => Ok(quote! { #p::UnaryKind::Exp }),
267 "ln" => Ok(quote! { #p::UnaryKind::Ln }),
268 "sin" => Ok(quote! { #p::UnaryKind::Sin }),
269 "cos" => Ok(quote! { #p::UnaryKind::Cos }),
270 "tan" => Ok(quote! { #p::UnaryKind::Tan }),
271 "sinh" => Ok(quote! { #p::UnaryKind::Sinh }),
272 "cosh" => Ok(quote! { #p::UnaryKind::Cosh }),
273 "tanh" => Ok(quote! { #p::UnaryKind::Tanh }),
274 "arcsin" => Ok(quote! { #p::UnaryKind::Arcsin }),
275 "arccos" => Ok(quote! { #p::UnaryKind::Arccos }),
276 "arctan" => Ok(quote! { #p::UnaryKind::Arctan }),
277 "arcsinh" => Ok(quote! { #p::UnaryKind::Arcsinh }),
278 "arccosh" => Ok(quote! { #p::UnaryKind::Arccosh }),
279 "arctanh" => Ok(quote! { #p::UnaryKind::Arctanh }),
280 "sqrt" => Ok(quote! { #p::UnaryKind::Sqrt }),
281 "abs" => Ok(quote! { #p::UnaryKind::Abs }),
282 _ => Err(syn::Error::new(
283 span,
284 format!("unknown unary operator: {name}"),
285 )),
286 }
287}
288
289// ---------------------------------------------------------------------------
290// Public proc-macros
291// ---------------------------------------------------------------------------
292
293/// Construct a `scirs2_symbolic::cas::pattern::Pattern` from the EML mini-DSL.
294///
295/// # Syntax
296///
297/// ```text
298/// eml_pattern!( <expr> )
299/// ```
300///
301/// where `<expr>` is one of:
302///
303/// | Token | Expansion |
304/// |-------|-----------|
305/// | `?N` | `Pattern::PatVar(N)` |
306/// | `var(N)` | `Pattern::PatGroundVar(N)` |
307/// | `const(f)` | `Pattern::PatConst(f)` |
308/// | `int(n)` | `Pattern::PatConstInt(n)` |
309/// | `add(A,B)` `sub(A,B)` `mul(A,B)` `div(A,B)` `pow(A,B)` | `Pattern::PatOp2(…, A, B)` |
310/// | `neg(A)` `sin(A)` `cos(A)` … | `Pattern::PatOp1(…, A)` |
311///
312/// # Example
313///
314/// ```rust,ignore
315/// use scirs2_symbolic::eml_pattern;
316///
317/// let pat = eml_pattern!(add(?0, const(0)));
318/// ```
319#[proc_macro]
320pub fn eml_pattern(input: TokenStream) -> TokenStream {
321 let expr = parse_macro_input!(input as PatternExpr);
322 expr.0.into()
323}
324
325/// Construct a `scirs2_symbolic::cas::pattern::Pattern` from the EML mini-DSL.
326///
327/// Identical to [`eml_pattern!`] — the different name labels the *right-hand side*
328/// (template / replacement) of a rewrite rule for readability.
329///
330/// # Example
331///
332/// ```rust,ignore
333/// use scirs2_symbolic::{eml_pattern, eml_template};
334///
335/// let lhs = eml_pattern!(add(?0, ?1));
336/// let rhs = eml_template!(add(?1, ?0)); // commutativity rewrite
337/// ```
338#[proc_macro]
339pub fn eml_template(input: TokenStream) -> TokenStream {
340 let expr = parse_macro_input!(input as PatternExpr);
341 expr.0.into()
342}