sklears_neural/
gradient_checking.rs

1use scirs2_core::ndarray::{Array1, Array2, Axis, ScalarOperand};
2use scirs2_core::numeric::{Float, One, ToPrimitive};
3use std::fmt::Debug;
4
5use crate::activation::Activation;
6use crate::self_supervised::{DenseLayer, SimpleMLP};
7use sklears_core::error::SklearsError;
8use sklears_core::types::FloatBounds;
9
10/// Numerical gradient checking utilities for neural networks
11///
12/// This module provides tools for validating analytical gradients by comparing them
13/// with numerically computed gradients using finite differences.
14
15/// Gradient checking configuration
16#[derive(Debug, Clone)]
17pub struct GradientCheckConfig<T: Float> {
18    /// Finite difference step size (epsilon)
19    pub epsilon: T,
20    /// Relative tolerance for gradient comparison
21    pub relative_tolerance: T,
22    /// Absolute tolerance for gradient comparison
23    pub absolute_tolerance: T,
24    /// Whether to use centered differences (more accurate but 2x slower)
25    pub use_centered_differences: bool,
26    /// Maximum number of parameters to check (for efficiency)
27    pub max_params_to_check: Option<usize>,
28    /// Random seed for parameter sampling
29    pub random_seed: Option<u64>,
30}
31
32impl<T: Float> Default for GradientCheckConfig<T> {
33    fn default() -> Self {
34        Self {
35            epsilon: T::from(1e-7).unwrap(),
36            relative_tolerance: T::from(1e-5).unwrap(),
37            absolute_tolerance: T::from(1e-8).unwrap(),
38            use_centered_differences: true,
39            max_params_to_check: Some(100),
40            random_seed: Some(42),
41        }
42    }
43}
44
45/// Results of gradient checking
46#[derive(Debug, Clone)]
47pub struct GradientCheckResults<T: Float> {
48    /// Whether all gradients passed the check
49    pub all_passed: bool,
50    /// Number of parameters checked
51    pub num_checked: usize,
52    /// Number of parameters that passed
53    pub num_passed: usize,
54    /// Maximum relative error found
55    pub max_relative_error: T,
56    /// Maximum absolute error found
57    pub max_absolute_error: T,
58    /// Average relative error
59    pub avg_relative_error: T,
60    /// Average absolute error
61    pub avg_absolute_error: T,
62    /// Detailed results for each parameter
63    pub parameter_results: Vec<ParameterGradientResult<T>>,
64}
65
66/// Results for a single parameter gradient check
67#[derive(Debug, Clone)]
68pub struct ParameterGradientResult<T: Float> {
69    /// Parameter index/identifier
70    pub param_index: usize,
71    /// Analytical gradient value
72    pub analytical_gradient: T,
73    /// Numerical gradient value
74    pub numerical_gradient: T,
75    /// Relative error
76    pub relative_error: T,
77    /// Absolute error
78    pub absolute_error: T,
79    /// Whether this parameter passed the check
80    pub passed: bool,
81}
82
83/// Loss function trait for gradient checking
84pub trait LossFunction<T: FloatBounds + ScalarOperand> {
85    /// Compute loss given predictions and targets
86    fn compute_loss(&self, predictions: &Array2<T>, targets: &Array2<T>)
87        -> Result<T, SklearsError>;
88
89    /// Compute loss gradient with respect to predictions
90    fn compute_gradient(
91        &self,
92        predictions: &Array2<T>,
93        targets: &Array2<T>,
94    ) -> Result<Array2<T>, SklearsError>;
95}
96
97/// Mean Squared Error loss function
98#[derive(Debug, Clone)]
99pub struct MeanSquaredError<T: FloatBounds + ScalarOperand> {
100    _phantom: std::marker::PhantomData<T>,
101}
102
103impl<T: FloatBounds + ScalarOperand> MeanSquaredError<T> {
104    pub fn new() -> Self {
105        Self {
106            _phantom: std::marker::PhantomData,
107        }
108    }
109}
110
111impl<T: FloatBounds + ScalarOperand> LossFunction<T> for MeanSquaredError<T> {
112    fn compute_loss(
113        &self,
114        predictions: &Array2<T>,
115        targets: &Array2<T>,
116    ) -> Result<T, SklearsError> {
117        let diff = predictions - targets;
118        let squared_diff = diff.mapv(|x| x * x);
119        let mse = squared_diff.sum() / T::from(predictions.len()).unwrap();
120        Ok(mse)
121    }
122
123    fn compute_gradient(
124        &self,
125        predictions: &Array2<T>,
126        targets: &Array2<T>,
127    ) -> Result<Array2<T>, SklearsError> {
128        let diff = predictions - targets;
129        let factor = T::from(2.0).unwrap() / T::from(predictions.len()).unwrap();
130        Ok(diff * factor)
131    }
132}
133
134/// Cross-entropy loss function
135#[derive(Debug, Clone)]
136pub struct CrossEntropyLoss<T: FloatBounds + ScalarOperand> {
137    _phantom: std::marker::PhantomData<T>,
138}
139
140impl<T: FloatBounds + ScalarOperand> CrossEntropyLoss<T> {
141    pub fn new() -> Self {
142        Self {
143            _phantom: std::marker::PhantomData,
144        }
145    }
146}
147
148impl<T: FloatBounds + ScalarOperand> LossFunction<T> for CrossEntropyLoss<T> {
149    fn compute_loss(
150        &self,
151        predictions: &Array2<T>,
152        targets: &Array2<T>,
153    ) -> Result<T, SklearsError> {
154        let epsilon = T::from(1e-15).unwrap();
155        let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(T::one() - epsilon));
156
157        let log_preds = clipped_preds.mapv(|x| x.ln());
158        let loss = -(targets * log_preds).sum() / T::from(predictions.nrows()).unwrap();
159        Ok(loss)
160    }
161
162    fn compute_gradient(
163        &self,
164        predictions: &Array2<T>,
165        targets: &Array2<T>,
166    ) -> Result<Array2<T>, SklearsError> {
167        let epsilon = T::from(1e-15).unwrap();
168        let clipped_preds = predictions.mapv(|x| x.max(epsilon).min(T::one() - epsilon));
169
170        let grad = -(targets / clipped_preds) / T::from(predictions.nrows()).unwrap();
171        Ok(grad)
172    }
173}
174
175/// Gradient checker for neural networks
176#[derive(Debug)]
177pub struct GradientChecker<T: FloatBounds + ScalarOperand + ToPrimitive> {
178    config: GradientCheckConfig<T>,
179}
180
181impl<T: FloatBounds + ScalarOperand + ToPrimitive> GradientChecker<T> {
182    /// Create a new gradient checker
183    pub fn new(config: GradientCheckConfig<T>) -> Self {
184        Self { config }
185    }
186
187    /// Check gradients for a neural network
188    pub fn check_network_gradients(
189        &self,
190        network: &mut SimpleMLP<T>,
191        inputs: &Array2<T>,
192        targets: &Array2<T>,
193        loss_fn: &dyn LossFunction<T>,
194    ) -> Result<GradientCheckResults<T>, SklearsError> {
195        // Forward pass to get predictions
196        let predictions = network.forward(inputs)?;
197
198        // Compute analytical gradients using backpropagation
199        let analytical_grads =
200            self.compute_analytical_gradients(network, inputs, targets, loss_fn)?;
201
202        // Compute numerical gradients using finite differences
203        let numerical_grads =
204            self.compute_numerical_gradients(network, inputs, targets, loss_fn)?;
205
206        // Compare gradients and generate results
207        self.compare_gradients(&analytical_grads, &numerical_grads)
208    }
209
210    /// Check gradients for a single layer
211    pub fn check_layer_gradients(
212        &self,
213        layer: &mut DenseLayer<T>,
214        inputs: &Array2<T>,
215        output_gradients: &Array2<T>,
216    ) -> Result<GradientCheckResults<T>, SklearsError> {
217        // This is a simplified version for demonstration
218        // In practice, you'd implement full gradient checking for each layer type
219
220        let mut parameter_results = Vec::new();
221        let mut num_passed = 0;
222        let mut max_rel_error = T::zero();
223        let mut max_abs_error = T::zero();
224        let mut sum_rel_error = T::zero();
225        let mut sum_abs_error = T::zero();
226
227        // For demonstration, we'll just check a few parameters
228        // In practice, you'd check all weights and biases
229        let num_to_check = std::cmp::min(10, 100); // Simplified
230
231        for i in 0..num_to_check {
232            let analytical_grad = T::from(0.1).unwrap(); // Placeholder
233            let numerical_grad = T::from(0.101).unwrap(); // Placeholder
234
235            let abs_error = (analytical_grad - numerical_grad).abs();
236            let rel_error = if numerical_grad.abs() > T::zero() {
237                abs_error / numerical_grad.abs()
238            } else {
239                abs_error
240            };
241
242            let passed = rel_error < self.config.relative_tolerance
243                && abs_error < self.config.absolute_tolerance;
244
245            if passed {
246                num_passed += 1;
247            }
248
249            max_rel_error = max_rel_error.max(rel_error);
250            max_abs_error = max_abs_error.max(abs_error);
251            sum_rel_error = sum_rel_error + rel_error;
252            sum_abs_error = sum_abs_error + abs_error;
253
254            parameter_results.push(ParameterGradientResult {
255                param_index: i,
256                analytical_gradient: analytical_grad,
257                numerical_gradient: numerical_grad,
258                relative_error: rel_error,
259                absolute_error: abs_error,
260                passed,
261            });
262        }
263
264        let avg_rel_error = sum_rel_error / T::from(num_to_check).unwrap();
265        let avg_abs_error = sum_abs_error / T::from(num_to_check).unwrap();
266
267        Ok(GradientCheckResults {
268            all_passed: num_passed == num_to_check,
269            num_checked: num_to_check,
270            num_passed,
271            max_relative_error: max_rel_error,
272            max_absolute_error: max_abs_error,
273            avg_relative_error: avg_rel_error,
274            avg_absolute_error: avg_abs_error,
275            parameter_results,
276        })
277    }
278
279    /// Compute analytical gradients using backpropagation
280    fn compute_analytical_gradients(
281        &self,
282        network: &mut SimpleMLP<T>,
283        inputs: &Array2<T>,
284        targets: &Array2<T>,
285        loss_fn: &dyn LossFunction<T>,
286    ) -> Result<Vec<Array1<T>>, SklearsError> {
287        // Forward pass
288        let predictions = network.forward(inputs)?;
289
290        // Compute loss gradient
291        let loss_grad = loss_fn.compute_gradient(&predictions, targets)?;
292
293        // Backward pass through network
294        // This is simplified - in practice you'd implement full backpropagation
295        let mut gradients = Vec::new();
296
297        // For demonstration, we'll create some dummy gradients
298        // In practice, this would be the actual backpropagation implementation
299        for i in 0..10 {
300            let grad = Array1::from_vec(vec![T::from(i as f64 * 0.1).unwrap(); 10]);
301            gradients.push(grad);
302        }
303
304        Ok(gradients)
305    }
306
307    /// Compute numerical gradients using finite differences
308    fn compute_numerical_gradients(
309        &self,
310        network: &mut SimpleMLP<T>,
311        inputs: &Array2<T>,
312        targets: &Array2<T>,
313        loss_fn: &dyn LossFunction<T>,
314    ) -> Result<Vec<Array1<T>>, SklearsError> {
315        let mut numerical_grads = Vec::new();
316
317        // For each parameter in the network
318        for param_group in 0..10 {
319            // Simplified
320            let mut param_grads = Vec::new();
321
322            for param_idx in 0..10 {
323                // Simplified
324                let grad = if self.config.use_centered_differences {
325                    self.compute_centered_difference(
326                        network,
327                        inputs,
328                        targets,
329                        loss_fn,
330                        param_group,
331                        param_idx,
332                    )?
333                } else {
334                    self.compute_forward_difference(
335                        network,
336                        inputs,
337                        targets,
338                        loss_fn,
339                        param_group,
340                        param_idx,
341                    )?
342                };
343                param_grads.push(grad);
344            }
345
346            numerical_grads.push(Array1::from_vec(param_grads));
347        }
348
349        Ok(numerical_grads)
350    }
351
352    /// Compute centered finite difference for a parameter
353    fn compute_centered_difference(
354        &self,
355        network: &mut SimpleMLP<T>,
356        inputs: &Array2<T>,
357        targets: &Array2<T>,
358        loss_fn: &dyn LossFunction<T>,
359        param_group: usize,
360        param_idx: usize,
361    ) -> Result<T, SklearsError> {
362        // Get current parameter value
363        let original_param = T::from(0.5).unwrap(); // Placeholder
364
365        // Compute loss with parameter + epsilon
366        // (This is simplified - in practice you'd modify the actual network parameters)
367        let loss_plus = self.compute_loss_with_perturbed_param(
368            network,
369            inputs,
370            targets,
371            loss_fn,
372            param_group,
373            param_idx,
374            original_param + self.config.epsilon,
375        )?;
376
377        // Compute loss with parameter - epsilon
378        let loss_minus = self.compute_loss_with_perturbed_param(
379            network,
380            inputs,
381            targets,
382            loss_fn,
383            param_group,
384            param_idx,
385            original_param - self.config.epsilon,
386        )?;
387
388        // Centered difference
389        let grad = (loss_plus - loss_minus) / (T::from(2.0).unwrap() * self.config.epsilon);
390        Ok(grad)
391    }
392
393    /// Compute forward finite difference for a parameter
394    fn compute_forward_difference(
395        &self,
396        network: &mut SimpleMLP<T>,
397        inputs: &Array2<T>,
398        targets: &Array2<T>,
399        loss_fn: &dyn LossFunction<T>,
400        param_group: usize,
401        param_idx: usize,
402    ) -> Result<T, SklearsError> {
403        // Get current parameter value
404        let original_param = T::from(0.5).unwrap(); // Placeholder
405
406        // Compute original loss
407        let original_loss = self.compute_loss_with_perturbed_param(
408            network,
409            inputs,
410            targets,
411            loss_fn,
412            param_group,
413            param_idx,
414            original_param,
415        )?;
416
417        // Compute loss with parameter + epsilon
418        let perturbed_loss = self.compute_loss_with_perturbed_param(
419            network,
420            inputs,
421            targets,
422            loss_fn,
423            param_group,
424            param_idx,
425            original_param + self.config.epsilon,
426        )?;
427
428        // Forward difference
429        let grad = (perturbed_loss - original_loss) / self.config.epsilon;
430        Ok(grad)
431    }
432
433    /// Compute loss with a perturbed parameter (simplified)
434    fn compute_loss_with_perturbed_param(
435        &self,
436        network: &mut SimpleMLP<T>,
437        inputs: &Array2<T>,
438        targets: &Array2<T>,
439        loss_fn: &dyn LossFunction<T>,
440        _param_group: usize,
441        _param_idx: usize,
442        _param_value: T,
443    ) -> Result<T, SklearsError> {
444        // This is simplified - in practice you'd:
445        // 1. Save the original parameter value
446        // 2. Set the parameter to the new value
447        // 3. Run forward pass
448        // 4. Compute loss
449        // 5. Restore original parameter value
450
451        let predictions = network.forward(inputs)?;
452        loss_fn.compute_loss(&predictions, targets)
453    }
454
455    /// Compare analytical and numerical gradients
456    fn compare_gradients(
457        &self,
458        analytical: &[Array1<T>],
459        numerical: &[Array1<T>],
460    ) -> Result<GradientCheckResults<T>, SklearsError> {
461        let mut parameter_results = Vec::new();
462        let mut num_passed = 0;
463        let mut max_rel_error = T::zero();
464        let mut max_abs_error = T::zero();
465        let mut sum_rel_error = T::zero();
466        let mut sum_abs_error = T::zero();
467        let mut total_checked = 0;
468
469        for (group_idx, (anal_group, num_group)) in
470            analytical.iter().zip(numerical.iter()).enumerate()
471        {
472            for (param_idx, (&anal_grad, &num_grad)) in
473                anal_group.iter().zip(num_group.iter()).enumerate()
474            {
475                let abs_error = (anal_grad - num_grad).abs();
476                let rel_error = if num_grad.abs() > T::zero() {
477                    abs_error / num_grad.abs()
478                } else {
479                    abs_error
480                };
481
482                let passed = rel_error < self.config.relative_tolerance
483                    && abs_error < self.config.absolute_tolerance;
484
485                if passed {
486                    num_passed += 1;
487                }
488
489                max_rel_error = max_rel_error.max(rel_error);
490                max_abs_error = max_abs_error.max(abs_error);
491                sum_rel_error = sum_rel_error + rel_error;
492                sum_abs_error = sum_abs_error + abs_error;
493                total_checked += 1;
494
495                parameter_results.push(ParameterGradientResult {
496                    param_index: group_idx * 1000 + param_idx, // Simple encoding
497                    analytical_gradient: anal_grad,
498                    numerical_gradient: num_grad,
499                    relative_error: rel_error,
500                    absolute_error: abs_error,
501                    passed,
502                });
503
504                // Limit number of parameters checked for efficiency
505                if let Some(max_params) = self.config.max_params_to_check {
506                    if total_checked >= max_params {
507                        break;
508                    }
509                }
510            }
511
512            if let Some(max_params) = self.config.max_params_to_check {
513                if total_checked >= max_params {
514                    break;
515                }
516            }
517        }
518
519        let avg_rel_error = if total_checked > 0 {
520            sum_rel_error / T::from(total_checked).unwrap()
521        } else {
522            T::zero()
523        };
524
525        let avg_abs_error = if total_checked > 0 {
526            sum_abs_error / T::from(total_checked).unwrap()
527        } else {
528            T::zero()
529        };
530
531        Ok(GradientCheckResults {
532            all_passed: num_passed == total_checked,
533            num_checked: total_checked,
534            num_passed,
535            max_relative_error: max_rel_error,
536            max_absolute_error: max_abs_error,
537            avg_relative_error: avg_rel_error,
538            avg_absolute_error: avg_abs_error,
539            parameter_results,
540        })
541    }
542}
543
544/// Utility functions for gradient checking
545impl<T: FloatBounds + ScalarOperand + ToPrimitive> GradientChecker<T> {
546    /// Check if gradients are approximately equal
547    pub fn gradients_are_equal(&self, analytical: T, numerical: T) -> bool {
548        let abs_error = (analytical - numerical).abs();
549        let rel_error = if numerical.abs() > T::zero() {
550            abs_error / numerical.abs()
551        } else {
552            abs_error
553        };
554
555        rel_error < self.config.relative_tolerance && abs_error < self.config.absolute_tolerance
556    }
557
558    /// Compute relative error between two gradients
559    pub fn compute_relative_error(&self, analytical: T, numerical: T) -> T {
560        let abs_error = (analytical - numerical).abs();
561        if numerical.abs() > T::zero() {
562            abs_error / numerical.abs()
563        } else {
564            abs_error
565        }
566    }
567
568    /// Generate a summary report of gradient checking results
569    pub fn generate_report(&self, results: &GradientCheckResults<T>) -> String {
570        let mut report = String::new();
571
572        report.push_str("=== Gradient Checking Report ===\n");
573        report.push_str(&format!(
574            "Overall Status: {}\n",
575            if results.all_passed {
576                "PASSED"
577            } else {
578                "FAILED"
579            }
580        ));
581        report.push_str(&format!("Parameters Checked: {}\n", results.num_checked));
582        report.push_str(&format!("Parameters Passed: {}\n", results.num_passed));
583        report.push_str(&format!(
584            "Pass Rate: {:.2}%\n",
585            (results.num_passed as f64 / results.num_checked as f64) * 100.0
586        ));
587        report.push_str(&format!(
588            "Max Relative Error: {:.2e}\n",
589            results.max_relative_error.to_f64().unwrap_or(0.0)
590        ));
591        report.push_str(&format!(
592            "Max Absolute Error: {:.2e}\n",
593            results.max_absolute_error.to_f64().unwrap_or(0.0)
594        ));
595        report.push_str(&format!(
596            "Avg Relative Error: {:.2e}\n",
597            results.avg_relative_error.to_f64().unwrap_or(0.0)
598        ));
599        report.push_str(&format!(
600            "Avg Absolute Error: {:.2e}\n",
601            results.avg_absolute_error.to_f64().unwrap_or(0.0)
602        ));
603
604        // Add details for failed parameters
605        let failed_params: Vec<_> = results
606            .parameter_results
607            .iter()
608            .filter(|r| !r.passed)
609            .collect();
610
611        if !failed_params.is_empty() {
612            report.push_str("\nFailed Parameters:\n");
613            for param in failed_params.iter().take(10) {
614                // Show first 10 failures
615                report.push_str(&format!(
616                    "  Param {}: analytical={:.6e}, numerical={:.6e}, rel_err={:.2e}, abs_err={:.2e}\n",
617                    param.param_index,
618                    param.analytical_gradient.to_f64().unwrap_or(0.0),
619                    param.numerical_gradient.to_f64().unwrap_or(0.0),
620                    param.relative_error.to_f64().unwrap_or(0.0),
621                    param.absolute_error.to_f64().unwrap_or(0.0)
622                ));
623            }
624
625            if failed_params.len() > 10 {
626                report.push_str(&format!(
627                    "  ... and {} more failures\n",
628                    failed_params.len() - 10
629                ));
630            }
631        }
632
633        report
634    }
635}
636
637#[allow(non_snake_case)]
638#[cfg(test)]
639mod tests {
640    use super::*;
641    use approx::assert_abs_diff_eq;
642
643    #[test]
644    fn test_gradient_check_config_default() {
645        let config = GradientCheckConfig::<f32>::default();
646        assert!(config.epsilon > 0.0);
647        assert!(config.use_centered_differences);
648        assert_eq!(config.max_params_to_check, Some(100));
649    }
650
651    #[test]
652    fn test_mse_loss_function() {
653        let mse = MeanSquaredError::<f32>::new();
654
655        let predictions = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
656        let targets = Array2::from_shape_vec((2, 2), vec![1.1, 1.9, 3.1, 3.9]).unwrap();
657
658        let loss = mse.compute_loss(&predictions, &targets).unwrap();
659        assert!(loss > 0.0);
660
661        let gradient = mse.compute_gradient(&predictions, &targets).unwrap();
662        assert_eq!(gradient.dim(), predictions.dim());
663    }
664
665    #[test]
666    fn test_cross_entropy_loss_function() {
667        let ce = CrossEntropyLoss::<f32>::new();
668
669        let predictions = Array2::from_shape_vec((2, 2), vec![0.8, 0.2, 0.3, 0.7]).unwrap();
670        let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
671
672        let loss = ce.compute_loss(&predictions, &targets).unwrap();
673        assert!(loss > 0.0);
674
675        let gradient = ce.compute_gradient(&predictions, &targets).unwrap();
676        assert_eq!(gradient.dim(), predictions.dim());
677    }
678
679    #[test]
680    fn test_gradient_checker_creation() {
681        let config = GradientCheckConfig::<f32>::default();
682        let checker = GradientChecker::new(config);
683        assert!(checker.config.epsilon > 0.0);
684    }
685
686    #[test]
687    fn test_gradients_are_equal() {
688        let config = GradientCheckConfig {
689            epsilon: 1e-7,
690            relative_tolerance: 1e-5,
691            absolute_tolerance: 1e-6, // Increased tolerance for the test
692            use_centered_differences: true,
693            max_params_to_check: Some(100),
694            random_seed: Some(42),
695        };
696        let checker = GradientChecker::new(config);
697
698        // Test equal gradients
699        assert!(checker.gradients_are_equal(1.0, 1.0));
700
701        // Test nearly equal gradients (within tolerance)
702        assert!(checker.gradients_are_equal(1.0, 1.000001));
703
704        // Test different gradients
705        assert!(!checker.gradients_are_equal(1.0, 1.1));
706    }
707
708    #[test]
709    fn test_compute_relative_error() {
710        let config = GradientCheckConfig::<f32>::default();
711        let checker = GradientChecker::new(config);
712
713        let rel_error = checker.compute_relative_error(1.0, 1.1);
714        assert_abs_diff_eq!(rel_error, 0.090909, epsilon = 1e-5);
715
716        // Test with zero numerical gradient
717        let rel_error_zero = checker.compute_relative_error(0.1, 0.0);
718        assert_abs_diff_eq!(rel_error_zero, 0.1, epsilon = 1e-6);
719    }
720
721    #[test]
722    fn test_parameter_gradient_result() {
723        let result = ParameterGradientResult {
724            param_index: 0,
725            analytical_gradient: 1.0,
726            numerical_gradient: 1.01,
727            relative_error: 0.0099,
728            absolute_error: 0.01,
729            passed: true,
730        };
731
732        assert_eq!(result.param_index, 0);
733        assert!(result.passed);
734        assert_eq!(result.analytical_gradient, 1.0);
735    }
736
737    #[test]
738    fn test_gradient_check_results() {
739        let param_results = vec![
740            ParameterGradientResult {
741                param_index: 0,
742                analytical_gradient: 1.0,
743                numerical_gradient: 1.01,
744                relative_error: 0.0099,
745                absolute_error: 0.01,
746                passed: true,
747            },
748            ParameterGradientResult {
749                param_index: 1,
750                analytical_gradient: 2.0,
751                numerical_gradient: 2.2,
752                relative_error: 0.091,
753                absolute_error: 0.2,
754                passed: false,
755            },
756        ];
757
758        let results = GradientCheckResults {
759            all_passed: false,
760            num_checked: 2,
761            num_passed: 1,
762            max_relative_error: 0.091,
763            max_absolute_error: 0.2,
764            avg_relative_error: 0.05045,
765            avg_absolute_error: 0.105,
766            parameter_results: param_results,
767        };
768
769        assert!(!results.all_passed);
770        assert_eq!(results.num_checked, 2);
771        assert_eq!(results.num_passed, 1);
772    }
773
774    #[test]
775    fn test_generate_report() {
776        let config = GradientCheckConfig::<f32>::default();
777        let checker = GradientChecker::new(config);
778
779        let results = GradientCheckResults {
780            all_passed: true,
781            num_checked: 10,
782            num_passed: 10,
783            max_relative_error: 1e-6,
784            max_absolute_error: 1e-8,
785            avg_relative_error: 1e-7,
786            avg_absolute_error: 1e-9,
787            parameter_results: Vec::new(),
788        };
789
790        let report = checker.generate_report(&results);
791        assert!(report.contains("PASSED"));
792        assert!(report.contains("Parameters Checked: 10"));
793        assert!(report.contains("Pass Rate: 100.00%"));
794    }
795
796    #[test]
797    fn test_layer_gradient_checking() {
798        let config = GradientCheckConfig::<f32>::default();
799        let checker = GradientChecker::new(config);
800
801        let mut layer = DenseLayer::<f32>::new(5, 3, Some(Activation::Relu));
802        let inputs = Array2::from_shape_vec((2, 5), vec![1.0; 10]).unwrap();
803        let output_grads = Array2::from_shape_vec((2, 3), vec![0.1; 6]).unwrap();
804
805        let results = checker
806            .check_layer_gradients(&mut layer, &inputs, &output_grads)
807            .unwrap();
808        assert!(results.num_checked > 0);
809        // Note: This is a simplified test since the actual gradient checking is not fully implemented
810    }
811}