Skip to main content

sim_lib_numbers_ad/implementation/
tape.rs

1//! Reverse-mode autodiff tape: a graph of recorded operations with forward
2//! evaluation and reverse gradient accumulation.
3
4/// A handle to a node recorded on a [`Tape`]: its index into the tape.
5///
6/// Returned by every tape operation and passed back in to build further
7/// operations or to read a value or gradient.
8#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
9pub struct Var(pub usize);
10
11/// One recorded operation on a [`Tape`].
12///
13/// Operands and arguments are stored as the indices ([`Var`] positions) of the
14/// nodes they consume, forming the computation graph that the reverse pass
15/// walks to accumulate gradients.
16#[derive(Clone, Debug, PartialEq)]
17pub enum TapeNode {
18    /// A constant with the given value.
19    Const(f64),
20    /// An independent input bound to the given gradient slot.
21    Input(usize),
22    /// Sum of the two operand nodes.
23    Add(usize, usize),
24    /// Difference of the two operand nodes (first minus second).
25    Sub(usize, usize),
26    /// Product of the two operand nodes.
27    Mul(usize, usize),
28    /// Quotient of the two operand nodes (first divided by second).
29    Div(usize, usize),
30    /// Sine of the operand node.
31    Sin(usize),
32    /// Cosine of the operand node.
33    Cos(usize),
34    /// Exponential of the operand node.
35    Exp(usize),
36    /// Natural logarithm of the operand node.
37    Ln(usize),
38    /// Square root of the operand node.
39    Sqrt(usize),
40    /// Reciprocal of the operand node.
41    Recip(usize),
42}
43
44/// A reverse-mode autodiff tape: the recorded operation graph plus the forward
45/// value of each node.
46///
47/// Build an expression by calling the operation methods, which append nodes and
48/// return [`Var`] handles; the forward value is computed eagerly as each node
49/// is pushed. Call [`grad`](Tape::grad) on an output node to run the reverse
50/// pass and accumulate the partial derivatives with respect to every input.
51///
52/// # Examples
53///
54/// ```
55/// use sim_lib_numbers_ad::Tape;
56///
57/// // f(a, b) = a * b + a, at a = 2, b = 5.
58/// let mut tape = Tape::new();
59/// let a = tape.input(0, 2.0);
60/// let b = tape.input(1, 5.0);
61/// let product = tape.mul(a, b);
62/// let out = tape.add(product, a);
63/// assert_eq!(tape.value(out), 12.0);
64/// // df/da = b + 1 = 6, df/db = a = 2.
65/// assert_eq!(tape.grad(out, 2), vec![6.0, 2.0]);
66/// ```
67#[derive(Clone, Debug, Default)]
68pub struct Tape {
69    nodes: Vec<TapeNode>,
70    values: Vec<f64>,
71}
72
73impl Tape {
74    /// Creates an empty tape.
75    pub fn new() -> Self {
76        Self::default()
77    }
78
79    /// Records a constant node and returns its handle.
80    pub fn constant(&mut self, value: f64) -> Var {
81        self.push(TapeNode::Const(value), value)
82    }
83
84    /// Records an independent input bound to gradient slot `slot` with the given
85    /// value, and returns its handle.
86    pub fn input(&mut self, slot: usize, value: f64) -> Var {
87        self.push(TapeNode::Input(slot), value)
88    }
89
90    /// Records `a + b` and returns the handle of the sum.
91    pub fn add(&mut self, a: Var, b: Var) -> Var {
92        self.push(TapeNode::Add(a.0, b.0), self.values[a.0] + self.values[b.0])
93    }
94
95    /// Records `a - b` and returns the handle of the difference.
96    pub fn sub(&mut self, a: Var, b: Var) -> Var {
97        self.push(TapeNode::Sub(a.0, b.0), self.values[a.0] - self.values[b.0])
98    }
99
100    /// Records `a * b` and returns the handle of the product.
101    pub fn mul(&mut self, a: Var, b: Var) -> Var {
102        self.push(TapeNode::Mul(a.0, b.0), self.values[a.0] * self.values[b.0])
103    }
104
105    /// Records `a / b` and returns the handle of the quotient.
106    pub fn div(&mut self, a: Var, b: Var) -> Var {
107        self.push(TapeNode::Div(a.0, b.0), self.values[a.0] / self.values[b.0])
108    }
109
110    /// Records `sin(arg)` and returns its handle.
111    pub fn sin(&mut self, arg: Var) -> Var {
112        self.push(TapeNode::Sin(arg.0), self.values[arg.0].sin())
113    }
114
115    /// Records `cos(arg)` and returns its handle.
116    pub fn cos(&mut self, arg: Var) -> Var {
117        self.push(TapeNode::Cos(arg.0), self.values[arg.0].cos())
118    }
119
120    /// Records `exp(arg)` and returns its handle.
121    pub fn exp(&mut self, arg: Var) -> Var {
122        self.push(TapeNode::Exp(arg.0), self.values[arg.0].exp())
123    }
124
125    /// Records `ln(arg)` and returns its handle.
126    pub fn ln(&mut self, arg: Var) -> Var {
127        self.push(TapeNode::Ln(arg.0), self.values[arg.0].ln())
128    }
129
130    /// Records `sqrt(arg)` and returns its handle.
131    pub fn sqrt(&mut self, arg: Var) -> Var {
132        self.push(TapeNode::Sqrt(arg.0), self.values[arg.0].sqrt())
133    }
134
135    /// Records `1 / arg` and returns its handle.
136    pub fn recip(&mut self, arg: Var) -> Var {
137        self.push(TapeNode::Recip(arg.0), self.values[arg.0].recip())
138    }
139
140    /// Returns the forward value recorded for `var`.
141    pub fn value(&self, var: Var) -> f64 {
142        self.values[var.0]
143    }
144
145    /// Runs the reverse pass from output `out` and returns the gradient with
146    /// respect to the `n_inputs` input slots.
147    ///
148    /// Seeds the adjoint of `out` with `1.0`, walks the recorded nodes in
149    /// reverse, and accumulates each input slot's partial derivative; the
150    /// returned vector has length `n_inputs`.
151    pub fn grad(&self, out: Var, n_inputs: usize) -> Vec<f64> {
152        let mut adjoints = vec![0.0; self.nodes.len()];
153        let mut input_grad = vec![0.0; n_inputs];
154        adjoints[out.0] = 1.0;
155
156        for index in (0..self.nodes.len()).rev() {
157            let seed = adjoints[index];
158            if seed == 0.0 {
159                continue;
160            }
161            match self.nodes[index] {
162                TapeNode::Const(_) => {}
163                TapeNode::Input(slot) => {
164                    if let Some(grad) = input_grad.get_mut(slot) {
165                        *grad += seed;
166                    }
167                }
168                TapeNode::Add(a, b) => {
169                    adjoints[a] += seed;
170                    adjoints[b] += seed;
171                }
172                TapeNode::Sub(a, b) => {
173                    adjoints[a] += seed;
174                    adjoints[b] -= seed;
175                }
176                TapeNode::Mul(a, b) => {
177                    adjoints[a] += seed * self.values[b];
178                    adjoints[b] += seed * self.values[a];
179                }
180                TapeNode::Div(a, b) => {
181                    let denom = self.values[b] * self.values[b];
182                    adjoints[a] += seed / self.values[b];
183                    adjoints[b] -= seed * self.values[a] / denom;
184                }
185                TapeNode::Sin(arg) => adjoints[arg] += seed * self.values[arg].cos(),
186                TapeNode::Cos(arg) => adjoints[arg] -= seed * self.values[arg].sin(),
187                TapeNode::Exp(arg) => adjoints[arg] += seed * self.values[index],
188                TapeNode::Ln(arg) => adjoints[arg] += seed / self.values[arg],
189                TapeNode::Sqrt(arg) => adjoints[arg] += seed / (2.0 * self.values[index]),
190                TapeNode::Recip(arg) => {
191                    let denom = self.values[arg] * self.values[arg];
192                    adjoints[arg] -= seed / denom;
193                }
194            }
195        }
196
197        input_grad
198    }
199
200    fn push(&mut self, node: TapeNode, value: f64) -> Var {
201        let index = self.nodes.len();
202        self.nodes.push(node);
203        self.values.push(value);
204        Var(index)
205    }
206}