scirs2_optimize/automatic_differentiation/
dual_numbers.rs

1//! Dual numbers for forward-mode automatic differentiation
2//!
3//! This module implements dual numbers, which are used for forward-mode automatic
4//! differentiation. Dual numbers extend real numbers with an infinitesimal part
5//! that tracks derivatives.
6
7use scirs2_core::ndarray::{Array1, ArrayView1};
8use std::ops::{Add, Div, Mul, Neg, Sub};
9
10/// Dual number for forward-mode automatic differentiation
11///
12/// A dual number is of the form a + b*ε where ε² = 0.
13/// This allows us to compute both function values and derivatives simultaneously.
14#[derive(Debug, Clone, Copy)]
15pub struct Dual {
16    /// Real part (function value)
17    value: f64,
18    /// Dual part (derivative)
19    derivative: f64,
20}
21
22impl Dual {
23    /// Create a new dual number
24    pub fn new(value: f64, derivative: f64) -> Self {
25        Self { value, derivative }
26    }
27
28    /// Create a dual number representing a constant (derivative = 0)
29    pub fn constant(value: f64) -> Self {
30        Self {
31            value,
32            derivative: 0.0,
33        }
34    }
35
36    /// Create a dual number representing a variable (derivative = 1)
37    pub fn variable(value: f64) -> Self {
38        Self {
39            value,
40            derivative: 1.0,
41        }
42    }
43
44    /// Get the real part (function value)
45    pub fn value(self) -> f64 {
46        self.value
47    }
48
49    /// Get the dual part (derivative)
50    pub fn derivative(self) -> f64 {
51        self.derivative
52    }
53
54    /// Compute sine of dual number
55    pub fn sin(self) -> Self {
56        Self {
57            value: self.value.sin(),
58            derivative: self.derivative * self.value.cos(),
59        }
60    }
61
62    /// Compute cosine of dual number
63    pub fn cos(self) -> Self {
64        Self {
65            value: self.value.cos(),
66            derivative: -self.derivative * self.value.sin(),
67        }
68    }
69
70    /// Compute tangent of dual number
71    pub fn tan(self) -> Self {
72        let cos_val = self.value.cos();
73        Self {
74            value: self.value.tan(),
75            derivative: self.derivative / (cos_val * cos_val),
76        }
77    }
78
79    /// Compute exponential of dual number
80    pub fn exp(self) -> Self {
81        let exp_val = self.value.exp();
82        Self {
83            value: exp_val,
84            derivative: self.derivative * exp_val,
85        }
86    }
87
88    /// Compute natural logarithm of dual number
89    pub fn ln(self) -> Self {
90        Self {
91            value: self.value.ln(),
92            derivative: self.derivative / self.value,
93        }
94    }
95
96    /// Compute power of dual number (self^n)
97    pub fn powi(self, n: i32) -> Self {
98        let n_f64 = n as f64;
99        Self {
100            value: self.value.powi(n),
101            derivative: self.derivative * n_f64 * self.value.powi(n - 1),
102        }
103    }
104
105    /// Compute power of dual number (self^p) where p is real
106    pub fn powf(self, p: f64) -> Self {
107        Self {
108            value: self.value.powf(p),
109            derivative: self.derivative * p * self.value.powf(p - 1.0),
110        }
111    }
112
113    /// Compute square root of dual number
114    pub fn sqrt(self) -> Self {
115        let sqrt_val = self.value.sqrt();
116        Self {
117            value: sqrt_val,
118            derivative: self.derivative / (2.0 * sqrt_val),
119        }
120    }
121
122    /// Compute absolute value of dual number
123    pub fn abs(self) -> Self {
124        Self {
125            value: self.value.abs(),
126            derivative: if self.value >= 0.0 {
127                self.derivative
128            } else {
129                -self.derivative
130            },
131        }
132    }
133
134    /// Compute maximum of two dual numbers
135    pub fn max(self, other: Self) -> Self {
136        if self.value >= other.value {
137            self
138        } else {
139            other
140        }
141    }
142
143    /// Compute minimum of two dual numbers
144    pub fn min(self, other: Self) -> Self {
145        if self.value <= other.value {
146            self
147        } else {
148            other
149        }
150    }
151}
152
153// Arithmetic operations for dual numbers
154
155impl Add for Dual {
156    type Output = Self;
157
158    fn add(self, other: Self) -> Self {
159        Self {
160            value: self.value + other.value,
161            derivative: self.derivative + other.derivative,
162        }
163    }
164}
165
166impl Add<f64> for Dual {
167    type Output = Self;
168
169    fn add(self, scalar: f64) -> Self {
170        Self {
171            value: self.value + scalar,
172            derivative: self.derivative,
173        }
174    }
175}
176
177impl Add<Dual> for f64 {
178    type Output = Dual;
179
180    fn add(self, dual: Dual) -> Dual {
181        dual + self
182    }
183}
184
185impl Sub for Dual {
186    type Output = Self;
187
188    fn sub(self, other: Self) -> Self {
189        Self {
190            value: self.value - other.value,
191            derivative: self.derivative - other.derivative,
192        }
193    }
194}
195
196impl Sub<f64> for Dual {
197    type Output = Self;
198
199    fn sub(self, scalar: f64) -> Self {
200        Self {
201            value: self.value - scalar,
202            derivative: self.derivative,
203        }
204    }
205}
206
207impl Sub<Dual> for f64 {
208    type Output = Dual;
209
210    fn sub(self, dual: Dual) -> Dual {
211        Dual {
212            value: self - dual.value,
213            derivative: -dual.derivative,
214        }
215    }
216}
217
218impl Mul for Dual {
219    type Output = Self;
220
221    fn mul(self, other: Self) -> Self {
222        Self {
223            value: self.value * other.value,
224            derivative: self.derivative * other.value + self.value * other.derivative,
225        }
226    }
227}
228
229impl Mul<f64> for Dual {
230    type Output = Self;
231
232    fn mul(self, scalar: f64) -> Self {
233        Self {
234            value: self.value * scalar,
235            derivative: self.derivative * scalar,
236        }
237    }
238}
239
240impl Mul<Dual> for f64 {
241    type Output = Dual;
242
243    fn mul(self, dual: Dual) -> Dual {
244        dual * self
245    }
246}
247
248impl Div for Dual {
249    type Output = Self;
250
251    fn div(self, other: Self) -> Self {
252        let denom = other.value * other.value;
253
254        // Protect against division by zero
255        let value = if other.value == 0.0 {
256            if self.value == 0.0 {
257                f64::NAN
258            }
259            // 0/0 is undefined
260            else if self.value > 0.0 {
261                f64::INFINITY
262            } else {
263                f64::NEG_INFINITY
264            }
265        } else {
266            self.value / other.value
267        };
268
269        let derivative = if denom == 0.0 {
270            // Handle derivative at division by zero
271            if other.value == 0.0 && self.derivative == 0.0 && other.derivative == 0.0 {
272                f64::NAN
273            } else {
274                f64::INFINITY
275            }
276        } else {
277            (self.derivative * other.value - self.value * other.derivative) / denom
278        };
279
280        Self { value, derivative }
281    }
282}
283
284impl Div<f64> for Dual {
285    type Output = Self;
286
287    fn div(self, scalar: f64) -> Self {
288        if scalar == 0.0 {
289            // Division by zero
290            Self {
291                value: if self.value == 0.0 {
292                    f64::NAN
293                } else if self.value > 0.0 {
294                    f64::INFINITY
295                } else {
296                    f64::NEG_INFINITY
297                },
298                derivative: if self.derivative == 0.0 {
299                    f64::NAN
300                } else {
301                    f64::INFINITY
302                },
303            }
304        } else {
305            Self {
306                value: self.value / scalar,
307                derivative: self.derivative / scalar,
308            }
309        }
310    }
311}
312
313impl Div<Dual> for f64 {
314    type Output = Dual;
315
316    fn div(self, dual: Dual) -> Dual {
317        Dual::constant(self) / dual
318    }
319}
320
321impl Neg for Dual {
322    type Output = Self;
323
324    fn neg(self) -> Self {
325        Self {
326            value: -self.value,
327            derivative: -self.derivative,
328        }
329    }
330}
331
332// Conversion traits
333impl From<f64> for Dual {
334    fn from(value: f64) -> Self {
335        Self::constant(value)
336    }
337}
338
339impl From<Dual> for f64 {
340    fn from(dual: Dual) -> Self {
341        dual.value
342    }
343}
344
345// Partial ordering for optimization algorithms
346impl PartialEq for Dual {
347    fn eq(&self, other: &Self) -> bool {
348        self.value == other.value
349    }
350}
351
352impl PartialOrd for Dual {
353    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
354        self.value.partial_cmp(&other.value)
355    }
356}
357
358/// Trait for dual number operations
359pub trait DualNumber: Clone + Copy {
360    /// Get the value part
361    fn value(self) -> f64;
362
363    /// Get the derivative part
364    fn derivative(self) -> f64;
365
366    /// Create from value and derivative
367    fn new(value: f64, derivative: f64) -> Self;
368
369    /// Create constant (derivative = 0)
370    fn constant(value: f64) -> Self;
371
372    /// Create variable (derivative = 1)
373    fn variable(value: f64) -> Self;
374}
375
376impl DualNumber for Dual {
377    fn value(self) -> f64 {
378        self.value
379    }
380
381    fn derivative(self) -> f64 {
382        self.derivative
383    }
384
385    fn new(value: f64, derivative: f64) -> Self {
386        Self::new(value, derivative)
387    }
388
389    fn constant(value: f64) -> Self {
390        Self::constant(value)
391    }
392
393    fn variable(value: f64) -> Self {
394        Self::variable(value)
395    }
396}
397
398/// Multi-dimensional dual number for computing gradients
399#[derive(Debug, Clone)]
400pub struct MultiDual {
401    /// Function value
402    value: f64,
403    /// Partial derivatives (gradient components)
404    derivatives: Array1<f64>,
405}
406
407impl MultiDual {
408    /// Create a new multi-dimensional dual number
409    pub fn new(value: f64, derivatives: Array1<f64>) -> Self {
410        Self { value, derivatives }
411    }
412
413    /// Create a constant multi-dual (all derivatives = 0)
414    pub fn constant(value: f64, nvars: usize) -> Self {
415        Self {
416            value,
417            derivatives: Array1::zeros(nvars),
418        }
419    }
420
421    /// Create a variable multi-dual (one derivative = 1, others = 0)
422    pub fn variable(value: f64, var_index: usize, nvars: usize) -> Self {
423        let mut derivatives = Array1::zeros(nvars);
424        derivatives[var_index] = 1.0;
425        Self { value, derivatives }
426    }
427
428    /// Get the function value
429    pub fn value(&self) -> f64 {
430        self.value
431    }
432
433    /// Get the gradient
434    pub fn gradient(&self) -> &Array1<f64> {
435        &self.derivatives
436    }
437
438    /// Get a specific partial derivative
439    pub fn partial(&self, index: usize) -> f64 {
440        self.derivatives[index]
441    }
442}
443
444// Arithmetic operations for MultiDual
445impl Add for MultiDual {
446    type Output = Self;
447
448    fn add(self, other: Self) -> Self {
449        Self {
450            value: self.value + other.value,
451            derivatives: &self.derivatives + &other.derivatives,
452        }
453    }
454}
455
456impl Mul for MultiDual {
457    type Output = Self;
458
459    fn mul(self, other: Self) -> Self {
460        Self {
461            value: self.value * other.value,
462            derivatives: &self.derivatives * other.value + &other.derivatives * self.value,
463        }
464    }
465}
466
467impl Mul<f64> for MultiDual {
468    type Output = Self;
469
470    fn mul(self, scalar: f64) -> Self {
471        Self {
472            value: self.value * scalar,
473            derivatives: &self.derivatives * scalar,
474        }
475    }
476}
477
478/// Create an array of dual numbers for gradient computation
479#[allow(dead_code)]
480pub fn create_dual_variables(x: &ArrayView1<f64>) -> Vec<Dual> {
481    x.iter().map(|&xi| Dual::variable(xi)).collect()
482}
483
484/// Create multi-dual variables for a given point
485#[allow(dead_code)]
486pub fn create_multi_dual_variables(x: &ArrayView1<f64>) -> Vec<MultiDual> {
487    let n = x.len();
488    x.iter()
489        .enumerate()
490        .map(|(i, &xi)| MultiDual::variable(xi, i, n))
491        .collect()
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use approx::assert_abs_diff_eq;
498
499    #[test]
500    fn test_dual_arithmetic() {
501        let a = Dual::new(2.0, 1.0);
502        let b = Dual::new(3.0, 0.5);
503
504        // Test addition
505        let sum = a + b;
506        assert_abs_diff_eq!(sum.value(), 5.0, epsilon = 1e-10);
507        assert_abs_diff_eq!(sum.derivative(), 1.5, epsilon = 1e-10);
508
509        // Test multiplication
510        let product = a * b;
511        assert_abs_diff_eq!(product.value(), 6.0, epsilon = 1e-10);
512        assert_abs_diff_eq!(product.derivative(), 4.0, epsilon = 1e-10); // 1*3 + 2*0.5
513
514        // Test division
515        let quotient = a / b;
516        assert_abs_diff_eq!(quotient.value(), 2.0 / 3.0, epsilon = 1e-10);
517        assert_abs_diff_eq!(
518            quotient.derivative(),
519            (1.0 * 3.0 - 2.0 * 0.5) / (3.0 * 3.0),
520            epsilon = 1e-10
521        );
522    }
523
524    #[test]
525    fn test_dual_functions() {
526        let x = Dual::variable(1.0);
527
528        // Test exp(x) at x=1
529        let exp_x = x.exp();
530        assert_abs_diff_eq!(exp_x.value(), std::f64::consts::E, epsilon = 1e-10);
531        assert_abs_diff_eq!(exp_x.derivative(), std::f64::consts::E, epsilon = 1e-10);
532
533        // Test sin(x) at x=0
534        let x0 = Dual::variable(0.0);
535        let sin_x = x0.sin();
536        assert_abs_diff_eq!(sin_x.value(), 0.0, epsilon = 1e-10);
537        assert_abs_diff_eq!(sin_x.derivative(), 1.0, epsilon = 1e-10); // cos(0) = 1
538
539        // Test x² at x=3
540        let x3 = Dual::variable(3.0);
541        let x_squared = x3.powi(2);
542        assert_abs_diff_eq!(x_squared.value(), 9.0, epsilon = 1e-10);
543        assert_abs_diff_eq!(x_squared.derivative(), 6.0, epsilon = 1e-10); // 2*3
544    }
545
546    #[test]
547    fn test_multi_dual() {
548        let x = MultiDual::variable(2.0, 0, 2);
549        let y = MultiDual::variable(3.0, 1, 2);
550
551        // Test x * y
552        let product = x * y;
553        assert_abs_diff_eq!(product.value(), 6.0, epsilon = 1e-10);
554        assert_abs_diff_eq!(product.partial(0), 3.0, epsilon = 1e-10); // ∂(xy)/∂x = y
555        assert_abs_diff_eq!(product.partial(1), 2.0, epsilon = 1e-10); // ∂(xy)/∂y = x
556    }
557
558    #[test]
559    fn test_create_dual_variables() {
560        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
561        let duals = create_dual_variables(&x.view());
562
563        assert_eq!(duals.len(), 3);
564        assert_abs_diff_eq!(duals[0].value(), 1.0, epsilon = 1e-10);
565        assert_abs_diff_eq!(duals[1].value(), 2.0, epsilon = 1e-10);
566        assert_abs_diff_eq!(duals[2].value(), 3.0, epsilon = 1e-10);
567
568        // All should have derivative = 1 (variables)
569        for dual in &duals {
570            assert_abs_diff_eq!(dual.derivative(), 1.0, epsilon = 1e-10);
571        }
572    }
573}