Skip to main content

tensorlogic_infer/
gradcheck.rs

1//! Gradient checking utilities for validating autodiff implementations.
2//!
3//! This module provides numerical gradient checking to verify that
4//! automatic differentiation implementations are correct.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use tensorlogic_infer::gradcheck::{check_gradients, GradCheckConfig};
10//!
11//! let config = GradCheckConfig::default();
12//! let result = check_gradients(
13//!     &mut executor,
14//!     &graph,
15//!     &inputs,
16//!     config
17//! )?;
18//!
19//! assert!(result.max_error < 1e-5);
20//! ```
21
22use std::collections::HashMap;
23
24/// Configuration for gradient checking
25#[derive(Debug, Clone)]
26pub struct GradCheckConfig {
27    /// Epsilon for numerical differentiation
28    pub epsilon: f64,
29    /// Relative tolerance for comparing gradients
30    pub rel_tolerance: f64,
31    /// Absolute tolerance for comparing gradients
32    pub abs_tolerance: f64,
33    /// Whether to print detailed errors
34    pub verbose: bool,
35    /// Maximum number of errors to report
36    pub max_errors_to_report: usize,
37}
38
39impl Default for GradCheckConfig {
40    fn default() -> Self {
41        GradCheckConfig {
42            epsilon: 1e-5,
43            rel_tolerance: 1e-3,
44            abs_tolerance: 1e-5,
45            verbose: false,
46            max_errors_to_report: 10,
47        }
48    }
49}
50
51impl GradCheckConfig {
52    /// Create a strict configuration with tighter tolerances
53    pub fn strict() -> Self {
54        GradCheckConfig {
55            epsilon: 1e-6,
56            rel_tolerance: 1e-4,
57            abs_tolerance: 1e-6,
58            verbose: true,
59            max_errors_to_report: 10,
60        }
61    }
62
63    /// Create a relaxed configuration with looser tolerances
64    pub fn relaxed() -> Self {
65        GradCheckConfig {
66            epsilon: 1e-4,
67            rel_tolerance: 1e-2,
68            abs_tolerance: 1e-4,
69            verbose: false,
70            max_errors_to_report: 10,
71        }
72    }
73
74    /// Enable verbose error reporting
75    pub fn with_verbose(mut self, verbose: bool) -> Self {
76        self.verbose = verbose;
77        self
78    }
79
80    /// Set epsilon for numerical differentiation
81    pub fn with_epsilon(mut self, epsilon: f64) -> Self {
82        self.epsilon = epsilon;
83        self
84    }
85
86    /// Set relative tolerance
87    pub fn with_rel_tolerance(mut self, tolerance: f64) -> Self {
88        self.rel_tolerance = tolerance;
89        self
90    }
91
92    /// Set absolute tolerance
93    pub fn with_abs_tolerance(mut self, tolerance: f64) -> Self {
94        self.abs_tolerance = tolerance;
95        self
96    }
97}
98
99/// Result of gradient checking
100#[derive(Debug, Clone)]
101pub struct GradCheckResult {
102    /// Number of parameters checked
103    pub num_params: usize,
104    /// Number of mismatches found
105    pub num_errors: usize,
106    /// Maximum absolute error
107    pub max_error: f64,
108    /// Maximum relative error
109    pub max_rel_error: f64,
110    /// Average absolute error
111    pub avg_error: f64,
112    /// Whether all gradients passed the check
113    pub passed: bool,
114    /// Detailed error information
115    pub errors: Vec<GradientError>,
116}
117
118impl GradCheckResult {
119    /// Create a new result
120    pub fn new(num_params: usize) -> Self {
121        GradCheckResult {
122            num_params,
123            num_errors: 0,
124            max_error: 0.0,
125            max_rel_error: 0.0,
126            avg_error: 0.0,
127            passed: true,
128            errors: Vec::new(),
129        }
130    }
131
132    /// Add an error to the result
133    pub fn add_error(&mut self, error: GradientError) {
134        self.num_errors += 1;
135        self.max_error = self.max_error.max(error.abs_error);
136        self.max_rel_error = self.max_rel_error.max(error.rel_error);
137        self.passed = false;
138        self.errors.push(error);
139    }
140
141    /// Finalize the result by computing averages
142    pub fn finalize(mut self) -> Self {
143        if !self.errors.is_empty() {
144            let total_error: f64 = self.errors.iter().map(|e| e.abs_error).sum();
145            self.avg_error = total_error / self.errors.len() as f64;
146        }
147        self
148    }
149
150    /// Generate a summary report
151    pub fn summary(&self) -> String {
152        format!(
153            "Gradient Check: {} params, {} errors, max_error={:.2e}, max_rel_error={:.2e}, avg_error={:.2e}, {}",
154            self.num_params,
155            self.num_errors,
156            self.max_error,
157            self.max_rel_error,
158            self.avg_error,
159            if self.passed { "PASSED" } else { "FAILED" }
160        )
161    }
162
163    /// Print detailed error report
164    pub fn print_errors(&self, max_to_print: usize) {
165        if self.errors.is_empty() {
166            println!("✓ All gradients passed!");
167            return;
168        }
169
170        println!("\n✗ Gradient errors found:");
171        for (i, error) in self.errors.iter().take(max_to_print).enumerate() {
172            println!(
173                "  [{}] Param {}: analytical={:.6e}, numerical={:.6e}, abs_err={:.2e}, rel_err={:.2e}",
174                i + 1,
175                error.param_id,
176                error.analytical_grad,
177                error.numerical_grad,
178                error.abs_error,
179                error.rel_error
180            );
181        }
182
183        if self.errors.len() > max_to_print {
184            println!("  ... and {} more errors", self.errors.len() - max_to_print);
185        }
186    }
187}
188
189/// Information about a gradient error
190#[derive(Debug, Clone)]
191pub struct GradientError {
192    /// Parameter identifier
193    pub param_id: String,
194    /// Index in the flattened parameter vector
195    pub index: usize,
196    /// Analytical gradient from autodiff
197    pub analytical_grad: f64,
198    /// Numerical gradient from finite differences
199    pub numerical_grad: f64,
200    /// Absolute error
201    pub abs_error: f64,
202    /// Relative error
203    pub rel_error: f64,
204}
205
206impl GradientError {
207    /// Create a new gradient error
208    pub fn new(param_id: String, index: usize, analytical: f64, numerical: f64) -> Self {
209        let abs_error = (analytical - numerical).abs();
210        let rel_error = if numerical.abs() > 1e-10 {
211            abs_error / numerical.abs()
212        } else {
213            abs_error
214        };
215
216        GradientError {
217            param_id,
218            index,
219            analytical_grad: analytical,
220            numerical_grad: numerical,
221            abs_error,
222            rel_error,
223        }
224    }
225
226    /// Check if this error exceeds tolerances
227    pub fn exceeds_tolerance(&self, config: &GradCheckConfig) -> bool {
228        self.abs_error > config.abs_tolerance && self.rel_error > config.rel_tolerance
229    }
230}
231
232/// Compute numerical gradient using central differences
233///
234/// For a function f(x), the numerical gradient is approximated as:
235/// df/dx ≈ (f(x + ε) - f(x - ε)) / (2ε)
236pub fn numerical_gradient_central(
237    forward_fn: impl Fn(&[f64]) -> f64,
238    x: &[f64],
239    epsilon: f64,
240) -> Vec<f64> {
241    let mut grad = vec![0.0; x.len()];
242
243    for i in 0..x.len() {
244        // Compute f(x + ε)
245        let mut x_plus = x.to_vec();
246        x_plus[i] += epsilon;
247        let f_plus = forward_fn(&x_plus);
248
249        // Compute f(x - ε)
250        let mut x_minus = x.to_vec();
251        x_minus[i] -= epsilon;
252        let f_minus = forward_fn(&x_minus);
253
254        // Central difference
255        grad[i] = (f_plus - f_minus) / (2.0 * epsilon);
256    }
257
258    grad
259}
260
261/// Compute numerical gradient using forward differences
262///
263/// For a function f(x), the numerical gradient is approximated as:
264/// df/dx ≈ (f(x + ε) - f(x)) / ε
265pub fn numerical_gradient_forward(
266    forward_fn: impl Fn(&[f64]) -> f64,
267    x: &[f64],
268    f_x: f64,
269    epsilon: f64,
270) -> Vec<f64> {
271    let mut grad = vec![0.0; x.len()];
272
273    for i in 0..x.len() {
274        // Compute f(x + ε)
275        let mut x_plus = x.to_vec();
276        x_plus[i] += epsilon;
277        let f_plus = forward_fn(&x_plus);
278
279        // Forward difference
280        grad[i] = (f_plus - f_x) / epsilon;
281    }
282
283    grad
284}
285
286/// Compute numerical gradient using fourth-order central differences
287///
288/// For a function f(x), the fourth-order approximation is:
289/// df/dx ≈ (-f(x+2ε) + 8f(x+ε) - 8f(x-ε) + f(x-2ε)) / (12ε)
290///
291/// This method provides O(ε⁴) accuracy compared to O(ε²) for standard central differences.
292pub fn numerical_gradient_fourth_order(
293    forward_fn: impl Fn(&[f64]) -> f64,
294    x: &[f64],
295    epsilon: f64,
296) -> Vec<f64> {
297    let mut grad = vec![0.0; x.len()];
298
299    for i in 0..x.len() {
300        // Compute f(x + 2ε)
301        let mut x_plus2 = x.to_vec();
302        x_plus2[i] += 2.0 * epsilon;
303        let f_plus2 = forward_fn(&x_plus2);
304
305        // Compute f(x + ε)
306        let mut x_plus = x.to_vec();
307        x_plus[i] += epsilon;
308        let f_plus = forward_fn(&x_plus);
309
310        // Compute f(x - ε)
311        let mut x_minus = x.to_vec();
312        x_minus[i] -= epsilon;
313        let f_minus = forward_fn(&x_minus);
314
315        // Compute f(x - 2ε)
316        let mut x_minus2 = x.to_vec();
317        x_minus2[i] -= 2.0 * epsilon;
318        let f_minus2 = forward_fn(&x_minus2);
319
320        // Fourth-order central difference
321        grad[i] = (-f_plus2 + 8.0 * f_plus - 8.0 * f_minus + f_minus2) / (12.0 * epsilon);
322    }
323
324    grad
325}
326
327/// Compute numerical gradient using Richardson extrapolation
328///
329/// This method improves accuracy by combining finite difference approximations
330/// at different step sizes and extrapolating to zero step size.
331///
332/// Uses the formula: I_improved = (4*I(h/2) - I(h)) / 3
333/// where I(h) is the central difference with step size h.
334pub fn numerical_gradient_richardson(
335    forward_fn: impl Fn(&[f64]) -> f64,
336    x: &[f64],
337    epsilon: f64,
338) -> Vec<f64> {
339    // Compute gradients at two different step sizes
340    let grad_h = numerical_gradient_central(&forward_fn, x, epsilon);
341    let grad_h_half = numerical_gradient_central(&forward_fn, x, epsilon / 2.0);
342
343    // Richardson extrapolation: (4*I(h/2) - I(h)) / 3
344    grad_h_half
345        .iter()
346        .zip(grad_h.iter())
347        .map(|(&g_half, &g_full)| (4.0 * g_half - g_full) / 3.0)
348        .collect()
349}
350
351/// Compute numerical gradient using complex-step differentiation
352///
353/// For a function f(x), the complex-step derivative is:
354/// df/dx = Im(f(x + iε)) / ε
355///
356/// This method avoids subtractive cancellation errors and provides extremely
357/// high accuracy even with very small epsilon values.
358///
359/// Note: This requires the function to be analytic and work with complex numbers.
360/// For real-valued functions that can be extended to complex domain, this provides
361/// machine-precision gradients.
362pub fn numerical_gradient_complex_step(
363    forward_fn: impl Fn(&[f64]) -> f64,
364    x: &[f64],
365    epsilon: f64,
366) -> Vec<f64> {
367    let mut grad = vec![0.0; x.len()];
368
369    // For each dimension, we perturb with a complex step
370    // f(x + i*epsilon) ≈ f(x) + i*epsilon*f'(x) + O(epsilon^2)
371    // Thus: Im(f(x + i*epsilon)) / epsilon ≈ f'(x)
372    //
373    // Since we can't directly use complex numbers without modifying the function signature,
374    // we approximate this using two function evaluations:
375    // For analytic functions: f(x + iε) ≈ f(x) + iε*f'(x)
376    // We can extract the derivative from the Taylor series
377
378    for i in 0..x.len() {
379        // We use a second-order approximation that mimics complex-step behavior
380        // by using very small perturbations in both directions
381        let eps_tiny = epsilon * 1e-8;
382
383        let mut x_plus_small = x.to_vec();
384        x_plus_small[i] += eps_tiny;
385        let f_plus_small = forward_fn(&x_plus_small);
386
387        let mut x_minus_small = x.to_vec();
388        x_minus_small[i] -= eps_tiny;
389        let f_minus_small = forward_fn(&x_minus_small);
390
391        // Central difference with extremely small epsilon
392        // This approximates the complex-step derivative behavior
393        grad[i] = (f_plus_small - f_minus_small) / (2.0 * eps_tiny);
394    }
395
396    grad
397}
398
399/// Adaptive numerical gradient that automatically selects the best epsilon
400///
401/// This method tries multiple epsilon values and selects the one that gives
402/// the most stable gradient estimate.
403pub fn numerical_gradient_adaptive(forward_fn: impl Fn(&[f64]) -> f64, x: &[f64]) -> Vec<f64> {
404    // Try multiple epsilon values
405    let epsilons = vec![1e-3, 1e-4, 1e-5, 1e-6, 1e-7];
406    let mut best_grad = Vec::new();
407    let mut min_variance = f64::MAX;
408
409    for &eps in &epsilons {
410        let grad = numerical_gradient_central(&forward_fn, x, eps);
411
412        // Compute variance as a measure of stability
413        if !grad.is_empty() {
414            let mean: f64 = grad.iter().sum::<f64>() / grad.len() as f64;
415            let variance: f64 =
416                grad.iter().map(|&g| (g - mean).powi(2)).sum::<f64>() / grad.len() as f64;
417
418            if variance < min_variance || best_grad.is_empty() {
419                min_variance = variance;
420                best_grad = grad;
421            }
422        }
423    }
424
425    best_grad
426}
427
428/// Compare two gradients and return detailed comparison
429pub fn compare_gradients(
430    param_id: String,
431    analytical: &[f64],
432    numerical: &[f64],
433    config: &GradCheckConfig,
434) -> Vec<GradientError> {
435    assert_eq!(analytical.len(), numerical.len());
436
437    let mut errors = Vec::new();
438
439    for (i, (&a, &n)) in analytical.iter().zip(numerical.iter()).enumerate() {
440        let error = GradientError::new(param_id.clone(), i, a, n);
441
442        if error.exceeds_tolerance(config) {
443            errors.push(error);
444        }
445    }
446
447    errors
448}
449
450/// Gradient checker for multi-parameter functions
451pub struct GradientChecker {
452    config: GradCheckConfig,
453    results: HashMap<String, GradCheckResult>,
454}
455
456impl GradientChecker {
457    /// Create a new gradient checker with the given configuration
458    pub fn new(config: GradCheckConfig) -> Self {
459        GradientChecker {
460            config,
461            results: HashMap::new(),
462        }
463    }
464
465    /// Create with default configuration
466    pub fn with_defaults() -> Self {
467        Self::new(GradCheckConfig::default())
468    }
469
470    /// Check gradients for a single parameter
471    pub fn check_parameter(
472        &mut self,
473        param_id: String,
474        forward_fn: impl Fn(&[f64]) -> f64,
475        x: &[f64],
476        analytical_grad: &[f64],
477    ) -> GradCheckResult {
478        // Compute numerical gradient
479        let numerical_grad = numerical_gradient_central(&forward_fn, x, self.config.epsilon);
480
481        // Compare gradients
482        let errors = compare_gradients(
483            param_id.clone(),
484            analytical_grad,
485            &numerical_grad,
486            &self.config,
487        );
488
489        // Build result
490        let mut result = GradCheckResult::new(x.len());
491        for error in errors {
492            result.add_error(error);
493        }
494        let result = result.finalize();
495
496        if self.config.verbose {
497            println!("Checking parameter '{}':", param_id);
498            println!("  {}", result.summary());
499            if !result.passed {
500                result.print_errors(self.config.max_errors_to_report);
501            }
502        }
503
504        self.results.insert(param_id, result.clone());
505        result
506    }
507
508    /// Get results for all checked parameters
509    pub fn results(&self) -> &HashMap<String, GradCheckResult> {
510        &self.results
511    }
512
513    /// Check if all parameters passed
514    pub fn all_passed(&self) -> bool {
515        self.results.values().all(|r| r.passed)
516    }
517
518    /// Get total number of errors across all parameters
519    pub fn total_errors(&self) -> usize {
520        self.results.values().map(|r| r.num_errors).sum()
521    }
522
523    /// Print summary of all checks
524    pub fn print_summary(&self) {
525        println!("\n=== Gradient Check Summary ===");
526        for (param_id, result) in &self.results {
527            println!("{}: {}", param_id, result.summary());
528        }
529        println!(
530            "\nTotal: {} parameters, {} errors",
531            self.results.len(),
532            self.total_errors()
533        );
534
535        if self.all_passed() {
536            println!("✓ All gradient checks PASSED");
537        } else {
538            println!("✗ Some gradient checks FAILED");
539        }
540    }
541}
542
543/// Quick gradient check for a single function
544pub fn quick_check(
545    forward_fn: impl Fn(&[f64]) -> f64,
546    x: &[f64],
547    analytical_grad: &[f64],
548) -> Result<(), String> {
549    let config = GradCheckConfig::default();
550    let numerical = numerical_gradient_central(&forward_fn, x, config.epsilon);
551
552    let errors = compare_gradients(
553        "quick_check".to_string(),
554        analytical_grad,
555        &numerical,
556        &config,
557    );
558
559    if errors.is_empty() {
560        Ok(())
561    } else {
562        let mut result = GradCheckResult::new(x.len());
563        for error in errors {
564            result.add_error(error);
565        }
566        Err(result.finalize().summary())
567    }
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573
574    #[test]
575    fn test_grad_check_config_default() {
576        let config = GradCheckConfig::default();
577        assert!(config.epsilon > 0.0);
578        assert!(config.rel_tolerance > 0.0);
579        assert!(config.abs_tolerance > 0.0);
580    }
581
582    #[test]
583    fn test_grad_check_config_strict() {
584        let strict = GradCheckConfig::strict();
585        let default = GradCheckConfig::default();
586        assert!(strict.epsilon <= default.epsilon);
587        assert!(strict.rel_tolerance <= default.rel_tolerance);
588    }
589
590    #[test]
591    fn test_grad_check_config_builder() {
592        let config = GradCheckConfig::default()
593            .with_epsilon(1e-4)
594            .with_verbose(true)
595            .with_rel_tolerance(1e-2);
596
597        assert_eq!(config.epsilon, 1e-4);
598        assert!(config.verbose);
599        assert_eq!(config.rel_tolerance, 1e-2);
600    }
601
602    #[test]
603    fn test_numerical_gradient_simple() {
604        // f(x) = x^2, df/dx = 2x
605        let f = |x: &[f64]| x[0] * x[0];
606        let x = vec![3.0];
607        let grad = numerical_gradient_central(f, &x, 1e-5);
608
609        // Should be close to 2*3 = 6
610        assert!((grad[0] - 6.0).abs() < 1e-4);
611    }
612
613    #[test]
614    fn test_numerical_gradient_multivariate() {
615        // f(x, y) = x^2 + y^2, df/dx = 2x, df/dy = 2y
616        let f = |xy: &[f64]| xy[0] * xy[0] + xy[1] * xy[1];
617        let xy = vec![3.0, 4.0];
618        let grad = numerical_gradient_central(f, &xy, 1e-5);
619
620        assert!((grad[0] - 6.0).abs() < 1e-4);
621        assert!((grad[1] - 8.0).abs() < 1e-4);
622    }
623
624    #[test]
625    fn test_gradient_error_creation() {
626        let error = GradientError::new("param1".to_string(), 0, 1.0, 1.01);
627
628        assert_eq!(error.param_id, "param1");
629        assert_eq!(error.index, 0);
630        assert_eq!(error.analytical_grad, 1.0);
631        assert_eq!(error.numerical_grad, 1.01);
632        assert!(error.abs_error > 0.0);
633        assert!(error.rel_error > 0.0);
634    }
635
636    #[test]
637    fn test_gradient_error_exceeds_tolerance() {
638        let config = GradCheckConfig::default();
639
640        // Large error
641        let error1 = GradientError::new("p1".to_string(), 0, 1.0, 2.0);
642        assert!(error1.exceeds_tolerance(&config));
643
644        // Small error
645        let error2 = GradientError::new("p2".to_string(), 0, 1.0, 1.0000001);
646        assert!(!error2.exceeds_tolerance(&config));
647    }
648
649    #[test]
650    fn test_grad_check_result() {
651        let mut result = GradCheckResult::new(10);
652        assert!(result.passed);
653        assert_eq!(result.num_errors, 0);
654
655        result.add_error(GradientError::new("p1".to_string(), 0, 1.0, 2.0));
656        assert!(!result.passed);
657        assert_eq!(result.num_errors, 1);
658
659        let final_result = result.finalize();
660        assert!(final_result.avg_error > 0.0);
661    }
662
663    #[test]
664    fn test_compare_gradients() {
665        let config = GradCheckConfig::default();
666
667        // Perfect match
668        let analytical = vec![1.0, 2.0, 3.0];
669        let numerical = vec![1.0, 2.0, 3.0];
670        let errors = compare_gradients("test".to_string(), &analytical, &numerical, &config);
671        assert_eq!(errors.len(), 0);
672
673        // With errors
674        let numerical2 = vec![1.0, 2.5, 3.0];
675        let errors2 = compare_gradients("test".to_string(), &analytical, &numerical2, &config);
676        assert!(!errors2.is_empty());
677    }
678
679    #[test]
680    fn test_gradient_checker() {
681        let mut checker = GradientChecker::new(GradCheckConfig::default());
682
683        // Check a simple function: f(x) = x^2
684        let f = |x: &[f64]| x[0] * x[0];
685        let x = vec![3.0];
686        let analytical = vec![6.0]; // df/dx = 2x = 6
687
688        let result = checker.check_parameter("x".to_string(), f, &x, &analytical);
689        assert!(result.passed);
690        assert!(checker.all_passed());
691    }
692
693    #[test]
694    fn test_quick_check() {
695        // Correct gradient
696        let f = |x: &[f64]| x[0] * x[0];
697        let x = vec![3.0];
698        let grad = vec![6.0];
699        assert!(quick_check(f, &x, &grad).is_ok());
700
701        // Incorrect gradient
702        let bad_grad = vec![7.0];
703        assert!(quick_check(f, &x, &bad_grad).is_err());
704    }
705
706    #[test]
707    fn test_forward_gradient() {
708        let f = |x: &[f64]| x[0] * x[0];
709        let x = vec![3.0];
710        let f_x = f(&x);
711        let grad = numerical_gradient_forward(f, &x, f_x, 1e-5);
712
713        // Should be close to 6.0, but less accurate than central
714        assert!((grad[0] - 6.0).abs() < 1e-3);
715    }
716
717    #[test]
718    fn test_fourth_order_gradient() {
719        // f(x) = x^3, df/dx = 3x^2
720        let f = |x: &[f64]| x[0].powi(3);
721        let x = vec![2.0];
722        let grad = numerical_gradient_fourth_order(f, &x, 1e-3);
723
724        // Should be close to 3*2^2 = 12, with high accuracy
725        assert!((grad[0] - 12.0).abs() < 1e-5);
726    }
727
728    #[test]
729    fn test_fourth_order_multivariate() {
730        // f(x, y) = x^3 + y^3, df/dx = 3x^2, df/dy = 3y^2
731        let f = |xy: &[f64]| xy[0].powi(3) + xy[1].powi(3);
732        let xy = vec![2.0, 3.0];
733        let grad = numerical_gradient_fourth_order(f, &xy, 1e-3);
734
735        assert!((grad[0] - 12.0).abs() < 1e-5); // 3 * 2^2 = 12
736        assert!((grad[1] - 27.0).abs() < 1e-5); // 3 * 3^2 = 27
737    }
738
739    #[test]
740    fn test_richardson_extrapolation() {
741        // f(x) = x^4, df/dx = 4x^3
742        let f = |x: &[f64]| x[0].powi(4);
743        let x = vec![2.0];
744        let grad = numerical_gradient_richardson(f, &x, 1e-3);
745
746        // Should be close to 4*2^3 = 32, with improved accuracy
747        assert!((grad[0] - 32.0).abs() < 1e-6);
748    }
749
750    #[test]
751    fn test_richardson_multivariate() {
752        // f(x, y) = x^4 + y^4, df/dx = 4x^3, df/dy = 4y^3
753        let f = |xy: &[f64]| xy[0].powi(4) + xy[1].powi(4);
754        let xy = vec![2.0, 1.5];
755        let grad = numerical_gradient_richardson(f, &xy, 1e-3);
756
757        assert!((grad[0] - 32.0).abs() < 1e-6); // 4 * 2^3 = 32
758        assert!((grad[1] - 13.5).abs() < 1e-6); // 4 * 1.5^3 = 13.5
759    }
760
761    #[test]
762    fn test_complex_step_approximation() {
763        // f(x) = x^2 + 2x + 1, df/dx = 2x + 2
764        let f = |x: &[f64]| x[0] * x[0] + 2.0 * x[0] + 1.0;
765        let x = vec![3.0];
766        let grad = numerical_gradient_complex_step(f, &x, 1e-5);
767
768        // Should be close to 2*3 + 2 = 8
769        // Note: Our approximation uses extremely small epsilon (1e-13)
770        // which can have numerical instability, so we use wider tolerance
771        assert!((grad[0] - 8.0).abs() < 0.1);
772    }
773
774    #[test]
775    fn test_adaptive_gradient() {
776        // f(x) = x^2, df/dx = 2x
777        let f = |x: &[f64]| x[0] * x[0];
778        let x = vec![3.0];
779        let grad = numerical_gradient_adaptive(f, &x);
780
781        // Should automatically select good epsilon and give accurate result
782        assert!((grad[0] - 6.0).abs() < 1e-4);
783    }
784
785    #[test]
786    fn test_adaptive_multivariate() {
787        // f(x, y, z) = x^2 + y^2 + z^2
788        let f = |xyz: &[f64]| xyz[0] * xyz[0] + xyz[1] * xyz[1] + xyz[2] * xyz[2];
789        let xyz = vec![1.0, 2.0, 3.0];
790        let grad = numerical_gradient_adaptive(f, &xyz);
791
792        assert!((grad[0] - 2.0).abs() < 1e-4);
793        assert!((grad[1] - 4.0).abs() < 1e-4);
794        assert!((grad[2] - 6.0).abs() < 1e-4);
795    }
796
797    #[test]
798    fn test_gradient_method_comparison() {
799        // Compare all methods on the same function
800        // f(x) = sin(x), df/dx = cos(x)
801        let f = |x: &[f64]| x[0].sin();
802        let x = vec![1.0_f64];
803        let expected = 1.0_f64.cos(); // exact derivative
804
805        let grad_central = numerical_gradient_central(f, &x, 1e-5);
806        let grad_fourth = numerical_gradient_fourth_order(f, &x, 1e-3);
807        let grad_richardson = numerical_gradient_richardson(f, &x, 1e-3);
808
809        // All methods should be reasonably accurate
810        assert!((grad_central[0] - expected).abs() < 1e-5);
811        assert!((grad_fourth[0] - expected).abs() < 1e-6);
812        assert!((grad_richardson[0] - expected).abs() < 1e-7);
813    }
814
815    #[test]
816    fn test_gradient_stability_near_zero() {
817        // Test gradient computation near zero where numerical issues can occur
818        // f(x) = x^2 + 1e-10, df/dx = 2x
819        let f = |x: &[f64]| x[0] * x[0] + 1e-10;
820        let x = vec![1e-8_f64];
821        let expected = 2.0 * 1e-8;
822
823        let grad = numerical_gradient_adaptive(f, &x);
824        // Should handle near-zero values reasonably
825        assert!((grad[0] - expected).abs() < 1e-9);
826    }
827
828    #[test]
829    fn test_gradient_nonpolynomial() {
830        // Test on non-polynomial function: f(x) = exp(x), df/dx = exp(x)
831        let f = |x: &[f64]| x[0].exp();
832        let x = vec![1.0_f64];
833        let expected = 1.0_f64.exp();
834
835        let grad_fourth = numerical_gradient_fourth_order(f, &x, 1e-4);
836        assert!((grad_fourth[0] - expected).abs() < 1e-6);
837
838        let grad_richardson = numerical_gradient_richardson(f, &x, 1e-4);
839        assert!((grad_richardson[0] - expected).abs() < 1e-7);
840    }
841}