scirs2_optimize/automatic_differentiation/
forward_mode.rs1use crate::automatic_differentiation::dual_numbers::{Dual, MultiDual};
8use crate::error::OptimizeError;
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
10
11#[derive(Debug, Clone)]
13pub struct ForwardADOptions {
14 pub compute_gradient: bool,
16 pub compute_hessian: bool,
18 pub h_hessian: f64,
20 pub use_second_order: bool,
22}
23
24impl Default for ForwardADOptions {
25 fn default() -> Self {
26 Self {
27 compute_gradient: true,
28 compute_hessian: false,
29 h_hessian: 1e-8,
30 use_second_order: false,
31 }
32 }
33}
34
35#[allow(dead_code)]
37pub fn forward_gradient<F>(func: F, x: &ArrayView1<f64>) -> Result<Array1<f64>, OptimizeError>
38where
39 F: Fn(&ArrayView1<f64>) -> f64,
40{
41 let n = x.len();
42 let mut gradient = Array1::zeros(n);
43
44 for i in 0..n {
46 let mut x_dual = Vec::with_capacity(n);
48 for j in 0..n {
49 if i == j {
50 x_dual.push(Dual::variable(x[j]));
51 } else {
52 x_dual.push(Dual::constant(x[j]));
53 }
54 }
55
56 let x_values: Vec<f64> = x_dual.iter().map(|d| d.value()).collect();
58 let _x_array = Array1::from_vec(x_values);
59
60 let h = 1e-8;
63 let mut x_plus = x.to_owned();
64 x_plus[i] += h;
65 let f_plus = func(&x_plus.view());
66
67 let mut x_minus = x.to_owned();
68 x_minus[i] -= h;
69 let f_minus = func(&x_minus.view());
70
71 gradient[i] = (f_plus - f_minus) / (2.0 * h);
72 }
73
74 Ok(gradient)
75}
76
77#[allow(dead_code)]
79pub fn forward_hessian_diagonal<F>(
80 func: F,
81 x: &ArrayView1<f64>,
82) -> Result<Array1<f64>, OptimizeError>
83where
84 F: Fn(&ArrayView1<f64>) -> f64,
85{
86 let n = x.len();
87 let mut hessian_diagonal = Array1::zeros(n);
88
89 let h = 1e-5; for i in 0..n {
93 let mut x_plus = x.to_owned();
94 x_plus[i] += h;
95 let f_plus = func(&x_plus.view());
96
97 let f_center = func(x);
98
99 let mut x_minus = x.to_owned();
100 x_minus[i] -= h;
101 let f_minus = func(&x_minus.view());
102
103 hessian_diagonal[i] = (f_plus - 2.0 * f_center + f_minus) / (h * h);
105 }
106
107 Ok(hessian_diagonal)
108}
109
110#[derive(Debug, Clone, Copy)]
112pub struct SecondOrderDual {
113 value: f64,
115 first: f64,
117 second: f64,
119}
120
121impl SecondOrderDual {
122 pub fn new(value: f64, first: f64, second: f64) -> Self {
124 Self {
125 value,
126 first,
127 second,
128 }
129 }
130
131 pub fn constant(value: f64) -> Self {
133 Self {
134 value,
135 first: 0.0,
136 second: 0.0,
137 }
138 }
139
140 pub fn variable(value: f64) -> Self {
142 Self {
143 value,
144 first: 1.0,
145 second: 0.0,
146 }
147 }
148
149 pub fn value(self) -> f64 {
151 self.value
152 }
153
154 pub fn first_derivative(self) -> f64 {
156 self.first
157 }
158
159 pub fn second_derivative(self) -> f64 {
161 self.second
162 }
163
164 pub fn exp(self) -> Self {
166 let exp_val = self.value.exp();
167 Self {
168 value: exp_val,
169 first: self.first * exp_val,
170 second: self.second * exp_val + self.first * self.first * exp_val,
171 }
172 }
173
174 #[allow(clippy::suspicious_operation_groupings)]
176 pub fn ln(self) -> Self {
177 Self {
178 value: self.value.ln(),
179 first: self.first / self.value,
180 second: (self.second * self.value - self.first * self.first)
182 / (self.value * self.value),
183 }
184 }
185
186 pub fn powi(self, n: i32) -> Self {
188 let n_f64 = n as f64;
189 let value_pow_n_minus_1 = self.value.powi(n - 1);
190 let value_pow_n_minus_2 = if n >= 2 { self.value.powi(n - 2) } else { 0.0 };
191
192 Self {
193 value: self.value.powi(n),
194 first: self.first * n_f64 * value_pow_n_minus_1,
195 second: self.second * n_f64 * value_pow_n_minus_1
196 + self.first * self.first * n_f64 * (n_f64 - 1.0) * value_pow_n_minus_2,
197 }
198 }
199
200 pub fn sin(self) -> Self {
202 let sin_val = self.value.sin();
203 let cos_val = self.value.cos();
204 Self {
205 value: sin_val,
206 first: self.first * cos_val,
207 second: self.second * cos_val - self.first * self.first * sin_val,
208 }
209 }
210
211 pub fn cos(self) -> Self {
213 let sin_val = self.value.sin();
214 let cos_val = self.value.cos();
215 Self {
216 value: cos_val,
217 first: -self.first * sin_val,
218 second: -self.second * sin_val - self.first * self.first * cos_val,
219 }
220 }
221}
222
223impl std::ops::Add for SecondOrderDual {
225 type Output = Self;
226
227 fn add(self, other: Self) -> Self {
228 Self {
229 value: self.value + other.value,
230 first: self.first + other.first,
231 second: self.second + other.second,
232 }
233 }
234}
235
236impl std::ops::Sub for SecondOrderDual {
237 type Output = Self;
238
239 fn sub(self, other: Self) -> Self {
240 Self {
241 value: self.value - other.value,
242 first: self.first - other.first,
243 second: self.second - other.second,
244 }
245 }
246}
247
248impl std::ops::Mul for SecondOrderDual {
249 type Output = Self;
250
251 fn mul(self, other: Self) -> Self {
252 Self {
253 value: self.value * other.value,
254 first: self.first * other.value + self.value * other.first,
255 second: self.second * other.value
256 + 2.0 * self.first * other.first
257 + self.value * other.second,
258 }
259 }
260}
261
262impl std::ops::Mul<f64> for SecondOrderDual {
263 type Output = Self;
264
265 fn mul(self, scalar: f64) -> Self {
266 Self {
267 value: self.value * scalar,
268 first: self.first * scalar,
269 second: self.second * scalar,
270 }
271 }
272}
273
274impl std::ops::Div for SecondOrderDual {
275 type Output = Self;
276
277 fn div(self, other: Self) -> Self {
278 let denom = other.value;
279 let denom_sq = denom * denom;
280 let denom_cb = denom_sq * denom;
281
282 Self {
283 value: self.value / denom,
284 first: (self.first * denom - self.value * other.first) / denom_sq,
285 second: (self.second * denom_sq - 2.0 * self.first * other.first * denom
286 + 2.0 * self.value * other.first * other.first
287 - self.value * other.second * denom)
288 / denom_cb,
289 }
290 }
291}
292
293#[allow(dead_code)]
295pub fn forward_hessian_diagonal_exact<F>(
296 func: F,
297 x: &ArrayView1<f64>,
298) -> Result<Array1<f64>, OptimizeError>
299where
300 F: Fn(&[SecondOrderDual]) -> SecondOrderDual,
301{
302 let n = x.len();
303 let mut hessian_diagonal = Array1::zeros(n);
304
305 for i in 0..n {
307 let mut x_dual = Vec::with_capacity(n);
309 for j in 0..n {
310 if i == j {
311 x_dual.push(SecondOrderDual::variable(x[j]));
312 } else {
313 x_dual.push(SecondOrderDual::constant(x[j]));
314 }
315 }
316
317 let result = func(&x_dual);
318 hessian_diagonal[i] = result.second_derivative();
319 }
320
321 Ok(hessian_diagonal)
322}
323
324#[allow(dead_code)]
326pub fn forward_gradient_multi<F>(func: F, x: &ArrayView1<f64>) -> Result<Array1<f64>, OptimizeError>
327where
328 F: Fn(&[MultiDual]) -> MultiDual,
329{
330 let n = x.len();
331
332 let x_multi: Vec<MultiDual> = x
334 .iter()
335 .enumerate()
336 .map(|(i, &xi)| MultiDual::variable(xi, i, n))
337 .collect();
338
339 let result = func(&x_multi);
340 Ok(result.gradient().clone())
341}
342
343#[allow(dead_code)]
345pub fn forward_jacobian<F>(
346 func: F,
347 x: &ArrayView1<f64>,
348 output_dim: usize,
349) -> Result<Array2<f64>, OptimizeError>
350where
351 F: Fn(&ArrayView1<f64>) -> Array1<f64>,
352{
353 let n = x.len();
354 let mut jacobian = Array2::zeros((output_dim, n));
355
356 for j in 0..n {
358 let h = 1e-8;
359 let mut x_plus = x.to_owned();
360 x_plus[j] += h;
361 let f_plus = func(&x_plus.view());
362
363 let mut x_minus = x.to_owned();
364 x_minus[j] -= h;
365 let f_minus = func(&x_minus.view());
366
367 for i in 0..output_dim {
368 jacobian[[i, j]] = (f_plus[i] - f_minus[i]) / (2.0 * h);
369 }
370 }
371
372 Ok(jacobian)
373}
374
375#[allow(dead_code)]
377pub fn is_forward_mode_efficient(input_dim: usize, output_dim: usize) -> bool {
378 input_dim <= 10 || (input_dim <= output_dim && input_dim <= 50)
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use approx::assert_abs_diff_eq;
387
388 #[test]
389 fn test_forward_gradient() {
390 let func = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[0] * x[1] + 2.0 * x[1] * x[1] };
392
393 let x = Array1::from_vec(vec![1.0, 2.0]);
394 let grad = forward_gradient(func, &x.view()).unwrap();
395
396 assert_abs_diff_eq!(grad[0], 4.0, epsilon = 1e-6);
399 assert_abs_diff_eq!(grad[1], 9.0, epsilon = 1e-6);
400 }
401
402 #[test]
403 fn test_forward_hessian_diagonal() {
404 let func = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[0] * x[1] + 2.0 * x[1] * x[1] };
406
407 let x = Array1::from_vec(vec![1.0, 2.0]);
408 let hess_diag = forward_hessian_diagonal(func, &x.view()).unwrap();
409
410 assert_abs_diff_eq!(hess_diag[0], 2.0, epsilon = 1e-4);
413 assert_abs_diff_eq!(hess_diag[1], 4.0, epsilon = 1e-4);
414 }
415
416 #[test]
417 fn test_second_order_dual_arithmetic() {
418 let a = SecondOrderDual::new(2.0, 1.0, 0.0);
419 let b = SecondOrderDual::new(3.0, 0.0, 0.0);
420
421 let product = a * b;
423 assert_abs_diff_eq!(product.value(), 6.0, epsilon = 1e-10);
424 assert_abs_diff_eq!(product.first_derivative(), 3.0, epsilon = 1e-10);
425 assert_abs_diff_eq!(product.second_derivative(), 0.0, epsilon = 1e-10);
426
427 let x = SecondOrderDual::variable(2.0);
429 let square = x.powi(2);
430 assert_abs_diff_eq!(square.value(), 4.0, epsilon = 1e-10);
431 assert_abs_diff_eq!(square.first_derivative(), 4.0, epsilon = 1e-10); assert_abs_diff_eq!(square.second_derivative(), 2.0, epsilon = 1e-10); }
434
435 #[test]
436 fn test_forward_jacobian() {
437 let func = |x: &ArrayView1<f64>| -> Array1<f64> {
439 Array1::from_vec(vec![x[0] * x[0] + x[1], x[0] * x[1], x[1] * x[1]])
440 };
441
442 let x = Array1::from_vec(vec![2.0, 3.0]);
443 let jac = forward_jacobian(func, &x.view(), 3).unwrap();
444
445 assert_abs_diff_eq!(jac[[0, 0]], 4.0, epsilon = 1e-6);
450 assert_abs_diff_eq!(jac[[0, 1]], 1.0, epsilon = 1e-6);
451 assert_abs_diff_eq!(jac[[1, 0]], 3.0, epsilon = 1e-6);
452 assert_abs_diff_eq!(jac[[1, 1]], 2.0, epsilon = 1e-6);
453 assert_abs_diff_eq!(jac[[2, 0]], 0.0, epsilon = 1e-6);
454 assert_abs_diff_eq!(jac[[2, 1]], 6.0, epsilon = 1e-6);
455 }
456
457 #[test]
458 fn test_is_forward_mode_efficient() {
459 assert!(is_forward_mode_efficient(3, 1));
461 assert!(is_forward_mode_efficient(5, 10));
462
463 assert!(!is_forward_mode_efficient(100, 1));
465 }
466}