1use scirs2_core::ndarray::{Array1, Array2, Axis, ScalarOperand};
2use scirs2_core::numeric::{Float, One, ToPrimitive};
3use std::fmt::Debug;
4
5use crate::activation::Activation;
6use crate::self_supervised::{DenseLayer, SimpleMLP};
7use sklears_core::error::SklearsError;
8use sklears_core::types::FloatBounds;
9
10#[derive(Debug, Clone)]
17pub struct GradientCheckConfig<T: Float> {
18 pub epsilon: T,
20 pub relative_tolerance: T,
22 pub absolute_tolerance: T,
24 pub use_centered_differences: bool,
26 pub max_params_to_check: Option<usize>,
28 pub random_seed: Option<u64>,
30}
31
32impl<T: Float> Default for GradientCheckConfig<T> {
33 fn default() -> Self {
34 Self {
35 epsilon: T::from(1e-7).unwrap(),
36 relative_tolerance: T::from(1e-5).unwrap(),
37 absolute_tolerance: T::from(1e-8).unwrap(),
38 use_centered_differences: true,
39 max_params_to_check: Some(100),
40 random_seed: Some(42),
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct GradientCheckResults<T: Float> {
48 pub all_passed: bool,
50 pub num_checked: usize,
52 pub num_passed: usize,
54 pub max_relative_error: T,
56 pub max_absolute_error: T,
58 pub avg_relative_error: T,
60 pub avg_absolute_error: T,
62 pub parameter_results: Vec<ParameterGradientResult<T>>,
64}
65
66#[derive(Debug, Clone)]
68pub struct ParameterGradientResult<T: Float> {
69 pub param_index: usize,
71 pub analytical_gradient: T,
73 pub numerical_gradient: T,
75 pub relative_error: T,
77 pub absolute_error: T,
79 pub passed: bool,
81}
82
83pub trait LossFunction<T: FloatBounds + ScalarOperand> {
85 fn compute_loss(&self, predictions: &Array2<T>, targets: &Array2<T>)
87 -> Result<T, SklearsError>;
88
89 fn compute_gradient(
91 &self,
92 predictions: &Array2<T>,
93 targets: &Array2<T>,
94 ) -> Result<Array2<T>, SklearsError>;
95}
96
97#[derive(Debug, Clone)]
99pub struct MeanSquaredError<T: FloatBounds + ScalarOperand> {
100 _phantom: std::marker::PhantomData<T>,
101}
102
103impl<T: FloatBounds + ScalarOperand> MeanSquaredError<T> {
104 pub fn new() -> Self {
105 Self {
106 _phantom: std::marker::PhantomData,
107 }
108 }
109}
110
111impl<T: FloatBounds + ScalarOperand> LossFunction<T> for MeanSquaredError<T> {
112 fn compute_loss(
113 &self,
114 predictions: &Array2<T>,
115 targets: &Array2<T>,
116 ) -> Result<T, SklearsError> {
117 let diff = predictions - targets;
118 let squared_diff = diff.mapv(|x| x * x);
119 let mse = squared_diff.sum() / T::from(predictions.len()).unwrap();
120 Ok(mse)
121 }
122
123 fn compute_gradient(
124 &self,
125 predictions: &Array2<T>,
126 targets: &Array2<T>,
127 ) -> Result<Array2<T>, SklearsError> {
128 let diff = predictions - targets;
129 let factor = T::from(2.0).unwrap() / T::from(predictions.len()).unwrap();
130 Ok(diff * factor)
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct CrossEntropyLoss<T: FloatBounds + ScalarOperand> {
137 _phantom: std::marker::PhantomData<T>,
138}
139
140impl<T: FloatBounds + ScalarOperand> CrossEntropyLoss<T> {
141 pub fn new() -> Self {
142 Self {
143 _phantom: std::marker::PhantomData,
144 }
145 }
146}
147
148impl<T: FloatBounds + ScalarOperand> LossFunction<T> for CrossEntropyLoss<T> {
149 fn compute_loss(
150 &self,
151 predictions: &Array2<T>,
152 targets: &Array2<T>,
153 ) -> Result<T, SklearsError> {
154 let epsilon = T::from(1e-15).unwrap();
155 let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(T::one() - epsilon));
156
157 let log_preds = clipped_preds.mapv(|x| x.ln());
158 let loss = -(targets * log_preds).sum() / T::from(predictions.nrows()).unwrap();
159 Ok(loss)
160 }
161
162 fn compute_gradient(
163 &self,
164 predictions: &Array2<T>,
165 targets: &Array2<T>,
166 ) -> Result<Array2<T>, SklearsError> {
167 let epsilon = T::from(1e-15).unwrap();
168 let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(T::one() - epsilon));
169
170 let grad = -(targets / clipped_preds) / T::from(predictions.nrows()).unwrap();
171 Ok(grad)
172 }
173}
174
175#[derive(Debug)]
177pub struct GradientChecker<T: FloatBounds + ScalarOperand + ToPrimitive> {
178 config: GradientCheckConfig<T>,
179}
180
181impl<T: FloatBounds + ScalarOperand + ToPrimitive> GradientChecker<T> {
182 pub fn new(config: GradientCheckConfig<T>) -> Self {
184 Self { config }
185 }
186
187 pub fn check_network_gradients(
189 &self,
190 network: &mut SimpleMLP<T>,
191 inputs: &Array2<T>,
192 targets: &Array2<T>,
193 loss_fn: &dyn LossFunction<T>,
194 ) -> Result<GradientCheckResults<T>, SklearsError> {
195 let predictions = network.forward(inputs)?;
197
198 let analytical_grads =
200 self.compute_analytical_gradients(network, inputs, targets, loss_fn)?;
201
202 let numerical_grads =
204 self.compute_numerical_gradients(network, inputs, targets, loss_fn)?;
205
206 self.compare_gradients(&analytical_grads, &numerical_grads)
208 }
209
210 pub fn check_layer_gradients(
212 &self,
213 layer: &mut DenseLayer<T>,
214 inputs: &Array2<T>,
215 output_gradients: &Array2<T>,
216 ) -> Result<GradientCheckResults<T>, SklearsError> {
217 let mut parameter_results = Vec::new();
221 let mut num_passed = 0;
222 let mut max_rel_error = T::zero();
223 let mut max_abs_error = T::zero();
224 let mut sum_rel_error = T::zero();
225 let mut sum_abs_error = T::zero();
226
227 let num_to_check = std::cmp::min(10, 100); for i in 0..num_to_check {
232 let analytical_grad = T::from(0.1).unwrap(); let numerical_grad = T::from(0.101).unwrap(); let abs_error = (analytical_grad - numerical_grad).abs();
236 let rel_error = if numerical_grad.abs() > T::zero() {
237 abs_error / numerical_grad.abs()
238 } else {
239 abs_error
240 };
241
242 let passed = rel_error < self.config.relative_tolerance
243 && abs_error < self.config.absolute_tolerance;
244
245 if passed {
246 num_passed += 1;
247 }
248
249 max_rel_error = max_rel_error.max(rel_error);
250 max_abs_error = max_abs_error.max(abs_error);
251 sum_rel_error = sum_rel_error + rel_error;
252 sum_abs_error = sum_abs_error + abs_error;
253
254 parameter_results.push(ParameterGradientResult {
255 param_index: i,
256 analytical_gradient: analytical_grad,
257 numerical_gradient: numerical_grad,
258 relative_error: rel_error,
259 absolute_error: abs_error,
260 passed,
261 });
262 }
263
264 let avg_rel_error = sum_rel_error / T::from(num_to_check).unwrap();
265 let avg_abs_error = sum_abs_error / T::from(num_to_check).unwrap();
266
267 Ok(GradientCheckResults {
268 all_passed: num_passed == num_to_check,
269 num_checked: num_to_check,
270 num_passed,
271 max_relative_error: max_rel_error,
272 max_absolute_error: max_abs_error,
273 avg_relative_error: avg_rel_error,
274 avg_absolute_error: avg_abs_error,
275 parameter_results,
276 })
277 }
278
279 fn compute_analytical_gradients(
281 &self,
282 network: &mut SimpleMLP<T>,
283 inputs: &Array2<T>,
284 targets: &Array2<T>,
285 loss_fn: &dyn LossFunction<T>,
286 ) -> Result<Vec<Array1<T>>, SklearsError> {
287 let predictions = network.forward(inputs)?;
289
290 let loss_grad = loss_fn.compute_gradient(&predictions, targets)?;
292
293 let mut gradients = Vec::new();
296
297 for i in 0..10 {
300 let grad = Array1::from_vec(vec![T::from(i as f64 * 0.1).unwrap(); 10]);
301 gradients.push(grad);
302 }
303
304 Ok(gradients)
305 }
306
307 fn compute_numerical_gradients(
309 &self,
310 network: &mut SimpleMLP<T>,
311 inputs: &Array2<T>,
312 targets: &Array2<T>,
313 loss_fn: &dyn LossFunction<T>,
314 ) -> Result<Vec<Array1<T>>, SklearsError> {
315 let mut numerical_grads = Vec::new();
316
317 for param_group in 0..10 {
319 let mut param_grads = Vec::new();
321
322 for param_idx in 0..10 {
323 let grad = if self.config.use_centered_differences {
325 self.compute_centered_difference(
326 network,
327 inputs,
328 targets,
329 loss_fn,
330 param_group,
331 param_idx,
332 )?
333 } else {
334 self.compute_forward_difference(
335 network,
336 inputs,
337 targets,
338 loss_fn,
339 param_group,
340 param_idx,
341 )?
342 };
343 param_grads.push(grad);
344 }
345
346 numerical_grads.push(Array1::from_vec(param_grads));
347 }
348
349 Ok(numerical_grads)
350 }
351
352 fn compute_centered_difference(
354 &self,
355 network: &mut SimpleMLP<T>,
356 inputs: &Array2<T>,
357 targets: &Array2<T>,
358 loss_fn: &dyn LossFunction<T>,
359 param_group: usize,
360 param_idx: usize,
361 ) -> Result<T, SklearsError> {
362 let original_param = T::from(0.5).unwrap(); let loss_plus = self.compute_loss_with_perturbed_param(
368 network,
369 inputs,
370 targets,
371 loss_fn,
372 param_group,
373 param_idx,
374 original_param + self.config.epsilon,
375 )?;
376
377 let loss_minus = self.compute_loss_with_perturbed_param(
379 network,
380 inputs,
381 targets,
382 loss_fn,
383 param_group,
384 param_idx,
385 original_param - self.config.epsilon,
386 )?;
387
388 let grad = (loss_plus - loss_minus) / (T::from(2.0).unwrap() * self.config.epsilon);
390 Ok(grad)
391 }
392
393 fn compute_forward_difference(
395 &self,
396 network: &mut SimpleMLP<T>,
397 inputs: &Array2<T>,
398 targets: &Array2<T>,
399 loss_fn: &dyn LossFunction<T>,
400 param_group: usize,
401 param_idx: usize,
402 ) -> Result<T, SklearsError> {
403 let original_param = T::from(0.5).unwrap(); let original_loss = self.compute_loss_with_perturbed_param(
408 network,
409 inputs,
410 targets,
411 loss_fn,
412 param_group,
413 param_idx,
414 original_param,
415 )?;
416
417 let perturbed_loss = self.compute_loss_with_perturbed_param(
419 network,
420 inputs,
421 targets,
422 loss_fn,
423 param_group,
424 param_idx,
425 original_param + self.config.epsilon,
426 )?;
427
428 let grad = (perturbed_loss - original_loss) / self.config.epsilon;
430 Ok(grad)
431 }
432
433 fn compute_loss_with_perturbed_param(
435 &self,
436 network: &mut SimpleMLP<T>,
437 inputs: &Array2<T>,
438 targets: &Array2<T>,
439 loss_fn: &dyn LossFunction<T>,
440 _param_group: usize,
441 _param_idx: usize,
442 _param_value: T,
443 ) -> Result<T, SklearsError> {
444 let predictions = network.forward(inputs)?;
452 loss_fn.compute_loss(&predictions, targets)
453 }
454
455 fn compare_gradients(
457 &self,
458 analytical: &[Array1<T>],
459 numerical: &[Array1<T>],
460 ) -> Result<GradientCheckResults<T>, SklearsError> {
461 let mut parameter_results = Vec::new();
462 let mut num_passed = 0;
463 let mut max_rel_error = T::zero();
464 let mut max_abs_error = T::zero();
465 let mut sum_rel_error = T::zero();
466 let mut sum_abs_error = T::zero();
467 let mut total_checked = 0;
468
469 for (group_idx, (anal_group, num_group)) in
470 analytical.iter().zip(numerical.iter()).enumerate()
471 {
472 for (param_idx, (&anal_grad, &num_grad)) in
473 anal_group.iter().zip(num_group.iter()).enumerate()
474 {
475 let abs_error = (anal_grad - num_grad).abs();
476 let rel_error = if num_grad.abs() > T::zero() {
477 abs_error / num_grad.abs()
478 } else {
479 abs_error
480 };
481
482 let passed = rel_error < self.config.relative_tolerance
483 && abs_error < self.config.absolute_tolerance;
484
485 if passed {
486 num_passed += 1;
487 }
488
489 max_rel_error = max_rel_error.max(rel_error);
490 max_abs_error = max_abs_error.max(abs_error);
491 sum_rel_error = sum_rel_error + rel_error;
492 sum_abs_error = sum_abs_error + abs_error;
493 total_checked += 1;
494
495 parameter_results.push(ParameterGradientResult {
496 param_index: group_idx * 1000 + param_idx, analytical_gradient: anal_grad,
498 numerical_gradient: num_grad,
499 relative_error: rel_error,
500 absolute_error: abs_error,
501 passed,
502 });
503
504 if let Some(max_params) = self.config.max_params_to_check {
506 if total_checked >= max_params {
507 break;
508 }
509 }
510 }
511
512 if let Some(max_params) = self.config.max_params_to_check {
513 if total_checked >= max_params {
514 break;
515 }
516 }
517 }
518
519 let avg_rel_error = if total_checked > 0 {
520 sum_rel_error / T::from(total_checked).unwrap()
521 } else {
522 T::zero()
523 };
524
525 let avg_abs_error = if total_checked > 0 {
526 sum_abs_error / T::from(total_checked).unwrap()
527 } else {
528 T::zero()
529 };
530
531 Ok(GradientCheckResults {
532 all_passed: num_passed == total_checked,
533 num_checked: total_checked,
534 num_passed,
535 max_relative_error: max_rel_error,
536 max_absolute_error: max_abs_error,
537 avg_relative_error: avg_rel_error,
538 avg_absolute_error: avg_abs_error,
539 parameter_results,
540 })
541 }
542}
543
544impl<T: FloatBounds + ScalarOperand + ToPrimitive> GradientChecker<T> {
546 pub fn gradients_are_equal(&self, analytical: T, numerical: T) -> bool {
548 let abs_error = (analytical - numerical).abs();
549 let rel_error = if numerical.abs() > T::zero() {
550 abs_error / numerical.abs()
551 } else {
552 abs_error
553 };
554
555 rel_error < self.config.relative_tolerance && abs_error < self.config.absolute_tolerance
556 }
557
558 pub fn compute_relative_error(&self, analytical: T, numerical: T) -> T {
560 let abs_error = (analytical - numerical).abs();
561 if numerical.abs() > T::zero() {
562 abs_error / numerical.abs()
563 } else {
564 abs_error
565 }
566 }
567
568 pub fn generate_report(&self, results: &GradientCheckResults<T>) -> String {
570 let mut report = String::new();
571
572 report.push_str("=== Gradient Checking Report ===\n");
573 report.push_str(&format!(
574 "Overall Status: {}\n",
575 if results.all_passed {
576 "PASSED"
577 } else {
578 "FAILED"
579 }
580 ));
581 report.push_str(&format!("Parameters Checked: {}\n", results.num_checked));
582 report.push_str(&format!("Parameters Passed: {}\n", results.num_passed));
583 report.push_str(&format!(
584 "Pass Rate: {:.2}%\n",
585 (results.num_passed as f64 / results.num_checked as f64) * 100.0
586 ));
587 report.push_str(&format!(
588 "Max Relative Error: {:.2e}\n",
589 results.max_relative_error.to_f64().unwrap_or(0.0)
590 ));
591 report.push_str(&format!(
592 "Max Absolute Error: {:.2e}\n",
593 results.max_absolute_error.to_f64().unwrap_or(0.0)
594 ));
595 report.push_str(&format!(
596 "Avg Relative Error: {:.2e}\n",
597 results.avg_relative_error.to_f64().unwrap_or(0.0)
598 ));
599 report.push_str(&format!(
600 "Avg Absolute Error: {:.2e}\n",
601 results.avg_absolute_error.to_f64().unwrap_or(0.0)
602 ));
603
604 let failed_params: Vec<_> = results
606 .parameter_results
607 .iter()
608 .filter(|r| !r.passed)
609 .collect();
610
611 if !failed_params.is_empty() {
612 report.push_str("\nFailed Parameters:\n");
613 for param in failed_params.iter().take(10) {
614 report.push_str(&format!(
616 " Param {}: analytical={:.6e}, numerical={:.6e}, rel_err={:.2e}, abs_err={:.2e}\n",
617 param.param_index,
618 param.analytical_gradient.to_f64().unwrap_or(0.0),
619 param.numerical_gradient.to_f64().unwrap_or(0.0),
620 param.relative_error.to_f64().unwrap_or(0.0),
621 param.absolute_error.to_f64().unwrap_or(0.0)
622 ));
623 }
624
625 if failed_params.len() > 10 {
626 report.push_str(&format!(
627 " ... and {} more failures\n",
628 failed_params.len() - 10
629 ));
630 }
631 }
632
633 report
634 }
635}
636
637#[allow(non_snake_case)]
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use approx::assert_abs_diff_eq;
642
643 #[test]
644 fn test_gradient_check_config_default() {
645 let config = GradientCheckConfig::<f32>::default();
646 assert!(config.epsilon > 0.0);
647 assert!(config.use_centered_differences);
648 assert_eq!(config.max_params_to_check, Some(100));
649 }
650
651 #[test]
652 fn test_mse_loss_function() {
653 let mse = MeanSquaredError::<f32>::new();
654
655 let predictions = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
656 let targets = Array2::from_shape_vec((2, 2), vec![1.1, 1.9, 3.1, 3.9]).unwrap();
657
658 let loss = mse.compute_loss(&predictions, &targets).unwrap();
659 assert!(loss > 0.0);
660
661 let gradient = mse.compute_gradient(&predictions, &targets).unwrap();
662 assert_eq!(gradient.dim(), predictions.dim());
663 }
664
665 #[test]
666 fn test_cross_entropy_loss_function() {
667 let ce = CrossEntropyLoss::<f32>::new();
668
669 let predictions = Array2::from_shape_vec((2, 2), vec![0.8, 0.2, 0.3, 0.7]).unwrap();
670 let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
671
672 let loss = ce.compute_loss(&predictions, &targets).unwrap();
673 assert!(loss > 0.0);
674
675 let gradient = ce.compute_gradient(&predictions, &targets).unwrap();
676 assert_eq!(gradient.dim(), predictions.dim());
677 }
678
679 #[test]
680 fn test_gradient_checker_creation() {
681 let config = GradientCheckConfig::<f32>::default();
682 let checker = GradientChecker::new(config);
683 assert!(checker.config.epsilon > 0.0);
684 }
685
686 #[test]
687 fn test_gradients_are_equal() {
688 let config = GradientCheckConfig {
689 epsilon: 1e-7,
690 relative_tolerance: 1e-5,
691 absolute_tolerance: 1e-6, use_centered_differences: true,
693 max_params_to_check: Some(100),
694 random_seed: Some(42),
695 };
696 let checker = GradientChecker::new(config);
697
698 assert!(checker.gradients_are_equal(1.0, 1.0));
700
701 assert!(checker.gradients_are_equal(1.0, 1.000001));
703
704 assert!(!checker.gradients_are_equal(1.0, 1.1));
706 }
707
708 #[test]
709 fn test_compute_relative_error() {
710 let config = GradientCheckConfig::<f32>::default();
711 let checker = GradientChecker::new(config);
712
713 let rel_error = checker.compute_relative_error(1.0, 1.1);
714 assert_abs_diff_eq!(rel_error, 0.090909, epsilon = 1e-5);
715
716 let rel_error_zero = checker.compute_relative_error(0.1, 0.0);
718 assert_abs_diff_eq!(rel_error_zero, 0.1, epsilon = 1e-6);
719 }
720
721 #[test]
722 fn test_parameter_gradient_result() {
723 let result = ParameterGradientResult {
724 param_index: 0,
725 analytical_gradient: 1.0,
726 numerical_gradient: 1.01,
727 relative_error: 0.0099,
728 absolute_error: 0.01,
729 passed: true,
730 };
731
732 assert_eq!(result.param_index, 0);
733 assert!(result.passed);
734 assert_eq!(result.analytical_gradient, 1.0);
735 }
736
737 #[test]
738 fn test_gradient_check_results() {
739 let param_results = vec![
740 ParameterGradientResult {
741 param_index: 0,
742 analytical_gradient: 1.0,
743 numerical_gradient: 1.01,
744 relative_error: 0.0099,
745 absolute_error: 0.01,
746 passed: true,
747 },
748 ParameterGradientResult {
749 param_index: 1,
750 analytical_gradient: 2.0,
751 numerical_gradient: 2.2,
752 relative_error: 0.091,
753 absolute_error: 0.2,
754 passed: false,
755 },
756 ];
757
758 let results = GradientCheckResults {
759 all_passed: false,
760 num_checked: 2,
761 num_passed: 1,
762 max_relative_error: 0.091,
763 max_absolute_error: 0.2,
764 avg_relative_error: 0.05045,
765 avg_absolute_error: 0.105,
766 parameter_results: param_results,
767 };
768
769 assert!(!results.all_passed);
770 assert_eq!(results.num_checked, 2);
771 assert_eq!(results.num_passed, 1);
772 }
773
774 #[test]
775 fn test_generate_report() {
776 let config = GradientCheckConfig::<f32>::default();
777 let checker = GradientChecker::new(config);
778
779 let results = GradientCheckResults {
780 all_passed: true,
781 num_checked: 10,
782 num_passed: 10,
783 max_relative_error: 1e-6,
784 max_absolute_error: 1e-8,
785 avg_relative_error: 1e-7,
786 avg_absolute_error: 1e-9,
787 parameter_results: Vec::new(),
788 };
789
790 let report = checker.generate_report(&results);
791 assert!(report.contains("PASSED"));
792 assert!(report.contains("Parameters Checked: 10"));
793 assert!(report.contains("Pass Rate: 100.00%"));
794 }
795
796 #[test]
797 fn test_layer_gradient_checking() {
798 let config = GradientCheckConfig::<f32>::default();
799 let checker = GradientChecker::new(config);
800
801 let mut layer = DenseLayer::<f32>::new(5, 3, Some(Activation::Relu));
802 let inputs = Array2::from_shape_vec((2, 5), vec![1.0; 10]).unwrap();
803 let output_grads = Array2::from_shape_vec((2, 3), vec![0.1; 6]).unwrap();
804
805 let results = checker
806 .check_layer_gradients(&mut layer, &inputs, &output_grads)
807 .unwrap();
808 assert!(results.num_checked > 0);
809 }
811}