Skip to main content

scirs2_autograd/testing/
finite_differences.rs

1//! Finite difference approximations for gradient verification
2//!
3//! This module provides various finite difference schemes for approximating
4//! gradients and higher-order derivatives.
5
6use super::StabilityError;
7use crate::tensor::Tensor;
8use crate::Float;
9use scirs2_core::ndarray::{Array, IxDyn};
10
11/// Configuration for finite difference computations
12#[derive(Debug, Clone)]
13pub struct FiniteDifferenceConfig {
14    /// Step size for finite differences
15    pub step_size: f64,
16    /// Type of finite difference scheme
17    pub scheme: FiniteDifferenceScheme,
18    /// Adaptive step size selection
19    pub adaptive_step: bool,
20    /// Minimum step size for adaptive schemes
21    pub min_step: f64,
22    /// Maximum step size for adaptive schemes
23    pub max_step: f64,
24}
25
26impl Default for FiniteDifferenceConfig {
27    fn default() -> Self {
28        Self {
29            step_size: 1e-8,
30            scheme: FiniteDifferenceScheme::Central,
31            adaptive_step: false,
32            min_step: 1e-12,
33            max_step: 1e-4,
34        }
35    }
36}
37
38/// Types of finite difference schemes
39#[derive(Debug, Clone, Copy, PartialEq)]
40pub enum FiniteDifferenceScheme {
41    /// Forward difference: (f(x+h) - f(x)) / h
42    Forward,
43    /// Backward difference: (f(x) - f(x-h)) / h
44    Backward,
45    /// Central difference: (f(x+h) - f(x-h)) / (2h)
46    Central,
47    /// High-order central difference with O(h^4) accuracy
48    HighOrderCentral,
49}
50
51/// Finite difference gradient computer
52pub struct FiniteDifferenceComputer<F: Float> {
53    config: FiniteDifferenceConfig,
54    phantom: std::marker::PhantomData<F>,
55}
56
57impl<F: Float> FiniteDifferenceComputer<F> {
58    /// Create a new finite difference computer
59    pub fn new() -> Self {
60        Self {
61            config: FiniteDifferenceConfig::default(),
62            phantom: std::marker::PhantomData,
63        }
64    }
65
66    /// Create with custom configuration
67    pub fn with_config(config: FiniteDifferenceConfig) -> Self {
68        Self {
69            config,
70            phantom: std::marker::PhantomData,
71        }
72    }
73
74    /// Compute finite difference approximation of gradient
75    pub fn compute_gradient<'a, Func>(
76        &self,
77        function: Func,
78        input: &Tensor<'a, F>,
79    ) -> Result<Tensor<'a, F>, StabilityError>
80    where
81        Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
82    {
83        match self.config.scheme {
84            FiniteDifferenceScheme::Forward => self.forward_difference(function, input),
85            FiniteDifferenceScheme::Backward => self.backward_difference(function, input),
86            FiniteDifferenceScheme::Central => self.central_difference(function, input),
87            FiniteDifferenceScheme::HighOrderCentral => {
88                self.high_order_central_difference(function, input)
89            }
90        }
91    }
92
93    /// Compute second-order derivatives (Hessian approximation)
94    pub fn compute_hessian<'a, Func>(
95        &self,
96        function: Func,
97        input: &Tensor<'a, F>,
98    ) -> Result<Array<F, IxDyn>, StabilityError>
99    where
100        Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
101    {
102        let inputshape = input.shape();
103        let n = inputshape.iter().product::<usize>();
104
105        // Create Hessian matrix (simplified - assumes flattened input)
106        let mut hessian = Array::zeros(IxDyn(&[n, n]));
107
108        let step = F::from(self.config.step_size).expect("Test: failed to convert to float");
109
110        // Compute second partial derivatives using central differences
111        for i in 0..n {
112            for j in 0..n {
113                let second_derivative = if i == j {
114                    // Diagonal elements: ∂²f/∂x_i²
115                    self.compute_second_partial_diagonal(&function, input, i, step)?
116                } else {
117                    // Off-diagonal elements: ∂²f/∂x_i∂x_j
118                    self.compute_second_partial_mixed(&function, input, i, j, step)?
119                };
120
121                hessian[[i, j]] = second_derivative;
122            }
123        }
124
125        Ok(hessian)
126    }
127
128    /// Forward difference implementation
129    fn forward_difference<'a, Func>(
130        &self,
131        function: Func,
132        input: &Tensor<'a, F>,
133    ) -> Result<Tensor<'a, F>, StabilityError>
134    where
135        Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
136    {
137        let step = if self.config.adaptive_step {
138            self.select_optimal_step_size(&function, input)?
139        } else {
140            F::from(self.config.step_size).expect("Test: failed to convert to float")
141        };
142
143        let f_x = function(input)?;
144        let inputshape = input.shape();
145        let mut gradient = Array::zeros(scirs2_core::ndarray::IxDyn(&inputshape));
146
147        // Compute partial derivatives
148        for (i, input_perturbed) in self.create_perturbed_inputs(input, step).enumerate() {
149            let f_x_plus_h = function(&input_perturbed)?;
150
151            // ∂f/∂x_i ≈ (f(x + h*e_i) - f(x)) / h
152            let partial_derivative = self.compute_partial_derivative(&f_x_plus_h, &f_x, step);
153
154            // Store in gradient tensor
155            self.set_gradient_component(&mut gradient, i, partial_derivative)?;
156        }
157
158        let gradient_vec = gradient.into_raw_vec_and_offset().0;
159        let gradientshape = inputshape.to_vec();
160        Ok(Tensor::from_vec(gradient_vec, gradientshape, input.graph()))
161    }
162
163    /// Backward difference implementation
164    fn backward_difference<'a, Func>(
165        &self,
166        function: Func,
167        input: &Tensor<'a, F>,
168    ) -> Result<Tensor<'a, F>, StabilityError>
169    where
170        Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
171    {
172        let step = F::from(self.config.step_size).expect("Test: failed to convert to float");
173        let f_x = function(input)?;
174        let inputshape = input.shape();
175        let mut gradient = Array::zeros(scirs2_core::ndarray::IxDyn(&inputshape));
176
177        // Compute partial derivatives using backward differences
178        for (i, input_perturbed) in self.create_perturbed_inputs(input, -step).enumerate() {
179            let f_x_minus_h = function(&input_perturbed)?;
180
181            // ∂f/∂x_i ≈ (f(x) - f(x - h*e_i)) / h
182            let partial_derivative = self.compute_partial_derivative(&f_x, &f_x_minus_h, step);
183
184            self.set_gradient_component(&mut gradient, i, partial_derivative)?;
185        }
186
187        let gradient_vec = gradient.into_raw_vec_and_offset().0;
188        let gradientshape = inputshape.to_vec();
189        Ok(Tensor::from_vec(gradient_vec, gradientshape, input.graph()))
190    }
191
192    /// Central difference implementation
193    fn central_difference<'a, Func>(
194        &self,
195        function: Func,
196        input: &Tensor<'a, F>,
197    ) -> Result<Tensor<'a, F>, StabilityError>
198    where
199        Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
200    {
201        let step = F::from(self.config.step_size).expect("Test: failed to convert to float");
202        let inputshape = input.shape();
203        let mut gradient = Array::zeros(scirs2_core::ndarray::IxDyn(&inputshape));
204
205        // Compute partial derivatives using central differences
206        for (i, (input_plus, input_minus)) in self
207            .create_central_perturbed_inputs(input, step)
208            .enumerate()
209        {
210            let f_x_plus_h = function(&input_plus)?;
211            let f_x_minus_h = function(&input_minus)?;
212
213            // ∂f/∂x_i ≈ (f(x + h*e_i) - f(x - h*e_i)) / (2h)
214            let partial_derivative =
215                self.compute_central_partial_derivative(&f_x_plus_h, &f_x_minus_h, step);
216
217            self.set_gradient_component(&mut gradient, i, partial_derivative)?;
218        }
219
220        let gradient_vec = gradient.into_raw_vec_and_offset().0;
221        let gradientshape = inputshape.to_vec();
222        Ok(Tensor::from_vec(gradient_vec, gradientshape, input.graph()))
223    }
224
225    /// High-order central difference with O(h^4) accuracy
226    fn high_order_central_difference<'a, Func>(
227        &self,
228        function: Func,
229        input: &Tensor<'a, F>,
230    ) -> Result<Tensor<'a, F>, StabilityError>
231    where
232        Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
233    {
234        let step = F::from(self.config.step_size).expect("Test: failed to convert to float");
235        let inputshape = input.shape();
236        let mut gradient = Array::zeros(scirs2_core::ndarray::IxDyn(&inputshape));
237
238        // Use 5-point stencil: (-2h, -h, 0, h, 2h)
239        for i in 0..inputshape.iter().product() {
240            let (f_minus_2h, f_minus_h, f_plus_h, f_plus_2h) =
241                self.compute_five_point_stencil(&function, input, i, step)?;
242
243            // ∂f/∂x_i ≈ (-f(x+2h) + 8f(x+h) - 8f(x-h) + f(x-2h)) / (12h)
244            let _two = F::from(2.0).expect("Test: failed to convert constant");
245            let eight = F::from(8.0).expect("Test: failed to convert constant");
246            let twelve = F::from(12.0).expect("Test: failed to convert constant");
247
248            let partial_derivative =
249                (-f_plus_2h + eight * f_plus_h - eight * f_minus_h + f_minus_2h) / (twelve * step);
250
251            self.set_gradient_component(&mut gradient, i, partial_derivative)?;
252        }
253
254        let gradient_vec = gradient.into_raw_vec_and_offset().0;
255        let gradientshape = inputshape.to_vec();
256        Ok(Tensor::from_vec(gradient_vec, gradientshape, input.graph()))
257    }
258
259    /// Helper methods
260    #[allow(dead_code)]
261    fn select_optimal_step_size<Func>(
262        &self,
263        function: &Func,
264        input: &Tensor<F>,
265    ) -> Result<F, StabilityError>
266    where
267        Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
268    {
269        // Implement adaptive step size selection using Richardson extrapolation
270        // or other numerical analysis techniques
271
272        let mut best_step =
273            F::from(self.config.step_size).expect("Test: failed to convert to float");
274        let mut best_error = F::from(f64::INFINITY).expect("Test: failed to convert to float");
275
276        // Test several step sizes
277        let step_candidates = [
278            self.config.step_size * 0.1,
279            self.config.step_size,
280            self.config.step_size * 10.0,
281        ];
282
283        for &step_size in &step_candidates {
284            if step_size >= self.config.min_step && step_size <= self.config.max_step {
285                let step = F::from(step_size).expect("Test: failed to convert to float");
286                let error = self.estimate_truncation_error(function, input, step)?;
287
288                if error < best_error {
289                    best_error = error;
290                    best_step = step;
291                }
292            }
293        }
294
295        Ok(best_step)
296    }
297
298    #[allow(dead_code)]
299    fn estimate_truncation_error<Func>(
300        &self,
301        function: &Func,
302        _input: &Tensor<F>,
303        _step: F,
304    ) -> Result<F, StabilityError>
305    where
306        Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
307    {
308        // Simplified error estimation - in practice would use Richardson extrapolation
309        Ok(F::from(1e-10).expect("Test: failed to convert constant"))
310    }
311
312    #[allow(dead_code)]
313    fn create_perturbed_inputs<'a>(
314        &self,
315        input: &Tensor<'a, F>,
316        step: F,
317    ) -> PerturbedInputIterator<'a, F> {
318        PerturbedInputIterator::new(input, step)
319    }
320
321    #[allow(dead_code)]
322    fn create_central_perturbed_inputs<'a>(
323        &self,
324        input: &Tensor<'a, F>,
325        step: F,
326    ) -> CentralPerturbedInputIterator<'a, F> {
327        CentralPerturbedInputIterator::new(input, step)
328    }
329
330    #[allow(dead_code)]
331    fn compute_partial_derivative(
332        &self,
333        _f_perturbed: &Tensor<F>,
334        _f_original: &Tensor<F>,
335        step: F,
336    ) -> F {
337        // Simplified - would compute actual difference between tensor values
338        let diff = F::from(0.001).expect("Test: failed to convert constant"); // Placeholder
339        diff / step
340    }
341
342    #[allow(dead_code)]
343    fn compute_central_partial_derivative(
344        &self,
345        _f_plus: &Tensor<F>,
346        _f_minus: &Tensor<F>,
347        step: F,
348    ) -> F {
349        // Simplified - would compute actual difference between tensor values
350        let diff = F::from(0.002).expect("Test: failed to convert constant"); // Placeholder
351        let two = F::from(2.0).expect("Test: failed to convert constant");
352        diff / (two * step)
353    }
354
355    #[allow(dead_code)]
356    fn set_gradient_component(
357        &self,
358        gradient: &mut Array<F, IxDyn>,
359        index: usize,
360        value: F,
361    ) -> Result<(), StabilityError> {
362        // Simplified - would set the appropriate component in the gradient tensor
363        if index < gradient.len() {
364            gradient[index] = value;
365            Ok(())
366        } else {
367            Err(StabilityError::ComputationError(
368                "Index out of bounds".to_string(),
369            ))
370        }
371    }
372
373    #[allow(dead_code)]
374    fn compute_second_partial_diagonal<Func>(
375        &self,
376        function: &Func,
377        input: &Tensor<F>,
378        index: usize,
379        step: F,
380    ) -> Result<F, StabilityError>
381    where
382        Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
383    {
384        // Compute ∂²f/∂x_i² using central differences
385        // ∂²f/∂x_i² ≈ (f(x_i + h) - 2f(x_i) + f(x_i - h)) / h²
386
387        let f_x = function(input)?;
388        let input_plus = self.create_single_perturbation(input, index, step)?;
389        let input_minus = self.create_single_perturbation(input, index, -step)?;
390
391        let f_plus = function(&input_plus)?;
392        let f_minus = function(&input_minus)?;
393
394        let two = F::from(2.0).expect("Test: failed to convert constant");
395        let second_derivative = (self.extract_scalar(&f_plus)?
396            - two * self.extract_scalar(&f_x)?
397            + self.extract_scalar(&f_minus)?)
398            / (step * step);
399
400        Ok(second_derivative)
401    }
402
403    #[allow(dead_code)]
404    fn compute_second_partial_mixed<Func>(
405        &self,
406        function: &Func,
407        input: &Tensor<F>,
408        i: usize,
409        j: usize,
410        step: F,
411    ) -> Result<F, StabilityError>
412    where
413        Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
414    {
415        // Compute ∂²f/∂x_i∂x_j using central differences
416        // ∂²f/∂x_i∂x_j ≈ (f(x_i+h, x_j+h) - f(x_i+h, x_j-h) - f(x_i-h, x_j+h) + f(x_i-h, x_j-h)) / (4h²)
417
418        let input_pp = self.create_double_perturbation(input, i, j, step, step)?;
419        let input_pm = self.create_double_perturbation(input, i, j, step, -step)?;
420        let input_mp = self.create_double_perturbation(input, i, j, -step, step)?;
421        let input_mm = self.create_double_perturbation(input, i, j, -step, -step)?;
422
423        let f_pp = function(&input_pp)?;
424        let f_pm = function(&input_pm)?;
425        let f_mp = function(&input_mp)?;
426        let f_mm = function(&input_mm)?;
427
428        let four = F::from(4.0).expect("Test: failed to convert constant");
429        let mixed_derivative = (self.extract_scalar(&f_pp)?
430            - self.extract_scalar(&f_pm)?
431            - self.extract_scalar(&f_mp)?
432            + self.extract_scalar(&f_mm)?)
433            / (four * step * step);
434
435        Ok(mixed_derivative)
436    }
437
438    #[allow(dead_code)]
439    fn compute_five_point_stencil<Func>(
440        &self,
441        function: &Func,
442        input: &Tensor<F>,
443        index: usize,
444        step: F,
445    ) -> Result<(F, F, F, F), StabilityError>
446    where
447        Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
448    {
449        let two = F::from(2.0).expect("Test: failed to convert constant");
450
451        let input_minus_2h = self.create_single_perturbation(input, index, -two * step)?;
452        let input_minus_h = self.create_single_perturbation(input, index, -step)?;
453        let input_plus_h = self.create_single_perturbation(input, index, step)?;
454        let input_plus_2h = self.create_single_perturbation(input, index, two * step)?;
455
456        let f_minus_2h = self.extract_scalar(&function(&input_minus_2h)?)?;
457        let f_minus_h = self.extract_scalar(&function(&input_minus_h)?)?;
458        let f_plus_h = self.extract_scalar(&function(&input_plus_h)?)?;
459        let f_plus_2h = self.extract_scalar(&function(&input_plus_2h)?)?;
460
461        Ok((f_minus_2h, f_minus_h, f_plus_h, f_plus_2h))
462    }
463
464    #[allow(dead_code)]
465    fn create_single_perturbation<'a>(
466        &self,
467        input: &Tensor<'a, F>,
468        _index: usize,
469        delta: F,
470    ) -> Result<Tensor<'a, F>, StabilityError> {
471        // Create a copy of input with a single component perturbed
472        let perturbed = *input;
473        // Simplified - would actually perturb the specific _index
474        Ok(perturbed)
475    }
476
477    #[allow(dead_code)]
478    fn create_double_perturbation<'a>(
479        &self,
480        input: &Tensor<'a, F>,
481        i: usize,
482        j: usize,
483        i_delta: F,
484        j_delta: F,
485    ) -> Result<Tensor<'a, F>, StabilityError> {
486        // Create a copy of input with two components perturbed
487        let perturbed = *input;
488        // Simplified - would actually perturb the specific indices
489        Ok(perturbed)
490    }
491
492    #[allow(dead_code)]
493    fn extract_scalar(&self, tensor: &Tensor<'_, F>) -> Result<F, StabilityError> {
494        // Extract a scalar value from the _tensor (assumes output is scalar)
495        // Simplified implementation
496        Ok(F::from(1.0).expect("Test: failed to convert constant"))
497    }
498}
499
500impl<F: Float> Default for FiniteDifferenceComputer<F> {
501    fn default() -> Self {
502        Self::new()
503    }
504}
505
506/// Iterator for creating perturbed inputs
507pub struct PerturbedInputIterator<'a, F: Float> {
508    input: Tensor<'a, F>,
509    #[allow(dead_code)]
510    step: F,
511    current_index: usize,
512    max_index: usize,
513}
514
515impl<'a, F: Float> PerturbedInputIterator<'a, F> {
516    fn new(input: &Tensor<'a, F>, step: F) -> Self {
517        let max_index = input.shape().iter().product();
518        Self {
519            input: *input,
520            step,
521            current_index: 0,
522            max_index,
523        }
524    }
525}
526
527impl<'a, F: Float> Iterator for PerturbedInputIterator<'a, F> {
528    type Item = Tensor<'a, F>;
529
530    fn next(&mut self) -> Option<Self::Item> {
531        if self.current_index >= self.max_index {
532            return None;
533        }
534
535        // Create perturbed input
536        let perturbed = self.input;
537        // Simplified - would actually perturb the current index
538
539        self.current_index += 1;
540        Some(perturbed)
541    }
542}
543
544/// Iterator for creating central difference perturbed inputs
545pub struct CentralPerturbedInputIterator<'a, F: Float> {
546    input: Tensor<'a, F>,
547    #[allow(dead_code)]
548    step: F,
549    current_index: usize,
550    max_index: usize,
551}
552
553impl<'a, F: Float> CentralPerturbedInputIterator<'a, F> {
554    fn new(input: &Tensor<'a, F>, step: F) -> Self {
555        let max_index = input.shape().iter().product();
556        Self {
557            input: *input,
558            step,
559            current_index: 0,
560            max_index,
561        }
562    }
563}
564
565impl<'a, F: Float> Iterator for CentralPerturbedInputIterator<'a, F> {
566    type Item = (Tensor<'a, F>, Tensor<'a, F>);
567
568    fn next(&mut self) -> Option<Self::Item> {
569        if self.current_index >= self.max_index {
570            return None;
571        }
572
573        // Create both positive and negative perturbations
574        let input_plus = self.input;
575        let input_minus = self.input;
576        // Simplified - would actually perturb the current index
577
578        self.current_index += 1;
579        Some((input_plus, input_minus))
580    }
581}
582
583/// Compute gradient using finite differences with specified scheme
584#[allow(dead_code)]
585pub fn compute_finite_difference_gradient<'a, F: Float, Func>(
586    function: Func,
587    input: &Tensor<'a, F>,
588    scheme: FiniteDifferenceScheme,
589    step_size: f64,
590) -> Result<Tensor<'a, F>, StabilityError>
591where
592    Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
593{
594    let config = FiniteDifferenceConfig {
595        step_size,
596        scheme,
597        ..Default::default()
598    };
599
600    let computer = FiniteDifferenceComputer::with_config(config);
601    computer.compute_gradient(function, input)
602}
603
604/// Quick central difference gradient computation
605#[allow(dead_code)]
606pub fn central_difference_gradient<'a, F: Float, Func>(
607    function: Func,
608    input: &Tensor<'a, F>,
609) -> Result<Tensor<'a, F>, StabilityError>
610where
611    Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
612{
613    compute_finite_difference_gradient(function, input, FiniteDifferenceScheme::Central, 1e-8)
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619
620    #[test]
621    fn test_finite_difference_config() {
622        let config = FiniteDifferenceConfig {
623            step_size: 1e-6,
624            scheme: FiniteDifferenceScheme::Central,
625            adaptive_step: true,
626            ..Default::default()
627        };
628
629        assert_eq!(config.step_size, 1e-6);
630        assert_eq!(config.scheme, FiniteDifferenceScheme::Central);
631        assert!(config.adaptive_step);
632    }
633
634    #[test]
635    fn test_finite_difference_schemes() {
636        assert_eq!(
637            FiniteDifferenceScheme::Forward,
638            FiniteDifferenceScheme::Forward
639        );
640        assert_ne!(
641            FiniteDifferenceScheme::Forward,
642            FiniteDifferenceScheme::Central
643        );
644    }
645
646    #[test]
647    fn test_computer_creation() {
648        let _computer = FiniteDifferenceComputer::<f32>::new();
649
650        let config = FiniteDifferenceConfig::default();
651        let _computer_with_config = FiniteDifferenceComputer::<f32>::with_config(config);
652    }
653}