Skip to main content

tang_ad/
lib.rs

1//! Reverse-mode automatic differentiation.
2//!
3//! Tape-based AD for efficient gradient computation when outputs are
4//! scalar and inputs are many (the ML/physics optimization case).
5
6#![no_std]
7
8extern crate alloc;
9
10mod tape;
11mod var;
12
13pub use tape::Tape;
14pub use var::Var;
15
16use alloc::vec::Vec;
17use tang_la::{DMat, DVec};
18/// Compute gradient of scalar-valued function via reverse-mode AD.
19///
20/// Returns gradient vector where `grad[i]` = ∂f/∂x_i.
21pub fn grad<F>(f: F, x: &[f64]) -> DVec<f64>
22where
23    F: Fn(&[Var]) -> Var,
24{
25    let tape = Tape::new();
26    let vars: Vec<Var> = x.iter().map(|&v| tape.var(v)).collect();
27    let indices: Vec<usize> = vars.iter().map(|v| v.index).collect();
28    let result = f(&vars);
29    let all_grads = result.backward();
30    DVec::from_fn(x.len(), |i| all_grads[indices[i]])
31}
32
33/// Compute Jacobian via forward-mode (Dual numbers).
34///
35/// Efficient when n (input dim) is small relative to m (output dim).
36pub fn jacobian_fwd<F>(f: F, x: &[f64]) -> DMat<f64>
37where
38    F: Fn(&[tang::Dual<f64>]) -> Vec<tang::Dual<f64>>,
39{
40    let n = x.len();
41    let mut columns = Vec::new();
42    for i in 0..n {
43        let inputs: Vec<tang::Dual<f64>> = x
44            .iter()
45            .enumerate()
46            .map(|(j, &v)| {
47                if i == j {
48                    tang::Dual::var(v)
49                } else {
50                    tang::Dual::constant(v)
51                }
52            })
53            .collect();
54        let outputs = f(&inputs);
55        columns.push(outputs.iter().map(|d| d.dual).collect::<Vec<_>>());
56    }
57    let m = columns.first().map_or(0, |c| c.len());
58    DMat::from_fn(m, n, |i, j| columns[j][i])
59}
60
61/// Compute Hessian of scalar-valued function via forward-over-forward (Dual<Dual<f64>>).
62pub fn hessian<F>(f: F, x: &[f64]) -> DMat<f64>
63where
64    F: Fn(&[tang::Dual<tang::Dual<f64>>]) -> tang::Dual<tang::Dual<f64>>,
65{
66    let n = x.len();
67    DMat::from_fn(n, n, |i, j| {
68        // Seed: x_k = Dual(Dual(x_k, δ_kj), Dual(δ_ki, 0))
69        let inputs: Vec<tang::Dual<tang::Dual<f64>>> = (0..n)
70            .map(|k| {
71                let real = tang::Dual::new(x[k], if k == j { 1.0 } else { 0.0 });
72                let dual = tang::Dual::new(if k == i { 1.0 } else { 0.0 }, 0.0);
73                tang::Dual::new(real, dual)
74            })
75            .collect();
76        f(&inputs).dual.dual
77    })
78}
79
80/// Vector-Jacobian product via reverse-mode: v^T J
81pub fn vjp<F>(f: F, x: &[f64], v: &[f64]) -> DVec<f64>
82where
83    F: Fn(&[Var]) -> Vec<Var>,
84{
85    let tape = Tape::new();
86    let vars: Vec<Var> = x.iter().map(|&val| tape.var(val)).collect();
87    let indices: Vec<usize> = vars.iter().map(|var| var.index).collect();
88    let outputs = f(&vars);
89
90    // Accumulate v^T J by taking weighted sum of output gradients
91    let n = x.len();
92    let mut result = DVec::zeros(n);
93    for (k, out) in outputs.iter().enumerate() {
94        let grads = out.backward();
95        let vk = v[k];
96        for i in 0..n {
97            result[i] = result[i] + vk * grads[indices[i]];
98        }
99    }
100    result
101}
102
103/// Jacobian-vector product via forward-mode: J v
104pub fn jvp<F>(f: F, x: &[f64], v: &[f64]) -> DVec<f64>
105where
106    F: Fn(&[tang::Dual<f64>]) -> Vec<tang::Dual<f64>>,
107{
108    let inputs: Vec<tang::Dual<f64>> = x
109        .iter()
110        .zip(v.iter())
111        .map(|(&xi, &vi)| tang::Dual::new(xi, vi))
112        .collect();
113    let outputs = f(&inputs);
114    DVec::from_fn(outputs.len(), |i| outputs[i].dual)
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use tang::Scalar;
121
122    #[test]
123    fn grad_simple() {
124        // f(x, y) = x*y, grad = (y, x)
125        let g = grad(|x| &x[0] * &x[1], &[3.0, 5.0]);
126        assert!((g[0] - 5.0).abs() < 1e-10);
127        assert!((g[1] - 3.0).abs() < 1e-10);
128    }
129
130    #[test]
131    fn grad_quadratic() {
132        // f(x) = x^2, grad = 2x at x=4 -> 8
133        let g = grad(|x| &x[0] * &x[0], &[4.0]);
134        assert!((g[0] - 8.0).abs() < 1e-10);
135    }
136
137    #[test]
138    fn jacobian_fwd_linear() {
139        // f(x, y) = (x + y, x - y)
140        let j = jacobian_fwd(|x| alloc::vec![x[0] + x[1], x[0] - x[1]], &[1.0, 2.0]);
141        assert_eq!(j.get(0, 0), 1.0);
142        assert_eq!(j.get(0, 1), 1.0);
143        assert_eq!(j.get(1, 0), 1.0);
144        assert_eq!(j.get(1, 1), -1.0);
145    }
146
147    #[test]
148    fn hessian_quadratic() {
149        // f(x, y) = x^2 + 2*x*y + y^2
150        // H = [[2, 2], [2, 2]]
151        let h = hessian(
152            |x| x[0] * x[0] + x[0] * x[1] * tang::Dual::from_f64(2.0) + x[1] * x[1],
153            &[1.0, 1.0],
154        );
155        assert!((h.get(0, 0) - 2.0).abs() < 1e-10);
156        assert!((h.get(0, 1) - 2.0).abs() < 1e-10);
157        assert!((h.get(1, 0) - 2.0).abs() < 1e-10);
158        assert!((h.get(1, 1) - 2.0).abs() < 1e-10);
159    }
160
161    #[test]
162    fn jvp_simple() {
163        // f(x, y) = (x*y, x+y), J = [[y, x], [1, 1]]
164        // Jv at (3,5), v=(1,0): (5, 1)
165        let result = jvp(
166            |x| alloc::vec![x[0] * x[1], x[0] + x[1]],
167            &[3.0, 5.0],
168            &[1.0, 0.0],
169        );
170        assert!((result[0] - 5.0).abs() < 1e-10);
171        assert!((result[1] - 1.0).abs() < 1e-10);
172    }
173}