Skip to main content

symdiff/
lib.rs

1//! Compile-time symbolic differentiation via a proc-macro attribute.
2//!
3//! The [`gradient`] macro differentiates a Rust function analytically at
4//! compile time and emits a companion `{fn}_gradient` that returns the
5//! closed-form gradient as a fixed-size array. No numerical approximation,
6//! no runtime overhead beyond the arithmetic itself.
7//!
8//! # How it works
9//!
10//! At compile time the macro:
11//!
12//! 1. Parses the function body into a symbolic expression tree.
13//! 2. Differentiates each component analytically using standard calculus rules.
14//! 3. Simplifies the result (constant folding, identity laws, CSE, etc.).
15//! 4. Applies greedy commutative/associative reordering to expose further
16//!    common sub-expressions.
17//! 5. Emits the gradient function with shared sub-expressions hoisted into
18//!    `let` bindings.
19//!
20//! # Example
21//!
22//! ```rust,ignore
23//! use symdiff::gradient;
24//!
25//! #[gradient(dim = 2)]
26//! fn rosenbrock(x: &[f64]) -> f64 {
27//!     let a = 1.0 - x[0];
28//!     let b = x[1] - x[0].powi(2);
29//!     a.powi(2) + 100.0 * b.powi(2)
30//! }
31//!
32//! // Generates rosenbrock unchanged, plus:
33//! // fn rosenbrock_gradient(x: &[f64]) -> [f64; 2] {
34//! //     let tmp0 = x[0].powi(1);   // shared sub-expressions hoisted
35//! //     ...
36//! //     [∂f/∂x[0], ∂f/∂x[1]]
37//! // }
38//!
39//! let g = rosenbrock_gradient(&[1.0, 1.0]);
40//! assert_eq!(g, [0.0, 0.0]); // minimum of the Rosenbrock function
41//! ```
42//!
43//! # Attribute parameters
44//!
45//! - `dim: usize` — number of gradient components; must match the number of
46//!   `x[i]` indices used in the function body (required)
47//! - `max_passes: usize` — maximum simplification passes; default 10 (optional)
48//!
49//! # Supported syntax
50//!
51//! The function body may contain `let` bindings followed by a bare tail
52//! expression. Within expressions, the following are supported:
53//!
54//! - Variables: `x[0]`, `x[1]`, … (integer literal indices only)
55//! - Numeric literals: `1`, `2.0`, etc.
56//! - Arithmetic: `+`, `-`, `*`, `/`, unary `-`
57//! - Methods: `powi(n)` (integer literal exponent), `sin`, `cos`, `ln`,
58//!   `exp`, `sqrt`
59//! - Transparent: parentheses, `as f64` casts, block expressions
60//!
61//! Anything else (closures, function calls, loops, …) causes a compile-time
62//! panic with a descriptive message.
63//!
64//! # Limitations
65//!
66//! - The input slice parameter must be named `x`.
67//! - `powi` exponents must be integer literals, not variables.
68//! - Conditional expressions and loops cannot be differentiated symbolically.
69
70mod arena;
71mod coordinator;
72mod transformers;
73mod visitors;
74
75use std::collections::HashMap;
76
77use quote::quote;
78use syn::{Expr, Ident, Pat, Stmt};
79
80use arena::{NodeId, SymArena, SymNode, VarId};
81use transformers::{DiffTransformer, SimplifyTransformer};
82use visitors::{RefCountVisitor, ToTokenStreamVisitor};
83
84use crate::coordinator::{Coordinator, GreedyCoordinator};
85
86/// Parse a `syn` expression into the [`SymArena`], returning the root [`NodeId`].
87///
88/// Panics on any unsupported expression form.
89fn parse_syn(arena: &mut SymArena, expr: &Expr, env: &HashMap<syn::Ident, NodeId>) -> NodeId {
90    match expr {
91        Expr::Lit(lit) => match &lit.lit {
92            syn::Lit::Int(int_lit) => {
93                let value = int_lit.base10_parse::<u64>().unwrap();
94                return arena.intern(SymNode::Const((value as f64).to_bits()));
95            }
96            syn::Lit::Float(flt_lit) => {
97                let value = flt_lit.base10_parse::<f64>().unwrap();
98                return arena.intern(SymNode::Const(value.to_bits()));
99            }
100            _ => panic!("Unsupported literal type"),
101        },
102
103        Expr::Index(index_expr) => {
104            if let Expr::Path(path_expr) = &*index_expr.expr {
105                if let Some(var_ident) = path_expr.path.get_ident() {
106                    if let Some(var_id) = arena.get_var_id(var_ident) {
107                        if let Expr::Lit(lit) = &*index_expr.index {
108                            if let syn::Lit::Int(int_lit) = &lit.lit {
109                                let idx = int_lit.base10_parse::<NodeId>().unwrap();
110                                return arena.intern(SymNode::Var(var_id, idx));
111                            } else {
112                                panic!("Expected integer literal for variable index");
113                            }
114                        } else {
115                            panic!("Expected literal for variable index");
116                        }
117                    } else {
118                        panic!("Unexpected variable identifier");
119                    }
120                } else {
121                    panic!("Unable to find variable identifier");
122                }
123            } else {
124                panic!("Expected variable access of the form x[i]");
125            }
126        }
127
128        Expr::Path(path_expr) => {
129            if let Some(ident) = path_expr.path.get_ident() {
130                if let Some(node_id) = env.get(ident) {
131                    return *node_id;
132                } else {
133                    panic!("Undefined variable: {}", ident);
134                }
135            } else {
136                panic!("Unsupported path expression");
137            }
138        }
139
140        Expr::Binary(bin_expr) => {
141            let left_id = parse_syn(arena, &bin_expr.left, env);
142            let right_id = parse_syn(arena, &bin_expr.right, env);
143            return match bin_expr.op {
144                syn::BinOp::Add(_) => arena.intern(SymNode::Add(left_id, right_id)),
145                syn::BinOp::Sub(_) => arena.intern(SymNode::Sub(left_id, right_id)),
146                syn::BinOp::Mul(_) => arena.intern(SymNode::Mul(left_id, right_id)),
147                syn::BinOp::Div(_) => arena.intern(SymNode::Div(left_id, right_id)),
148                _ => panic!("Unsupported binary operator"),
149            };
150        }
151
152        Expr::Unary(un) => {
153            let operand_id = parse_syn(arena, &un.expr, env);
154            return match un.op {
155                syn::UnOp::Neg(_) => arena.intern(SymNode::Neg(operand_id)),
156                _ => panic!("Unsupported unary operator"),
157            };
158        }
159
160        Expr::MethodCall(call) => {
161            let receiver_id = parse_syn(arena, &call.receiver, env);
162            let method_name = call.method.to_string();
163            return match method_name.as_str() {
164                "powi" => {
165                    if call.args.len() != 1 {
166                        panic!("powi method call must have exactly one argument");
167                    }
168                    if let Expr::Lit(lit) = &call.args[0] {
169                        if let syn::Lit::Int(int_lit) = &lit.lit {
170                            let exp = int_lit.base10_parse::<i32>().unwrap();
171                            arena.intern(SymNode::Powi(receiver_id, exp))
172                        } else {
173                            panic!("Expected integer literal for powi exponent");
174                        }
175                    } else {
176                        panic!("Expected literal for powi exponent");
177                    }
178                }
179                "sin" => arena.intern(SymNode::Sin(receiver_id)),
180                "cos" => arena.intern(SymNode::Cos(receiver_id)),
181                "ln" => arena.intern(SymNode::Ln(receiver_id)),
182                "exp" => arena.intern(SymNode::Exp(receiver_id)),
183                "sqrt" => arena.intern(SymNode::Sqrt(receiver_id)),
184                _ => panic!("Unsupported method call: {}", method_name),
185            };
186        }
187
188        Expr::Paren(paren) => parse_syn(arena, &paren.expr, env),
189        Expr::Group(group) => parse_syn(arena, &group.expr, env),
190        Expr::Cast(c) => parse_syn(arena, &c.expr, env),
191
192        _ => panic!("Unsupported expression type {:?}", expr),
193    }
194}
195
196/// Differentiate `root_id` with respect to `var_idx`, simplify the result, and
197/// emit the `root_id` of the simplified expression.
198fn compile_expression(
199    arena: &mut SymArena,
200    root_id: NodeId,
201    var_id: VarId,
202    var_idx: usize,
203    options: &Options,
204) -> NodeId {
205    let cost_estimates = HashMap::from([
206        (SymNode::Const(0), 0),
207        (SymNode::Var(0, 0), 0),
208        (SymNode::Add(0, 0), 1),
209        (SymNode::Sub(0, 0), 1),
210        (SymNode::Mul(0, 0), 1),
211        (SymNode::Div(0, 0), 15),
212        (SymNode::Powi(0, 0), 3),
213        (SymNode::Neg(0), 0),
214        (SymNode::Sin(0), 100),
215        (SymNode::Cos(0), 100),
216        (SymNode::Ln(0), 60),
217        (SymNode::Exp(0), 60),
218        (SymNode::Sqrt(0), 5),
219    ]);
220
221    // Differentiate with respect to variable var_idx
222    let diff_transformer = DiffTransformer::new(var_id, var_idx);
223    let mut root_id = arena.transform(root_id, &diff_transformer);
224
225    // Simplify the result
226    let simplify_transformer = SimplifyTransformer::new();
227    for _ in 0..options.max_passes {
228        let new_root_id = arena.transform(root_id, &simplify_transformer);
229        if new_root_id == root_id {
230            break; // No further simplification possible
231        }
232        root_id = new_root_id;
233    }
234
235    // Commutative and associative reordering to canonicalize expressions and expose more common sub-expressions.
236    let greedy_coordinator = GreedyCoordinator::new(&cost_estimates);
237    for _ in 0..options.max_passes {
238        let new_root_id = greedy_coordinator.optimize(root_id, arena);
239        if new_root_id == root_id {
240            break; // No further optimization possible
241        }
242        root_id = new_root_id;
243    }
244
245    root_id
246}
247
248/// Walk a function body and return its environment and tail expression.
249///
250/// Expects zero or more `let` bindings followed by a bare return expression.
251/// Bound names are stored so later uses are inlined symbolically. Returns
252/// `None` for the tail if the body ends with a semicolon or is empty.
253fn parse_body(
254    arena: &mut SymArena,
255    block: &syn::Block,
256) -> (HashMap<syn::Ident, NodeId>, Option<Expr>) {
257    let mut env = HashMap::new();
258
259    for stmt in &block.stmts {
260        match stmt {
261            Stmt::Local(local) => {
262                let ident = match &local.pat {
263                    Pat::Ident(pat_ident) => Some(&pat_ident.ident),
264                    Pat::Type(pat_ident) => {
265                        if let Pat::Ident(inner_ident) = &*pat_ident.pat {
266                            Some(&inner_ident.ident)
267                        } else {
268                            None
269                        }
270                    }
271                    _ => None,
272                };
273
274                if let (Some(name), Some(init)) = (ident, &local.init) {
275                    let expr = &init.expr;
276                    let node_id = parse_syn(arena, expr, &env);
277                    env.insert(name.clone(), node_id);
278                }
279            }
280            Stmt::Expr(expr, None) => return (env, Some(expr.clone())),
281            _ => {
282                // Unsupported statement type (e.g. semi-colon terminated expr, item, macro).
283                // For simplicity, we require the function body to be a single expression.
284            }
285        }
286    }
287
288    (env, None)
289}
290
291/// Parsed attribute arguments for [`gradient`].
292#[derive(deluxe::ParseMetaItem)]
293struct DerivativeOptions {
294    // Variable over which to compute the gradient
295    var: Ident,
296    /// Number of gradient components, i.e. the length of the output array.
297    dim: usize,
298    /// Maximum simplification passes; defaults to 10.
299    max_passes: Option<usize>,
300    /// Specify whether to output data as sparse structures
301    sparse: Option<bool>,
302    /// Whether the input array access should be unchecked
303    unchecked: Option<bool>,
304    /// Specify whether to prune the tree after simplification
305    /// This operation may be expensive so diabled by default
306    prune: Option<bool>,
307}
308
309struct Options {
310    max_passes: usize,
311    sparse: bool,
312    unchecked: bool,
313    prune: bool,
314}
315
316impl From<DerivativeOptions> for Options {
317    fn from(value: DerivativeOptions) -> Self {
318        Self {
319            max_passes: value.max_passes.unwrap_or(10),
320            sparse: value.sparse.unwrap_or(false),
321            unchecked: value.unchecked.unwrap_or(false),
322            prune: value.prune.unwrap_or(false),
323        }
324    }
325}
326
327/// Emit the original function unchanged, plus `{fn}_gradient(x: &[f64]) -> [f64; dim]`
328/// where element `i` is `∂f/∂x[i]` in closed form.
329#[proc_macro_attribute]
330pub fn gradient(
331    attr: proc_macro::TokenStream,
332    item: proc_macro::TokenStream,
333) -> proc_macro::TokenStream {
334    let input_fn = syn::parse_macro_input!(item as syn::ItemFn);
335    let fn_name = &input_fn.sig.ident;
336    let params = &input_fn.sig.inputs;
337    let body = &input_fn.block;
338    let vis = &input_fn.vis;
339
340    let derivative_options = deluxe::parse::<DerivativeOptions>(attr.into())
341        .expect("Failed to parse macro attribute arguments for gradient.");
342    let dim = derivative_options.dim;
343    let var = derivative_options.var.clone();
344    let options = Options::from(derivative_options);
345
346    let mut arena = SymArena::new();
347
348    // Add all parameters as variables
349    let param_idents = params.iter().filter_map(|param| {
350        if let syn::FnArg::Typed(pat_type) = param {
351            if let syn::Pat::Ident(pat_ident) = *pat_type.pat.clone() {
352                return Some(pat_ident.ident);
353            }
354        }
355        None
356    });
357
358    param_idents.for_each(|ident| {
359        arena.intern_var_ident(&ident);
360    });
361
362    // Find the id of the variable of interest
363    let var_id = arena.get_var_id(&var).expect("Unable to find variable");
364
365    let (env, func_def) = parse_body(&mut arena, body);
366
367    if func_def.is_none() {
368        panic!("Function body must end with a bare expression (no trailing semicolon).");
369    }
370
371    let root = parse_syn(&mut arena, &func_def.unwrap(), &env);
372
373    if options.prune {
374        // Prune the expression tree to remove any nodes that are not ancestors of the root.
375        arena.prune(root);
376    }
377
378    let tokens = (0..dim)
379        .map(|i| compile_expression(&mut arena, root, var_id, i, &options))
380        .collect::<Vec<_>>();
381
382    // Count all references to each node to determine which sub-expressions are shared and should be hoisted into `let` bindings.
383    let mut ref_count = RefCountVisitor::new();
384    tokens.iter().for_each(|t| arena.accept(*t, &mut ref_count));
385
386    let mut cache = HashMap::new();
387    let mut instructions = Vec::new();
388
389    // Generate tokens for the gradient expression, hoisting any shared sub-expressions into `let` bindings and caching their tokens for reuse.
390    let mut to_tokens_visitor = ToTokenStreamVisitor::new(
391        &ref_count.get_counts(),
392        &mut cache,
393        &mut instructions,
394        &options,
395    );
396
397    let (dim, gradient_token) = {
398        if options.sparse {
399            // For sparse output, we must emit both the values and indices
400            let sparse_tokens = tokens
401                .iter()
402                .enumerate()
403                .filter(|(_, t)| &SymNode::Const(0) != arena.get_node(**t))
404                .map(|(i, t)| (i, arena.accept(*t, &mut to_tokens_visitor)))
405                .collect::<Vec<_>>();
406            let indices = sparse_tokens
407                .iter()
408                .map(|(i, _)| quote! { #i })
409                .collect::<Vec<_>>();
410            let values = sparse_tokens
411                .iter()
412                .map(|(_, t)| t.clone())
413                .collect::<Vec<_>>();
414            (
415                sparse_tokens.len(),
416                quote! { ([#(#indices),*], [#(#values),*]) },
417            )
418        } else {
419            let dense_tokens = tokens
420                .iter()
421                .map(|t| arena.accept(*t, &mut to_tokens_visitor))
422                .collect::<Vec<_>>();
423            (dim, quote! { [#(#dense_tokens),*] })
424        }
425    };
426
427    // Get the emitted `let` bindings in the order they must be emitted.
428    let instruction_tokens = to_tokens_visitor.get_instructions().to_vec();
429
430    let return_type = if options.sparse {
431        quote! { ([usize; #dim], [f64; #dim]) }
432    } else {
433        quote! { [f64; #dim] }
434    };
435
436    let grad_name = syn::Ident::new(
437        &format!("{}_gradient", fn_name),
438        proc_macro2::Span::call_site(),
439    );
440
441    let expanded = quote!(
442        #input_fn
443
444        #vis fn #grad_name(#params) -> #return_type {
445            unsafe {
446                #(#instruction_tokens)*
447                #gradient_token
448            }
449        }
450    );
451
452    expanded.into()
453}