scirs2_optimize/automatic_differentiation/
dual_numbers.rs1use scirs2_core::ndarray::{Array1, ArrayView1};
8use std::ops::{Add, Div, Mul, Neg, Sub};
9
10#[derive(Debug, Clone, Copy)]
15pub struct Dual {
16 value: f64,
18 derivative: f64,
20}
21
22impl Dual {
23 pub fn new(value: f64, derivative: f64) -> Self {
25 Self { value, derivative }
26 }
27
28 pub fn constant(value: f64) -> Self {
30 Self {
31 value,
32 derivative: 0.0,
33 }
34 }
35
36 pub fn variable(value: f64) -> Self {
38 Self {
39 value,
40 derivative: 1.0,
41 }
42 }
43
44 pub fn value(self) -> f64 {
46 self.value
47 }
48
49 pub fn derivative(self) -> f64 {
51 self.derivative
52 }
53
54 pub fn sin(self) -> Self {
56 Self {
57 value: self.value.sin(),
58 derivative: self.derivative * self.value.cos(),
59 }
60 }
61
62 pub fn cos(self) -> Self {
64 Self {
65 value: self.value.cos(),
66 derivative: -self.derivative * self.value.sin(),
67 }
68 }
69
70 pub fn tan(self) -> Self {
72 let cos_val = self.value.cos();
73 Self {
74 value: self.value.tan(),
75 derivative: self.derivative / (cos_val * cos_val),
76 }
77 }
78
79 pub fn exp(self) -> Self {
81 let exp_val = self.value.exp();
82 Self {
83 value: exp_val,
84 derivative: self.derivative * exp_val,
85 }
86 }
87
88 pub fn ln(self) -> Self {
90 Self {
91 value: self.value.ln(),
92 derivative: self.derivative / self.value,
93 }
94 }
95
96 pub fn powi(self, n: i32) -> Self {
98 let n_f64 = n as f64;
99 Self {
100 value: self.value.powi(n),
101 derivative: self.derivative * n_f64 * self.value.powi(n - 1),
102 }
103 }
104
105 pub fn powf(self, p: f64) -> Self {
107 Self {
108 value: self.value.powf(p),
109 derivative: self.derivative * p * self.value.powf(p - 1.0),
110 }
111 }
112
113 pub fn sqrt(self) -> Self {
115 let sqrt_val = self.value.sqrt();
116 Self {
117 value: sqrt_val,
118 derivative: self.derivative / (2.0 * sqrt_val),
119 }
120 }
121
122 pub fn abs(self) -> Self {
124 Self {
125 value: self.value.abs(),
126 derivative: if self.value >= 0.0 {
127 self.derivative
128 } else {
129 -self.derivative
130 },
131 }
132 }
133
134 pub fn max(self, other: Self) -> Self {
136 if self.value >= other.value {
137 self
138 } else {
139 other
140 }
141 }
142
143 pub fn min(self, other: Self) -> Self {
145 if self.value <= other.value {
146 self
147 } else {
148 other
149 }
150 }
151}
152
153impl Add for Dual {
156 type Output = Self;
157
158 fn add(self, other: Self) -> Self {
159 Self {
160 value: self.value + other.value,
161 derivative: self.derivative + other.derivative,
162 }
163 }
164}
165
166impl Add<f64> for Dual {
167 type Output = Self;
168
169 fn add(self, scalar: f64) -> Self {
170 Self {
171 value: self.value + scalar,
172 derivative: self.derivative,
173 }
174 }
175}
176
177impl Add<Dual> for f64 {
178 type Output = Dual;
179
180 fn add(self, dual: Dual) -> Dual {
181 dual + self
182 }
183}
184
185impl Sub for Dual {
186 type Output = Self;
187
188 fn sub(self, other: Self) -> Self {
189 Self {
190 value: self.value - other.value,
191 derivative: self.derivative - other.derivative,
192 }
193 }
194}
195
196impl Sub<f64> for Dual {
197 type Output = Self;
198
199 fn sub(self, scalar: f64) -> Self {
200 Self {
201 value: self.value - scalar,
202 derivative: self.derivative,
203 }
204 }
205}
206
207impl Sub<Dual> for f64 {
208 type Output = Dual;
209
210 fn sub(self, dual: Dual) -> Dual {
211 Dual {
212 value: self - dual.value,
213 derivative: -dual.derivative,
214 }
215 }
216}
217
218impl Mul for Dual {
219 type Output = Self;
220
221 fn mul(self, other: Self) -> Self {
222 Self {
223 value: self.value * other.value,
224 derivative: self.derivative * other.value + self.value * other.derivative,
225 }
226 }
227}
228
229impl Mul<f64> for Dual {
230 type Output = Self;
231
232 fn mul(self, scalar: f64) -> Self {
233 Self {
234 value: self.value * scalar,
235 derivative: self.derivative * scalar,
236 }
237 }
238}
239
240impl Mul<Dual> for f64 {
241 type Output = Dual;
242
243 fn mul(self, dual: Dual) -> Dual {
244 dual * self
245 }
246}
247
248impl Div for Dual {
249 type Output = Self;
250
251 fn div(self, other: Self) -> Self {
252 let denom = other.value * other.value;
253
254 let value = if other.value == 0.0 {
256 if self.value == 0.0 {
257 f64::NAN
258 }
259 else if self.value > 0.0 {
261 f64::INFINITY
262 } else {
263 f64::NEG_INFINITY
264 }
265 } else {
266 self.value / other.value
267 };
268
269 let derivative = if denom == 0.0 {
270 if other.value == 0.0 && self.derivative == 0.0 && other.derivative == 0.0 {
272 f64::NAN
273 } else {
274 f64::INFINITY
275 }
276 } else {
277 (self.derivative * other.value - self.value * other.derivative) / denom
278 };
279
280 Self { value, derivative }
281 }
282}
283
284impl Div<f64> for Dual {
285 type Output = Self;
286
287 fn div(self, scalar: f64) -> Self {
288 if scalar == 0.0 {
289 Self {
291 value: if self.value == 0.0 {
292 f64::NAN
293 } else if self.value > 0.0 {
294 f64::INFINITY
295 } else {
296 f64::NEG_INFINITY
297 },
298 derivative: if self.derivative == 0.0 {
299 f64::NAN
300 } else {
301 f64::INFINITY
302 },
303 }
304 } else {
305 Self {
306 value: self.value / scalar,
307 derivative: self.derivative / scalar,
308 }
309 }
310 }
311}
312
313impl Div<Dual> for f64 {
314 type Output = Dual;
315
316 fn div(self, dual: Dual) -> Dual {
317 Dual::constant(self) / dual
318 }
319}
320
321impl Neg for Dual {
322 type Output = Self;
323
324 fn neg(self) -> Self {
325 Self {
326 value: -self.value,
327 derivative: -self.derivative,
328 }
329 }
330}
331
332impl From<f64> for Dual {
334 fn from(value: f64) -> Self {
335 Self::constant(value)
336 }
337}
338
339impl From<Dual> for f64 {
340 fn from(dual: Dual) -> Self {
341 dual.value
342 }
343}
344
345impl PartialEq for Dual {
347 fn eq(&self, other: &Self) -> bool {
348 self.value == other.value
349 }
350}
351
352impl PartialOrd for Dual {
353 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
354 self.value.partial_cmp(&other.value)
355 }
356}
357
358pub trait DualNumber: Clone + Copy {
360 fn value(self) -> f64;
362
363 fn derivative(self) -> f64;
365
366 fn new(value: f64, derivative: f64) -> Self;
368
369 fn constant(value: f64) -> Self;
371
372 fn variable(value: f64) -> Self;
374}
375
376impl DualNumber for Dual {
377 fn value(self) -> f64 {
378 self.value
379 }
380
381 fn derivative(self) -> f64 {
382 self.derivative
383 }
384
385 fn new(value: f64, derivative: f64) -> Self {
386 Self::new(value, derivative)
387 }
388
389 fn constant(value: f64) -> Self {
390 Self::constant(value)
391 }
392
393 fn variable(value: f64) -> Self {
394 Self::variable(value)
395 }
396}
397
398#[derive(Debug, Clone)]
400pub struct MultiDual {
401 value: f64,
403 derivatives: Array1<f64>,
405}
406
407impl MultiDual {
408 pub fn new(value: f64, derivatives: Array1<f64>) -> Self {
410 Self { value, derivatives }
411 }
412
413 pub fn constant(value: f64, nvars: usize) -> Self {
415 Self {
416 value,
417 derivatives: Array1::zeros(nvars),
418 }
419 }
420
421 pub fn variable(value: f64, var_index: usize, nvars: usize) -> Self {
423 let mut derivatives = Array1::zeros(nvars);
424 derivatives[var_index] = 1.0;
425 Self { value, derivatives }
426 }
427
428 pub fn value(&self) -> f64 {
430 self.value
431 }
432
433 pub fn gradient(&self) -> &Array1<f64> {
435 &self.derivatives
436 }
437
438 pub fn partial(&self, index: usize) -> f64 {
440 self.derivatives[index]
441 }
442}
443
444impl Add for MultiDual {
446 type Output = Self;
447
448 fn add(self, other: Self) -> Self {
449 Self {
450 value: self.value + other.value,
451 derivatives: &self.derivatives + &other.derivatives,
452 }
453 }
454}
455
456impl Mul for MultiDual {
457 type Output = Self;
458
459 fn mul(self, other: Self) -> Self {
460 Self {
461 value: self.value * other.value,
462 derivatives: &self.derivatives * other.value + &other.derivatives * self.value,
463 }
464 }
465}
466
467impl Mul<f64> for MultiDual {
468 type Output = Self;
469
470 fn mul(self, scalar: f64) -> Self {
471 Self {
472 value: self.value * scalar,
473 derivatives: &self.derivatives * scalar,
474 }
475 }
476}
477
478#[allow(dead_code)]
480pub fn create_dual_variables(x: &ArrayView1<f64>) -> Vec<Dual> {
481 x.iter().map(|&xi| Dual::variable(xi)).collect()
482}
483
484#[allow(dead_code)]
486pub fn create_multi_dual_variables(x: &ArrayView1<f64>) -> Vec<MultiDual> {
487 let n = x.len();
488 x.iter()
489 .enumerate()
490 .map(|(i, &xi)| MultiDual::variable(xi, i, n))
491 .collect()
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497 use approx::assert_abs_diff_eq;
498
499 #[test]
500 fn test_dual_arithmetic() {
501 let a = Dual::new(2.0, 1.0);
502 let b = Dual::new(3.0, 0.5);
503
504 let sum = a + b;
506 assert_abs_diff_eq!(sum.value(), 5.0, epsilon = 1e-10);
507 assert_abs_diff_eq!(sum.derivative(), 1.5, epsilon = 1e-10);
508
509 let product = a * b;
511 assert_abs_diff_eq!(product.value(), 6.0, epsilon = 1e-10);
512 assert_abs_diff_eq!(product.derivative(), 4.0, epsilon = 1e-10); let quotient = a / b;
516 assert_abs_diff_eq!(quotient.value(), 2.0 / 3.0, epsilon = 1e-10);
517 assert_abs_diff_eq!(
518 quotient.derivative(),
519 (1.0 * 3.0 - 2.0 * 0.5) / (3.0 * 3.0),
520 epsilon = 1e-10
521 );
522 }
523
524 #[test]
525 fn test_dual_functions() {
526 let x = Dual::variable(1.0);
527
528 let exp_x = x.exp();
530 assert_abs_diff_eq!(exp_x.value(), std::f64::consts::E, epsilon = 1e-10);
531 assert_abs_diff_eq!(exp_x.derivative(), std::f64::consts::E, epsilon = 1e-10);
532
533 let x0 = Dual::variable(0.0);
535 let sin_x = x0.sin();
536 assert_abs_diff_eq!(sin_x.value(), 0.0, epsilon = 1e-10);
537 assert_abs_diff_eq!(sin_x.derivative(), 1.0, epsilon = 1e-10); let x3 = Dual::variable(3.0);
541 let x_squared = x3.powi(2);
542 assert_abs_diff_eq!(x_squared.value(), 9.0, epsilon = 1e-10);
543 assert_abs_diff_eq!(x_squared.derivative(), 6.0, epsilon = 1e-10); }
545
546 #[test]
547 fn test_multi_dual() {
548 let x = MultiDual::variable(2.0, 0, 2);
549 let y = MultiDual::variable(3.0, 1, 2);
550
551 let product = x * y;
553 assert_abs_diff_eq!(product.value(), 6.0, epsilon = 1e-10);
554 assert_abs_diff_eq!(product.partial(0), 3.0, epsilon = 1e-10); assert_abs_diff_eq!(product.partial(1), 2.0, epsilon = 1e-10); }
557
558 #[test]
559 fn test_create_dual_variables() {
560 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
561 let duals = create_dual_variables(&x.view());
562
563 assert_eq!(duals.len(), 3);
564 assert_abs_diff_eq!(duals[0].value(), 1.0, epsilon = 1e-10);
565 assert_abs_diff_eq!(duals[1].value(), 2.0, epsilon = 1e-10);
566 assert_abs_diff_eq!(duals[2].value(), 3.0, epsilon = 1e-10);
567
568 for dual in &duals {
570 assert_abs_diff_eq!(dual.derivative(), 1.0, epsilon = 1e-10);
571 }
572 }
573}