sim_lib_numbers_ad/implementation/tape.rs
1//! Reverse-mode autodiff tape: a graph of recorded operations with forward
2//! evaluation and reverse gradient accumulation.
3
4/// A handle to a node recorded on a [`Tape`]: its index into the tape.
5///
6/// Returned by every tape operation and passed back in to build further
7/// operations or to read a value or gradient.
8#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
9pub struct Var(pub usize);
10
11/// One recorded operation on a [`Tape`].
12///
13/// Operands and arguments are stored as the indices ([`Var`] positions) of the
14/// nodes they consume, forming the computation graph that the reverse pass
15/// walks to accumulate gradients.
16#[derive(Clone, Debug, PartialEq)]
17pub enum TapeNode {
18 /// A constant with the given value.
19 Const(f64),
20 /// An independent input bound to the given gradient slot.
21 Input(usize),
22 /// Sum of the two operand nodes.
23 Add(usize, usize),
24 /// Difference of the two operand nodes (first minus second).
25 Sub(usize, usize),
26 /// Product of the two operand nodes.
27 Mul(usize, usize),
28 /// Quotient of the two operand nodes (first divided by second).
29 Div(usize, usize),
30 /// Sine of the operand node.
31 Sin(usize),
32 /// Cosine of the operand node.
33 Cos(usize),
34 /// Exponential of the operand node.
35 Exp(usize),
36 /// Natural logarithm of the operand node.
37 Ln(usize),
38 /// Square root of the operand node.
39 Sqrt(usize),
40 /// Reciprocal of the operand node.
41 Recip(usize),
42}
43
44/// A reverse-mode autodiff tape: the recorded operation graph plus the forward
45/// value of each node.
46///
47/// Build an expression by calling the operation methods, which append nodes and
48/// return [`Var`] handles; the forward value is computed eagerly as each node
49/// is pushed. Call [`grad`](Tape::grad) on an output node to run the reverse
50/// pass and accumulate the partial derivatives with respect to every input.
51///
52/// # Examples
53///
54/// ```
55/// use sim_lib_numbers_ad::Tape;
56///
57/// // f(a, b) = a * b + a, at a = 2, b = 5.
58/// let mut tape = Tape::new();
59/// let a = tape.input(0, 2.0);
60/// let b = tape.input(1, 5.0);
61/// let product = tape.mul(a, b);
62/// let out = tape.add(product, a);
63/// assert_eq!(tape.value(out), 12.0);
64/// // df/da = b + 1 = 6, df/db = a = 2.
65/// assert_eq!(tape.grad(out, 2), vec![6.0, 2.0]);
66/// ```
67#[derive(Clone, Debug, Default)]
68pub struct Tape {
69 nodes: Vec<TapeNode>,
70 values: Vec<f64>,
71}
72
73impl Tape {
74 /// Creates an empty tape.
75 pub fn new() -> Self {
76 Self::default()
77 }
78
79 /// Records a constant node and returns its handle.
80 pub fn constant(&mut self, value: f64) -> Var {
81 self.push(TapeNode::Const(value), value)
82 }
83
84 /// Records an independent input bound to gradient slot `slot` with the given
85 /// value, and returns its handle.
86 pub fn input(&mut self, slot: usize, value: f64) -> Var {
87 self.push(TapeNode::Input(slot), value)
88 }
89
90 /// Records `a + b` and returns the handle of the sum.
91 pub fn add(&mut self, a: Var, b: Var) -> Var {
92 self.push(TapeNode::Add(a.0, b.0), self.values[a.0] + self.values[b.0])
93 }
94
95 /// Records `a - b` and returns the handle of the difference.
96 pub fn sub(&mut self, a: Var, b: Var) -> Var {
97 self.push(TapeNode::Sub(a.0, b.0), self.values[a.0] - self.values[b.0])
98 }
99
100 /// Records `a * b` and returns the handle of the product.
101 pub fn mul(&mut self, a: Var, b: Var) -> Var {
102 self.push(TapeNode::Mul(a.0, b.0), self.values[a.0] * self.values[b.0])
103 }
104
105 /// Records `a / b` and returns the handle of the quotient.
106 pub fn div(&mut self, a: Var, b: Var) -> Var {
107 self.push(TapeNode::Div(a.0, b.0), self.values[a.0] / self.values[b.0])
108 }
109
110 /// Records `sin(arg)` and returns its handle.
111 pub fn sin(&mut self, arg: Var) -> Var {
112 self.push(TapeNode::Sin(arg.0), self.values[arg.0].sin())
113 }
114
115 /// Records `cos(arg)` and returns its handle.
116 pub fn cos(&mut self, arg: Var) -> Var {
117 self.push(TapeNode::Cos(arg.0), self.values[arg.0].cos())
118 }
119
120 /// Records `exp(arg)` and returns its handle.
121 pub fn exp(&mut self, arg: Var) -> Var {
122 self.push(TapeNode::Exp(arg.0), self.values[arg.0].exp())
123 }
124
125 /// Records `ln(arg)` and returns its handle.
126 pub fn ln(&mut self, arg: Var) -> Var {
127 self.push(TapeNode::Ln(arg.0), self.values[arg.0].ln())
128 }
129
130 /// Records `sqrt(arg)` and returns its handle.
131 pub fn sqrt(&mut self, arg: Var) -> Var {
132 self.push(TapeNode::Sqrt(arg.0), self.values[arg.0].sqrt())
133 }
134
135 /// Records `1 / arg` and returns its handle.
136 pub fn recip(&mut self, arg: Var) -> Var {
137 self.push(TapeNode::Recip(arg.0), self.values[arg.0].recip())
138 }
139
140 /// Returns the forward value recorded for `var`.
141 pub fn value(&self, var: Var) -> f64 {
142 self.values[var.0]
143 }
144
145 /// Runs the reverse pass from output `out` and returns the gradient with
146 /// respect to the `n_inputs` input slots.
147 ///
148 /// Seeds the adjoint of `out` with `1.0`, walks the recorded nodes in
149 /// reverse, and accumulates each input slot's partial derivative; the
150 /// returned vector has length `n_inputs`.
151 pub fn grad(&self, out: Var, n_inputs: usize) -> Vec<f64> {
152 let mut adjoints = vec![0.0; self.nodes.len()];
153 let mut input_grad = vec![0.0; n_inputs];
154 adjoints[out.0] = 1.0;
155
156 for index in (0..self.nodes.len()).rev() {
157 let seed = adjoints[index];
158 if seed == 0.0 {
159 continue;
160 }
161 match self.nodes[index] {
162 TapeNode::Const(_) => {}
163 TapeNode::Input(slot) => {
164 if let Some(grad) = input_grad.get_mut(slot) {
165 *grad += seed;
166 }
167 }
168 TapeNode::Add(a, b) => {
169 adjoints[a] += seed;
170 adjoints[b] += seed;
171 }
172 TapeNode::Sub(a, b) => {
173 adjoints[a] += seed;
174 adjoints[b] -= seed;
175 }
176 TapeNode::Mul(a, b) => {
177 adjoints[a] += seed * self.values[b];
178 adjoints[b] += seed * self.values[a];
179 }
180 TapeNode::Div(a, b) => {
181 let denom = self.values[b] * self.values[b];
182 adjoints[a] += seed / self.values[b];
183 adjoints[b] -= seed * self.values[a] / denom;
184 }
185 TapeNode::Sin(arg) => adjoints[arg] += seed * self.values[arg].cos(),
186 TapeNode::Cos(arg) => adjoints[arg] -= seed * self.values[arg].sin(),
187 TapeNode::Exp(arg) => adjoints[arg] += seed * self.values[index],
188 TapeNode::Ln(arg) => adjoints[arg] += seed / self.values[arg],
189 TapeNode::Sqrt(arg) => adjoints[arg] += seed / (2.0 * self.values[index]),
190 TapeNode::Recip(arg) => {
191 let denom = self.values[arg] * self.values[arg];
192 adjoints[arg] -= seed / denom;
193 }
194 }
195 }
196
197 input_grad
198 }
199
200 fn push(&mut self, node: TapeNode, value: f64) -> Var {
201 let index = self.nodes.len();
202 self.nodes.push(node);
203 self.values.push(value);
204 Var(index)
205 }
206}