1mod arena;
71mod coordinator;
72mod transformers;
73mod visitors;
74
75use std::collections::HashMap;
76
77use quote::quote;
78use syn::{Expr, Ident, Pat, Stmt};
79
80use arena::{NodeId, SymArena, SymNode, VarId};
81use transformers::{DiffTransformer, SimplifyTransformer};
82use visitors::{RefCountVisitor, ToTokenStreamVisitor};
83
84use crate::coordinator::{Coordinator, GreedyCoordinator};
85
86fn parse_syn(arena: &mut SymArena, expr: &Expr, env: &HashMap<syn::Ident, NodeId>) -> NodeId {
90 match expr {
91 Expr::Lit(lit) => match &lit.lit {
92 syn::Lit::Int(int_lit) => {
93 let value = int_lit.base10_parse::<u64>().unwrap();
94 return arena.intern(SymNode::Const((value as f64).to_bits()));
95 }
96 syn::Lit::Float(flt_lit) => {
97 let value = flt_lit.base10_parse::<f64>().unwrap();
98 return arena.intern(SymNode::Const(value.to_bits()));
99 }
100 _ => panic!("Unsupported literal type"),
101 },
102
103 Expr::Index(index_expr) => {
104 if let Expr::Path(path_expr) = &*index_expr.expr {
105 if let Some(var_ident) = path_expr.path.get_ident() {
106 if let Some(var_id) = arena.get_var_id(var_ident) {
107 if let Expr::Lit(lit) = &*index_expr.index {
108 if let syn::Lit::Int(int_lit) = &lit.lit {
109 let idx = int_lit.base10_parse::<NodeId>().unwrap();
110 return arena.intern(SymNode::Var(var_id, idx));
111 } else {
112 panic!("Expected integer literal for variable index");
113 }
114 } else {
115 panic!("Expected literal for variable index");
116 }
117 } else {
118 panic!("Unexpected variable identifier");
119 }
120 } else {
121 panic!("Unable to find variable identifier");
122 }
123 } else {
124 panic!("Expected variable access of the form x[i]");
125 }
126 }
127
128 Expr::Path(path_expr) => {
129 if let Some(ident) = path_expr.path.get_ident() {
130 if let Some(node_id) = env.get(ident) {
131 return *node_id;
132 } else {
133 panic!("Undefined variable: {}", ident);
134 }
135 } else {
136 panic!("Unsupported path expression");
137 }
138 }
139
140 Expr::Binary(bin_expr) => {
141 let left_id = parse_syn(arena, &bin_expr.left, env);
142 let right_id = parse_syn(arena, &bin_expr.right, env);
143 return match bin_expr.op {
144 syn::BinOp::Add(_) => arena.intern(SymNode::Add(left_id, right_id)),
145 syn::BinOp::Sub(_) => arena.intern(SymNode::Sub(left_id, right_id)),
146 syn::BinOp::Mul(_) => arena.intern(SymNode::Mul(left_id, right_id)),
147 syn::BinOp::Div(_) => arena.intern(SymNode::Div(left_id, right_id)),
148 _ => panic!("Unsupported binary operator"),
149 };
150 }
151
152 Expr::Unary(un) => {
153 let operand_id = parse_syn(arena, &un.expr, env);
154 return match un.op {
155 syn::UnOp::Neg(_) => arena.intern(SymNode::Neg(operand_id)),
156 _ => panic!("Unsupported unary operator"),
157 };
158 }
159
160 Expr::MethodCall(call) => {
161 let receiver_id = parse_syn(arena, &call.receiver, env);
162 let method_name = call.method.to_string();
163 return match method_name.as_str() {
164 "powi" => {
165 if call.args.len() != 1 {
166 panic!("powi method call must have exactly one argument");
167 }
168 if let Expr::Lit(lit) = &call.args[0] {
169 if let syn::Lit::Int(int_lit) = &lit.lit {
170 let exp = int_lit.base10_parse::<i32>().unwrap();
171 arena.intern(SymNode::Powi(receiver_id, exp))
172 } else {
173 panic!("Expected integer literal for powi exponent");
174 }
175 } else {
176 panic!("Expected literal for powi exponent");
177 }
178 }
179 "sin" => arena.intern(SymNode::Sin(receiver_id)),
180 "cos" => arena.intern(SymNode::Cos(receiver_id)),
181 "ln" => arena.intern(SymNode::Ln(receiver_id)),
182 "exp" => arena.intern(SymNode::Exp(receiver_id)),
183 "sqrt" => arena.intern(SymNode::Sqrt(receiver_id)),
184 _ => panic!("Unsupported method call: {}", method_name),
185 };
186 }
187
188 Expr::Paren(paren) => parse_syn(arena, &paren.expr, env),
189 Expr::Group(group) => parse_syn(arena, &group.expr, env),
190 Expr::Cast(c) => parse_syn(arena, &c.expr, env),
191
192 _ => panic!("Unsupported expression type {:?}", expr),
193 }
194}
195
196fn compile_expression(
199 arena: &mut SymArena,
200 root_id: NodeId,
201 var_id: VarId,
202 var_idx: usize,
203 options: &Options,
204) -> NodeId {
205 let cost_estimates = HashMap::from([
206 (SymNode::Const(0), 0),
207 (SymNode::Var(0, 0), 0),
208 (SymNode::Add(0, 0), 1),
209 (SymNode::Sub(0, 0), 1),
210 (SymNode::Mul(0, 0), 1),
211 (SymNode::Div(0, 0), 15),
212 (SymNode::Powi(0, 0), 3),
213 (SymNode::Neg(0), 0),
214 (SymNode::Sin(0), 100),
215 (SymNode::Cos(0), 100),
216 (SymNode::Ln(0), 60),
217 (SymNode::Exp(0), 60),
218 (SymNode::Sqrt(0), 5),
219 ]);
220
221 let diff_transformer = DiffTransformer::new(var_id, var_idx);
223 let mut root_id = arena.transform(root_id, &diff_transformer);
224
225 let simplify_transformer = SimplifyTransformer::new();
227 for _ in 0..options.max_passes {
228 let new_root_id = arena.transform(root_id, &simplify_transformer);
229 if new_root_id == root_id {
230 break; }
232 root_id = new_root_id;
233 }
234
235 let greedy_coordinator = GreedyCoordinator::new(&cost_estimates);
237 for _ in 0..options.max_passes {
238 let new_root_id = greedy_coordinator.optimize(root_id, arena);
239 if new_root_id == root_id {
240 break; }
242 root_id = new_root_id;
243 }
244
245 root_id
246}
247
248fn parse_body(
254 arena: &mut SymArena,
255 block: &syn::Block,
256) -> (HashMap<syn::Ident, NodeId>, Option<Expr>) {
257 let mut env = HashMap::new();
258
259 for stmt in &block.stmts {
260 match stmt {
261 Stmt::Local(local) => {
262 let ident = match &local.pat {
263 Pat::Ident(pat_ident) => Some(&pat_ident.ident),
264 Pat::Type(pat_ident) => {
265 if let Pat::Ident(inner_ident) = &*pat_ident.pat {
266 Some(&inner_ident.ident)
267 } else {
268 None
269 }
270 }
271 _ => None,
272 };
273
274 if let (Some(name), Some(init)) = (ident, &local.init) {
275 let expr = &init.expr;
276 let node_id = parse_syn(arena, expr, &env);
277 env.insert(name.clone(), node_id);
278 }
279 }
280 Stmt::Expr(expr, None) => return (env, Some(expr.clone())),
281 _ => {
282 }
285 }
286 }
287
288 (env, None)
289}
290
291#[derive(deluxe::ParseMetaItem)]
293struct DerivativeOptions {
294 var: Ident,
296 dim: usize,
298 max_passes: Option<usize>,
300 sparse: Option<bool>,
302 unchecked: Option<bool>,
304 prune: Option<bool>,
307}
308
309struct Options {
310 max_passes: usize,
311 sparse: bool,
312 unchecked: bool,
313 prune: bool,
314}
315
316impl From<DerivativeOptions> for Options {
317 fn from(value: DerivativeOptions) -> Self {
318 Self {
319 max_passes: value.max_passes.unwrap_or(10),
320 sparse: value.sparse.unwrap_or(false),
321 unchecked: value.unchecked.unwrap_or(false),
322 prune: value.prune.unwrap_or(false),
323 }
324 }
325}
326
327#[proc_macro_attribute]
330pub fn gradient(
331 attr: proc_macro::TokenStream,
332 item: proc_macro::TokenStream,
333) -> proc_macro::TokenStream {
334 let input_fn = syn::parse_macro_input!(item as syn::ItemFn);
335 let fn_name = &input_fn.sig.ident;
336 let params = &input_fn.sig.inputs;
337 let body = &input_fn.block;
338 let vis = &input_fn.vis;
339
340 let derivative_options = deluxe::parse::<DerivativeOptions>(attr.into())
341 .expect("Failed to parse macro attribute arguments for gradient.");
342 let dim = derivative_options.dim;
343 let var = derivative_options.var.clone();
344 let options = Options::from(derivative_options);
345
346 let mut arena = SymArena::new();
347
348 let param_idents = params.iter().filter_map(|param| {
350 if let syn::FnArg::Typed(pat_type) = param {
351 if let syn::Pat::Ident(pat_ident) = *pat_type.pat.clone() {
352 return Some(pat_ident.ident);
353 }
354 }
355 None
356 });
357
358 param_idents.for_each(|ident| {
359 arena.intern_var_ident(&ident);
360 });
361
362 let var_id = arena.get_var_id(&var).expect("Unable to find variable");
364
365 let (env, func_def) = parse_body(&mut arena, body);
366
367 if func_def.is_none() {
368 panic!("Function body must end with a bare expression (no trailing semicolon).");
369 }
370
371 let root = parse_syn(&mut arena, &func_def.unwrap(), &env);
372
373 if options.prune {
374 arena.prune(root);
376 }
377
378 let tokens = (0..dim)
379 .map(|i| compile_expression(&mut arena, root, var_id, i, &options))
380 .collect::<Vec<_>>();
381
382 let mut ref_count = RefCountVisitor::new();
384 tokens.iter().for_each(|t| arena.accept(*t, &mut ref_count));
385
386 let mut cache = HashMap::new();
387 let mut instructions = Vec::new();
388
389 let mut to_tokens_visitor = ToTokenStreamVisitor::new(
391 &ref_count.get_counts(),
392 &mut cache,
393 &mut instructions,
394 &options,
395 );
396
397 let (dim, gradient_token) = {
398 if options.sparse {
399 let sparse_tokens = tokens
401 .iter()
402 .enumerate()
403 .filter(|(_, t)| &SymNode::Const(0) != arena.get_node(**t))
404 .map(|(i, t)| (i, arena.accept(*t, &mut to_tokens_visitor)))
405 .collect::<Vec<_>>();
406 let indices = sparse_tokens
407 .iter()
408 .map(|(i, _)| quote! { #i })
409 .collect::<Vec<_>>();
410 let values = sparse_tokens
411 .iter()
412 .map(|(_, t)| t.clone())
413 .collect::<Vec<_>>();
414 (
415 sparse_tokens.len(),
416 quote! { ([#(#indices),*], [#(#values),*]) },
417 )
418 } else {
419 let dense_tokens = tokens
420 .iter()
421 .map(|t| arena.accept(*t, &mut to_tokens_visitor))
422 .collect::<Vec<_>>();
423 (dim, quote! { [#(#dense_tokens),*] })
424 }
425 };
426
427 let instruction_tokens = to_tokens_visitor.get_instructions().to_vec();
429
430 let return_type = if options.sparse {
431 quote! { ([usize; #dim], [f64; #dim]) }
432 } else {
433 quote! { [f64; #dim] }
434 };
435
436 let grad_name = syn::Ident::new(
437 &format!("{}_gradient", fn_name),
438 proc_macro2::Span::call_site(),
439 );
440
441 let expanded = quote!(
442 #input_fn
443
444 #vis fn #grad_name(#params) -> #return_type {
445 unsafe {
446 #(#instruction_tokens)*
447 #gradient_token
448 }
449 }
450 );
451
452 expanded.into()
453}