Skip to main content

sim_lib_numbers_ad/implementation/
dual.rs

1//! Forward-mode dual numbers: a value paired with an `N`-slot gradient, with
2//! arithmetic and elementary functions that propagate derivatives.
3
4use std::{
5    array,
6    ops::{Add, Div, Mul, Neg, Sub},
7};
8
9use super::Scalarish;
10
11/// A forward-mode dual number: a primal value paired with an `N`-slot gradient.
12///
13/// Each arithmetic and elementary operation propagates the derivative slots
14/// alongside the value, so evaluating an expression on `Dual<N>` computes both
15/// the result and its partial derivatives with respect to `N` seeded
16/// directions in a single pass. `N` is the number of independent directions
17/// (inputs) tracked; `Dual<0>` carries no gradient and reduces to plain `f64`
18/// arithmetic on the value.
19///
20/// # Examples
21///
22/// Differentiate `f(x) = x * x + 3 * x` at `x = 2`, where `f'(x) = 2x + 3 = 7`:
23///
24/// ```
25/// use sim_lib_numbers_ad::Dual;
26///
27/// let x = Dual::<1>::var(2.0, 0);
28/// let y = x * x + Dual::<1>::cst(3.0) * x;
29/// assert_eq!(y.v, 10.0);
30/// assert_eq!(y.d, [7.0]);
31/// ```
32#[derive(Clone, Copy, Debug, PartialEq)]
33pub struct Dual<const N: usize> {
34    /// The primal value of the number.
35    pub v: f64,
36    /// The gradient: one partial derivative per tracked direction.
37    pub d: [f64; N],
38}
39
40impl<const N: usize> Dual<N> {
41    /// Builds a constant: the given value with a zero gradient.
42    pub fn cst(v: f64) -> Self {
43        Self { v, d: [0.0; N] }
44    }
45
46    /// Builds an independent variable: the given value seeded with derivative
47    /// `1.0` in gradient slot `slot` (and zero elsewhere).
48    ///
49    /// A `slot` outside `0..N` seeds no direction, yielding a constant.
50    pub fn var(v: f64, slot: usize) -> Self {
51        let mut d = [0.0; N];
52        if let Some(seed) = d.get_mut(slot) {
53            *seed = 1.0;
54        }
55        Self { v, d }
56    }
57}
58
59impl<const N: usize> Add for Dual<N> {
60    type Output = Self;
61
62    fn add(self, rhs: Self) -> Self::Output {
63        Self {
64            v: self.v + rhs.v,
65            d: array::from_fn(|index| self.d[index] + rhs.d[index]),
66        }
67    }
68}
69
70impl<const N: usize> Sub for Dual<N> {
71    type Output = Self;
72
73    fn sub(self, rhs: Self) -> Self::Output {
74        Self {
75            v: self.v - rhs.v,
76            d: array::from_fn(|index| self.d[index] - rhs.d[index]),
77        }
78    }
79}
80
81impl<const N: usize> Mul for Dual<N> {
82    type Output = Self;
83
84    fn mul(self, rhs: Self) -> Self::Output {
85        Self {
86            v: self.v * rhs.v,
87            d: array::from_fn(|index| self.d[index].mul_add(rhs.v, rhs.d[index] * self.v)),
88        }
89    }
90}
91
92impl<const N: usize> Div for Dual<N> {
93    type Output = Self;
94
95    fn div(self, rhs: Self) -> Self::Output {
96        let denom = rhs.v * rhs.v;
97        Self {
98            v: self.v / rhs.v,
99            d: array::from_fn(|index| (self.d[index] * rhs.v - self.v * rhs.d[index]) / denom),
100        }
101    }
102}
103
104impl<const N: usize> Neg for Dual<N> {
105    type Output = Self;
106
107    fn neg(self) -> Self::Output {
108        Self {
109            v: -self.v,
110            d: array::from_fn(|index| -self.d[index]),
111        }
112    }
113}
114
115impl<const N: usize> Scalarish for Dual<N> {
116    fn from_f64(x: f64) -> Self {
117        Self::cst(x)
118    }
119
120    fn sin(self) -> Self {
121        let cos_v = self.v.cos();
122        Self {
123            v: self.v.sin(),
124            d: array::from_fn(|index| self.d[index] * cos_v),
125        }
126    }
127
128    fn cos(self) -> Self {
129        let sin_v = self.v.sin();
130        Self {
131            v: self.v.cos(),
132            d: array::from_fn(|index| -self.d[index] * sin_v),
133        }
134    }
135
136    fn exp(self) -> Self {
137        let exp_v = self.v.exp();
138        Self {
139            v: exp_v,
140            d: array::from_fn(|index| self.d[index] * exp_v),
141        }
142    }
143
144    fn ln(self) -> Self {
145        Self {
146            v: self.v.ln(),
147            d: array::from_fn(|index| self.d[index] / self.v),
148        }
149    }
150
151    fn sqrt(self) -> Self {
152        let sqrt_v = self.v.sqrt();
153        Self {
154            v: sqrt_v,
155            d: array::from_fn(|index| self.d[index] / (2.0 * sqrt_v)),
156        }
157    }
158
159    fn recip(self) -> Self {
160        let denom = self.v * self.v;
161        Self {
162            v: self.v.recip(),
163            d: array::from_fn(|index| -self.d[index] / denom),
164        }
165    }
166}