Skip to main content

simsym_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenTree};
3use quote::quote;
4use syn::Ident;
5
6fn crate_ident() -> Ident {
7    Ident::new("simsym", Span::call_site())
8}
9
10#[proc_macro]
11pub fn expr(input: TokenStream) -> TokenStream {
12    let tokens: Vec<TokenTree> = proc_macro2::TokenStream::from(input).into_iter().collect();
13    let mut pos = 0;
14    match parse_expr(&tokens, &mut pos) {
15        Ok(ts) if pos == tokens.len() => ts.into(),
16        Ok(_) => syn::Error::new(Span::call_site(), "trailing tokens in expr!")
17            .to_compile_error()
18            .into(),
19        Err(e) => e.to_compile_error().into(),
20    }
21}
22
23fn parse_expr(tokens: &[TokenTree], pos: &mut usize) -> syn::Result<proc_macro2::TokenStream> {
24    parse_additive(tokens, pos)
25}
26
27fn parse_additive(tokens: &[TokenTree], pos: &mut usize) -> syn::Result<proc_macro2::TokenStream> {
28    let root = crate_ident();
29    let mut left = parse_multiplicative(tokens, pos)?;
30    while let Some(op) = peek_op(tokens, *pos) {
31        if op != '+' && op != '-' {
32            break;
33        }
34        *pos += 1;
35        let right = parse_multiplicative(tokens, pos)?;
36        left = if op == '+' {
37            quote! { #root::expr::add(#left, #right) }
38        } else {
39            quote! { #root::expr::sub(#left, #right) }
40        };
41    }
42    Ok(left)
43}
44
45fn parse_multiplicative(tokens: &[TokenTree], pos: &mut usize) -> syn::Result<proc_macro2::TokenStream> {
46    let root = crate_ident();
47    let mut left = parse_power(tokens, pos)?;
48    while let Some(op) = peek_op(tokens, *pos) {
49        if op != '*' && op != '/' {
50            break;
51        }
52        *pos += 1;
53        let right = parse_power(tokens, pos)?;
54        left = if op == '*' {
55            quote! { #root::expr::mul(#left, #right) }
56        } else {
57            quote! { #root::expr::div(#left, #right) }
58        };
59    }
60    Ok(left)
61}
62
63fn parse_power(tokens: &[TokenTree], pos: &mut usize) -> syn::Result<proc_macro2::TokenStream> {
64    let root = crate_ident();
65    let mut left = parse_unary(tokens, pos)?;
66    if peek_op(tokens, *pos) == Some('^') {
67        *pos += 1;
68        let right = parse_power(tokens, pos)?;
69        left = quote! { #root::expr::pow(#left, #right) };
70    }
71    Ok(left)
72}
73
74fn parse_unary(tokens: &[TokenTree], pos: &mut usize) -> syn::Result<proc_macro2::TokenStream> {
75    let root = crate_ident();
76    if peek_op(tokens, *pos) == Some('-') {
77        *pos += 1;
78        let inner = parse_unary(tokens, pos)?;
79        return Ok(quote! { #root::expr::neg(#inner) });
80    }
81    parse_atom(tokens, pos)
82}
83
84fn parse_atom(tokens: &[TokenTree], pos: &mut usize) -> syn::Result<proc_macro2::TokenStream> {
85    let root = crate_ident();
86    let Some(tok) = tokens.get(*pos) else {
87        return Err(syn::Error::new(Span::call_site(), "unexpected end of expr!"));
88    };
89    match tok {
90        TokenTree::Group(g) if g.delimiter() == proc_macro2::Delimiter::Parenthesis => {
91            *pos += 1;
92            let inner: Vec<TokenTree> = g.stream().into_iter().collect();
93            let mut p = 0;
94            let e = parse_expr(&inner, &mut p)?;
95            if p != inner.len() {
96                return Err(syn::Error::new(g.span(), "trailing tokens in parentheses"));
97            }
98            Ok(e)
99        }
100        TokenTree::Ident(id) => {
101            let name = id.to_string();
102            if matches!(
103                name.as_str(),
104                "sin" | "cos" | "tan" | "cot" | "sec" | "csc"
105                    | "asin" | "acos" | "atan" | "acot" | "asec" | "acsc"
106                    | "sinh" | "cosh" | "tanh" | "coth" | "sech" | "csch"
107                    | "asinh" | "acosh" | "atanh" | "acoth" | "asech" | "acsch"
108                    | "exp" | "ln"
109            ) {
110                *pos += 1;
111                let args = parse_paren_args(tokens, pos)?;
112                let fname = syn::Ident::new(&name, id.span());
113                return Ok(quote! { #root::expr::#fname(#args) });
114            }
115            // `e^x` is exp(x), not the symbol e raised to x
116            if name == "e" && peek_op(tokens, *pos) == Some('^') {
117                *pos += 1; // ^
118                let exp = parse_power(tokens, pos)?;
119                return Ok(quote! { #root::expr::exp(#exp) });
120            }
121            *pos += 1;
122            Ok(quote! { #root::expr::var(#root::symbol(#name)) })
123        }
124        TokenTree::Literal(lit) => {
125            *pos += 1;
126            let s = lit.to_string();
127            if s.contains('.') {
128                let v: f64 = s.parse().map_err(|_| lit_err(lit))?;
129                let n = (v * 1_000_000.0).round() as i64;
130                Ok(quote! { #root::expr::const_(#root::rational(#n, 1_000_000i64)) })
131            } else {
132                let v: i64 = s.parse().map_err(|_| lit_err(lit))?;
133                Ok(quote! { #root::expr::const_(#root::rational(#v, 1i64)) })
134            }
135        }
136        _ => Err(syn::Error::new(tok.span(), "expected atom in expr!")),
137    }
138}
139
140fn parse_paren_args(tokens: &[TokenTree], pos: &mut usize) -> syn::Result<proc_macro2::TokenStream> {
141    let Some(TokenTree::Group(g)) = tokens.get(*pos) else {
142        return Err(syn::Error::new(Span::call_site(), "expected '('"));
143    };
144    if g.delimiter() != proc_macro2::Delimiter::Parenthesis {
145        return Err(syn::Error::new(g.span(), "expected '('"));
146    }
147    *pos += 1;
148    let inner: Vec<TokenTree> = g.stream().into_iter().collect();
149    let mut p = 0;
150    let e = parse_expr(&inner, &mut p)?;
151    if p != inner.len() {
152        return Err(syn::Error::new(g.span(), "trailing tokens in function call"));
153    }
154    Ok(e)
155}
156
157fn peek_op(tokens: &[TokenTree], pos: usize) -> Option<char> {
158    match tokens.get(pos)? {
159        TokenTree::Punct(p) if p.spacing() == proc_macro2::Spacing::Alone => {
160            Some(p.as_char())
161        }
162        _ => None,
163    }
164}
165
166fn lit_err(lit: &proc_macro2::Literal) -> syn::Error {
167    syn::Error::new(lit.span(), "invalid numeric literal")
168}