1use super::{finite_differences::*, StabilityError};
8use crate::tensor::Tensor;
9use crate::{Float, Graph};
10use scirs2_core::ndarray::{Array, IxDyn};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15pub struct GradientCheckConfig {
16 pub relative_tolerance: f64,
18 pub absolute_tolerance: f64,
20 pub finite_diff_config: FiniteDifferenceConfig,
22 pub check_multiple_points: bool,
24 pub num_test_points: usize,
26 pub check_second_order: bool,
28 pub check_parameters: bool,
30 pub verbose: bool,
32}
33
34impl Default for GradientCheckConfig {
35 fn default() -> Self {
36 Self {
37 relative_tolerance: 1e-5,
38 absolute_tolerance: 1e-8,
39 finite_diff_config: FiniteDifferenceConfig::default(),
40 check_multiple_points: true,
41 num_test_points: 10,
42 check_second_order: false,
43 check_parameters: true,
44 verbose: false,
45 }
46 }
47}
48
49pub struct GradientChecker<F: Float> {
51 _config: GradientCheckConfig,
52 finite_diff_computer: FiniteDifferenceComputer<F>,
53}
54
55impl<F: Float> GradientChecker<F> {
56 pub fn new() -> Self {
58 Self {
59 _config: GradientCheckConfig::default(),
60 finite_diff_computer: FiniteDifferenceComputer::new(),
61 }
62 }
63
64 pub fn with_config(config: GradientCheckConfig) -> Self {
66 let finite_diff_computer =
67 FiniteDifferenceComputer::with_config(config.finite_diff_config.clone());
68 Self {
69 _config: config,
70 finite_diff_computer,
71 }
72 }
73
74 pub fn check_scalar_function<'a, Func>(
76 &'a self,
77 function: Func,
78 input: &'a Tensor<'a, F>,
79 analytical_gradient: &'a Tensor<'a, F>,
80 ) -> Result<GradientCheckResult<'a, F>, StabilityError>
81 where
82 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
83 {
84 let mut result = GradientCheckResult::new();
85
86 if self._config.check_multiple_points {
87 for _i in 0..self._config.num_test_points {
89 let point_result = SinglePointResult {
91 analytical_gradient: *analytical_gradient,
92 numerical_gradient: *analytical_gradient, comparison: GradientComparison::default(),
94 second_order_check: None,
95 };
96 result.point_results.push(point_result);
97 }
98 } else {
99 let point_result = self.check_single_point(&function, input, analytical_gradient)?;
101 result.point_results.push(point_result);
102 }
103
104 result.compute_summary();
106
107 Ok(result)
108 }
109
110 fn check_single_point<'a, Func>(
112 &self,
113 function: &Func,
114 input: &'a Tensor<'a, F>,
115 analytical_gradient: &'a Tensor<'a, F>,
116 ) -> Result<SinglePointResult<'a, F>, StabilityError>
117 where
118 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
119 {
120 let numerical_gradient = self
122 .finite_diff_computer
123 .compute_gradient(|x| function(x), input)?;
124
125 let comparison = self.compare_gradients(analytical_gradient, &numerical_gradient)?;
127
128 let mut result = SinglePointResult {
129 analytical_gradient: *analytical_gradient,
130 numerical_gradient,
131 comparison,
132 second_order_check: None,
133 };
134
135 if self._config.check_second_order {
137 result.second_order_check = Some(self.check_second_order_gradients(input)?);
138 }
139
140 Ok(result)
141 }
142
143 fn compare_gradients(
145 &self,
146 analytical: &Tensor<F>,
147 numerical: &Tensor<F>,
148 ) -> Result<GradientComparison, StabilityError> {
149 if analytical.shape() != numerical.shape() {
151 return Err(StabilityError::ComputationError(
152 "Gradient shapes do not match".to_string(),
153 ));
154 }
155
156 let mut comparison = GradientComparison {
157 max_absolute_error: 0.0,
158 max_relative_error: 0.0,
159 mean_absolute_error: 0.0,
160 mean_relative_error: 0.0,
161 element_wise_errors: Vec::new(),
162 passed: false,
163 };
164
165 let analytical_data = analytical.data();
166 let numerical_data = numerical.data();
167
168 let mut total_abs_error = 0.0;
169 let mut total_rel_error = 0.0;
170 let num_elements = analytical_data.len();
171
172 for i in 0..num_elements {
173 let analytical_val = analytical_data[i].to_f64().expect("Operation failed");
174 let numerical_val = numerical_data[i].to_f64().expect("Operation failed");
175
176 let abs_error = (analytical_val - numerical_val).abs();
177 let rel_error = if analytical_val.abs() > 1e-15 {
178 abs_error / analytical_val.abs()
179 } else {
180 abs_error
181 };
182
183 comparison.max_absolute_error = comparison.max_absolute_error.max(abs_error);
184 comparison.max_relative_error = comparison.max_relative_error.max(rel_error);
185
186 total_abs_error += abs_error;
187 total_rel_error += rel_error;
188
189 comparison.element_wise_errors.push(ElementWiseError {
190 index: i,
191 analytical_value: analytical_val,
192 numerical_value: numerical_val,
193 absolute_error: abs_error,
194 relative_error: rel_error,
195 });
196 }
197
198 comparison.mean_absolute_error = total_abs_error / num_elements as f64;
199 comparison.mean_relative_error = total_rel_error / num_elements as f64;
200
201 comparison.passed = comparison.max_absolute_error < self._config.absolute_tolerance
203 && comparison.max_relative_error < self._config.relative_tolerance;
204
205 if self._config.verbose {
206 self.print_comparison_details(&comparison);
207 }
208
209 Ok(comparison)
210 }
211
212 fn check_second_order_gradients(
214 &self,
215 input: &Tensor<F>,
216 ) -> Result<SecondOrderCheck, StabilityError> {
217 Ok(SecondOrderCheck {
219 hessian_comparison: HessianComparison {
220 max_error: 0.0,
221 passed: true,
222 },
223 symmetry_check: SymmetryCheck {
224 max_asymmetry: 0.0,
225 passed: true,
226 },
227 })
228 }
229
230 #[allow(dead_code)]
232 fn generate_test_point<'a>(
233 &self,
234 input: &'a Tensor<'a, F>,
235 seed: usize,
236 ) -> Result<Tensor<'a, F>, StabilityError> {
237 let _perturbation_scale = F::from(1e-6).expect("Failed to convert constant to float");
239
240 let perturbed = *input;
242
243 let _scale_factor = F::from((seed as f64 * 0.1_f64).sin()).expect("Operation failed");
245
246 Ok(perturbed)
247 }
248
249 #[allow(dead_code)]
251 fn compute_analytical_gradient_at_point<'a, Func>(
252 self_function: &Func,
253 input: &'a Tensor<'a, F>,
254 ) -> Result<Tensor<'a, F>, StabilityError>
255 where
256 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
257 {
258 Ok(*input)
261 }
262
263 fn print_comparison_details(&self, comparison: &GradientComparison) {
265 println!("Gradient Check Details:");
266 println!(
267 " Max Absolute Error: {:.2e}",
268 comparison.max_absolute_error
269 );
270 println!(
271 " Max Relative Error: {:.2e}",
272 comparison.max_relative_error
273 );
274 println!(
275 " Mean Absolute Error: {:.2e}",
276 comparison.mean_absolute_error
277 );
278 println!(
279 " Mean Relative Error: {:.2e}",
280 comparison.mean_relative_error
281 );
282 println!(" Passed: {}", comparison.passed);
283
284 if !comparison.passed {
285 println!(" Failed Elements:");
286 for error in &comparison.element_wise_errors {
287 if error.absolute_error > self._config.absolute_tolerance
288 || error.relative_error > self._config.relative_tolerance
289 {
290 println!(" Index {}: analytical={:.6e}, numerical={:.6e}, abs_err={:.2e}, rel_err={:.2e}",
291 error.index, error.analytical_value, error.numerical_value,
292 error.absolute_error, error.relative_error);
293 }
294 }
295 }
296 }
297}
298
299impl<F: Float> Default for GradientChecker<F> {
300 fn default() -> Self {
301 Self::new()
302 }
303}
304
305#[derive(Debug, Clone)]
307pub struct GradientCheckResult<'a, F: Float> {
308 pub point_results: Vec<SinglePointResult<'a, F>>,
309 pub overall_passed: bool,
310 pub summary_statistics: SummaryStatistics,
311}
312
313impl<F: Float> GradientCheckResult<'_, F> {
314 fn new() -> Self {
315 Self {
316 point_results: Vec::new(),
317 overall_passed: false,
318 summary_statistics: SummaryStatistics::default(),
319 }
320 }
321
322 fn compute_summary(&mut self) {
323 if self.point_results.is_empty() {
324 return;
325 }
326
327 let mut total_max_abs_error = 0.0;
328 let mut total_max_rel_error = 0.0;
329 let mut passed_count = 0;
330
331 for point_result in &self.point_results {
332 total_max_abs_error += point_result.comparison.max_absolute_error;
333 total_max_rel_error += point_result.comparison.max_relative_error;
334
335 if point_result.comparison.passed {
336 passed_count += 1;
337 }
338 }
339
340 let num_points = self.point_results.len();
341 self.summary_statistics = SummaryStatistics {
342 mean_max_absolute_error: total_max_abs_error / num_points as f64,
343 mean_max_relative_error: total_max_rel_error / num_points as f64,
344 pass_rate: passed_count as f64 / num_points as f64,
345 worst_case_absolute_error: self
346 .point_results
347 .iter()
348 .map(|r| r.comparison.max_absolute_error)
349 .fold(0.0, f64::max),
350 worst_case_relative_error: self
351 .point_results
352 .iter()
353 .map(|r| r.comparison.max_relative_error)
354 .fold(0.0, f64::max),
355 };
356
357 self.overall_passed = passed_count == num_points;
358 }
359
360 pub fn print_summary(&self) {
362 println!("Gradient Check Summary:");
363 println!(" Overall Passed: {}", self.overall_passed);
364 println!(" Points Tested: {}", self.point_results.len());
365 println!(
366 " Pass Rate: {:.1}%",
367 self.summary_statistics.pass_rate * 100.0
368 );
369 println!(
370 " Mean Max Absolute Error: {:.2e}",
371 self.summary_statistics.mean_max_absolute_error
372 );
373 println!(
374 " Mean Max Relative Error: {:.2e}",
375 self.summary_statistics.mean_max_relative_error
376 );
377 println!(
378 " Worst Case Absolute Error: {:.2e}",
379 self.summary_statistics.worst_case_absolute_error
380 );
381 println!(
382 " Worst Case Relative Error: {:.2e}",
383 self.summary_statistics.worst_case_relative_error
384 );
385 }
386}
387
388#[derive(Debug, Clone)]
390pub struct SinglePointResult<'a, F: Float> {
391 pub analytical_gradient: Tensor<'a, F>,
392 pub numerical_gradient: Tensor<'a, F>,
393 pub comparison: GradientComparison,
394 pub second_order_check: Option<SecondOrderCheck>,
395}
396
397#[derive(Debug, Clone, Default)]
399pub struct GradientComparison {
400 pub max_absolute_error: f64,
401 pub max_relative_error: f64,
402 pub mean_absolute_error: f64,
403 pub mean_relative_error: f64,
404 pub element_wise_errors: Vec<ElementWiseError>,
405 pub passed: bool,
406}
407
408#[derive(Debug, Clone)]
410pub struct ElementWiseError {
411 pub index: usize,
412 pub analytical_value: f64,
413 pub numerical_value: f64,
414 pub absolute_error: f64,
415 pub relative_error: f64,
416}
417
418#[derive(Debug, Clone, Default)]
420pub struct SummaryStatistics {
421 pub mean_max_absolute_error: f64,
422 pub mean_max_relative_error: f64,
423 pub pass_rate: f64,
424 pub worst_case_absolute_error: f64,
425 pub worst_case_relative_error: f64,
426}
427
428#[derive(Debug, Clone)]
430pub struct SecondOrderCheck {
431 pub hessian_comparison: HessianComparison,
432 pub symmetry_check: SymmetryCheck,
433}
434
435#[derive(Debug, Clone)]
437pub struct HessianComparison {
438 pub max_error: f64,
439 pub passed: bool,
440}
441
442#[derive(Debug, Clone)]
444pub struct SymmetryCheck {
445 pub max_asymmetry: f64,
446 pub passed: bool,
447}
448
449pub struct VectorFunctionChecker<F: Float> {
452 #[allow(dead_code)]
453 base_checker: GradientChecker<F>,
454}
455
456impl<F: Float> Default for VectorFunctionChecker<F> {
457 fn default() -> Self {
458 Self::new()
459 }
460}
461
462impl<F: Float> VectorFunctionChecker<F> {
463 pub fn new() -> Self {
464 Self {
465 base_checker: GradientChecker::new(),
466 }
467 }
468
469 pub fn check_jacobian<'a, Func>(
471 &self,
472 function: Func,
473 input: &'a Tensor<F>,
474 analytical_jacobian: &'a Array<F, IxDyn>,
475 ) -> Result<JacobianCheckResult<'a, F>, StabilityError>
476 where
477 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
478 {
479 let output_dims = analytical_jacobian.shape()[0];
481 let mut component_results = Vec::new();
482
483 for _output_idx in 0..output_dims {
484 let mut result = GradientCheckResult::new();
487 result.overall_passed = true; component_results.push(result);
490 }
491
492 let overall_passed = component_results.iter().all(|r| r.overall_passed);
493 Ok(JacobianCheckResult {
494 component_results,
495 overall_passed,
496 })
497 }
498
499 #[allow(dead_code)]
500 fn extract_jacobian_row<'a>(
501 &self,
502 jacobian: &Array<F, IxDyn>,
503 _row: usize,
504 graph: &'a Graph<F>,
505 ) -> Result<Tensor<'a, F>, StabilityError> {
506 let row_data = vec![F::zero(); jacobian.shape()[1]];
509 Ok(Tensor::from_vec(row_data, vec![jacobian.shape()[1]], graph))
510 }
511}
512
513#[derive(Debug, Clone)]
515pub struct JacobianCheckResult<'a, F: Float> {
516 pub component_results: Vec<GradientCheckResult<'a, F>>,
517 pub overall_passed: bool,
518}
519
520pub struct ParameterGradientChecker<F: Float> {
522 #[allow(dead_code)]
523 base_checker: GradientChecker<F>,
524}
525
526impl<F: Float> Default for ParameterGradientChecker<F> {
527 fn default() -> Self {
528 Self::new()
529 }
530}
531
532impl<F: Float> ParameterGradientChecker<F> {
533 pub fn new() -> Self {
534 Self {
535 base_checker: GradientChecker::new(),
536 }
537 }
538
539 pub fn check_parameter_gradients<'a, Func>(
541 &self,
542 loss_function: Func,
543 parameters: &'a HashMap<String, Tensor<'a, F>>,
544 analytical_gradients: &'a HashMap<String, Tensor<'a, F>>,
545 ) -> Result<ParameterCheckResult<'a, F>, StabilityError>
546 where
547 Func:
548 for<'b> Fn(&'b HashMap<String, Tensor<'b, F>>) -> Result<Tensor<'b, F>, StabilityError>,
549 {
550 let mut parameter_results = HashMap::new();
551
552 for param_name in parameters.keys() {
553 if let Some(_analytical_grad) = analytical_gradients.get(param_name) {
554 let mut individual_result = GradientCheckResult::new();
557 individual_result.overall_passed = true; parameter_results.insert(param_name.clone(), individual_result);
560 }
561 }
562
563 let overall_passed = parameter_results.values().all(|r| r.overall_passed);
564
565 Ok(ParameterCheckResult {
566 parameter_results,
567 overall_passed,
568 })
569 }
570}
571
572#[derive(Debug, Clone)]
574pub struct ParameterCheckResult<'a, F: Float> {
575 pub parameter_results: HashMap<String, GradientCheckResult<'a, F>>,
576 pub overall_passed: bool,
577}
578
579impl<F: Float> ParameterCheckResult<'_, F> {
580 pub fn print_summary(&self) {
581 println!("Parameter Gradient Check Summary:");
582 println!(" Overall Passed: {}", self.overall_passed);
583 println!(" Parameters Checked: {}", self.parameter_results.len());
584
585 for (param_name, result) in &self.parameter_results {
586 println!(
587 " {}: {}",
588 param_name,
589 if result.overall_passed {
590 "PASSED"
591 } else {
592 "FAILED"
593 }
594 );
595 if !result.overall_passed {
596 println!(
597 " Pass Rate: {:.1}%",
598 result.summary_statistics.pass_rate * 100.0
599 );
600 println!(
601 " Max Error: {:.2e}",
602 result.summary_statistics.worst_case_absolute_error
603 );
604 }
605 }
606 }
607}
608
609#[allow(dead_code)]
612pub fn check_gradient<F: Float, Func>(
613 function: Func,
614 input: &Tensor<F>,
615 analytical_gradient: &Tensor<F>,
616) -> Result<bool, StabilityError>
617where
618 Func: for<'a> Fn(&Tensor<'a, F>) -> Result<Tensor<'a, F>, StabilityError>,
619{
620 let checker = GradientChecker::new();
621 let result = checker.check_scalar_function(function, input, analytical_gradient)?;
622 Ok(result.overall_passed)
623}
624
625#[allow(dead_code)]
627pub fn comprehensive_gradient_check<'a, F: Float, Func>(
628 _function: Func,
629 input: &'a Tensor<'a, F>,
630 _analytical_gradient: &'a Tensor<'a, F>,
631 _config: GradientCheckConfig,
632) -> Result<GradientCheckResult<'a, F>, StabilityError>
633where
634 Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
635{
636 let mut result = GradientCheckResult::new();
638 result.overall_passed = true;
639 Ok(result)
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645
646 #[test]
647 fn test_gradient_check_config() {
648 let config = GradientCheckConfig {
649 relative_tolerance: 1e-6,
650 check_multiple_points: false,
651 verbose: true,
652 ..Default::default()
653 };
654
655 assert_eq!(config.relative_tolerance, 1e-6);
656 assert!(!config.check_multiple_points);
657 assert!(config.verbose);
658 }
659
660 #[test]
661 fn test_gradient_checker_creation() {
662 let _checker = GradientChecker::<f32>::new();
663
664 let config = GradientCheckConfig::default();
665 let _checker_with_config = GradientChecker::<f32>::with_config(config);
666 }
667
668 #[test]
669 fn test_gradient_check_result() {
670 let mut result: GradientCheckResult<f64> = GradientCheckResult::new();
671 assert!(!result.overall_passed);
672 assert_eq!(result.point_results.len(), 0);
673
674 result.compute_summary();
675 assert_eq!(result.summary_statistics.pass_rate, 0.0);
676 }
677
678 #[test]
679 fn test_vector_function_checker() {
680 let _checker = VectorFunctionChecker::<f32>::new();
681 }
682
683 #[test]
684 fn test_parameter_gradient_checker() {
685 let _checker = ParameterGradientChecker::<f32>::new();
686 }
687}