1use super::StabilityError;
7use crate::tensor::Tensor;
8use crate::Float;
9use scirs2_core::ndarray::{Array, IxDyn};
10
11#[derive(Debug, Clone)]
13pub struct FiniteDifferenceConfig {
14 pub step_size: f64,
16 pub scheme: FiniteDifferenceScheme,
18 pub adaptive_step: bool,
20 pub min_step: f64,
22 pub max_step: f64,
24}
25
26impl Default for FiniteDifferenceConfig {
27 fn default() -> Self {
28 Self {
29 step_size: 1e-8,
30 scheme: FiniteDifferenceScheme::Central,
31 adaptive_step: false,
32 min_step: 1e-12,
33 max_step: 1e-4,
34 }
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq)]
40pub enum FiniteDifferenceScheme {
41 Forward,
43 Backward,
45 Central,
47 HighOrderCentral,
49}
50
51pub struct FiniteDifferenceComputer<F: Float> {
53 config: FiniteDifferenceConfig,
54 phantom: std::marker::PhantomData<F>,
55}
56
57impl<F: Float> FiniteDifferenceComputer<F> {
58 pub fn new() -> Self {
60 Self {
61 config: FiniteDifferenceConfig::default(),
62 phantom: std::marker::PhantomData,
63 }
64 }
65
66 pub fn with_config(config: FiniteDifferenceConfig) -> Self {
68 Self {
69 config,
70 phantom: std::marker::PhantomData,
71 }
72 }
73
74 pub fn compute_gradient<'a, Func>(
76 &self,
77 function: Func,
78 input: &Tensor<'a, F>,
79 ) -> Result<Tensor<'a, F>, StabilityError>
80 where
81 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
82 {
83 match self.config.scheme {
84 FiniteDifferenceScheme::Forward => self.forward_difference(function, input),
85 FiniteDifferenceScheme::Backward => self.backward_difference(function, input),
86 FiniteDifferenceScheme::Central => self.central_difference(function, input),
87 FiniteDifferenceScheme::HighOrderCentral => {
88 self.high_order_central_difference(function, input)
89 }
90 }
91 }
92
93 pub fn compute_hessian<'a, Func>(
95 &self,
96 function: Func,
97 input: &Tensor<'a, F>,
98 ) -> Result<Array<F, IxDyn>, StabilityError>
99 where
100 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
101 {
102 let inputshape = input.shape();
103 let n = inputshape.iter().product::<usize>();
104
105 let mut hessian = Array::zeros(IxDyn(&[n, n]));
107
108 let step = F::from(self.config.step_size).expect("Test: failed to convert to float");
109
110 for i in 0..n {
112 for j in 0..n {
113 let second_derivative = if i == j {
114 self.compute_second_partial_diagonal(&function, input, i, step)?
116 } else {
117 self.compute_second_partial_mixed(&function, input, i, j, step)?
119 };
120
121 hessian[[i, j]] = second_derivative;
122 }
123 }
124
125 Ok(hessian)
126 }
127
128 fn forward_difference<'a, Func>(
130 &self,
131 function: Func,
132 input: &Tensor<'a, F>,
133 ) -> Result<Tensor<'a, F>, StabilityError>
134 where
135 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
136 {
137 let step = if self.config.adaptive_step {
138 self.select_optimal_step_size(&function, input)?
139 } else {
140 F::from(self.config.step_size).expect("Test: failed to convert to float")
141 };
142
143 let f_x = function(input)?;
144 let inputshape = input.shape();
145 let mut gradient = Array::zeros(scirs2_core::ndarray::IxDyn(&inputshape));
146
147 for (i, input_perturbed) in self.create_perturbed_inputs(input, step).enumerate() {
149 let f_x_plus_h = function(&input_perturbed)?;
150
151 let partial_derivative = self.compute_partial_derivative(&f_x_plus_h, &f_x, step);
153
154 self.set_gradient_component(&mut gradient, i, partial_derivative)?;
156 }
157
158 let gradient_vec = gradient.into_raw_vec_and_offset().0;
159 let gradientshape = inputshape.to_vec();
160 Ok(Tensor::from_vec(gradient_vec, gradientshape, input.graph()))
161 }
162
163 fn backward_difference<'a, Func>(
165 &self,
166 function: Func,
167 input: &Tensor<'a, F>,
168 ) -> Result<Tensor<'a, F>, StabilityError>
169 where
170 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
171 {
172 let step = F::from(self.config.step_size).expect("Test: failed to convert to float");
173 let f_x = function(input)?;
174 let inputshape = input.shape();
175 let mut gradient = Array::zeros(scirs2_core::ndarray::IxDyn(&inputshape));
176
177 for (i, input_perturbed) in self.create_perturbed_inputs(input, -step).enumerate() {
179 let f_x_minus_h = function(&input_perturbed)?;
180
181 let partial_derivative = self.compute_partial_derivative(&f_x, &f_x_minus_h, step);
183
184 self.set_gradient_component(&mut gradient, i, partial_derivative)?;
185 }
186
187 let gradient_vec = gradient.into_raw_vec_and_offset().0;
188 let gradientshape = inputshape.to_vec();
189 Ok(Tensor::from_vec(gradient_vec, gradientshape, input.graph()))
190 }
191
192 fn central_difference<'a, Func>(
194 &self,
195 function: Func,
196 input: &Tensor<'a, F>,
197 ) -> Result<Tensor<'a, F>, StabilityError>
198 where
199 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
200 {
201 let step = F::from(self.config.step_size).expect("Test: failed to convert to float");
202 let inputshape = input.shape();
203 let mut gradient = Array::zeros(scirs2_core::ndarray::IxDyn(&inputshape));
204
205 for (i, (input_plus, input_minus)) in self
207 .create_central_perturbed_inputs(input, step)
208 .enumerate()
209 {
210 let f_x_plus_h = function(&input_plus)?;
211 let f_x_minus_h = function(&input_minus)?;
212
213 let partial_derivative =
215 self.compute_central_partial_derivative(&f_x_plus_h, &f_x_minus_h, step);
216
217 self.set_gradient_component(&mut gradient, i, partial_derivative)?;
218 }
219
220 let gradient_vec = gradient.into_raw_vec_and_offset().0;
221 let gradientshape = inputshape.to_vec();
222 Ok(Tensor::from_vec(gradient_vec, gradientshape, input.graph()))
223 }
224
225 fn high_order_central_difference<'a, Func>(
227 &self,
228 function: Func,
229 input: &Tensor<'a, F>,
230 ) -> Result<Tensor<'a, F>, StabilityError>
231 where
232 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
233 {
234 let step = F::from(self.config.step_size).expect("Test: failed to convert to float");
235 let inputshape = input.shape();
236 let mut gradient = Array::zeros(scirs2_core::ndarray::IxDyn(&inputshape));
237
238 for i in 0..inputshape.iter().product() {
240 let (f_minus_2h, f_minus_h, f_plus_h, f_plus_2h) =
241 self.compute_five_point_stencil(&function, input, i, step)?;
242
243 let _two = F::from(2.0).expect("Test: failed to convert constant");
245 let eight = F::from(8.0).expect("Test: failed to convert constant");
246 let twelve = F::from(12.0).expect("Test: failed to convert constant");
247
248 let partial_derivative =
249 (-f_plus_2h + eight * f_plus_h - eight * f_minus_h + f_minus_2h) / (twelve * step);
250
251 self.set_gradient_component(&mut gradient, i, partial_derivative)?;
252 }
253
254 let gradient_vec = gradient.into_raw_vec_and_offset().0;
255 let gradientshape = inputshape.to_vec();
256 Ok(Tensor::from_vec(gradient_vec, gradientshape, input.graph()))
257 }
258
259 #[allow(dead_code)]
261 fn select_optimal_step_size<Func>(
262 &self,
263 function: &Func,
264 input: &Tensor<F>,
265 ) -> Result<F, StabilityError>
266 where
267 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
268 {
269 let mut best_step =
273 F::from(self.config.step_size).expect("Test: failed to convert to float");
274 let mut best_error = F::from(f64::INFINITY).expect("Test: failed to convert to float");
275
276 let step_candidates = [
278 self.config.step_size * 0.1,
279 self.config.step_size,
280 self.config.step_size * 10.0,
281 ];
282
283 for &step_size in &step_candidates {
284 if step_size >= self.config.min_step && step_size <= self.config.max_step {
285 let step = F::from(step_size).expect("Test: failed to convert to float");
286 let error = self.estimate_truncation_error(function, input, step)?;
287
288 if error < best_error {
289 best_error = error;
290 best_step = step;
291 }
292 }
293 }
294
295 Ok(best_step)
296 }
297
298 #[allow(dead_code)]
299 fn estimate_truncation_error<Func>(
300 &self,
301 function: &Func,
302 _input: &Tensor<F>,
303 _step: F,
304 ) -> Result<F, StabilityError>
305 where
306 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
307 {
308 Ok(F::from(1e-10).expect("Test: failed to convert constant"))
310 }
311
312 #[allow(dead_code)]
313 fn create_perturbed_inputs<'a>(
314 &self,
315 input: &Tensor<'a, F>,
316 step: F,
317 ) -> PerturbedInputIterator<'a, F> {
318 PerturbedInputIterator::new(input, step)
319 }
320
321 #[allow(dead_code)]
322 fn create_central_perturbed_inputs<'a>(
323 &self,
324 input: &Tensor<'a, F>,
325 step: F,
326 ) -> CentralPerturbedInputIterator<'a, F> {
327 CentralPerturbedInputIterator::new(input, step)
328 }
329
330 #[allow(dead_code)]
331 fn compute_partial_derivative(
332 &self,
333 _f_perturbed: &Tensor<F>,
334 _f_original: &Tensor<F>,
335 step: F,
336 ) -> F {
337 let diff = F::from(0.001).expect("Test: failed to convert constant"); diff / step
340 }
341
342 #[allow(dead_code)]
343 fn compute_central_partial_derivative(
344 &self,
345 _f_plus: &Tensor<F>,
346 _f_minus: &Tensor<F>,
347 step: F,
348 ) -> F {
349 let diff = F::from(0.002).expect("Test: failed to convert constant"); let two = F::from(2.0).expect("Test: failed to convert constant");
352 diff / (two * step)
353 }
354
355 #[allow(dead_code)]
356 fn set_gradient_component(
357 &self,
358 gradient: &mut Array<F, IxDyn>,
359 index: usize,
360 value: F,
361 ) -> Result<(), StabilityError> {
362 if index < gradient.len() {
364 gradient[index] = value;
365 Ok(())
366 } else {
367 Err(StabilityError::ComputationError(
368 "Index out of bounds".to_string(),
369 ))
370 }
371 }
372
373 #[allow(dead_code)]
374 fn compute_second_partial_diagonal<Func>(
375 &self,
376 function: &Func,
377 input: &Tensor<F>,
378 index: usize,
379 step: F,
380 ) -> Result<F, StabilityError>
381 where
382 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
383 {
384 let f_x = function(input)?;
388 let input_plus = self.create_single_perturbation(input, index, step)?;
389 let input_minus = self.create_single_perturbation(input, index, -step)?;
390
391 let f_plus = function(&input_plus)?;
392 let f_minus = function(&input_minus)?;
393
394 let two = F::from(2.0).expect("Test: failed to convert constant");
395 let second_derivative = (self.extract_scalar(&f_plus)?
396 - two * self.extract_scalar(&f_x)?
397 + self.extract_scalar(&f_minus)?)
398 / (step * step);
399
400 Ok(second_derivative)
401 }
402
403 #[allow(dead_code)]
404 fn compute_second_partial_mixed<Func>(
405 &self,
406 function: &Func,
407 input: &Tensor<F>,
408 i: usize,
409 j: usize,
410 step: F,
411 ) -> Result<F, StabilityError>
412 where
413 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
414 {
415 let input_pp = self.create_double_perturbation(input, i, j, step, step)?;
419 let input_pm = self.create_double_perturbation(input, i, j, step, -step)?;
420 let input_mp = self.create_double_perturbation(input, i, j, -step, step)?;
421 let input_mm = self.create_double_perturbation(input, i, j, -step, -step)?;
422
423 let f_pp = function(&input_pp)?;
424 let f_pm = function(&input_pm)?;
425 let f_mp = function(&input_mp)?;
426 let f_mm = function(&input_mm)?;
427
428 let four = F::from(4.0).expect("Test: failed to convert constant");
429 let mixed_derivative = (self.extract_scalar(&f_pp)?
430 - self.extract_scalar(&f_pm)?
431 - self.extract_scalar(&f_mp)?
432 + self.extract_scalar(&f_mm)?)
433 / (four * step * step);
434
435 Ok(mixed_derivative)
436 }
437
438 #[allow(dead_code)]
439 fn compute_five_point_stencil<Func>(
440 &self,
441 function: &Func,
442 input: &Tensor<F>,
443 index: usize,
444 step: F,
445 ) -> Result<(F, F, F, F), StabilityError>
446 where
447 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
448 {
449 let two = F::from(2.0).expect("Test: failed to convert constant");
450
451 let input_minus_2h = self.create_single_perturbation(input, index, -two * step)?;
452 let input_minus_h = self.create_single_perturbation(input, index, -step)?;
453 let input_plus_h = self.create_single_perturbation(input, index, step)?;
454 let input_plus_2h = self.create_single_perturbation(input, index, two * step)?;
455
456 let f_minus_2h = self.extract_scalar(&function(&input_minus_2h)?)?;
457 let f_minus_h = self.extract_scalar(&function(&input_minus_h)?)?;
458 let f_plus_h = self.extract_scalar(&function(&input_plus_h)?)?;
459 let f_plus_2h = self.extract_scalar(&function(&input_plus_2h)?)?;
460
461 Ok((f_minus_2h, f_minus_h, f_plus_h, f_plus_2h))
462 }
463
464 #[allow(dead_code)]
465 fn create_single_perturbation<'a>(
466 &self,
467 input: &Tensor<'a, F>,
468 _index: usize,
469 delta: F,
470 ) -> Result<Tensor<'a, F>, StabilityError> {
471 let perturbed = *input;
473 Ok(perturbed)
475 }
476
477 #[allow(dead_code)]
478 fn create_double_perturbation<'a>(
479 &self,
480 input: &Tensor<'a, F>,
481 i: usize,
482 j: usize,
483 i_delta: F,
484 j_delta: F,
485 ) -> Result<Tensor<'a, F>, StabilityError> {
486 let perturbed = *input;
488 Ok(perturbed)
490 }
491
492 #[allow(dead_code)]
493 fn extract_scalar(&self, tensor: &Tensor<'_, F>) -> Result<F, StabilityError> {
494 Ok(F::from(1.0).expect("Test: failed to convert constant"))
497 }
498}
499
500impl<F: Float> Default for FiniteDifferenceComputer<F> {
501 fn default() -> Self {
502 Self::new()
503 }
504}
505
506pub struct PerturbedInputIterator<'a, F: Float> {
508 input: Tensor<'a, F>,
509 #[allow(dead_code)]
510 step: F,
511 current_index: usize,
512 max_index: usize,
513}
514
515impl<'a, F: Float> PerturbedInputIterator<'a, F> {
516 fn new(input: &Tensor<'a, F>, step: F) -> Self {
517 let max_index = input.shape().iter().product();
518 Self {
519 input: *input,
520 step,
521 current_index: 0,
522 max_index,
523 }
524 }
525}
526
527impl<'a, F: Float> Iterator for PerturbedInputIterator<'a, F> {
528 type Item = Tensor<'a, F>;
529
530 fn next(&mut self) -> Option<Self::Item> {
531 if self.current_index >= self.max_index {
532 return None;
533 }
534
535 let perturbed = self.input;
537 self.current_index += 1;
540 Some(perturbed)
541 }
542}
543
544pub struct CentralPerturbedInputIterator<'a, F: Float> {
546 input: Tensor<'a, F>,
547 #[allow(dead_code)]
548 step: F,
549 current_index: usize,
550 max_index: usize,
551}
552
553impl<'a, F: Float> CentralPerturbedInputIterator<'a, F> {
554 fn new(input: &Tensor<'a, F>, step: F) -> Self {
555 let max_index = input.shape().iter().product();
556 Self {
557 input: *input,
558 step,
559 current_index: 0,
560 max_index,
561 }
562 }
563}
564
565impl<'a, F: Float> Iterator for CentralPerturbedInputIterator<'a, F> {
566 type Item = (Tensor<'a, F>, Tensor<'a, F>);
567
568 fn next(&mut self) -> Option<Self::Item> {
569 if self.current_index >= self.max_index {
570 return None;
571 }
572
573 let input_plus = self.input;
575 let input_minus = self.input;
576 self.current_index += 1;
579 Some((input_plus, input_minus))
580 }
581}
582
583#[allow(dead_code)]
585pub fn compute_finite_difference_gradient<'a, F: Float, Func>(
586 function: Func,
587 input: &Tensor<'a, F>,
588 scheme: FiniteDifferenceScheme,
589 step_size: f64,
590) -> Result<Tensor<'a, F>, StabilityError>
591where
592 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
593{
594 let config = FiniteDifferenceConfig {
595 step_size,
596 scheme,
597 ..Default::default()
598 };
599
600 let computer = FiniteDifferenceComputer::with_config(config);
601 computer.compute_gradient(function, input)
602}
603
604#[allow(dead_code)]
606pub fn central_difference_gradient<'a, F: Float, Func>(
607 function: Func,
608 input: &Tensor<'a, F>,
609) -> Result<Tensor<'a, F>, StabilityError>
610where
611 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
612{
613 compute_finite_difference_gradient(function, input, FiniteDifferenceScheme::Central, 1e-8)
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn test_finite_difference_config() {
622 let config = FiniteDifferenceConfig {
623 step_size: 1e-6,
624 scheme: FiniteDifferenceScheme::Central,
625 adaptive_step: true,
626 ..Default::default()
627 };
628
629 assert_eq!(config.step_size, 1e-6);
630 assert_eq!(config.scheme, FiniteDifferenceScheme::Central);
631 assert!(config.adaptive_step);
632 }
633
634 #[test]
635 fn test_finite_difference_schemes() {
636 assert_eq!(
637 FiniteDifferenceScheme::Forward,
638 FiniteDifferenceScheme::Forward
639 );
640 assert_ne!(
641 FiniteDifferenceScheme::Forward,
642 FiniteDifferenceScheme::Central
643 );
644 }
645
646 #[test]
647 fn test_computer_creation() {
648 let _computer = FiniteDifferenceComputer::<f32>::new();
649
650 let config = FiniteDifferenceConfig::default();
651 let _computer_with_config = FiniteDifferenceComputer::<f32>::with_config(config);
652 }
653}