sim_lib_numbers_ad/implementation/
dual.rs1use std::{
5 array,
6 ops::{Add, Div, Mul, Neg, Sub},
7};
8
9use super::Scalarish;
10
11#[derive(Clone, Copy, Debug, PartialEq)]
33pub struct Dual<const N: usize> {
34 pub v: f64,
36 pub d: [f64; N],
38}
39
40impl<const N: usize> Dual<N> {
41 pub fn cst(v: f64) -> Self {
43 Self { v, d: [0.0; N] }
44 }
45
46 pub fn var(v: f64, slot: usize) -> Self {
51 let mut d = [0.0; N];
52 if let Some(seed) = d.get_mut(slot) {
53 *seed = 1.0;
54 }
55 Self { v, d }
56 }
57}
58
59impl<const N: usize> Add for Dual<N> {
60 type Output = Self;
61
62 fn add(self, rhs: Self) -> Self::Output {
63 Self {
64 v: self.v + rhs.v,
65 d: array::from_fn(|index| self.d[index] + rhs.d[index]),
66 }
67 }
68}
69
70impl<const N: usize> Sub for Dual<N> {
71 type Output = Self;
72
73 fn sub(self, rhs: Self) -> Self::Output {
74 Self {
75 v: self.v - rhs.v,
76 d: array::from_fn(|index| self.d[index] - rhs.d[index]),
77 }
78 }
79}
80
81impl<const N: usize> Mul for Dual<N> {
82 type Output = Self;
83
84 fn mul(self, rhs: Self) -> Self::Output {
85 Self {
86 v: self.v * rhs.v,
87 d: array::from_fn(|index| self.d[index].mul_add(rhs.v, rhs.d[index] * self.v)),
88 }
89 }
90}
91
92impl<const N: usize> Div for Dual<N> {
93 type Output = Self;
94
95 fn div(self, rhs: Self) -> Self::Output {
96 let denom = rhs.v * rhs.v;
97 Self {
98 v: self.v / rhs.v,
99 d: array::from_fn(|index| (self.d[index] * rhs.v - self.v * rhs.d[index]) / denom),
100 }
101 }
102}
103
104impl<const N: usize> Neg for Dual<N> {
105 type Output = Self;
106
107 fn neg(self) -> Self::Output {
108 Self {
109 v: -self.v,
110 d: array::from_fn(|index| -self.d[index]),
111 }
112 }
113}
114
115impl<const N: usize> Scalarish for Dual<N> {
116 fn from_f64(x: f64) -> Self {
117 Self::cst(x)
118 }
119
120 fn sin(self) -> Self {
121 let cos_v = self.v.cos();
122 Self {
123 v: self.v.sin(),
124 d: array::from_fn(|index| self.d[index] * cos_v),
125 }
126 }
127
128 fn cos(self) -> Self {
129 let sin_v = self.v.sin();
130 Self {
131 v: self.v.cos(),
132 d: array::from_fn(|index| -self.d[index] * sin_v),
133 }
134 }
135
136 fn exp(self) -> Self {
137 let exp_v = self.v.exp();
138 Self {
139 v: exp_v,
140 d: array::from_fn(|index| self.d[index] * exp_v),
141 }
142 }
143
144 fn ln(self) -> Self {
145 Self {
146 v: self.v.ln(),
147 d: array::from_fn(|index| self.d[index] / self.v),
148 }
149 }
150
151 fn sqrt(self) -> Self {
152 let sqrt_v = self.v.sqrt();
153 Self {
154 v: sqrt_v,
155 d: array::from_fn(|index| self.d[index] / (2.0 * sqrt_v)),
156 }
157 }
158
159 fn recip(self) -> Self {
160 let denom = self.v * self.v;
161 Self {
162 v: self.v.recip(),
163 d: array::from_fn(|index| -self.d[index] / denom),
164 }
165 }
166}