Skip to main content

trustformers_optim/
gradient_processing.rs

1//! # Gradient Processing Enhancements
2//!
3//! This module provides advanced gradient processing techniques that can improve
4//! training stability, convergence speed, and final model performance.
5//!
6//! ## Available Techniques
7//!
8//! - **Gradient Centralization**: Removes the mean of gradients to improve convergence
9//! - **Gradient Standardization**: Normalizes gradients to unit variance
10//! - **Adaptive Gradient Clipping**: Dynamically adjusts clipping based on gradient history
11//! - **Gradient Noise Injection**: Adds controlled noise to escape local minima
12//! - **Gradient Smoothing**: Applies exponential moving average to gradients
13//! - **Hessian-based Preconditioning**: Uses second-order information to precondition gradients
14
15use anyhow::{anyhow, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use trustformers_core::tensor::Tensor;
19
20/// Configuration for gradient processing techniques.
21#[derive(Debug, Clone, Serialize, Deserialize, Default)]
22pub struct GradientProcessingConfig {
23    /// Enable gradient centralization
24    pub enable_centralization: bool,
25    /// Enable gradient standardization
26    pub enable_standardization: bool,
27    /// Enable adaptive gradient clipping
28    pub enable_adaptive_clipping: bool,
29    /// Enable gradient noise injection
30    pub enable_noise_injection: bool,
31    /// Enable gradient smoothing
32    pub enable_smoothing: bool,
33    /// Enable Hessian-based preconditioning
34    pub enable_hessian_preconditioning: bool,
35    /// Adaptive clipping parameters
36    pub adaptive_clipping: AdaptiveClippingConfig,
37    /// Noise injection parameters
38    pub noise_injection: NoiseInjectionConfig,
39    /// Smoothing parameters
40    pub smoothing: SmoothingConfig,
41    /// Hessian preconditioning parameters
42    pub hessian_preconditioning: HessianPreconditioningConfig,
43}
44
45/// Configuration for adaptive gradient clipping.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct AdaptiveClippingConfig {
48    /// Initial clipping threshold
49    pub initial_clip_norm: f32,
50    /// Minimum clipping threshold
51    pub min_clip_norm: f32,
52    /// Maximum clipping threshold
53    pub max_clip_norm: f32,
54    /// Adaptation rate
55    pub adaptation_rate: f32,
56    /// Target gradient norm percentile
57    pub target_percentile: f32,
58    /// History window size for computing statistics
59    pub history_window: usize,
60}
61
62impl Default for AdaptiveClippingConfig {
63    fn default() -> Self {
64        Self {
65            initial_clip_norm: 1.0,
66            min_clip_norm: 0.1,
67            max_clip_norm: 10.0,
68            adaptation_rate: 0.01,
69            target_percentile: 0.9,
70            history_window: 100,
71        }
72    }
73}
74
75/// Configuration for gradient noise injection.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct NoiseInjectionConfig {
78    /// Initial noise scale
79    pub initial_noise_scale: f32,
80    /// Noise decay rate per step
81    pub decay_rate: f32,
82    /// Minimum noise scale
83    pub min_noise_scale: f32,
84    /// Noise type
85    pub noise_type: NoiseType,
86}
87
88impl Default for NoiseInjectionConfig {
89    fn default() -> Self {
90        Self {
91            initial_noise_scale: 0.1,
92            decay_rate: 0.999,
93            min_noise_scale: 1e-6,
94            noise_type: NoiseType::Gaussian,
95        }
96    }
97}
98
99/// Configuration for gradient smoothing.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct SmoothingConfig {
102    /// Exponential moving average decay rate
103    pub decay: f32,
104    /// Whether to debias the moving average
105    pub debias: bool,
106}
107
108impl Default for SmoothingConfig {
109    fn default() -> Self {
110        Self {
111            decay: 0.9,
112            debias: true,
113        }
114    }
115}
116
117/// Configuration for Hessian-based preconditioning.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct HessianPreconditioningConfig {
120    /// Type of Hessian approximation to use
121    pub approximation_type: HessianApproximationType,
122    /// Damping factor for numerical stability
123    pub damping: f32,
124    /// Update frequency for Hessian approximation (every N steps)
125    pub update_frequency: usize,
126    /// History window for maintaining Hessian approximation
127    pub history_window: usize,
128    /// Minimum eigenvalue threshold for conditioning
129    pub min_eigenvalue: f32,
130    /// Maximum condition number allowed
131    pub max_condition_number: f32,
132}
133
134impl Default for HessianPreconditioningConfig {
135    fn default() -> Self {
136        Self {
137            approximation_type: HessianApproximationType::Diagonal,
138            damping: 1e-4,
139            update_frequency: 10,
140            history_window: 20,
141            min_eigenvalue: 1e-8,
142            max_condition_number: 1e6,
143        }
144    }
145}
146
147/// Types of noise for gradient noise injection.
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub enum NoiseType {
150    Gaussian,
151    Uniform,
152    Laplace,
153}
154
155/// Types of Hessian approximation methods.
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub enum HessianApproximationType {
158    /// Use only the diagonal of the Hessian (most efficient)
159    Diagonal,
160    /// Use Gauss-Newton approximation (J^T J)
161    GaussNewton,
162    /// Use Fisher Information Matrix approximation
163    FisherInformation,
164    /// Use quasi-Newton L-BFGS-style approximation
165    QuasiNewton,
166}
167
168/// Gradient processor that applies various enhancement techniques.
169#[derive(Debug)]
170pub struct GradientProcessor {
171    config: GradientProcessingConfig,
172    current_step: usize,
173
174    // Adaptive clipping state
175    gradient_norm_history: Vec<f32>,
176    current_clip_norm: f32,
177
178    // Noise injection state
179    current_noise_scale: f32,
180
181    // Smoothing state
182    smoothed_gradients: HashMap<usize, Tensor>,
183    smoothing_bias_correction: f32,
184
185    // Hessian preconditioning state
186    hessian_diagonal: HashMap<usize, Tensor>,
187    hessian_inverse: HashMap<usize, Tensor>,
188    last_hessian_update: usize,
189    gradient_history: Vec<Vec<Tensor>>,
190}
191
192impl GradientProcessor {
193    /// Create a new gradient processor with the given configuration.
194    pub fn new(config: GradientProcessingConfig) -> Self {
195        Self {
196            current_clip_norm: config.adaptive_clipping.initial_clip_norm,
197            current_noise_scale: config.noise_injection.initial_noise_scale,
198            config,
199            current_step: 0,
200            gradient_norm_history: Vec::new(),
201            smoothed_gradients: HashMap::new(),
202            smoothing_bias_correction: 1.0,
203            hessian_diagonal: HashMap::new(),
204            hessian_inverse: HashMap::new(),
205            last_hessian_update: 0,
206            gradient_history: Vec::new(),
207        }
208    }
209
210    /// Create a gradient processor with default configuration.
211    pub fn with_defaults() -> Self {
212        Self::new(GradientProcessingConfig::default())
213    }
214
215    /// Process gradients with enabled techniques.
216    pub fn process_gradients(&mut self, gradients: &mut [Tensor]) -> Result<()> {
217        self.current_step += 1;
218
219        // Apply gradient centralization
220        if self.config.enable_centralization {
221            self.apply_centralization(gradients)?;
222        }
223
224        // Apply gradient standardization
225        if self.config.enable_standardization {
226            self.apply_standardization(gradients)?;
227        }
228
229        // Apply gradient smoothing
230        if self.config.enable_smoothing {
231            self.apply_smoothing(gradients)?;
232        }
233
234        // Apply Hessian-based preconditioning
235        if self.config.enable_hessian_preconditioning {
236            self.apply_hessian_preconditioning(gradients)?;
237        }
238
239        // Apply adaptive gradient clipping
240        if self.config.enable_adaptive_clipping {
241            self.apply_adaptive_clipping(gradients)?;
242        }
243
244        // Apply gradient noise injection
245        if self.config.enable_noise_injection {
246            self.apply_noise_injection(gradients)?;
247        }
248
249        Ok(())
250    }
251
252    /// Apply gradient centralization (remove mean).
253    fn apply_centralization(&self, gradients: &mut [Tensor]) -> Result<()> {
254        for gradient in gradients.iter_mut() {
255            // Compute mean across all dimensions
256            let mean = gradient.mean()?;
257            *gradient = gradient.sub(&mean)?;
258        }
259        Ok(())
260    }
261
262    /// Apply gradient standardization (normalize to unit variance).
263    fn apply_standardization(&self, gradients: &mut [Tensor]) -> Result<()> {
264        for gradient in gradients.iter_mut() {
265            // Compute standard deviation manually
266            let mean = gradient.mean()?;
267            let centered = gradient.sub(&mean)?;
268            let squared = centered.mul(&centered)?;
269            let variance = squared.mean()?;
270            let std_dev = variance.sqrt()?;
271
272            // Add small epsilon to prevent division by zero
273            let epsilon = Tensor::scalar(1e-8)?;
274            let std_dev_safe = std_dev.add(&epsilon)?;
275
276            // Normalize
277            *gradient = gradient.div(&std_dev_safe)?;
278        }
279        Ok(())
280    }
281
282    /// Apply adaptive gradient clipping.
283    fn apply_adaptive_clipping(&mut self, gradients: &mut [Tensor]) -> Result<()> {
284        // Compute total gradient norm
285        let mut total_norm_sq = 0.0;
286        for gradient in gradients.iter() {
287            let norm_sq = gradient.norm_squared()?.to_scalar()?;
288            total_norm_sq += norm_sq;
289        }
290        let total_norm = total_norm_sq.sqrt();
291
292        // Update gradient norm history
293        self.gradient_norm_history.push(total_norm);
294        if self.gradient_norm_history.len() > self.config.adaptive_clipping.history_window {
295            self.gradient_norm_history.remove(0);
296        }
297
298        // Update adaptive clipping threshold
299        if self.gradient_norm_history.len() >= 10 {
300            // Compute target percentile of gradient norms
301            let mut sorted_norms = self.gradient_norm_history.clone();
302            sorted_norms.sort_by(|a, b| a.partial_cmp(b).unwrap());
303            let percentile_idx = (sorted_norms.len() as f32
304                * self.config.adaptive_clipping.target_percentile)
305                as usize;
306            let target_norm = sorted_norms[percentile_idx.min(sorted_norms.len() - 1)];
307
308            // Adapt clipping threshold towards target
309            let adaptation = self.config.adaptive_clipping.adaptation_rate
310                * (target_norm - self.current_clip_norm);
311            self.current_clip_norm += adaptation;
312
313            // Clamp to bounds
314            self.current_clip_norm = self
315                .current_clip_norm
316                .max(self.config.adaptive_clipping.min_clip_norm)
317                .min(self.config.adaptive_clipping.max_clip_norm);
318        }
319
320        // Apply clipping if needed
321        if total_norm > self.current_clip_norm {
322            let clip_factor = self.current_clip_norm / total_norm;
323            for gradient in gradients.iter_mut() {
324                *gradient = gradient.mul_scalar(clip_factor)?;
325            }
326        }
327
328        Ok(())
329    }
330
331    /// Apply gradient noise injection.
332    fn apply_noise_injection(&mut self, gradients: &mut [Tensor]) -> Result<()> {
333        // Decay noise scale
334        self.current_noise_scale *= self.config.noise_injection.decay_rate;
335        self.current_noise_scale =
336            self.current_noise_scale.max(self.config.noise_injection.min_noise_scale);
337
338        for gradient in gradients.iter_mut() {
339            let noise = match self.config.noise_injection.noise_type {
340                NoiseType::Gaussian => {
341                    let noise_tensor = Tensor::randn(&gradient.shape())?;
342                    noise_tensor.mul_scalar(self.current_noise_scale)?;
343                    noise_tensor
344                },
345                NoiseType::Uniform => {
346                    let bound = self.current_noise_scale * 3.0_f32.sqrt(); // Match variance with Gaussian
347                    let noise_tensor = Tensor::randn(&gradient.shape())?;
348                    noise_tensor.mul_scalar(bound)?;
349                    noise_tensor
350                },
351                NoiseType::Laplace => {
352                    // Approximate Laplace with scaled Gaussian (simplified)
353                    let noise_tensor = Tensor::randn(&gradient.shape())?;
354                    noise_tensor.mul_scalar(self.current_noise_scale * 2.0_f32.sqrt())?;
355                    noise_tensor
356                },
357            };
358
359            *gradient = gradient.add(&noise)?;
360        }
361
362        Ok(())
363    }
364
365    /// Apply gradient smoothing with exponential moving average.
366    fn apply_smoothing(&mut self, gradients: &mut [Tensor]) -> Result<()> {
367        let decay = self.config.smoothing.decay;
368
369        for (i, gradient) in gradients.iter_mut().enumerate() {
370            if let Some(smoothed) = self.smoothed_gradients.get(&i) {
371                // Update smoothed gradient: smoothed = decay * smoothed + (1 - decay) * gradient
372                let new_smoothed =
373                    smoothed.mul_scalar(decay)?.add(&gradient.mul_scalar(1.0 - decay)?)?;
374                self.smoothed_gradients.insert(i, new_smoothed.clone());
375
376                // Apply bias correction if enabled
377                if self.config.smoothing.debias {
378                    self.smoothing_bias_correction *= decay;
379                    let bias_corrected =
380                        new_smoothed.div_scalar(1.0 - self.smoothing_bias_correction)?;
381                    *gradient = bias_corrected;
382                } else {
383                    *gradient = new_smoothed;
384                }
385            } else {
386                // First time seeing this gradient
387                self.smoothed_gradients.insert(i, gradient.clone());
388            }
389        }
390
391        Ok(())
392    }
393
394    /// Apply Hessian-based preconditioning to gradients.
395    fn apply_hessian_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
396        // Store gradient history for Hessian approximation
397        self.gradient_history.push(gradients.to_vec());
398        if self.gradient_history.len() > self.config.hessian_preconditioning.history_window {
399            self.gradient_history.remove(0);
400        }
401
402        // Update Hessian approximation if needed
403        if self.current_step - self.last_hessian_update
404            >= self.config.hessian_preconditioning.update_frequency
405        {
406            self.update_hessian_approximation(gradients)?;
407            self.last_hessian_update = self.current_step;
408        }
409
410        // Apply preconditioning based on approximation type
411        match self.config.hessian_preconditioning.approximation_type {
412            HessianApproximationType::Diagonal => {
413                self.apply_diagonal_preconditioning(gradients)?;
414            },
415            HessianApproximationType::GaussNewton => {
416                self.apply_gauss_newton_preconditioning(gradients)?;
417            },
418            HessianApproximationType::FisherInformation => {
419                self.apply_fisher_information_preconditioning(gradients)?;
420            },
421            HessianApproximationType::QuasiNewton => {
422                self.apply_quasi_newton_preconditioning(gradients)?;
423            },
424        }
425
426        Ok(())
427    }
428
429    /// Update Hessian approximation based on gradient history.
430    fn update_hessian_approximation(&mut self, gradients: &[Tensor]) -> Result<()> {
431        match self.config.hessian_preconditioning.approximation_type {
432            HessianApproximationType::Diagonal => {
433                self.update_diagonal_hessian(gradients)?;
434            },
435            HessianApproximationType::GaussNewton => {
436                self.update_gauss_newton_hessian(gradients)?;
437            },
438            HessianApproximationType::FisherInformation => {
439                self.update_fisher_information_hessian(gradients)?;
440            },
441            HessianApproximationType::QuasiNewton => {
442                self.update_quasi_newton_hessian(gradients)?;
443            },
444        }
445        Ok(())
446    }
447
448    /// Update diagonal Hessian approximation using gradient variance.
449    fn update_diagonal_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
450        for (i, gradient) in gradients.iter().enumerate() {
451            // Approximate diagonal Hessian using gradient variance over history
452            if self.gradient_history.len() > 1 {
453                let mut variance = Tensor::zeros(&gradient.shape())?;
454                let mut mean = Tensor::zeros(&gradient.shape())?;
455
456                // Compute mean
457                for grad_vec in &self.gradient_history {
458                    if let Some(hist_grad) = grad_vec.get(i) {
459                        mean = mean.add(hist_grad)?;
460                    }
461                }
462                mean = mean.div_scalar(self.gradient_history.len() as f32)?;
463
464                // Compute variance (approximation of diagonal Hessian)
465                for grad_vec in &self.gradient_history {
466                    if let Some(hist_grad) = grad_vec.get(i) {
467                        let diff = hist_grad.sub(&mean)?;
468                        variance = variance.add(&diff.mul(&diff)?)?;
469                    }
470                }
471                variance = variance.div_scalar(self.gradient_history.len() as f32)?;
472
473                // Add damping for numerical stability
474                let damping = Tensor::ones(&gradient.shape())?
475                    .mul_scalar(self.config.hessian_preconditioning.damping)?;
476                variance = variance.add(&damping)?;
477
478                self.hessian_diagonal.insert(i, variance);
479            }
480        }
481        Ok(())
482    }
483
484    /// Update Gauss-Newton Hessian approximation (simplified).
485    fn update_gauss_newton_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
486        // Simplified Gauss-Newton approximation using gradient outer product
487        for (i, gradient) in gradients.iter().enumerate() {
488            // Approximate with gradient outer product (simplified)
489            let outer_product = gradient.mul(gradient)?;
490
491            // Add damping
492            let damping = Tensor::ones(&gradient.shape())?
493                .mul_scalar(self.config.hessian_preconditioning.damping)?;
494            let hessian_approx = outer_product.add(&damping)?;
495
496            self.hessian_diagonal.insert(i, hessian_approx);
497        }
498        Ok(())
499    }
500
501    /// Update Fisher Information Matrix approximation.
502    fn update_fisher_information_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
503        // Fisher Information Matrix approximation (similar to Gauss-Newton for this context)
504        for (i, gradient) in gradients.iter().enumerate() {
505            // Approximate Fisher Information using gradient squared
506            let fisher_approx = gradient.mul(gradient)?;
507
508            // Add damping
509            let damping = Tensor::ones(&gradient.shape())?
510                .mul_scalar(self.config.hessian_preconditioning.damping)?;
511            let hessian_approx = fisher_approx.add(&damping)?;
512
513            self.hessian_diagonal.insert(i, hessian_approx);
514        }
515        Ok(())
516    }
517
518    /// Update quasi-Newton Hessian approximation using L-BFGS-style update.
519    fn update_quasi_newton_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
520        // Simplified quasi-Newton approximation using gradient differences
521        if self.gradient_history.len() > 1 {
522            for (i, gradient) in gradients.iter().enumerate() {
523                // Get previous gradient
524                if let Some(prev_grad_vec) =
525                    self.gradient_history.get(self.gradient_history.len() - 2)
526                {
527                    if let Some(prev_grad) = prev_grad_vec.get(i) {
528                        // Compute gradient difference
529                        let grad_diff = gradient.sub(prev_grad)?;
530
531                        // Approximate Hessian using gradient difference magnitude
532                        let hessian_approx = grad_diff.abs()?;
533
534                        // Add damping
535                        let damping = Tensor::ones(&gradient.shape())?
536                            .mul_scalar(self.config.hessian_preconditioning.damping)?;
537                        let final_hessian = hessian_approx.add(&damping)?;
538
539                        self.hessian_diagonal.insert(i, final_hessian);
540                    }
541                }
542            }
543        }
544        Ok(())
545    }
546
547    /// Apply diagonal preconditioning to gradients.
548    fn apply_diagonal_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
549        for (i, gradient) in gradients.iter_mut().enumerate() {
550            if let Some(hessian_diag) = self.hessian_diagonal.get(&i) {
551                // Compute preconditioned gradient: H^{-1} * g
552                // For diagonal H, this is element-wise division
553                let min_val = Tensor::scalar(self.config.hessian_preconditioning.min_eigenvalue)?;
554                let clamped_hessian = hessian_diag.max(&min_val)?;
555
556                *gradient = gradient.div(&clamped_hessian)?;
557            }
558        }
559        Ok(())
560    }
561
562    /// Apply Gauss-Newton preconditioning to gradients.
563    fn apply_gauss_newton_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
564        // For simplicity, use diagonal approximation
565        self.apply_diagonal_preconditioning(gradients)
566    }
567
568    /// Apply Fisher Information preconditioning to gradients.
569    fn apply_fisher_information_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
570        // For simplicity, use diagonal approximation
571        self.apply_diagonal_preconditioning(gradients)
572    }
573
574    /// Apply quasi-Newton preconditioning to gradients.
575    fn apply_quasi_newton_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
576        // For simplicity, use diagonal approximation
577        self.apply_diagonal_preconditioning(gradients)
578    }
579
580    /// Get current adaptive clipping threshold.
581    pub fn get_current_clip_norm(&self) -> f32 {
582        self.current_clip_norm
583    }
584
585    /// Get current noise scale.
586    pub fn get_current_noise_scale(&self) -> f32 {
587        self.current_noise_scale
588    }
589
590    /// Get gradient norm statistics.
591    pub fn get_gradient_norm_stats(&self) -> Option<(f32, f32, f32)> {
592        if self.gradient_norm_history.is_empty() {
593            return None;
594        }
595
596        let sum: f32 = self.gradient_norm_history.iter().sum();
597        let mean = sum / self.gradient_norm_history.len() as f32;
598
599        let variance = self.gradient_norm_history.iter().map(|x| (x - mean).powi(2)).sum::<f32>()
600            / self.gradient_norm_history.len() as f32;
601        let std_dev = variance.sqrt();
602
603        let max_norm = self.gradient_norm_history.iter().fold(0.0f32, |acc, &x| acc.max(x));
604
605        Some((mean, std_dev, max_norm))
606    }
607
608    /// Reset internal state.
609    pub fn reset(&mut self) {
610        self.current_step = 0;
611        self.gradient_norm_history.clear();
612        self.smoothed_gradients.clear();
613        self.current_clip_norm = self.config.adaptive_clipping.initial_clip_norm;
614        self.current_noise_scale = self.config.noise_injection.initial_noise_scale;
615        self.smoothing_bias_correction = 1.0;
616        self.hessian_diagonal.clear();
617        self.hessian_inverse.clear();
618        self.last_hessian_update = 0;
619        self.gradient_history.clear();
620    }
621
622    /// Update configuration.
623    pub fn set_config(&mut self, config: GradientProcessingConfig) {
624        self.config = config;
625        self.reset();
626    }
627
628    /// Get current configuration.
629    pub fn get_config(&self) -> &GradientProcessingConfig {
630        &self.config
631    }
632}
633
634/// Wrapper for optimizers that automatically applies gradient processing.
635pub struct GradientProcessedOptimizer<T> {
636    base_optimizer: T,
637    gradient_processor: GradientProcessor,
638}
639
640impl<T> GradientProcessedOptimizer<T> {
641    /// Create a new gradient-processed optimizer.
642    pub fn new(base_optimizer: T, config: GradientProcessingConfig) -> Self {
643        Self {
644            base_optimizer,
645            gradient_processor: GradientProcessor::new(config),
646        }
647    }
648
649    /// Create with default gradient processing configuration.
650    pub fn with_default_processing(base_optimizer: T) -> Self {
651        Self::new(base_optimizer, GradientProcessingConfig::default())
652    }
653
654    /// Get reference to the gradient processor.
655    pub fn gradient_processor(&self) -> &GradientProcessor {
656        &self.gradient_processor
657    }
658
659    /// Get mutable reference to the gradient processor.
660    pub fn gradient_processor_mut(&mut self) -> &mut GradientProcessor {
661        &mut self.gradient_processor
662    }
663
664    /// Get reference to the base optimizer.
665    pub fn base_optimizer(&self) -> &T {
666        &self.base_optimizer
667    }
668
669    /// Get mutable reference to the base optimizer.
670    pub fn base_optimizer_mut(&mut self) -> &mut T {
671        &mut self.base_optimizer
672    }
673}
674
675impl<T: crate::optimizer::OptimizerState> crate::optimizer::OptimizerState
676    for GradientProcessedOptimizer<T>
677{
678    fn zero_grad(&mut self) -> Result<()> {
679        self.base_optimizer.zero_grad()
680    }
681
682    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
683        // Extract gradients from parameters
684        let mut gradients = Vec::new();
685        for param in parameters.iter() {
686            if let Ok(grad) = param.grad() {
687                gradients.push(grad);
688            } else {
689                return Err(anyhow!("Parameter missing gradient"));
690            }
691        }
692
693        // Process gradients
694        self.gradient_processor.process_gradients(&mut gradients)?;
695
696        // Update parameter gradients with processed versions
697        for (param, processed_grad) in parameters.iter_mut().zip(gradients.iter()) {
698            param.set_grad(processed_grad.clone())?;
699        }
700
701        // Perform optimization step
702        self.base_optimizer.step(parameters)
703    }
704
705    fn get_lr(&self) -> f32 {
706        self.base_optimizer.get_lr()
707    }
708
709    fn set_lr(&mut self, lr: f32) {
710        self.base_optimizer.set_lr(lr);
711    }
712
713    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
714        // For simplicity, we'll only save the base optimizer state
715        // In a full implementation, we'd also save gradient processor state
716        self.base_optimizer.state_dict()
717    }
718
719    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
720        self.base_optimizer.load_state_dict(state)
721    }
722}
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727
728    #[test]
729    fn test_gradient_processing_config_default() {
730        let config = GradientProcessingConfig::default();
731        assert!(!config.enable_centralization);
732        assert!(!config.enable_standardization);
733        assert!(!config.enable_adaptive_clipping);
734        assert!(!config.enable_noise_injection);
735        assert!(!config.enable_smoothing);
736    }
737
738    #[test]
739    fn test_adaptive_clipping_config_default() {
740        let config = AdaptiveClippingConfig::default();
741        assert_eq!(config.initial_clip_norm, 1.0);
742        assert_eq!(config.min_clip_norm, 0.1);
743        assert_eq!(config.max_clip_norm, 10.0);
744        assert_eq!(config.adaptation_rate, 0.01);
745        assert_eq!(config.target_percentile, 0.9);
746        assert_eq!(config.history_window, 100);
747    }
748
749    #[test]
750    fn test_gradient_processor_creation() {
751        let processor = GradientProcessor::with_defaults();
752        assert_eq!(processor.current_step, 0);
753        assert_eq!(processor.gradient_norm_history.len(), 0);
754    }
755
756    #[test]
757    fn test_gradient_norm_stats_empty() {
758        let processor = GradientProcessor::with_defaults();
759        assert!(processor.get_gradient_norm_stats().is_none());
760    }
761
762    #[test]
763    fn test_gradient_processor_reset() {
764        let mut processor = GradientProcessor::with_defaults();
765        processor.current_step = 10;
766        processor.gradient_norm_history.push(1.0);
767
768        processor.reset();
769
770        assert_eq!(processor.current_step, 0);
771        assert_eq!(processor.gradient_norm_history.len(), 0);
772        assert_eq!(processor.hessian_diagonal.len(), 0);
773        assert_eq!(processor.gradient_history.len(), 0);
774    }
775
776    #[test]
777    fn test_hessian_preconditioning_config_default() {
778        let config = HessianPreconditioningConfig::default();
779        assert!(matches!(
780            config.approximation_type,
781            HessianApproximationType::Diagonal
782        ));
783        assert_eq!(config.damping, 1e-4);
784        assert_eq!(config.update_frequency, 10);
785        assert_eq!(config.history_window, 20);
786        assert_eq!(config.min_eigenvalue, 1e-8);
787        assert_eq!(config.max_condition_number, 1e6);
788    }
789
790    #[test]
791    fn test_hessian_preconditioning_enabled() {
792        let mut config = GradientProcessingConfig::default();
793        config.enable_hessian_preconditioning = true;
794
795        let processor = GradientProcessor::new(config);
796        assert!(processor.config.enable_hessian_preconditioning);
797    }
798
799    #[test]
800    fn test_hessian_approximation_types() {
801        let mut config = GradientProcessingConfig::default();
802        config.enable_hessian_preconditioning = true;
803
804        // Test different approximation types
805        config.hessian_preconditioning.approximation_type = HessianApproximationType::Diagonal;
806        let processor = GradientProcessor::new(config.clone());
807        assert!(matches!(
808            processor.config.hessian_preconditioning.approximation_type,
809            HessianApproximationType::Diagonal
810        ));
811
812        config.hessian_preconditioning.approximation_type = HessianApproximationType::GaussNewton;
813        let processor = GradientProcessor::new(config.clone());
814        assert!(matches!(
815            processor.config.hessian_preconditioning.approximation_type,
816            HessianApproximationType::GaussNewton
817        ));
818
819        config.hessian_preconditioning.approximation_type =
820            HessianApproximationType::FisherInformation;
821        let processor = GradientProcessor::new(config.clone());
822        assert!(matches!(
823            processor.config.hessian_preconditioning.approximation_type,
824            HessianApproximationType::FisherInformation
825        ));
826
827        config.hessian_preconditioning.approximation_type = HessianApproximationType::QuasiNewton;
828        let processor = GradientProcessor::new(config.clone());
829        assert!(matches!(
830            processor.config.hessian_preconditioning.approximation_type,
831            HessianApproximationType::QuasiNewton
832        ));
833    }
834}