1#![no_std]
7
8extern crate alloc;
9
10mod tape;
11mod var;
12
13pub use tape::Tape;
14pub use var::Var;
15
16use alloc::vec::Vec;
17use tang_la::{DMat, DVec};
18pub fn grad<F>(f: F, x: &[f64]) -> DVec<f64>
22where
23 F: Fn(&[Var]) -> Var,
24{
25 let tape = Tape::new();
26 let vars: Vec<Var> = x.iter().map(|&v| tape.var(v)).collect();
27 let indices: Vec<usize> = vars.iter().map(|v| v.index).collect();
28 let result = f(&vars);
29 let all_grads = result.backward();
30 DVec::from_fn(x.len(), |i| all_grads[indices[i]])
31}
32
33pub fn jacobian_fwd<F>(f: F, x: &[f64]) -> DMat<f64>
37where
38 F: Fn(&[tang::Dual<f64>]) -> Vec<tang::Dual<f64>>,
39{
40 let n = x.len();
41 let mut columns = Vec::new();
42 for i in 0..n {
43 let inputs: Vec<tang::Dual<f64>> = x
44 .iter()
45 .enumerate()
46 .map(|(j, &v)| {
47 if i == j {
48 tang::Dual::var(v)
49 } else {
50 tang::Dual::constant(v)
51 }
52 })
53 .collect();
54 let outputs = f(&inputs);
55 columns.push(outputs.iter().map(|d| d.dual).collect::<Vec<_>>());
56 }
57 let m = columns.first().map_or(0, |c| c.len());
58 DMat::from_fn(m, n, |i, j| columns[j][i])
59}
60
61pub fn hessian<F>(f: F, x: &[f64]) -> DMat<f64>
63where
64 F: Fn(&[tang::Dual<tang::Dual<f64>>]) -> tang::Dual<tang::Dual<f64>>,
65{
66 let n = x.len();
67 DMat::from_fn(n, n, |i, j| {
68 let inputs: Vec<tang::Dual<tang::Dual<f64>>> = (0..n)
70 .map(|k| {
71 let real = tang::Dual::new(x[k], if k == j { 1.0 } else { 0.0 });
72 let dual = tang::Dual::new(if k == i { 1.0 } else { 0.0 }, 0.0);
73 tang::Dual::new(real, dual)
74 })
75 .collect();
76 f(&inputs).dual.dual
77 })
78}
79
80pub fn vjp<F>(f: F, x: &[f64], v: &[f64]) -> DVec<f64>
82where
83 F: Fn(&[Var]) -> Vec<Var>,
84{
85 let tape = Tape::new();
86 let vars: Vec<Var> = x.iter().map(|&val| tape.var(val)).collect();
87 let indices: Vec<usize> = vars.iter().map(|var| var.index).collect();
88 let outputs = f(&vars);
89
90 let n = x.len();
92 let mut result = DVec::zeros(n);
93 for (k, out) in outputs.iter().enumerate() {
94 let grads = out.backward();
95 let vk = v[k];
96 for i in 0..n {
97 result[i] = result[i] + vk * grads[indices[i]];
98 }
99 }
100 result
101}
102
103pub fn jvp<F>(f: F, x: &[f64], v: &[f64]) -> DVec<f64>
105where
106 F: Fn(&[tang::Dual<f64>]) -> Vec<tang::Dual<f64>>,
107{
108 let inputs: Vec<tang::Dual<f64>> = x
109 .iter()
110 .zip(v.iter())
111 .map(|(&xi, &vi)| tang::Dual::new(xi, vi))
112 .collect();
113 let outputs = f(&inputs);
114 DVec::from_fn(outputs.len(), |i| outputs[i].dual)
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use tang::Scalar;
121
122 #[test]
123 fn grad_simple() {
124 let g = grad(|x| &x[0] * &x[1], &[3.0, 5.0]);
126 assert!((g[0] - 5.0).abs() < 1e-10);
127 assert!((g[1] - 3.0).abs() < 1e-10);
128 }
129
130 #[test]
131 fn grad_quadratic() {
132 let g = grad(|x| &x[0] * &x[0], &[4.0]);
134 assert!((g[0] - 8.0).abs() < 1e-10);
135 }
136
137 #[test]
138 fn jacobian_fwd_linear() {
139 let j = jacobian_fwd(|x| alloc::vec![x[0] + x[1], x[0] - x[1]], &[1.0, 2.0]);
141 assert_eq!(j.get(0, 0), 1.0);
142 assert_eq!(j.get(0, 1), 1.0);
143 assert_eq!(j.get(1, 0), 1.0);
144 assert_eq!(j.get(1, 1), -1.0);
145 }
146
147 #[test]
148 fn hessian_quadratic() {
149 let h = hessian(
152 |x| x[0] * x[0] + x[0] * x[1] * tang::Dual::from_f64(2.0) + x[1] * x[1],
153 &[1.0, 1.0],
154 );
155 assert!((h.get(0, 0) - 2.0).abs() < 1e-10);
156 assert!((h.get(0, 1) - 2.0).abs() < 1e-10);
157 assert!((h.get(1, 0) - 2.0).abs() < 1e-10);
158 assert!((h.get(1, 1) - 2.0).abs() < 1e-10);
159 }
160
161 #[test]
162 fn jvp_simple() {
163 let result = jvp(
166 |x| alloc::vec![x[0] * x[1], x[0] + x[1]],
167 &[3.0, 5.0],
168 &[1.0, 0.0],
169 );
170 assert!((result[0] - 5.0).abs() < 1e-10);
171 assert!((result[1] - 1.0).abs() < 1e-10);
172 }
173}