1use std::collections::HashMap;
4
5use crate::graph::ExprGraph;
6use crate::node::{ExprId, Node};
7
8impl ExprGraph {
9 pub fn diff(&mut self, expr: ExprId, var: u16) -> ExprId {
14 let mut memo = HashMap::new();
15 self.diff_inner(expr, var, &mut memo)
16 }
17
18 fn diff_inner(
19 &mut self,
20 expr: ExprId,
21 var: u16,
22 memo: &mut HashMap<(ExprId, u16), ExprId>,
23 ) -> ExprId {
24 if let Some(&cached) = memo.get(&(expr, var)) {
25 return cached;
26 }
27
28 let result = match self.node(expr) {
29 Node::Var(n) => {
30 if n == var {
31 ExprId::ONE
32 } else {
33 ExprId::ZERO
34 }
35 }
36 Node::Lit(_) => ExprId::ZERO,
37
38 Node::Add(a, b) => {
39 let da = self.diff_inner(a, var, memo);
41 let db = self.diff_inner(b, var, memo);
42 self.add(da, db)
43 }
44
45 Node::Mul(a, b) => {
46 let da = self.diff_inner(a, var, memo);
48 let db = self.diff_inner(b, var, memo);
49 let t1 = self.mul(da, b);
50 let t2 = self.mul(a, db);
51 self.add(t1, t2)
52 }
53
54 Node::Neg(a) => {
55 let da = self.diff_inner(a, var, memo);
57 self.neg(da)
58 }
59
60 Node::Recip(a) => {
61 let da = self.diff_inner(a, var, memo);
63 let a_sq = self.mul(a, a);
64 let r = self.recip(a_sq);
65 let t = self.mul(da, r);
66 self.neg(t)
67 }
68
69 Node::Sqrt(a) => {
70 let da = self.diff_inner(a, var, memo);
72 let sq = self.sqrt(a);
73 let two_sq = self.mul(ExprId::TWO, sq);
74 let r = self.recip(two_sq);
75 self.mul(da, r)
76 }
77
78 Node::Sin(a) => {
79 let da = self.diff_inner(a, var, memo);
82 let half_pi = self.lit(std::f64::consts::FRAC_PI_2);
83 let shifted = self.add(a, half_pi);
84 let cos_a = self.sin(shifted);
85 self.mul(cos_a, da)
86 }
87
88 Node::Atan2(y, x) => {
89 let dy = self.diff_inner(y, var, memo);
91 let dx = self.diff_inner(x, var, memo);
92 let x_dy = self.mul(x, dy);
93 let y_dx = self.mul(y, dx);
94 let neg_y_dx = self.neg(y_dx);
95 let numer = self.add(x_dy, neg_y_dx);
96 let xx = self.mul(x, x);
97 let yy = self.mul(y, y);
98 let denom = self.add(xx, yy);
99 let r = self.recip(denom);
100 self.mul(numer, r)
101 }
102
103 Node::Exp2(a) => {
104 let da = self.diff_inner(a, var, memo);
106 let ln2 = self.lit(std::f64::consts::LN_2);
107 let exp2_a = self.exp2(a);
108 let t = self.mul(ln2, exp2_a);
109 self.mul(t, da)
110 }
111
112 Node::Log2(a) => {
113 let da = self.diff_inner(a, var, memo);
115 let ln2 = self.lit(std::f64::consts::LN_2);
116 let ln2_a = self.mul(ln2, a);
117 let r = self.recip(ln2_a);
118 self.mul(da, r)
119 }
120
121 Node::Select(c, a, b) => {
122 let da = self.diff_inner(a, var, memo);
124 let db = self.diff_inner(b, var, memo);
125 self.select(c, da, db)
126 }
127 };
128
129 memo.insert((expr, var), result);
130 result
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use crate::graph::ExprGraph;
137 use crate::node::ExprId;
138
139 #[test]
140 fn diff_constant() {
141 let mut g = ExprGraph::new();
142 let c = g.lit(5.0);
143 let dc = g.diff(c, 0);
144 assert_eq!(dc, ExprId::ZERO);
145 }
146
147 #[test]
148 fn diff_var_self() {
149 let mut g = ExprGraph::new();
150 let x = g.var(0);
151 let dx = g.diff(x, 0);
152 assert_eq!(dx, ExprId::ONE);
153 }
154
155 #[test]
156 fn diff_var_other() {
157 let mut g = ExprGraph::new();
158 let x = g.var(0);
159 let dx = g.diff(x, 1);
160 assert_eq!(dx, ExprId::ZERO);
161 }
162
163 #[test]
164 fn diff_add() {
165 let mut g = ExprGraph::new();
167 let x = g.var(0);
168 let c = g.lit(3.0);
169 let sum = g.add(x, c);
170 let d = g.diff(sum, 0);
171 let result: f64 = g.eval(d, &[99.0]); assert!((result - 1.0).abs() < 1e-10);
174 }
175
176 #[test]
177 fn diff_mul_product_rule() {
178 let mut g = ExprGraph::new();
180 let x = g.var(0);
181 let xx = g.mul(x, x);
182 let d = g.diff(xx, 0);
183 let result: f64 = g.eval(d, &[3.0]);
185 assert!((result - 6.0).abs() < 1e-10);
186 }
187
188 #[test]
189 fn diff_sin() {
190 let mut g = ExprGraph::new();
192 let x = g.var(0);
193 let s = g.sin(x);
194 let ds = g.diff(s, 0);
195 let result: f64 = g.eval(ds, &[0.0]);
197 assert!((result - 1.0).abs() < 1e-10);
198 }
199
200 #[test]
201 fn diff_chain_rule() {
202 let mut g = ExprGraph::new();
204 let x = g.var(0);
205 let xx = g.mul(x, x);
206 let s = g.sin(xx);
207 let ds = g.diff(s, 0);
208 let expected = 2.0 * 1.0_f64.cos();
210 let result: f64 = g.eval(ds, &[1.0]);
211 assert!((result - expected).abs() < 1e-10);
212 }
213
214 #[test]
215 fn diff_sqrt() {
216 let mut g = ExprGraph::new();
218 let x = g.var(0);
219 let sq = g.sqrt(x);
220 let d = g.diff(sq, 0);
221 let result: f64 = g.eval(d, &[4.0]);
223 assert!((result - 0.25).abs() < 1e-10);
224 }
225
226 #[test]
227 fn diff_recip() {
228 let mut g = ExprGraph::new();
230 let x = g.var(0);
231 let r = g.recip(x);
232 let d = g.diff(r, 0);
233 let result: f64 = g.eval(d, &[2.0]);
235 assert!((result - (-0.25)).abs() < 1e-10);
236 }
237
238 #[test]
239 fn diff_memoization() {
240 let mut g = ExprGraph::new();
243 let x = g.var(0);
244 let xx = g.mul(x, x);
245 let sum = g.add(xx, xx);
246 let d = g.diff(sum, 0);
247 let result: f64 = g.eval(d, &[3.0]);
249 assert!((result - 12.0).abs() < 1e-10);
250 }
251
252 #[test]
253 fn diff_select() {
254 let mut g = ExprGraph::new();
256 let x = g.var(0);
257 let xx = g.mul(x, x);
258 let xp1 = g.add(x, ExprId::ONE);
259 let s = g.select(x, xx, xp1);
260 let ds = g.diff(s, 0);
261 let result: f64 = g.eval(ds, &[2.0]);
262 assert!((result - 4.0).abs() < 1e-10);
263
264 let result2: f64 = g.eval(ds, &[-1.0]);
266 assert!((result2 - 1.0).abs() < 1e-10);
267 }
268
269 #[test]
270 fn diff_dot_product() {
271 let mut g = ExprGraph::new();
274 let x0 = g.var(0);
275 let x1 = g.var(1);
276 let x2 = g.var(2);
277 let x3 = g.var(3);
278 let x4 = g.var(4);
279 let x5 = g.var(5);
280
281 let t0 = g.mul(x0, x3);
282 let t1 = g.mul(x1, x4);
283 let t2 = g.mul(x2, x5);
284 let s01 = g.add(t0, t1);
285 let dot = g.add(s01, t2);
286
287 let d0 = g.diff(dot, 0);
288 let result: f64 = g.eval(d0, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
290 assert!((result - 4.0).abs() < 1e-10);
291 }
292}