scirs2_optimize/automatic_differentiation/
mod.rs1pub mod dual_numbers;
8pub mod forward_mode;
9pub mod reverse_mode;
10pub mod tape;
11
12pub use dual_numbers::{Dual, DualNumber};
14pub use forward_mode::{forward_gradient, forward_hessian_diagonal, ForwardADOptions};
15pub use reverse_mode::{reverse_gradient, reverse_hessian, ReverseADOptions};
16pub use tape::{ComputationTape, TapeNode, Variable};
17
18use crate::error::OptimizeError;
19use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
20
21#[derive(Debug, Clone, Copy)]
23pub enum ADMode {
24 Forward,
26 Reverse,
28 Auto,
30}
31
32#[derive(Debug, Clone)]
34pub struct AutoDiffOptions {
35 pub mode: ADMode,
37 pub auto_threshold: usize,
39 pub enable_sparse: bool,
41 pub compute_hessian: bool,
43 pub forward_options: ForwardADOptions,
45 pub reverse_options: ReverseADOptions,
47}
48
49impl Default for AutoDiffOptions {
50 fn default() -> Self {
51 Self {
52 mode: ADMode::Auto,
53 auto_threshold: 10,
54 enable_sparse: false,
55 compute_hessian: false,
56 forward_options: ForwardADOptions::default(),
57 reverse_options: ReverseADOptions::default(),
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ADResult {
65 pub value: f64,
67 pub gradient: Option<Array1<f64>>,
69 pub hessian: Option<Array2<f64>>,
71 pub n_fev: usize,
73 pub mode_used: ADMode,
75}
76
77pub trait AutoDiffFunction<T> {
79 fn eval(&self, x: &[T]) -> T;
81}
82
83pub struct FunctionWrapper<F> {
85 func: F,
86}
87
88impl<F> FunctionWrapper<F>
89where
90 F: Fn(&ArrayView1<f64>) -> f64,
91{
92 pub fn new(func: F) -> Self {
93 Self { func }
94 }
95}
96
97impl<F> AutoDiffFunction<f64> for FunctionWrapper<F>
98where
99 F: Fn(&ArrayView1<f64>) -> f64,
100{
101 fn eval(&self, x: &[f64]) -> f64 {
102 let x_array = Array1::from_vec(x.to_vec());
103 (self.func)(&x_array.view())
104 }
105}
106
107impl<F> AutoDiffFunction<Dual> for FunctionWrapper<F>
108where
109 F: Fn(&ArrayView1<f64>) -> f64 + Clone,
110{
111 fn eval(&self, x: &[Dual]) -> Dual {
112 let values: Vec<f64> = x.iter().map(|d| d.value()).collect();
115 let x_array = Array1::from_vec(values);
116 Dual::constant((self.func)(&x_array.view()))
117 }
118}
119
120#[allow(dead_code)]
122pub fn autodiff<F>(
123 func: F,
124 x: &ArrayView1<f64>,
125 options: &AutoDiffOptions,
126) -> Result<ADResult, OptimizeError>
127where
128 F: Fn(&ArrayView1<f64>) -> f64 + Clone,
129{
130 let n = x.len();
131 let mode = match options.mode {
132 ADMode::Auto => {
133 if n <= options.auto_threshold {
134 ADMode::Forward
135 } else {
136 ADMode::Reverse
137 }
138 }
139 mode => mode,
140 };
141
142 match mode {
143 ADMode::Forward => autodiff_forward(func, x, &options.forward_options),
144 ADMode::Reverse => autodiff_reverse(func, x, &options.reverse_options),
145 ADMode::Auto => unreachable!(), }
147}
148
149#[allow(dead_code)]
151fn autodiff_forward<F>(
152 func: F,
153 x: &ArrayView1<f64>,
154 options: &ForwardADOptions,
155) -> Result<ADResult, OptimizeError>
156where
157 F: Fn(&ArrayView1<f64>) -> f64 + Clone,
158{
159 let n = x.len();
160 let mut n_fev = 0;
161
162 let value = func(x);
164 n_fev += 1;
165
166 let gradient = if options.compute_gradient {
168 let grad = forward_gradient(func.clone(), x)?;
169 n_fev += n; Some(grad)
171 } else {
172 None
173 };
174
175 let hessian = if options.compute_hessian {
177 let hess_diag = forward_hessian_diagonal(func, x)?;
178 n_fev += n; let mut hess = Array2::zeros((n, n));
182 for i in 0..n {
183 hess[[i, i]] = hess_diag[i];
184 }
185 Some(hess)
186 } else {
187 None
188 };
189
190 Ok(ADResult {
191 value,
192 gradient,
193 hessian,
194 n_fev,
195 mode_used: ADMode::Forward,
196 })
197}
198
199#[allow(dead_code)]
201fn autodiff_reverse<F>(
202 func: F,
203 x: &ArrayView1<f64>,
204 options: &ReverseADOptions,
205) -> Result<ADResult, OptimizeError>
206where
207 F: Fn(&ArrayView1<f64>) -> f64 + Clone,
208{
209 let mut n_fev = 0;
210
211 let value = func(x);
213 n_fev += 1;
214
215 let gradient = if options.compute_gradient {
217 let grad = reverse_gradient(func.clone(), x)?;
218 n_fev += 1; Some(grad)
220 } else {
221 None
222 };
223
224 let hessian = if options.compute_hessian {
226 let hess = reverse_hessian(func, x)?;
227 n_fev += x.len(); Some(hess)
229 } else {
230 None
231 };
232
233 Ok(ADResult {
234 value,
235 gradient,
236 hessian,
237 n_fev,
238 mode_used: ADMode::Reverse,
239 })
240}
241
242#[allow(dead_code)]
244pub fn create_ad_gradient<F>(
245 func: F,
246 options: AutoDiffOptions,
247) -> impl Fn(&ArrayView1<f64>) -> Array1<f64>
248where
249 F: Fn(&ArrayView1<f64>) -> f64 + Clone + 'static,
250{
251 move |x: &ArrayView1<f64>| -> Array1<f64> {
252 let mut opts = options.clone();
253 opts.forward_options.compute_gradient = true;
254 opts.reverse_options.compute_gradient = true;
255
256 match autodiff(func.clone(), x, &opts) {
257 Ok(result) => result.gradient.unwrap_or_else(|| Array1::zeros(x.len())),
258 Err(_) => Array1::zeros(x.len()), }
260 }
261}
262
263#[allow(dead_code)]
265pub fn create_ad_hessian<F>(
266 func: F,
267 options: AutoDiffOptions,
268) -> impl Fn(&ArrayView1<f64>) -> Array2<f64>
269where
270 F: Fn(&ArrayView1<f64>) -> f64 + Clone + 'static,
271{
272 move |x: &ArrayView1<f64>| -> Array2<f64> {
273 let mut opts = options.clone();
274 opts.forward_options.compute_hessian = true;
275 opts.reverse_options.compute_hessian = true;
276
277 match autodiff(func.clone(), x, &opts) {
278 Ok(result) => result
279 .hessian
280 .unwrap_or_else(|| Array2::zeros((x.len(), x.len()))),
281 Err(_) => Array2::zeros((x.len(), x.len())), }
283 }
284}
285
286#[allow(dead_code)]
288pub fn optimize_ad_mode(problem_dim: usize, output_dim: usize, expected_sparsity: f64) -> ADMode {
289 if problem_dim <= 5 {
293 ADMode::Forward
294 } else if expected_sparsity > 0.8 {
295 ADMode::Forward
297 } else if output_dim == 1 && problem_dim > 20 {
298 ADMode::Reverse
299 } else {
300 ADMode::Reverse
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use approx::assert_abs_diff_eq;
309
310 #[test]
311 fn test_autodiff_quadratic() {
312 let func = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + 2.0 * x[1] * x[1] + x[0] * x[1] };
313
314 let x = Array1::from_vec(vec![1.0, 2.0]);
315 let mut options = AutoDiffOptions::default();
316 options.forward_options.compute_gradient = true;
317 options.reverse_options.compute_gradient = true;
318
319 options.mode = ADMode::Forward;
321 let result_forward = autodiff(func, &x.view(), &options).unwrap();
322
323 assert_abs_diff_eq!(result_forward.value, 11.0, epsilon = 1e-10); if let Some(grad) = result_forward.gradient {
326 assert_abs_diff_eq!(grad[0], 4.0, epsilon = 1e-7);
329 assert_abs_diff_eq!(grad[1], 9.0, epsilon = 1e-7);
330 }
331
332 options.mode = ADMode::Reverse;
334 let result_reverse = autodiff(func, &x.view(), &options).unwrap();
335
336 assert_abs_diff_eq!(result_reverse.value, 11.0, epsilon = 1e-10);
337
338 if let Some(grad) = result_reverse.gradient {
339 assert_abs_diff_eq!(grad[0], 4.0, epsilon = 1e-7);
340 assert_abs_diff_eq!(grad[1], 9.0, epsilon = 1e-7);
341 }
342 }
343
344 #[test]
345 fn test_ad_mode_selection() {
346 assert!(matches!(optimize_ad_mode(3, 1, 0.1), ADMode::Forward));
348
349 assert!(matches!(optimize_ad_mode(100, 1, 0.1), ADMode::Reverse));
351
352 assert!(matches!(optimize_ad_mode(50, 1, 0.9), ADMode::Forward));
354 }
355
356 #[test]
357 fn test_create_ad_gradient() {
358 let func = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
359
360 let options = AutoDiffOptions::default();
361 let grad_func = create_ad_gradient(func, options);
362
363 let x = Array1::from_vec(vec![3.0, 4.0]);
364 let grad = grad_func(&x.view());
365
366 assert_abs_diff_eq!(grad[0], 6.0, epsilon = 1e-6);
368 assert_abs_diff_eq!(grad[1], 8.0, epsilon = 1e-6);
369 }
370}