1use std::collections::HashMap;
23
24#[derive(Debug, Clone)]
26pub struct GradCheckConfig {
27 pub epsilon: f64,
29 pub rel_tolerance: f64,
31 pub abs_tolerance: f64,
33 pub verbose: bool,
35 pub max_errors_to_report: usize,
37}
38
39impl Default for GradCheckConfig {
40 fn default() -> Self {
41 GradCheckConfig {
42 epsilon: 1e-5,
43 rel_tolerance: 1e-3,
44 abs_tolerance: 1e-5,
45 verbose: false,
46 max_errors_to_report: 10,
47 }
48 }
49}
50
51impl GradCheckConfig {
52 pub fn strict() -> Self {
54 GradCheckConfig {
55 epsilon: 1e-6,
56 rel_tolerance: 1e-4,
57 abs_tolerance: 1e-6,
58 verbose: true,
59 max_errors_to_report: 10,
60 }
61 }
62
63 pub fn relaxed() -> Self {
65 GradCheckConfig {
66 epsilon: 1e-4,
67 rel_tolerance: 1e-2,
68 abs_tolerance: 1e-4,
69 verbose: false,
70 max_errors_to_report: 10,
71 }
72 }
73
74 pub fn with_verbose(mut self, verbose: bool) -> Self {
76 self.verbose = verbose;
77 self
78 }
79
80 pub fn with_epsilon(mut self, epsilon: f64) -> Self {
82 self.epsilon = epsilon;
83 self
84 }
85
86 pub fn with_rel_tolerance(mut self, tolerance: f64) -> Self {
88 self.rel_tolerance = tolerance;
89 self
90 }
91
92 pub fn with_abs_tolerance(mut self, tolerance: f64) -> Self {
94 self.abs_tolerance = tolerance;
95 self
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct GradCheckResult {
102 pub num_params: usize,
104 pub num_errors: usize,
106 pub max_error: f64,
108 pub max_rel_error: f64,
110 pub avg_error: f64,
112 pub passed: bool,
114 pub errors: Vec<GradientError>,
116}
117
118impl GradCheckResult {
119 pub fn new(num_params: usize) -> Self {
121 GradCheckResult {
122 num_params,
123 num_errors: 0,
124 max_error: 0.0,
125 max_rel_error: 0.0,
126 avg_error: 0.0,
127 passed: true,
128 errors: Vec::new(),
129 }
130 }
131
132 pub fn add_error(&mut self, error: GradientError) {
134 self.num_errors += 1;
135 self.max_error = self.max_error.max(error.abs_error);
136 self.max_rel_error = self.max_rel_error.max(error.rel_error);
137 self.passed = false;
138 self.errors.push(error);
139 }
140
141 pub fn finalize(mut self) -> Self {
143 if !self.errors.is_empty() {
144 let total_error: f64 = self.errors.iter().map(|e| e.abs_error).sum();
145 self.avg_error = total_error / self.errors.len() as f64;
146 }
147 self
148 }
149
150 pub fn summary(&self) -> String {
152 format!(
153 "Gradient Check: {} params, {} errors, max_error={:.2e}, max_rel_error={:.2e}, avg_error={:.2e}, {}",
154 self.num_params,
155 self.num_errors,
156 self.max_error,
157 self.max_rel_error,
158 self.avg_error,
159 if self.passed { "PASSED" } else { "FAILED" }
160 )
161 }
162
163 pub fn print_errors(&self, max_to_print: usize) {
165 if self.errors.is_empty() {
166 println!("✓ All gradients passed!");
167 return;
168 }
169
170 println!("\n✗ Gradient errors found:");
171 for (i, error) in self.errors.iter().take(max_to_print).enumerate() {
172 println!(
173 " [{}] Param {}: analytical={:.6e}, numerical={:.6e}, abs_err={:.2e}, rel_err={:.2e}",
174 i + 1,
175 error.param_id,
176 error.analytical_grad,
177 error.numerical_grad,
178 error.abs_error,
179 error.rel_error
180 );
181 }
182
183 if self.errors.len() > max_to_print {
184 println!(" ... and {} more errors", self.errors.len() - max_to_print);
185 }
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct GradientError {
192 pub param_id: String,
194 pub index: usize,
196 pub analytical_grad: f64,
198 pub numerical_grad: f64,
200 pub abs_error: f64,
202 pub rel_error: f64,
204}
205
206impl GradientError {
207 pub fn new(param_id: String, index: usize, analytical: f64, numerical: f64) -> Self {
209 let abs_error = (analytical - numerical).abs();
210 let rel_error = if numerical.abs() > 1e-10 {
211 abs_error / numerical.abs()
212 } else {
213 abs_error
214 };
215
216 GradientError {
217 param_id,
218 index,
219 analytical_grad: analytical,
220 numerical_grad: numerical,
221 abs_error,
222 rel_error,
223 }
224 }
225
226 pub fn exceeds_tolerance(&self, config: &GradCheckConfig) -> bool {
228 self.abs_error > config.abs_tolerance && self.rel_error > config.rel_tolerance
229 }
230}
231
232pub fn numerical_gradient_central(
237 forward_fn: impl Fn(&[f64]) -> f64,
238 x: &[f64],
239 epsilon: f64,
240) -> Vec<f64> {
241 let mut grad = vec![0.0; x.len()];
242
243 for i in 0..x.len() {
244 let mut x_plus = x.to_vec();
246 x_plus[i] += epsilon;
247 let f_plus = forward_fn(&x_plus);
248
249 let mut x_minus = x.to_vec();
251 x_minus[i] -= epsilon;
252 let f_minus = forward_fn(&x_minus);
253
254 grad[i] = (f_plus - f_minus) / (2.0 * epsilon);
256 }
257
258 grad
259}
260
261pub fn numerical_gradient_forward(
266 forward_fn: impl Fn(&[f64]) -> f64,
267 x: &[f64],
268 f_x: f64,
269 epsilon: f64,
270) -> Vec<f64> {
271 let mut grad = vec![0.0; x.len()];
272
273 for i in 0..x.len() {
274 let mut x_plus = x.to_vec();
276 x_plus[i] += epsilon;
277 let f_plus = forward_fn(&x_plus);
278
279 grad[i] = (f_plus - f_x) / epsilon;
281 }
282
283 grad
284}
285
286pub fn numerical_gradient_fourth_order(
293 forward_fn: impl Fn(&[f64]) -> f64,
294 x: &[f64],
295 epsilon: f64,
296) -> Vec<f64> {
297 let mut grad = vec![0.0; x.len()];
298
299 for i in 0..x.len() {
300 let mut x_plus2 = x.to_vec();
302 x_plus2[i] += 2.0 * epsilon;
303 let f_plus2 = forward_fn(&x_plus2);
304
305 let mut x_plus = x.to_vec();
307 x_plus[i] += epsilon;
308 let f_plus = forward_fn(&x_plus);
309
310 let mut x_minus = x.to_vec();
312 x_minus[i] -= epsilon;
313 let f_minus = forward_fn(&x_minus);
314
315 let mut x_minus2 = x.to_vec();
317 x_minus2[i] -= 2.0 * epsilon;
318 let f_minus2 = forward_fn(&x_minus2);
319
320 grad[i] = (-f_plus2 + 8.0 * f_plus - 8.0 * f_minus + f_minus2) / (12.0 * epsilon);
322 }
323
324 grad
325}
326
327pub fn numerical_gradient_richardson(
335 forward_fn: impl Fn(&[f64]) -> f64,
336 x: &[f64],
337 epsilon: f64,
338) -> Vec<f64> {
339 let grad_h = numerical_gradient_central(&forward_fn, x, epsilon);
341 let grad_h_half = numerical_gradient_central(&forward_fn, x, epsilon / 2.0);
342
343 grad_h_half
345 .iter()
346 .zip(grad_h.iter())
347 .map(|(&g_half, &g_full)| (4.0 * g_half - g_full) / 3.0)
348 .collect()
349}
350
351pub fn numerical_gradient_complex_step(
363 forward_fn: impl Fn(&[f64]) -> f64,
364 x: &[f64],
365 epsilon: f64,
366) -> Vec<f64> {
367 let mut grad = vec![0.0; x.len()];
368
369 for i in 0..x.len() {
379 let eps_tiny = epsilon * 1e-8;
382
383 let mut x_plus_small = x.to_vec();
384 x_plus_small[i] += eps_tiny;
385 let f_plus_small = forward_fn(&x_plus_small);
386
387 let mut x_minus_small = x.to_vec();
388 x_minus_small[i] -= eps_tiny;
389 let f_minus_small = forward_fn(&x_minus_small);
390
391 grad[i] = (f_plus_small - f_minus_small) / (2.0 * eps_tiny);
394 }
395
396 grad
397}
398
399pub fn numerical_gradient_adaptive(forward_fn: impl Fn(&[f64]) -> f64, x: &[f64]) -> Vec<f64> {
404 let epsilons = vec![1e-3, 1e-4, 1e-5, 1e-6, 1e-7];
406 let mut best_grad = Vec::new();
407 let mut min_variance = f64::MAX;
408
409 for &eps in &epsilons {
410 let grad = numerical_gradient_central(&forward_fn, x, eps);
411
412 if !grad.is_empty() {
414 let mean: f64 = grad.iter().sum::<f64>() / grad.len() as f64;
415 let variance: f64 =
416 grad.iter().map(|&g| (g - mean).powi(2)).sum::<f64>() / grad.len() as f64;
417
418 if variance < min_variance || best_grad.is_empty() {
419 min_variance = variance;
420 best_grad = grad;
421 }
422 }
423 }
424
425 best_grad
426}
427
428pub fn compare_gradients(
430 param_id: String,
431 analytical: &[f64],
432 numerical: &[f64],
433 config: &GradCheckConfig,
434) -> Vec<GradientError> {
435 assert_eq!(analytical.len(), numerical.len());
436
437 let mut errors = Vec::new();
438
439 for (i, (&a, &n)) in analytical.iter().zip(numerical.iter()).enumerate() {
440 let error = GradientError::new(param_id.clone(), i, a, n);
441
442 if error.exceeds_tolerance(config) {
443 errors.push(error);
444 }
445 }
446
447 errors
448}
449
450pub struct GradientChecker {
452 config: GradCheckConfig,
453 results: HashMap<String, GradCheckResult>,
454}
455
456impl GradientChecker {
457 pub fn new(config: GradCheckConfig) -> Self {
459 GradientChecker {
460 config,
461 results: HashMap::new(),
462 }
463 }
464
465 pub fn with_defaults() -> Self {
467 Self::new(GradCheckConfig::default())
468 }
469
470 pub fn check_parameter(
472 &mut self,
473 param_id: String,
474 forward_fn: impl Fn(&[f64]) -> f64,
475 x: &[f64],
476 analytical_grad: &[f64],
477 ) -> GradCheckResult {
478 let numerical_grad = numerical_gradient_central(&forward_fn, x, self.config.epsilon);
480
481 let errors = compare_gradients(
483 param_id.clone(),
484 analytical_grad,
485 &numerical_grad,
486 &self.config,
487 );
488
489 let mut result = GradCheckResult::new(x.len());
491 for error in errors {
492 result.add_error(error);
493 }
494 let result = result.finalize();
495
496 if self.config.verbose {
497 println!("Checking parameter '{}':", param_id);
498 println!(" {}", result.summary());
499 if !result.passed {
500 result.print_errors(self.config.max_errors_to_report);
501 }
502 }
503
504 self.results.insert(param_id, result.clone());
505 result
506 }
507
508 pub fn results(&self) -> &HashMap<String, GradCheckResult> {
510 &self.results
511 }
512
513 pub fn all_passed(&self) -> bool {
515 self.results.values().all(|r| r.passed)
516 }
517
518 pub fn total_errors(&self) -> usize {
520 self.results.values().map(|r| r.num_errors).sum()
521 }
522
523 pub fn print_summary(&self) {
525 println!("\n=== Gradient Check Summary ===");
526 for (param_id, result) in &self.results {
527 println!("{}: {}", param_id, result.summary());
528 }
529 println!(
530 "\nTotal: {} parameters, {} errors",
531 self.results.len(),
532 self.total_errors()
533 );
534
535 if self.all_passed() {
536 println!("✓ All gradient checks PASSED");
537 } else {
538 println!("✗ Some gradient checks FAILED");
539 }
540 }
541}
542
543pub fn quick_check(
545 forward_fn: impl Fn(&[f64]) -> f64,
546 x: &[f64],
547 analytical_grad: &[f64],
548) -> Result<(), String> {
549 let config = GradCheckConfig::default();
550 let numerical = numerical_gradient_central(&forward_fn, x, config.epsilon);
551
552 let errors = compare_gradients(
553 "quick_check".to_string(),
554 analytical_grad,
555 &numerical,
556 &config,
557 );
558
559 if errors.is_empty() {
560 Ok(())
561 } else {
562 let mut result = GradCheckResult::new(x.len());
563 for error in errors {
564 result.add_error(error);
565 }
566 Err(result.finalize().summary())
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573
574 #[test]
575 fn test_grad_check_config_default() {
576 let config = GradCheckConfig::default();
577 assert!(config.epsilon > 0.0);
578 assert!(config.rel_tolerance > 0.0);
579 assert!(config.abs_tolerance > 0.0);
580 }
581
582 #[test]
583 fn test_grad_check_config_strict() {
584 let strict = GradCheckConfig::strict();
585 let default = GradCheckConfig::default();
586 assert!(strict.epsilon <= default.epsilon);
587 assert!(strict.rel_tolerance <= default.rel_tolerance);
588 }
589
590 #[test]
591 fn test_grad_check_config_builder() {
592 let config = GradCheckConfig::default()
593 .with_epsilon(1e-4)
594 .with_verbose(true)
595 .with_rel_tolerance(1e-2);
596
597 assert_eq!(config.epsilon, 1e-4);
598 assert!(config.verbose);
599 assert_eq!(config.rel_tolerance, 1e-2);
600 }
601
602 #[test]
603 fn test_numerical_gradient_simple() {
604 let f = |x: &[f64]| x[0] * x[0];
606 let x = vec![3.0];
607 let grad = numerical_gradient_central(f, &x, 1e-5);
608
609 assert!((grad[0] - 6.0).abs() < 1e-4);
611 }
612
613 #[test]
614 fn test_numerical_gradient_multivariate() {
615 let f = |xy: &[f64]| xy[0] * xy[0] + xy[1] * xy[1];
617 let xy = vec![3.0, 4.0];
618 let grad = numerical_gradient_central(f, &xy, 1e-5);
619
620 assert!((grad[0] - 6.0).abs() < 1e-4);
621 assert!((grad[1] - 8.0).abs() < 1e-4);
622 }
623
624 #[test]
625 fn test_gradient_error_creation() {
626 let error = GradientError::new("param1".to_string(), 0, 1.0, 1.01);
627
628 assert_eq!(error.param_id, "param1");
629 assert_eq!(error.index, 0);
630 assert_eq!(error.analytical_grad, 1.0);
631 assert_eq!(error.numerical_grad, 1.01);
632 assert!(error.abs_error > 0.0);
633 assert!(error.rel_error > 0.0);
634 }
635
636 #[test]
637 fn test_gradient_error_exceeds_tolerance() {
638 let config = GradCheckConfig::default();
639
640 let error1 = GradientError::new("p1".to_string(), 0, 1.0, 2.0);
642 assert!(error1.exceeds_tolerance(&config));
643
644 let error2 = GradientError::new("p2".to_string(), 0, 1.0, 1.0000001);
646 assert!(!error2.exceeds_tolerance(&config));
647 }
648
649 #[test]
650 fn test_grad_check_result() {
651 let mut result = GradCheckResult::new(10);
652 assert!(result.passed);
653 assert_eq!(result.num_errors, 0);
654
655 result.add_error(GradientError::new("p1".to_string(), 0, 1.0, 2.0));
656 assert!(!result.passed);
657 assert_eq!(result.num_errors, 1);
658
659 let final_result = result.finalize();
660 assert!(final_result.avg_error > 0.0);
661 }
662
663 #[test]
664 fn test_compare_gradients() {
665 let config = GradCheckConfig::default();
666
667 let analytical = vec![1.0, 2.0, 3.0];
669 let numerical = vec![1.0, 2.0, 3.0];
670 let errors = compare_gradients("test".to_string(), &analytical, &numerical, &config);
671 assert_eq!(errors.len(), 0);
672
673 let numerical2 = vec![1.0, 2.5, 3.0];
675 let errors2 = compare_gradients("test".to_string(), &analytical, &numerical2, &config);
676 assert!(!errors2.is_empty());
677 }
678
679 #[test]
680 fn test_gradient_checker() {
681 let mut checker = GradientChecker::new(GradCheckConfig::default());
682
683 let f = |x: &[f64]| x[0] * x[0];
685 let x = vec![3.0];
686 let analytical = vec![6.0]; let result = checker.check_parameter("x".to_string(), f, &x, &analytical);
689 assert!(result.passed);
690 assert!(checker.all_passed());
691 }
692
693 #[test]
694 fn test_quick_check() {
695 let f = |x: &[f64]| x[0] * x[0];
697 let x = vec![3.0];
698 let grad = vec![6.0];
699 assert!(quick_check(f, &x, &grad).is_ok());
700
701 let bad_grad = vec![7.0];
703 assert!(quick_check(f, &x, &bad_grad).is_err());
704 }
705
706 #[test]
707 fn test_forward_gradient() {
708 let f = |x: &[f64]| x[0] * x[0];
709 let x = vec![3.0];
710 let f_x = f(&x);
711 let grad = numerical_gradient_forward(f, &x, f_x, 1e-5);
712
713 assert!((grad[0] - 6.0).abs() < 1e-3);
715 }
716
717 #[test]
718 fn test_fourth_order_gradient() {
719 let f = |x: &[f64]| x[0].powi(3);
721 let x = vec![2.0];
722 let grad = numerical_gradient_fourth_order(f, &x, 1e-3);
723
724 assert!((grad[0] - 12.0).abs() < 1e-5);
726 }
727
728 #[test]
729 fn test_fourth_order_multivariate() {
730 let f = |xy: &[f64]| xy[0].powi(3) + xy[1].powi(3);
732 let xy = vec![2.0, 3.0];
733 let grad = numerical_gradient_fourth_order(f, &xy, 1e-3);
734
735 assert!((grad[0] - 12.0).abs() < 1e-5); assert!((grad[1] - 27.0).abs() < 1e-5); }
738
739 #[test]
740 fn test_richardson_extrapolation() {
741 let f = |x: &[f64]| x[0].powi(4);
743 let x = vec![2.0];
744 let grad = numerical_gradient_richardson(f, &x, 1e-3);
745
746 assert!((grad[0] - 32.0).abs() < 1e-6);
748 }
749
750 #[test]
751 fn test_richardson_multivariate() {
752 let f = |xy: &[f64]| xy[0].powi(4) + xy[1].powi(4);
754 let xy = vec![2.0, 1.5];
755 let grad = numerical_gradient_richardson(f, &xy, 1e-3);
756
757 assert!((grad[0] - 32.0).abs() < 1e-6); assert!((grad[1] - 13.5).abs() < 1e-6); }
760
761 #[test]
762 fn test_complex_step_approximation() {
763 let f = |x: &[f64]| x[0] * x[0] + 2.0 * x[0] + 1.0;
765 let x = vec![3.0];
766 let grad = numerical_gradient_complex_step(f, &x, 1e-5);
767
768 assert!((grad[0] - 8.0).abs() < 0.1);
772 }
773
774 #[test]
775 fn test_adaptive_gradient() {
776 let f = |x: &[f64]| x[0] * x[0];
778 let x = vec![3.0];
779 let grad = numerical_gradient_adaptive(f, &x);
780
781 assert!((grad[0] - 6.0).abs() < 1e-4);
783 }
784
785 #[test]
786 fn test_adaptive_multivariate() {
787 let f = |xyz: &[f64]| xyz[0] * xyz[0] + xyz[1] * xyz[1] + xyz[2] * xyz[2];
789 let xyz = vec![1.0, 2.0, 3.0];
790 let grad = numerical_gradient_adaptive(f, &xyz);
791
792 assert!((grad[0] - 2.0).abs() < 1e-4);
793 assert!((grad[1] - 4.0).abs() < 1e-4);
794 assert!((grad[2] - 6.0).abs() < 1e-4);
795 }
796
797 #[test]
798 fn test_gradient_method_comparison() {
799 let f = |x: &[f64]| x[0].sin();
802 let x = vec![1.0_f64];
803 let expected = 1.0_f64.cos(); let grad_central = numerical_gradient_central(f, &x, 1e-5);
806 let grad_fourth = numerical_gradient_fourth_order(f, &x, 1e-3);
807 let grad_richardson = numerical_gradient_richardson(f, &x, 1e-3);
808
809 assert!((grad_central[0] - expected).abs() < 1e-5);
811 assert!((grad_fourth[0] - expected).abs() < 1e-6);
812 assert!((grad_richardson[0] - expected).abs() < 1e-7);
813 }
814
815 #[test]
816 fn test_gradient_stability_near_zero() {
817 let f = |x: &[f64]| x[0] * x[0] + 1e-10;
820 let x = vec![1e-8_f64];
821 let expected = 2.0 * 1e-8;
822
823 let grad = numerical_gradient_adaptive(f, &x);
824 assert!((grad[0] - expected).abs() < 1e-9);
826 }
827
828 #[test]
829 fn test_gradient_nonpolynomial() {
830 let f = |x: &[f64]| x[0].exp();
832 let x = vec![1.0_f64];
833 let expected = 1.0_f64.exp();
834
835 let grad_fourth = numerical_gradient_fourth_order(f, &x, 1e-4);
836 assert!((grad_fourth[0] - expected).abs() < 1e-6);
837
838 let grad_richardson = numerical_gradient_richardson(f, &x, 1e-4);
839 assert!((grad_richardson[0] - expected).abs() < 1e-7);
840 }
841}