1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::{cell::RefCell, rc::Rc};
4
5use crate::Float;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct NodeId {
10 index: usize,
11 graph_id: u64,
12}
13
14impl NodeId {
15 fn new(index: usize, graph_id: u64) -> Self {
16 Self { index, graph_id }
17 }
18}
19
20static NEXT_GRAPH_ID: AtomicU64 = AtomicU64::new(1);
21
22#[derive(Debug)]
25pub struct ExprGraph {
26 graph_id: u64,
27 nodes: Vec<Node>,
28 node_map: HashMap<String, NodeId>,
29 inputs: Vec<NodeId>,
30 input_names: Vec<String>,
31 outputs: Vec<NodeId>,
32 max_arity: usize,
33 next_id: usize,
34}
35
36#[derive(Debug, Clone)]
38pub enum Node {
39 Input(String),
40 Const(Float),
41 AfterOperation(Op, Box<[NodeId]>),
42 Output(NodeId),
43}
44
45#[derive(Debug, Clone, Copy)]
47pub enum Op {
48 Scale(Float),
49 Sin,
50 Cos,
51 Pow(i32),
52 Add,
53 Mul,
54}
55
56#[derive(Debug, Default)]
59pub struct EvalTape {
60 primals: Vec<Float>,
61 tangents: Vec<Float>,
62 input_count: usize,
63 scratch_primals: Vec<Float>,
64 scratch_partials: Vec<Float>,
65}
66
67impl EvalTape {
68 pub fn new() -> Self {
69 Self::default()
70 }
71
72 pub fn with_capacity(nodes: usize, input_count: usize, max_arity: usize) -> Self {
73 Self {
74 primals: Vec::with_capacity(nodes),
75 tangents: Vec::with_capacity(nodes * input_count),
76 input_count,
77 scratch_primals: Vec::with_capacity(max_arity),
78 scratch_partials: Vec::with_capacity(max_arity),
79 }
80 }
81
82 fn reset(&mut self, nodes: usize, input_count: usize, max_arity: usize) {
83 self.input_count = input_count;
84 self.primals.clear();
85 self.tangents.clear();
86 self.primals.resize(nodes, 0.0);
87 self.tangents.resize(nodes * input_count, 0.0);
88 self.scratch_primals.clear();
89 self.scratch_partials.clear();
90 self.scratch_primals.resize(max_arity, 0.0);
91 self.scratch_partials.resize(max_arity, 0.0);
92 }
93
94 fn tangent_index(&self, node_idx: usize, input_idx: usize) -> usize {
95 node_idx * self.input_count + input_idx
96 }
97}
98
99#[derive(Debug, Default)]
101pub struct ReverseTape {
102 primals: Vec<Float>,
103 adjoints: Vec<Float>,
104 scratch_primals: Vec<Float>,
105 scratch_partials: Vec<Float>,
106}
107
108impl ReverseTape {
109 pub fn new() -> Self {
110 Self::default()
111 }
112
113 pub fn with_capacity(nodes: usize, max_arity: usize) -> Self {
114 Self {
115 primals: Vec::with_capacity(nodes),
116 adjoints: Vec::with_capacity(nodes),
117 scratch_primals: Vec::with_capacity(max_arity),
118 scratch_partials: Vec::with_capacity(max_arity),
119 }
120 }
121
122 fn reset(&mut self, nodes: usize, max_arity: usize) {
123 self.primals.clear();
124 self.adjoints.clear();
125 self.primals.resize(nodes, 0.0);
126 self.adjoints.resize(nodes, 0.0);
127 self.scratch_primals.clear();
128 self.scratch_partials.clear();
129 self.scratch_primals.resize(max_arity, 0.0);
130 self.scratch_partials.resize(max_arity, 0.0);
131 }
132}
133
134impl Op {
135 fn validate_arity(self, inputs_len: usize) {
136 let ok = match self {
137 Op::Scale(_) | Op::Sin | Op::Cos | Op::Pow(_) => inputs_len == 1,
138 Op::Add | Op::Mul => inputs_len >= 2,
139 };
140
141 assert!(
142 ok,
143 "invalid arity for {:?}: expected {}, got {}",
144 self,
145 match self {
146 Op::Scale(_) | Op::Sin | Op::Cos | Op::Pow(_) => "1",
147 Op::Add | Op::Mul => ">= 2",
148 },
149 inputs_len
150 );
151 }
152
153 fn apply(self, inputs: &[Float]) -> Float {
154 match self {
155 Op::Scale(factor) => inputs[0] * factor,
156 Op::Sin => inputs[0].sin(),
157 Op::Cos => inputs[0].cos(),
158 Op::Pow(exp) => inputs[0].powi(exp),
159 Op::Add => inputs.iter().sum(),
160 Op::Mul => inputs.iter().product(),
161 }
162 }
163
164 fn compute_derivative(self, inputs: &[Float], input_idx: usize) -> Float {
165 match self {
166 Op::Scale(factor) => factor,
167 Op::Sin => inputs[0].cos(),
168 Op::Cos => -inputs[0].sin(),
169 Op::Pow(exp) => {
170 if exp == 0 {
171 0.0
172 } else {
173 exp as Float * inputs[0].powi(exp - 1)
174 }
175 }
176 Op::Add => 1.0,
177 Op::Mul => inputs
178 .iter()
179 .enumerate()
180 .filter(|(i, _)| *i != input_idx)
181 .map(|(_, &x)| x)
182 .product(),
183 }
184 }
185}
186
187impl ExprGraph {
188 pub fn new() -> Self {
189 Self {
190 graph_id: NEXT_GRAPH_ID.fetch_add(1, Ordering::Relaxed),
191 nodes: Vec::new(),
192 node_map: HashMap::new(),
193 inputs: Vec::new(),
194 input_names: Vec::new(),
195 outputs: Vec::new(),
196 max_arity: 0,
197 next_id: 0,
198 }
199 }
200
201 fn make_node_id(&self, index: usize) -> NodeId {
202 NodeId::new(index, self.graph_id)
203 }
204
205 fn is_valid_node(&self, id: NodeId) -> bool {
206 id.graph_id == self.graph_id && id.index < self.next_id
207 }
208
209 fn assert_valid_node(&self, id: NodeId, context: &str) {
210 assert!(
211 self.is_valid_node(id),
212 "{context} does not belong to this graph or is out of bounds"
213 );
214 }
215
216 pub fn input(&mut self, name: String) -> NodeId {
217 assert!(
218 !self.node_map.contains_key(&name),
219 "input name already exists: {name}"
220 );
221
222 let id = self.make_node_id(self.next_id);
223 self.next_id += 1;
224 self.nodes.push(Node::Input(name.clone()));
225 self.node_map.insert(name.clone(), id);
226 self.inputs.push(id);
227 self.input_names.push(name);
228 id
229 }
230
231 pub fn constant(&mut self, value: Float) -> NodeId {
232 let id = self.make_node_id(self.next_id);
233 self.next_id += 1;
234 self.nodes.push(Node::Const(value));
235 id
236 }
237
238 pub fn operation<I>(&mut self, op: Op, inputs: I) -> NodeId
239 where
240 I: AsRef<[NodeId]>,
241 {
242 let inputs_ref = inputs.as_ref();
243 op.validate_arity(inputs_ref.len());
244 assert!(
245 inputs_ref.iter().all(|id| self.is_valid_node(*id)),
246 "operation inputs must reference earlier nodes in the same graph"
247 );
248 self.max_arity = self.max_arity.max(inputs_ref.len());
249 let id = self.make_node_id(self.next_id);
250 self.next_id += 1;
251 self.nodes
252 .push(Node::AfterOperation(op, Box::from(inputs_ref)));
253 id
254 }
255
256 pub fn output(&mut self, node: NodeId) -> NodeId {
257 self.assert_valid_node(node, "output node");
258 let id = self.make_node_id(self.next_id);
259 self.next_id += 1;
260 self.nodes.push(Node::Output(node));
261 self.outputs.push(id);
262 id
263 }
264
265 pub fn fwd_tape(&self) -> EvalTape {
268 EvalTape::with_capacity(self.nodes.len(), self.inputs.len(), self.max_arity)
269 }
270
271 pub fn tape(&self) -> ReverseTape {
273 self.reverse_tape()
274 }
275
276 pub fn reverse_tape(&self) -> ReverseTape {
277 ReverseTape::with_capacity(self.nodes.len(), self.max_arity)
278 }
279
280 pub fn input_names(&self) -> &[String] {
281 &self.input_names
282 }
283
284 pub fn eval_fwd(&self, inputs: &[Float]) -> Vec<(Float, Vec<Float>)> {
287 let mut tape = self.fwd_tape();
288 self.eval_fwd_with_tape(inputs, &mut tape)
289 }
290
291 pub fn eval_fwd_with_tape(
294 &self,
295 inputs: &[Float],
296 tape: &mut EvalTape,
297 ) -> Vec<(Float, Vec<Float>)> {
298 assert_eq!(
299 inputs.len(),
300 self.inputs.len(),
301 "expected {} inputs, got {}",
302 self.inputs.len(),
303 inputs.len()
304 );
305
306 tape.reset(self.nodes.len(), self.inputs.len(), self.max_arity);
307
308 for (input_idx, node_id) in self.inputs.iter().enumerate() {
310 let node_idx = node_id.index;
311 tape.primals[node_idx] = inputs[input_idx];
312 let tangent_idx = tape.tangent_index(node_idx, input_idx);
313 tape.tangents[tangent_idx] = 1.0;
314 }
315
316 for (i, node) in self.nodes.iter().enumerate() {
318 match node {
319 Node::AfterOperation(op, inputs) => {
320 let arity = inputs.len();
321 let input_primals = &mut tape.scratch_primals[..arity];
322 for (slot, &id) in input_primals.iter_mut().zip(inputs.iter()) {
323 *slot = tape.primals[id.index];
324 }
325
326 tape.primals[i] = op.apply(input_primals);
327
328 let partials = &mut tape.scratch_partials[..arity];
330 for (j, partial) in partials.iter_mut().enumerate() {
331 *partial = op.compute_derivative(input_primals, j);
332 }
333
334 let input_count = tape.input_count;
335 let tangents = &mut tape.tangents;
336 for input_dim in 0..input_count {
337 let mut total = 0.0;
338 for (j, &input_id) in inputs.iter().enumerate() {
339 let idx = input_id.index * input_count + input_dim;
340 total += tangents[idx] * partials[j];
341 }
342 let out_idx = i * input_count + input_dim;
343 tangents[out_idx] = total;
344 }
345 }
346 Node::Const(value) => {
347 tape.primals[i] = *value;
348 }
349 _ => {}
350 }
351 }
352
353 for (i, node) in self.nodes.iter().enumerate() {
355 if let Node::Output(input_id) = node {
356 tape.primals[i] = tape.primals[input_id.index];
357 let src_start = tape.tangent_index(input_id.index, 0);
358 let dst_start = tape.tangent_index(i, 0);
359 let len = tape.input_count;
360 tape.tangents
361 .copy_within(src_start..(src_start + len), dst_start);
362 }
363 }
364
365 self.outputs
366 .iter()
367 .map(|id| {
368 let idx = id.index;
369 let start = tape.tangent_index(idx, 0);
370 let end = start + tape.input_count;
371 (tape.primals[idx], tape.tangents[start..end].to_vec())
372 })
373 .collect()
374 }
375
376 pub fn eval_fwd_one(&self, inputs: &[Float]) -> (Float, Vec<Float>) {
377 let mut tape = self.fwd_tape();
378 self.eval_fwd_one_with_tape(inputs, &mut tape)
379 }
380
381 pub fn eval_fwd_one_with_tape(
382 &self,
383 inputs: &[Float],
384 tape: &mut EvalTape,
385 ) -> (Float, Vec<Float>) {
386 let mut outputs = self.eval_fwd_with_tape(inputs, tape);
387 assert!(
388 outputs.len() == 1,
389 "expected a single output, got {}",
390 outputs.len()
391 );
392 outputs.remove(0)
393 }
394
395 pub fn eval_fwd_named(&self, inputs: &[Float]) -> Vec<(Float, Vec<(String, Float)>)> {
396 let mut tape = self.fwd_tape();
397 self.eval_fwd_named_with_tape(inputs, &mut tape)
398 }
399
400 pub fn eval_fwd_named_with_tape(
401 &self,
402 inputs: &[Float],
403 tape: &mut EvalTape,
404 ) -> Vec<(Float, Vec<(String, Float)>)> {
405 let outputs = self.eval_fwd_with_tape(inputs, tape);
406 outputs
407 .into_iter()
408 .map(|(value, grads)| {
409 let named = self
410 .input_names
411 .iter()
412 .cloned()
413 .zip(grads)
414 .collect::<Vec<_>>();
415 (value, named)
416 })
417 .collect()
418 }
419
420 pub fn eval(&self, inputs: &[Float]) -> Vec<(Float, Vec<Float>)> {
423 let mut tape = self.reverse_tape();
424 self.eval_with_tape(inputs, &mut tape)
425 }
426
427 pub fn eval_with_tape(
430 &self,
431 inputs: &[Float],
432 tape: &mut ReverseTape,
433 ) -> Vec<(Float, Vec<Float>)> {
434 self.eval_for_with_tape(inputs, &self.outputs, tape)
435 }
436
437 pub fn eval_for(&self, inputs: &[Float], outputs: &[NodeId]) -> Vec<(Float, Vec<Float>)> {
439 let mut tape = self.reverse_tape();
440 self.eval_for_with_tape(inputs, outputs, &mut tape)
441 }
442
443 pub fn eval_for_with_tape(
445 &self,
446 inputs: &[Float],
447 outputs: &[NodeId],
448 tape: &mut ReverseTape,
449 ) -> Vec<(Float, Vec<Float>)> {
450 assert_eq!(
451 inputs.len(),
452 self.inputs.len(),
453 "expected {} inputs, got {}",
454 self.inputs.len(),
455 inputs.len()
456 );
457 for &output in outputs {
458 self.assert_valid_node(output, "requested output");
459 }
460
461 tape.reset(self.nodes.len(), self.max_arity);
462
463 for (input_idx, node_id) in self.inputs.iter().enumerate() {
465 tape.primals[node_id.index] = inputs[input_idx];
466 }
467
468 for (i, node) in self.nodes.iter().enumerate() {
469 match node {
470 Node::AfterOperation(op, inputs) => {
471 let arity = inputs.len();
472 let input_primals = &mut tape.scratch_primals[..arity];
473 for (slot, &id) in input_primals.iter_mut().zip(inputs.iter()) {
474 *slot = tape.primals[id.index];
475 }
476 tape.primals[i] = op.apply(input_primals);
477 }
478 Node::Output(input_id) => {
479 tape.primals[i] = tape.primals[input_id.index];
480 }
481 Node::Const(value) => {
482 tape.primals[i] = *value;
483 }
484 Node::Input(_) => {}
485 }
486 }
487
488 let mut results = Vec::with_capacity(outputs.len());
489
490 for output_id in outputs {
491 tape.adjoints.fill(0.0);
492 tape.adjoints[output_id.index] = 1.0;
493
494 for (i, node) in self.nodes.iter().enumerate().rev() {
495 match node {
496 Node::Output(input_id) => {
497 tape.adjoints[input_id.index] += tape.adjoints[i];
498 }
499 Node::AfterOperation(op, inputs) => {
500 let arity = inputs.len();
501 let input_primals = &mut tape.scratch_primals[..arity];
502 for (slot, &id) in input_primals.iter_mut().zip(inputs.iter()) {
503 *slot = tape.primals[id.index];
504 }
505
506 let partials = &mut tape.scratch_partials[..arity];
507 for (j, partial) in partials.iter_mut().enumerate() {
508 *partial = op.compute_derivative(input_primals, j);
509 }
510
511 let adj = tape.adjoints[i];
512 if adj != 0.0 {
513 for (j, &input_id) in inputs.iter().enumerate() {
514 tape.adjoints[input_id.index] += adj * partials[j];
515 }
516 }
517 }
518 Node::Const(_) | Node::Input(_) => {}
519 }
520 }
521
522 let grads = self
523 .inputs
524 .iter()
525 .map(|id| tape.adjoints[id.index])
526 .collect::<Vec<_>>();
527 results.push((tape.primals[output_id.index], grads));
528 }
529
530 results
531 }
532
533 pub fn eval_one(&self, inputs: &[Float]) -> (Float, Vec<Float>) {
534 let mut tape = self.reverse_tape();
535 self.eval_one_with_tape(inputs, &mut tape)
536 }
537
538 pub fn eval_one_with_tape(
539 &self,
540 inputs: &[Float],
541 tape: &mut ReverseTape,
542 ) -> (Float, Vec<Float>) {
543 let mut outputs = self.eval_with_tape(inputs, tape);
544 assert!(
545 outputs.len() == 1,
546 "expected a single output, got {}",
547 outputs.len()
548 );
549 outputs.remove(0)
550 }
551
552 pub fn eval_named(&self, inputs: &[Float]) -> Vec<(Float, Vec<(String, Float)>)> {
553 let mut tape = self.reverse_tape();
554 self.eval_named_with_tape(inputs, &mut tape)
555 }
556
557 pub fn eval_named_with_tape(
558 &self,
559 inputs: &[Float],
560 tape: &mut ReverseTape,
561 ) -> Vec<(Float, Vec<(String, Float)>)> {
562 let outputs = self.eval_with_tape(inputs, tape);
563 outputs
564 .into_iter()
565 .map(|(value, grads)| {
566 let named = self
567 .input_names
568 .iter()
569 .cloned()
570 .zip(grads)
571 .collect::<Vec<_>>();
572 (value, named)
573 })
574 .collect()
575 }
576
577 pub fn eval_named_for(
578 &self,
579 inputs: &[Float],
580 outputs: &[NodeId],
581 ) -> Vec<(Float, Vec<(String, Float)>)> {
582 let mut tape = self.reverse_tape();
583 self.eval_named_for_with_tape(inputs, outputs, &mut tape)
584 }
585
586 pub fn eval_named_for_with_tape(
587 &self,
588 inputs: &[Float],
589 outputs: &[NodeId],
590 tape: &mut ReverseTape,
591 ) -> Vec<(Float, Vec<(String, Float)>)> {
592 let outputs = self.eval_for_with_tape(inputs, outputs, tape);
593 outputs
594 .into_iter()
595 .map(|(value, grads)| {
596 let named = self
597 .input_names
598 .iter()
599 .cloned()
600 .zip(grads)
601 .collect::<Vec<_>>();
602 (value, named)
603 })
604 .collect()
605 }
606}
607
608impl Default for ExprGraph {
609 fn default() -> Self {
610 Self::new()
611 }
612}
613
614#[derive(Debug, Clone)]
615pub struct Gradients {
616 pub value: Float,
617 pub grads: Vec<(String, Float)>,
618}
619
620impl Gradients {
621 pub fn get(&self, name: &str) -> Option<Float> {
622 self.grads
623 .iter()
624 .find_map(|(key, value)| (key == name).then_some(*value))
625 }
626}
627
628#[derive(Debug, Clone, PartialEq, Eq)]
629pub enum TapeError {
630 InputLengthMismatch { expected: usize, got: usize },
631 UnknownInput(String),
632}
633
634impl std::fmt::Display for TapeError {
635 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636 match self {
637 Self::InputLengthMismatch { expected, got } => {
638 write!(f, "expected {expected} inputs, got {got}")
639 }
640 Self::UnknownInput(name) => write!(f, "unknown input name: {name}"),
641 }
642 }
643}
644
645impl std::error::Error for TapeError {}
646
647#[derive(Debug, Clone)]
649pub struct Tape {
650 inner: Rc<RefCell<TapeInner>>,
651}
652
653#[derive(Debug)]
654struct TapeInner {
655 graph: ExprGraph,
656 values: Vec<Float>,
657}
658
659#[derive(Debug, Clone)]
661pub struct Var {
662 id: NodeId,
663 inner: Rc<RefCell<TapeInner>>,
664}
665
666impl Tape {
667 pub fn new() -> Self {
668 Self {
669 inner: Rc::new(RefCell::new(TapeInner {
670 graph: ExprGraph::new(),
671 values: Vec::new(),
672 })),
673 }
674 }
675
676 pub fn input(&mut self, name: impl Into<String>, value: Float) -> Var {
677 let mut inner = self.inner.borrow_mut();
678 let id = inner.graph.input(name.into());
679 inner.values.push(value);
680 Var {
681 id,
682 inner: self.inner.clone(),
683 }
684 }
685
686 pub fn input_unnamed(&mut self, value: Float) -> Var {
687 let idx = self.inner.borrow().values.len();
688 self.input(format!("_{}", idx), value)
689 }
690
691 pub fn constant(&mut self, value: Float) -> Var {
692 let mut inner = self.inner.borrow_mut();
693 let id = inner.graph.constant(value);
694 Var {
695 id,
696 inner: self.inner.clone(),
697 }
698 }
699
700 pub fn set_inputs(&mut self, values: &[Float]) {
701 self.try_set_inputs(values)
702 .expect("input length mismatch for Tape::set_inputs");
703 }
704
705 pub fn try_set_inputs(&mut self, values: &[Float]) -> Result<(), TapeError> {
706 let mut inner = self.inner.borrow_mut();
707 let expected = inner.values.len();
708 if values.len() != expected {
709 return Err(TapeError::InputLengthMismatch {
710 expected,
711 got: values.len(),
712 });
713 }
714 inner.values.copy_from_slice(values);
715 Ok(())
716 }
717
718 pub fn set(&mut self, name: &str, value: Float) {
719 self.try_set(name, value)
720 .expect("unknown input name for Tape::set");
721 }
722
723 pub fn try_set(&mut self, name: &str, value: Float) -> Result<(), TapeError> {
724 let mut inner = self.inner.borrow_mut();
725 let Some(idx) = inner.graph.input_names.iter().position(|n| n == name) else {
726 return Err(TapeError::UnknownInput(name.to_string()));
727 };
728 inner.values[idx] = value;
729 Ok(())
730 }
731
732 pub fn input_names(&self) -> Vec<String> {
733 self.inner.borrow().graph.input_names.clone()
734 }
735
736 pub fn gradients(&self, output: &Var) -> Gradients {
737 output.assert_same_tape(self);
738 let inner = self.inner.borrow();
739 let results = inner.graph.eval_named_for(&inner.values, &[output.id]);
740 let (value, grads) = results.into_iter().next().expect("missing output");
741 Gradients { value, grads }
742 }
743
744 pub fn gradients_for(&self, outputs: &[Var]) -> Vec<Gradients> {
745 if outputs.is_empty() {
746 return Vec::new();
747 }
748 outputs[0].assert_same_tape(self);
749 for var in outputs.iter().skip(1) {
750 var.assert_same_tape(self);
751 }
752
753 let inner = self.inner.borrow();
754 let ids = outputs.iter().map(|var| var.id).collect::<Vec<_>>();
755 inner
756 .graph
757 .eval_named_for(&inner.values, &ids)
758 .into_iter()
759 .map(|(value, grads)| Gradients { value, grads })
760 .collect()
761 }
762}
763
764impl Default for Tape {
765 fn default() -> Self {
766 Self::new()
767 }
768}
769
770impl Var {
771 fn assert_same_tape(&self, tape: &Tape) {
772 assert!(
773 Rc::ptr_eq(&self.inner, &tape.inner),
774 "cannot mix Vars from different tapes"
775 );
776 }
777
778 fn assert_same_var_tape(&self, other: &Var) {
779 assert!(
780 Rc::ptr_eq(&self.inner, &other.inner),
781 "cannot mix Vars from different tapes"
782 );
783 }
784
785 fn unary_op(&self, op: Op) -> Var {
786 let mut inner = self.inner.borrow_mut();
787 let id = inner.graph.operation(op, vec![self.id]);
788 Var {
789 id,
790 inner: self.inner.clone(),
791 }
792 }
793
794 fn binary_op(&self, rhs: &Var, op: Op) -> Var {
795 self.assert_same_var_tape(rhs);
796 let mut inner = self.inner.borrow_mut();
797 let id = inner.graph.operation(op, vec![self.id, rhs.id]);
798 Var {
799 id,
800 inner: self.inner.clone(),
801 }
802 }
803
804 fn konst(&self, value: Float) -> Var {
805 let mut inner = self.inner.borrow_mut();
806 let id = inner.graph.constant(value);
807 Var {
808 id,
809 inner: self.inner.clone(),
810 }
811 }
812
813 pub fn sin(&self) -> Var {
814 self.unary_op(Op::Sin)
815 }
816
817 pub fn cos(&self) -> Var {
818 self.unary_op(Op::Cos)
819 }
820
821 pub fn powi(&self, exp: i32) -> Var {
822 self.unary_op(Op::Pow(exp))
823 }
824
825 pub fn scale(&self, factor: Float) -> Var {
826 self.unary_op(Op::Scale(factor))
827 }
828}
829
830impl std::ops::Add for Var {
831 type Output = Var;
832 fn add(self, rhs: Var) -> Self::Output {
833 self.binary_op(&rhs, Op::Add)
834 }
835}
836
837impl std::ops::Add<Float> for Var {
838 type Output = Var;
839 fn add(self, rhs: Float) -> Self::Output {
840 let rhs = self.konst(rhs);
841 self.binary_op(&rhs, Op::Add)
842 }
843}
844
845impl std::ops::Sub for Var {
846 type Output = Var;
847 fn sub(self, rhs: Var) -> Self::Output {
848 self + (-rhs)
849 }
850}
851
852impl std::ops::Sub<Float> for Var {
853 type Output = Var;
854 fn sub(self, rhs: Float) -> Self::Output {
855 self + (-rhs)
856 }
857}
858
859impl std::ops::Mul for Var {
860 type Output = Var;
861 fn mul(self, rhs: Var) -> Self::Output {
862 self.binary_op(&rhs, Op::Mul)
863 }
864}
865
866impl std::ops::Mul<Float> for Var {
867 type Output = Var;
868 fn mul(self, rhs: Float) -> Self::Output {
869 self.scale(rhs)
870 }
871}
872
873impl std::ops::Div for Var {
874 type Output = Var;
875 fn div(self, rhs: Var) -> Self::Output {
876 self * rhs.powi(-1)
877 }
878}
879
880impl std::ops::Div<Float> for Var {
881 type Output = Var;
882 fn div(self, rhs: Float) -> Self::Output {
883 self.scale(1.0 / rhs)
884 }
885}
886
887impl std::ops::Neg for Var {
888 type Output = Var;
889 fn neg(self) -> Self::Output {
890 self.scale(-1.0)
891 }
892}
893
894#[macro_export]
934macro_rules! expr {
935 (input -> $($rest:tt)*) => {
937 {
938 use $crate::autodiff::{ExprGraph, Op};
939 let mut graph = ExprGraph::new();
940 let __input = graph.input("input".to_string());
941 $crate::expr! {
942 @build_single
943 graph,
944 __input,
945 $($rest)*
946 }
947 }
948 };
949
950 (inputs: [$($input:ident),*] $($rest:tt)*) => {
952 {
953 use $crate::autodiff::{ExprGraph, Op};
954 let mut graph = ExprGraph::new();
955 $(let $input = graph.input(stringify!($input).to_string());)*
956 $crate::expr! {
957 @build_multi
958 graph,
959 $($rest)*
960 }
961 }
962 };
963
964 (@build_single $graph:ident, $node:ident, Add -> $($rest:tt)*) => {
966 compile_error!("Add is n-ary; use `inputs: [...]` and `(@a, @b, ...) -> Add`");
967 };
968
969 (@build_single $graph:ident, $node:ident, Mul -> $($rest:tt)*) => {
970 compile_error!("Mul is n-ary; use `inputs: [...]` and `(@a, @b, ...) -> Mul`");
971 };
972
973 (@build_single $graph:ident, $node:ident, $op:ident -> $($rest:tt)*) => {
974 let __next = $graph.operation(Op::$op, vec![$node]);
975 $crate::expr! {
976 @build_single
977 $graph,
978 __next,
979 $($rest)*
980 }
981 };
982
983 (@build_single $graph:ident, $node:ident, $op:ident ( $($op_args:tt)* ) -> $($rest:tt)*) => {
984 let __next = $graph.operation(Op::$op($($op_args)*), vec![$node]);
985 $crate::expr! {
986 @build_single
987 $graph,
988 __next,
989 $($rest)*
990 }
991 };
992
993 (@build_single $graph:ident, $node:ident, output) => {
994 $graph.output($node);
995 $graph
996 };
997
998 (@build_multi $graph:ident, $node:ident -> Add -> @ $result:ident $($rest:tt)*) => {
999 compile_error!("Add is n-ary; use (@a, @b, ...) -> Add");
1000 };
1001
1002 (@build_multi $graph:ident, $node:ident -> Mul -> @ $result:ident $($rest:tt)*) => {
1003 compile_error!("Mul is n-ary; use (@a, @b, ...) -> Mul");
1004 };
1005
1006 (@build_multi $graph:ident, $node:ident -> Add ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1007 compile_error!("Add takes no arguments and is n-ary; use (@a, @b, ...) -> Add");
1008 };
1009
1010 (@build_multi $graph:ident, $node:ident -> Mul ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1011 compile_error!("Mul takes no arguments and is n-ary; use (@a, @b, ...) -> Mul");
1012 };
1013
1014 (@build_multi $graph:ident, $node:ident -> $op:ident -> @ $result:ident $($rest:tt)*) => {
1015 let $result = $graph.operation(Op::$op, vec![$node]);
1016 $crate::expr! { @build_multi $graph, $($rest)* }
1017 };
1018
1019 (@build_multi $graph:ident, $node:ident -> $op:ident ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1020 let $result = $graph.operation(Op::$op($($op_args)*), vec![$node]);
1021 $crate::expr! { @build_multi $graph, $($rest)* }
1022 };
1023
1024 (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> Sin -> @ $result:ident $($rest:tt)*) => {
1026 compile_error!("Sin is unary; use x -> Sin");
1027 };
1028
1029 (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> Cos -> @ $result:ident $($rest:tt)*) => {
1030 compile_error!("Cos is unary; use x -> Cos");
1031 };
1032
1033 (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> Scale ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1034 compile_error!("Scale is unary; use x -> Scale(factor)");
1035 };
1036
1037 (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> Pow ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1038 compile_error!("Pow is unary; use x -> Pow(exp)");
1039 };
1040
1041 (@build_multi $graph:ident, ( @ $node:ident ) -> Add -> @ $result:ident $($rest:tt)*) => {
1043 compile_error!("Add requires at least 2 inputs");
1044 };
1045
1046 (@build_multi $graph:ident, ( @ $node:ident ) -> Mul -> @ $result:ident $($rest:tt)*) => {
1047 compile_error!("Mul requires at least 2 inputs");
1048 };
1049
1050 (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> $op:ident -> @ $result:ident $($rest:tt)*) => {
1051 let $result = $graph.operation(Op::$op, vec![$($node),+]);
1052 $crate::expr! { @build_multi $graph, $($rest)* }
1053 };
1054
1055 (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> $op:ident ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1057 let $result = $graph.operation(Op::$op($($op_args)*), vec![$($node),+]);
1058 $crate::expr! { @build_multi $graph, $($rest)* }
1059 };
1060
1061 (@build_multi $graph:ident, output @ $node:ident) => {
1062 $graph.output($node);
1063 $graph
1064 };
1065
1066 (@build_multi $graph:ident, output) => {
1067 $graph
1068 };
1069}
1070
1071#[cfg(test)]
1072mod tests {
1073 use super::*;
1074
1075 fn approx_eq(a: Float, b: Float, eps: Float) {
1076 let diff = (a - b).abs();
1077 assert!(diff <= eps, "expected {a} ~= {b} (diff={diff}, eps={eps})");
1078 }
1079
1080 #[test]
1081 fn reverse_matches_forward_and_finite_difference() {
1082 let mut g = ExprGraph::new();
1083 let x = g.input("x".to_string());
1084 let z = g.input("z".to_string());
1085 let x_sq = g.operation(Op::Pow(2), [x]);
1086 let z_cos = g.operation(Op::Cos, [z]);
1087 let sum = g.operation(Op::Add, [x_sq, z_cos]);
1088 let out = g.operation(Op::Sin, [sum]);
1089 g.output(out);
1090
1091 let base = [1.3, -0.7];
1092 let (fwd_val, fwd_grad) = g.eval_fwd_one(&base);
1093 let (rev_val, rev_grad) = g.eval_one(&base);
1094
1095 approx_eq(fwd_val, rev_val, 1e-12);
1096 approx_eq(fwd_grad[0], rev_grad[0], 1e-10);
1097 approx_eq(fwd_grad[1], rev_grad[1], 1e-10);
1098
1099 let eps = 1e-7;
1100 for i in 0..base.len() {
1101 let mut plus = base;
1102 let mut minus = base;
1103 plus[i] += eps;
1104 minus[i] -= eps;
1105 let f_plus = g.eval_fwd_one(&plus).0;
1106 let f_minus = g.eval_fwd_one(&minus).0;
1107 let numeric = (f_plus - f_minus) / (2.0 * eps);
1108 approx_eq(rev_grad[i], numeric, 1e-6);
1109 }
1110 }
1111
1112 #[test]
1113 fn output_rejects_foreign_node_id() {
1114 let mut g1 = ExprGraph::new();
1115 let foreign = g1.input("x".to_string());
1116
1117 let mut g2 = ExprGraph::new();
1118 let _ = g2.input("y".to_string());
1119 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1120 g2.output(foreign);
1121 }));
1122 assert!(result.is_err());
1123 }
1124
1125 #[test]
1126 fn tape_try_set_variants() {
1127 let mut tape = Tape::new();
1128 let x = tape.input("x", 1.0);
1129 let y = tape.input("y", 2.0);
1130 let out = x + y;
1131
1132 tape.try_set_inputs(&[3.0, 4.0])
1133 .expect("valid input update");
1134 let grads = tape.gradients(&out);
1135 approx_eq(grads.value, 7.0, 1e-12);
1136
1137 let err = tape
1138 .try_set_inputs(&[1.0])
1139 .expect_err("length mismatch should fail");
1140 assert!(matches!(
1141 err,
1142 TapeError::InputLengthMismatch {
1143 expected: 2,
1144 got: 1
1145 }
1146 ));
1147
1148 tape.try_set("x", 5.0).expect("known input should be set");
1149 let err = tape
1150 .try_set("missing", 0.0)
1151 .expect_err("unknown input should fail");
1152 assert!(matches!(err, TapeError::UnknownInput(_)));
1153 }
1154
1155 #[test]
1156 fn pow_zero_has_zero_gradient_at_zero() {
1157 let mut g = ExprGraph::new();
1158 let x = g.input("x".to_string());
1159 let out = g.operation(Op::Pow(0), [x]);
1160 g.output(out);
1161
1162 let (value, grads) = g.eval_one(&[0.0]);
1163 approx_eq(value, 1.0, 1e-12);
1164 approx_eq(grads[0], 0.0, 1e-12);
1165 assert!(grads[0].is_finite());
1166 }
1167}