scirs2_optimize/automatic_differentiation/
mod.rs

1//! Automatic differentiation for exact gradient and Hessian computation
2//!
3//! This module provides automatic differentiation capabilities for optimization,
4//! supporting both forward-mode and reverse-mode AD for efficient and exact
5//! derivative computation.
6
7pub mod dual_numbers;
8pub mod forward_mode;
9pub mod reverse_mode;
10pub mod tape;
11
12// Re-export commonly used items
13pub use dual_numbers::{Dual, DualNumber};
14pub use forward_mode::{forward_gradient, forward_hessian_diagonal, ForwardADOptions};
15pub use reverse_mode::{reverse_gradient, reverse_hessian, ReverseADOptions};
16pub use tape::{ComputationTape, TapeNode, Variable};
17
18use crate::error::OptimizeError;
19use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
20
21/// Automatic differentiation mode selection
22#[derive(Debug, Clone, Copy)]
23pub enum ADMode {
24    /// Forward-mode AD (efficient for low-dimensional problems)
25    Forward,
26    /// Reverse-mode AD (efficient for high-dimensional problems)
27    Reverse,
28    /// Automatic mode selection based on problem dimension
29    Auto,
30}
31
32/// Options for automatic differentiation
33#[derive(Debug, Clone)]
34pub struct AutoDiffOptions {
35    /// AD mode to use
36    pub mode: ADMode,
37    /// Threshold for automatic mode selection (dimension)
38    pub auto_threshold: usize,
39    /// Enable sparse AD for sparse functions
40    pub enable_sparse: bool,
41    /// Compute Hessian when possible
42    pub compute_hessian: bool,
43    /// Forward-mode specific options
44    pub forward_options: ForwardADOptions,
45    /// Reverse-mode specific options
46    pub reverse_options: ReverseADOptions,
47}
48
49impl Default for AutoDiffOptions {
50    fn default() -> Self {
51        Self {
52            mode: ADMode::Auto,
53            auto_threshold: 10,
54            enable_sparse: false,
55            compute_hessian: false,
56            forward_options: ForwardADOptions::default(),
57            reverse_options: ReverseADOptions::default(),
58        }
59    }
60}
61
62/// Result of automatic differentiation computation
63#[derive(Debug, Clone)]
64pub struct ADResult {
65    /// Function value
66    pub value: f64,
67    /// Gradient (if computed)
68    pub gradient: Option<Array1<f64>>,
69    /// Hessian (if computed)
70    pub hessian: Option<Array2<f64>>,
71    /// Number of function evaluations used
72    pub n_fev: usize,
73    /// AD mode used
74    pub mode_used: ADMode,
75}
76
77/// Function trait for automatic differentiation
78pub trait AutoDiffFunction<T> {
79    /// Evaluate the function with AD variables
80    fn eval(&self, x: &[T]) -> T;
81}
82
83/// Wrapper for regular functions to make them compatible with AD
84pub struct FunctionWrapper<F> {
85    func: F,
86}
87
88impl<F> FunctionWrapper<F>
89where
90    F: Fn(&ArrayView1<f64>) -> f64,
91{
92    pub fn new(func: F) -> Self {
93        Self { func }
94    }
95}
96
97impl<F> AutoDiffFunction<f64> for FunctionWrapper<F>
98where
99    F: Fn(&ArrayView1<f64>) -> f64,
100{
101    fn eval(&self, x: &[f64]) -> f64 {
102        let x_array = Array1::from_vec(x.to_vec());
103        (self.func)(&x_array.view())
104    }
105}
106
107impl<F> AutoDiffFunction<Dual> for FunctionWrapper<F>
108where
109    F: Fn(&ArrayView1<f64>) -> f64 + Clone,
110{
111    fn eval(&self, x: &[Dual]) -> Dual {
112        // For demonstration - this would need proper dual number evaluation
113        // In practice, the function would need to be rewritten using dual arithmetic
114        let values: Vec<f64> = x.iter().map(|d| d.value()).collect();
115        let x_array = Array1::from_vec(values);
116        Dual::constant((self.func)(&x_array.view()))
117    }
118}
119
120/// Main automatic differentiation function
121#[allow(dead_code)]
122pub fn autodiff<F>(
123    func: F,
124    x: &ArrayView1<f64>,
125    options: &AutoDiffOptions,
126) -> Result<ADResult, OptimizeError>
127where
128    F: Fn(&ArrayView1<f64>) -> f64 + Clone,
129{
130    let n = x.len();
131    let mode = match options.mode {
132        ADMode::Auto => {
133            if n <= options.auto_threshold {
134                ADMode::Forward
135            } else {
136                ADMode::Reverse
137            }
138        }
139        mode => mode,
140    };
141
142    match mode {
143        ADMode::Forward => autodiff_forward(func, x, &options.forward_options),
144        ADMode::Reverse => autodiff_reverse(func, x, &options.reverse_options),
145        ADMode::Auto => unreachable!(), // Already handled above
146    }
147}
148
149/// Forward-mode automatic differentiation
150#[allow(dead_code)]
151fn autodiff_forward<F>(
152    func: F,
153    x: &ArrayView1<f64>,
154    options: &ForwardADOptions,
155) -> Result<ADResult, OptimizeError>
156where
157    F: Fn(&ArrayView1<f64>) -> f64 + Clone,
158{
159    let n = x.len();
160    let mut n_fev = 0;
161
162    // Compute function value
163    let value = func(x);
164    n_fev += 1;
165
166    // Compute gradient using forward-mode AD
167    let gradient = if options.compute_gradient {
168        let grad = forward_gradient(func.clone(), x)?;
169        n_fev += n; // Forward mode requires n+1 evaluations for gradient
170        Some(grad)
171    } else {
172        None
173    };
174
175    // Compute Hessian diagonal using forward-mode AD (if requested)
176    let hessian = if options.compute_hessian {
177        let hess_diag = forward_hessian_diagonal(func, x)?;
178        n_fev += n; // Additional evaluations for Hessian diagonal
179
180        // Convert diagonal to full matrix (zeros off-diagonal)
181        let mut hess = Array2::zeros((n, n));
182        for i in 0..n {
183            hess[[i, i]] = hess_diag[i];
184        }
185        Some(hess)
186    } else {
187        None
188    };
189
190    Ok(ADResult {
191        value,
192        gradient,
193        hessian,
194        n_fev,
195        mode_used: ADMode::Forward,
196    })
197}
198
199/// Reverse-mode automatic differentiation
200#[allow(dead_code)]
201fn autodiff_reverse<F>(
202    func: F,
203    x: &ArrayView1<f64>,
204    options: &ReverseADOptions,
205) -> Result<ADResult, OptimizeError>
206where
207    F: Fn(&ArrayView1<f64>) -> f64 + Clone,
208{
209    let mut n_fev = 0;
210
211    // Compute function value
212    let value = func(x);
213    n_fev += 1;
214
215    // Compute gradient using reverse-mode AD
216    let gradient = if options.compute_gradient {
217        let grad = reverse_gradient(func.clone(), x)?;
218        n_fev += 1; // Reverse mode requires only 1 additional evaluation for gradient
219        Some(grad)
220    } else {
221        None
222    };
223
224    // Compute Hessian using reverse-mode AD (if requested)
225    let hessian = if options.compute_hessian {
226        let hess = reverse_hessian(func, x)?;
227        n_fev += x.len(); // Reverse mode for Hessian requires n additional evaluations
228        Some(hess)
229    } else {
230        None
231    };
232
233    Ok(ADResult {
234        value,
235        gradient,
236        hessian,
237        n_fev,
238        mode_used: ADMode::Reverse,
239    })
240}
241
242/// Create a gradient function using automatic differentiation
243#[allow(dead_code)]
244pub fn create_ad_gradient<F>(
245    func: F,
246    options: AutoDiffOptions,
247) -> impl Fn(&ArrayView1<f64>) -> Array1<f64>
248where
249    F: Fn(&ArrayView1<f64>) -> f64 + Clone + 'static,
250{
251    move |x: &ArrayView1<f64>| -> Array1<f64> {
252        let mut opts = options.clone();
253        opts.forward_options.compute_gradient = true;
254        opts.reverse_options.compute_gradient = true;
255
256        match autodiff(func.clone(), x, &opts) {
257            Ok(result) => result.gradient.unwrap_or_else(|| Array1::zeros(x.len())),
258            Err(_) => Array1::zeros(x.len()), // Fallback to zeros on error
259        }
260    }
261}
262
263/// Create a Hessian function using automatic differentiation
264#[allow(dead_code)]
265pub fn create_ad_hessian<F>(
266    func: F,
267    options: AutoDiffOptions,
268) -> impl Fn(&ArrayView1<f64>) -> Array2<f64>
269where
270    F: Fn(&ArrayView1<f64>) -> f64 + Clone + 'static,
271{
272    move |x: &ArrayView1<f64>| -> Array2<f64> {
273        let mut opts = options.clone();
274        opts.forward_options.compute_hessian = true;
275        opts.reverse_options.compute_hessian = true;
276
277        match autodiff(func.clone(), x, &opts) {
278            Ok(result) => result
279                .hessian
280                .unwrap_or_else(|| Array2::zeros((x.len(), x.len()))),
281            Err(_) => Array2::zeros((x.len(), x.len())), // Fallback to zeros on error
282        }
283    }
284}
285
286/// Optimize AD mode selection based on problem characteristics
287#[allow(dead_code)]
288pub fn optimize_ad_mode(problem_dim: usize, output_dim: usize, expected_sparsity: f64) -> ADMode {
289    // Forward mode is efficient when input dimension is small
290    // Reverse mode is efficient when output dimension is small (typically 1 for optimization)
291
292    if problem_dim <= 5 {
293        ADMode::Forward
294    } else if expected_sparsity > 0.8 {
295        // For very sparse problems, forward mode might be better
296        ADMode::Forward
297    } else if output_dim == 1 && problem_dim > 20 {
298        ADMode::Reverse
299    } else {
300        // Default to reverse mode for optimization (output_dim = 1)
301        ADMode::Reverse
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use approx::assert_abs_diff_eq;
309
310    #[test]
311    fn test_autodiff_quadratic() {
312        let func = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + 2.0 * x[1] * x[1] + x[0] * x[1] };
313
314        let x = Array1::from_vec(vec![1.0, 2.0]);
315        let mut options = AutoDiffOptions::default();
316        options.forward_options.compute_gradient = true;
317        options.reverse_options.compute_gradient = true;
318
319        // Test forward mode
320        options.mode = ADMode::Forward;
321        let result_forward = autodiff(func, &x.view(), &options).unwrap();
322
323        assert_abs_diff_eq!(result_forward.value, 11.0, epsilon = 1e-10); // 1 + 8 + 2 = 11
324
325        if let Some(grad) = result_forward.gradient {
326            // ∂f/∂x₀ = 2x₀ + x₁ = 2(1) + 2 = 4
327            // ∂f/∂x₁ = 4x₁ + x₀ = 4(2) + 1 = 9
328            assert_abs_diff_eq!(grad[0], 4.0, epsilon = 1e-7);
329            assert_abs_diff_eq!(grad[1], 9.0, epsilon = 1e-7);
330        }
331
332        // Test reverse mode
333        options.mode = ADMode::Reverse;
334        let result_reverse = autodiff(func, &x.view(), &options).unwrap();
335
336        assert_abs_diff_eq!(result_reverse.value, 11.0, epsilon = 1e-10);
337
338        if let Some(grad) = result_reverse.gradient {
339            assert_abs_diff_eq!(grad[0], 4.0, epsilon = 1e-7);
340            assert_abs_diff_eq!(grad[1], 9.0, epsilon = 1e-7);
341        }
342    }
343
344    #[test]
345    fn test_ad_mode_selection() {
346        // Small problem should use forward mode
347        assert!(matches!(optimize_ad_mode(3, 1, 0.1), ADMode::Forward));
348
349        // Large problem should use reverse mode
350        assert!(matches!(optimize_ad_mode(100, 1, 0.1), ADMode::Reverse));
351
352        // Sparse problem should use forward mode
353        assert!(matches!(optimize_ad_mode(50, 1, 0.9), ADMode::Forward));
354    }
355
356    #[test]
357    fn test_create_ad_gradient() {
358        let func = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
359
360        let options = AutoDiffOptions::default();
361        let grad_func = create_ad_gradient(func, options);
362
363        let x = Array1::from_vec(vec![3.0, 4.0]);
364        let grad = grad_func(&x.view());
365
366        // ∂f/∂x₀ = 2x₀ = 6, ∂f/∂x₁ = 2x₁ = 8
367        assert_abs_diff_eq!(grad[0], 6.0, epsilon = 1e-6);
368        assert_abs_diff_eq!(grad[1], 8.0, epsilon = 1e-6);
369    }
370}