1use anyhow::{anyhow, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use trustformers_core::tensor::Tensor;
19
20#[derive(Debug, Clone, Serialize, Deserialize, Default)]
22pub struct GradientProcessingConfig {
23 pub enable_centralization: bool,
25 pub enable_standardization: bool,
27 pub enable_adaptive_clipping: bool,
29 pub enable_noise_injection: bool,
31 pub enable_smoothing: bool,
33 pub enable_hessian_preconditioning: bool,
35 pub adaptive_clipping: AdaptiveClippingConfig,
37 pub noise_injection: NoiseInjectionConfig,
39 pub smoothing: SmoothingConfig,
41 pub hessian_preconditioning: HessianPreconditioningConfig,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct AdaptiveClippingConfig {
48 pub initial_clip_norm: f32,
50 pub min_clip_norm: f32,
52 pub max_clip_norm: f32,
54 pub adaptation_rate: f32,
56 pub target_percentile: f32,
58 pub history_window: usize,
60}
61
62impl Default for AdaptiveClippingConfig {
63 fn default() -> Self {
64 Self {
65 initial_clip_norm: 1.0,
66 min_clip_norm: 0.1,
67 max_clip_norm: 10.0,
68 adaptation_rate: 0.01,
69 target_percentile: 0.9,
70 history_window: 100,
71 }
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct NoiseInjectionConfig {
78 pub initial_noise_scale: f32,
80 pub decay_rate: f32,
82 pub min_noise_scale: f32,
84 pub noise_type: NoiseType,
86}
87
88impl Default for NoiseInjectionConfig {
89 fn default() -> Self {
90 Self {
91 initial_noise_scale: 0.1,
92 decay_rate: 0.999,
93 min_noise_scale: 1e-6,
94 noise_type: NoiseType::Gaussian,
95 }
96 }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct SmoothingConfig {
102 pub decay: f32,
104 pub debias: bool,
106}
107
108impl Default for SmoothingConfig {
109 fn default() -> Self {
110 Self {
111 decay: 0.9,
112 debias: true,
113 }
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct HessianPreconditioningConfig {
120 pub approximation_type: HessianApproximationType,
122 pub damping: f32,
124 pub update_frequency: usize,
126 pub history_window: usize,
128 pub min_eigenvalue: f32,
130 pub max_condition_number: f32,
132}
133
134impl Default for HessianPreconditioningConfig {
135 fn default() -> Self {
136 Self {
137 approximation_type: HessianApproximationType::Diagonal,
138 damping: 1e-4,
139 update_frequency: 10,
140 history_window: 20,
141 min_eigenvalue: 1e-8,
142 max_condition_number: 1e6,
143 }
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub enum NoiseType {
150 Gaussian,
151 Uniform,
152 Laplace,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub enum HessianApproximationType {
158 Diagonal,
160 GaussNewton,
162 FisherInformation,
164 QuasiNewton,
166}
167
168#[derive(Debug)]
170pub struct GradientProcessor {
171 config: GradientProcessingConfig,
172 current_step: usize,
173
174 gradient_norm_history: Vec<f32>,
176 current_clip_norm: f32,
177
178 current_noise_scale: f32,
180
181 smoothed_gradients: HashMap<usize, Tensor>,
183 smoothing_bias_correction: f32,
184
185 hessian_diagonal: HashMap<usize, Tensor>,
187 hessian_inverse: HashMap<usize, Tensor>,
188 last_hessian_update: usize,
189 gradient_history: Vec<Vec<Tensor>>,
190}
191
192impl GradientProcessor {
193 pub fn new(config: GradientProcessingConfig) -> Self {
195 Self {
196 current_clip_norm: config.adaptive_clipping.initial_clip_norm,
197 current_noise_scale: config.noise_injection.initial_noise_scale,
198 config,
199 current_step: 0,
200 gradient_norm_history: Vec::new(),
201 smoothed_gradients: HashMap::new(),
202 smoothing_bias_correction: 1.0,
203 hessian_diagonal: HashMap::new(),
204 hessian_inverse: HashMap::new(),
205 last_hessian_update: 0,
206 gradient_history: Vec::new(),
207 }
208 }
209
210 pub fn with_defaults() -> Self {
212 Self::new(GradientProcessingConfig::default())
213 }
214
215 pub fn process_gradients(&mut self, gradients: &mut [Tensor]) -> Result<()> {
217 self.current_step += 1;
218
219 if self.config.enable_centralization {
221 self.apply_centralization(gradients)?;
222 }
223
224 if self.config.enable_standardization {
226 self.apply_standardization(gradients)?;
227 }
228
229 if self.config.enable_smoothing {
231 self.apply_smoothing(gradients)?;
232 }
233
234 if self.config.enable_hessian_preconditioning {
236 self.apply_hessian_preconditioning(gradients)?;
237 }
238
239 if self.config.enable_adaptive_clipping {
241 self.apply_adaptive_clipping(gradients)?;
242 }
243
244 if self.config.enable_noise_injection {
246 self.apply_noise_injection(gradients)?;
247 }
248
249 Ok(())
250 }
251
252 fn apply_centralization(&self, gradients: &mut [Tensor]) -> Result<()> {
254 for gradient in gradients.iter_mut() {
255 let mean = gradient.mean()?;
257 *gradient = gradient.sub(&mean)?;
258 }
259 Ok(())
260 }
261
262 fn apply_standardization(&self, gradients: &mut [Tensor]) -> Result<()> {
264 for gradient in gradients.iter_mut() {
265 let mean = gradient.mean()?;
267 let centered = gradient.sub(&mean)?;
268 let squared = centered.mul(¢ered)?;
269 let variance = squared.mean()?;
270 let std_dev = variance.sqrt()?;
271
272 let epsilon = Tensor::scalar(1e-8)?;
274 let std_dev_safe = std_dev.add(&epsilon)?;
275
276 *gradient = gradient.div(&std_dev_safe)?;
278 }
279 Ok(())
280 }
281
282 fn apply_adaptive_clipping(&mut self, gradients: &mut [Tensor]) -> Result<()> {
284 let mut total_norm_sq = 0.0;
286 for gradient in gradients.iter() {
287 let norm_sq = gradient.norm_squared()?.to_scalar()?;
288 total_norm_sq += norm_sq;
289 }
290 let total_norm = total_norm_sq.sqrt();
291
292 self.gradient_norm_history.push(total_norm);
294 if self.gradient_norm_history.len() > self.config.adaptive_clipping.history_window {
295 self.gradient_norm_history.remove(0);
296 }
297
298 if self.gradient_norm_history.len() >= 10 {
300 let mut sorted_norms = self.gradient_norm_history.clone();
302 sorted_norms.sort_by(|a, b| a.partial_cmp(b).unwrap());
303 let percentile_idx = (sorted_norms.len() as f32
304 * self.config.adaptive_clipping.target_percentile)
305 as usize;
306 let target_norm = sorted_norms[percentile_idx.min(sorted_norms.len() - 1)];
307
308 let adaptation = self.config.adaptive_clipping.adaptation_rate
310 * (target_norm - self.current_clip_norm);
311 self.current_clip_norm += adaptation;
312
313 self.current_clip_norm = self
315 .current_clip_norm
316 .max(self.config.adaptive_clipping.min_clip_norm)
317 .min(self.config.adaptive_clipping.max_clip_norm);
318 }
319
320 if total_norm > self.current_clip_norm {
322 let clip_factor = self.current_clip_norm / total_norm;
323 for gradient in gradients.iter_mut() {
324 *gradient = gradient.mul_scalar(clip_factor)?;
325 }
326 }
327
328 Ok(())
329 }
330
331 fn apply_noise_injection(&mut self, gradients: &mut [Tensor]) -> Result<()> {
333 self.current_noise_scale *= self.config.noise_injection.decay_rate;
335 self.current_noise_scale =
336 self.current_noise_scale.max(self.config.noise_injection.min_noise_scale);
337
338 for gradient in gradients.iter_mut() {
339 let noise = match self.config.noise_injection.noise_type {
340 NoiseType::Gaussian => {
341 let noise_tensor = Tensor::randn(&gradient.shape())?;
342 noise_tensor.mul_scalar(self.current_noise_scale)?;
343 noise_tensor
344 },
345 NoiseType::Uniform => {
346 let bound = self.current_noise_scale * 3.0_f32.sqrt(); let noise_tensor = Tensor::randn(&gradient.shape())?;
348 noise_tensor.mul_scalar(bound)?;
349 noise_tensor
350 },
351 NoiseType::Laplace => {
352 let noise_tensor = Tensor::randn(&gradient.shape())?;
354 noise_tensor.mul_scalar(self.current_noise_scale * 2.0_f32.sqrt())?;
355 noise_tensor
356 },
357 };
358
359 *gradient = gradient.add(&noise)?;
360 }
361
362 Ok(())
363 }
364
365 fn apply_smoothing(&mut self, gradients: &mut [Tensor]) -> Result<()> {
367 let decay = self.config.smoothing.decay;
368
369 for (i, gradient) in gradients.iter_mut().enumerate() {
370 if let Some(smoothed) = self.smoothed_gradients.get(&i) {
371 let new_smoothed =
373 smoothed.mul_scalar(decay)?.add(&gradient.mul_scalar(1.0 - decay)?)?;
374 self.smoothed_gradients.insert(i, new_smoothed.clone());
375
376 if self.config.smoothing.debias {
378 self.smoothing_bias_correction *= decay;
379 let bias_corrected =
380 new_smoothed.div_scalar(1.0 - self.smoothing_bias_correction)?;
381 *gradient = bias_corrected;
382 } else {
383 *gradient = new_smoothed;
384 }
385 } else {
386 self.smoothed_gradients.insert(i, gradient.clone());
388 }
389 }
390
391 Ok(())
392 }
393
394 fn apply_hessian_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
396 self.gradient_history.push(gradients.to_vec());
398 if self.gradient_history.len() > self.config.hessian_preconditioning.history_window {
399 self.gradient_history.remove(0);
400 }
401
402 if self.current_step - self.last_hessian_update
404 >= self.config.hessian_preconditioning.update_frequency
405 {
406 self.update_hessian_approximation(gradients)?;
407 self.last_hessian_update = self.current_step;
408 }
409
410 match self.config.hessian_preconditioning.approximation_type {
412 HessianApproximationType::Diagonal => {
413 self.apply_diagonal_preconditioning(gradients)?;
414 },
415 HessianApproximationType::GaussNewton => {
416 self.apply_gauss_newton_preconditioning(gradients)?;
417 },
418 HessianApproximationType::FisherInformation => {
419 self.apply_fisher_information_preconditioning(gradients)?;
420 },
421 HessianApproximationType::QuasiNewton => {
422 self.apply_quasi_newton_preconditioning(gradients)?;
423 },
424 }
425
426 Ok(())
427 }
428
429 fn update_hessian_approximation(&mut self, gradients: &[Tensor]) -> Result<()> {
431 match self.config.hessian_preconditioning.approximation_type {
432 HessianApproximationType::Diagonal => {
433 self.update_diagonal_hessian(gradients)?;
434 },
435 HessianApproximationType::GaussNewton => {
436 self.update_gauss_newton_hessian(gradients)?;
437 },
438 HessianApproximationType::FisherInformation => {
439 self.update_fisher_information_hessian(gradients)?;
440 },
441 HessianApproximationType::QuasiNewton => {
442 self.update_quasi_newton_hessian(gradients)?;
443 },
444 }
445 Ok(())
446 }
447
448 fn update_diagonal_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
450 for (i, gradient) in gradients.iter().enumerate() {
451 if self.gradient_history.len() > 1 {
453 let mut variance = Tensor::zeros(&gradient.shape())?;
454 let mut mean = Tensor::zeros(&gradient.shape())?;
455
456 for grad_vec in &self.gradient_history {
458 if let Some(hist_grad) = grad_vec.get(i) {
459 mean = mean.add(hist_grad)?;
460 }
461 }
462 mean = mean.div_scalar(self.gradient_history.len() as f32)?;
463
464 for grad_vec in &self.gradient_history {
466 if let Some(hist_grad) = grad_vec.get(i) {
467 let diff = hist_grad.sub(&mean)?;
468 variance = variance.add(&diff.mul(&diff)?)?;
469 }
470 }
471 variance = variance.div_scalar(self.gradient_history.len() as f32)?;
472
473 let damping = Tensor::ones(&gradient.shape())?
475 .mul_scalar(self.config.hessian_preconditioning.damping)?;
476 variance = variance.add(&damping)?;
477
478 self.hessian_diagonal.insert(i, variance);
479 }
480 }
481 Ok(())
482 }
483
484 fn update_gauss_newton_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
486 for (i, gradient) in gradients.iter().enumerate() {
488 let outer_product = gradient.mul(gradient)?;
490
491 let damping = Tensor::ones(&gradient.shape())?
493 .mul_scalar(self.config.hessian_preconditioning.damping)?;
494 let hessian_approx = outer_product.add(&damping)?;
495
496 self.hessian_diagonal.insert(i, hessian_approx);
497 }
498 Ok(())
499 }
500
501 fn update_fisher_information_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
503 for (i, gradient) in gradients.iter().enumerate() {
505 let fisher_approx = gradient.mul(gradient)?;
507
508 let damping = Tensor::ones(&gradient.shape())?
510 .mul_scalar(self.config.hessian_preconditioning.damping)?;
511 let hessian_approx = fisher_approx.add(&damping)?;
512
513 self.hessian_diagonal.insert(i, hessian_approx);
514 }
515 Ok(())
516 }
517
518 fn update_quasi_newton_hessian(&mut self, gradients: &[Tensor]) -> Result<()> {
520 if self.gradient_history.len() > 1 {
522 for (i, gradient) in gradients.iter().enumerate() {
523 if let Some(prev_grad_vec) =
525 self.gradient_history.get(self.gradient_history.len() - 2)
526 {
527 if let Some(prev_grad) = prev_grad_vec.get(i) {
528 let grad_diff = gradient.sub(prev_grad)?;
530
531 let hessian_approx = grad_diff.abs()?;
533
534 let damping = Tensor::ones(&gradient.shape())?
536 .mul_scalar(self.config.hessian_preconditioning.damping)?;
537 let final_hessian = hessian_approx.add(&damping)?;
538
539 self.hessian_diagonal.insert(i, final_hessian);
540 }
541 }
542 }
543 }
544 Ok(())
545 }
546
547 fn apply_diagonal_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
549 for (i, gradient) in gradients.iter_mut().enumerate() {
550 if let Some(hessian_diag) = self.hessian_diagonal.get(&i) {
551 let min_val = Tensor::scalar(self.config.hessian_preconditioning.min_eigenvalue)?;
554 let clamped_hessian = hessian_diag.max(&min_val)?;
555
556 *gradient = gradient.div(&clamped_hessian)?;
557 }
558 }
559 Ok(())
560 }
561
562 fn apply_gauss_newton_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
564 self.apply_diagonal_preconditioning(gradients)
566 }
567
568 fn apply_fisher_information_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
570 self.apply_diagonal_preconditioning(gradients)
572 }
573
574 fn apply_quasi_newton_preconditioning(&mut self, gradients: &mut [Tensor]) -> Result<()> {
576 self.apply_diagonal_preconditioning(gradients)
578 }
579
580 pub fn get_current_clip_norm(&self) -> f32 {
582 self.current_clip_norm
583 }
584
585 pub fn get_current_noise_scale(&self) -> f32 {
587 self.current_noise_scale
588 }
589
590 pub fn get_gradient_norm_stats(&self) -> Option<(f32, f32, f32)> {
592 if self.gradient_norm_history.is_empty() {
593 return None;
594 }
595
596 let sum: f32 = self.gradient_norm_history.iter().sum();
597 let mean = sum / self.gradient_norm_history.len() as f32;
598
599 let variance = self.gradient_norm_history.iter().map(|x| (x - mean).powi(2)).sum::<f32>()
600 / self.gradient_norm_history.len() as f32;
601 let std_dev = variance.sqrt();
602
603 let max_norm = self.gradient_norm_history.iter().fold(0.0f32, |acc, &x| acc.max(x));
604
605 Some((mean, std_dev, max_norm))
606 }
607
608 pub fn reset(&mut self) {
610 self.current_step = 0;
611 self.gradient_norm_history.clear();
612 self.smoothed_gradients.clear();
613 self.current_clip_norm = self.config.adaptive_clipping.initial_clip_norm;
614 self.current_noise_scale = self.config.noise_injection.initial_noise_scale;
615 self.smoothing_bias_correction = 1.0;
616 self.hessian_diagonal.clear();
617 self.hessian_inverse.clear();
618 self.last_hessian_update = 0;
619 self.gradient_history.clear();
620 }
621
622 pub fn set_config(&mut self, config: GradientProcessingConfig) {
624 self.config = config;
625 self.reset();
626 }
627
628 pub fn get_config(&self) -> &GradientProcessingConfig {
630 &self.config
631 }
632}
633
634pub struct GradientProcessedOptimizer<T> {
636 base_optimizer: T,
637 gradient_processor: GradientProcessor,
638}
639
640impl<T> GradientProcessedOptimizer<T> {
641 pub fn new(base_optimizer: T, config: GradientProcessingConfig) -> Self {
643 Self {
644 base_optimizer,
645 gradient_processor: GradientProcessor::new(config),
646 }
647 }
648
649 pub fn with_default_processing(base_optimizer: T) -> Self {
651 Self::new(base_optimizer, GradientProcessingConfig::default())
652 }
653
654 pub fn gradient_processor(&self) -> &GradientProcessor {
656 &self.gradient_processor
657 }
658
659 pub fn gradient_processor_mut(&mut self) -> &mut GradientProcessor {
661 &mut self.gradient_processor
662 }
663
664 pub fn base_optimizer(&self) -> &T {
666 &self.base_optimizer
667 }
668
669 pub fn base_optimizer_mut(&mut self) -> &mut T {
671 &mut self.base_optimizer
672 }
673}
674
675impl<T: crate::optimizer::OptimizerState> crate::optimizer::OptimizerState
676 for GradientProcessedOptimizer<T>
677{
678 fn zero_grad(&mut self) -> Result<()> {
679 self.base_optimizer.zero_grad()
680 }
681
682 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
683 let mut gradients = Vec::new();
685 for param in parameters.iter() {
686 if let Ok(grad) = param.grad() {
687 gradients.push(grad);
688 } else {
689 return Err(anyhow!("Parameter missing gradient"));
690 }
691 }
692
693 self.gradient_processor.process_gradients(&mut gradients)?;
695
696 for (param, processed_grad) in parameters.iter_mut().zip(gradients.iter()) {
698 param.set_grad(processed_grad.clone())?;
699 }
700
701 self.base_optimizer.step(parameters)
703 }
704
705 fn get_lr(&self) -> f32 {
706 self.base_optimizer.get_lr()
707 }
708
709 fn set_lr(&mut self, lr: f32) {
710 self.base_optimizer.set_lr(lr);
711 }
712
713 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
714 self.base_optimizer.state_dict()
717 }
718
719 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
720 self.base_optimizer.load_state_dict(state)
721 }
722}
723
724#[cfg(test)]
725mod tests {
726 use super::*;
727
728 #[test]
729 fn test_gradient_processing_config_default() {
730 let config = GradientProcessingConfig::default();
731 assert!(!config.enable_centralization);
732 assert!(!config.enable_standardization);
733 assert!(!config.enable_adaptive_clipping);
734 assert!(!config.enable_noise_injection);
735 assert!(!config.enable_smoothing);
736 }
737
738 #[test]
739 fn test_adaptive_clipping_config_default() {
740 let config = AdaptiveClippingConfig::default();
741 assert_eq!(config.initial_clip_norm, 1.0);
742 assert_eq!(config.min_clip_norm, 0.1);
743 assert_eq!(config.max_clip_norm, 10.0);
744 assert_eq!(config.adaptation_rate, 0.01);
745 assert_eq!(config.target_percentile, 0.9);
746 assert_eq!(config.history_window, 100);
747 }
748
749 #[test]
750 fn test_gradient_processor_creation() {
751 let processor = GradientProcessor::with_defaults();
752 assert_eq!(processor.current_step, 0);
753 assert_eq!(processor.gradient_norm_history.len(), 0);
754 }
755
756 #[test]
757 fn test_gradient_norm_stats_empty() {
758 let processor = GradientProcessor::with_defaults();
759 assert!(processor.get_gradient_norm_stats().is_none());
760 }
761
762 #[test]
763 fn test_gradient_processor_reset() {
764 let mut processor = GradientProcessor::with_defaults();
765 processor.current_step = 10;
766 processor.gradient_norm_history.push(1.0);
767
768 processor.reset();
769
770 assert_eq!(processor.current_step, 0);
771 assert_eq!(processor.gradient_norm_history.len(), 0);
772 assert_eq!(processor.hessian_diagonal.len(), 0);
773 assert_eq!(processor.gradient_history.len(), 0);
774 }
775
776 #[test]
777 fn test_hessian_preconditioning_config_default() {
778 let config = HessianPreconditioningConfig::default();
779 assert!(matches!(
780 config.approximation_type,
781 HessianApproximationType::Diagonal
782 ));
783 assert_eq!(config.damping, 1e-4);
784 assert_eq!(config.update_frequency, 10);
785 assert_eq!(config.history_window, 20);
786 assert_eq!(config.min_eigenvalue, 1e-8);
787 assert_eq!(config.max_condition_number, 1e6);
788 }
789
790 #[test]
791 fn test_hessian_preconditioning_enabled() {
792 let mut config = GradientProcessingConfig::default();
793 config.enable_hessian_preconditioning = true;
794
795 let processor = GradientProcessor::new(config);
796 assert!(processor.config.enable_hessian_preconditioning);
797 }
798
799 #[test]
800 fn test_hessian_approximation_types() {
801 let mut config = GradientProcessingConfig::default();
802 config.enable_hessian_preconditioning = true;
803
804 config.hessian_preconditioning.approximation_type = HessianApproximationType::Diagonal;
806 let processor = GradientProcessor::new(config.clone());
807 assert!(matches!(
808 processor.config.hessian_preconditioning.approximation_type,
809 HessianApproximationType::Diagonal
810 ));
811
812 config.hessian_preconditioning.approximation_type = HessianApproximationType::GaussNewton;
813 let processor = GradientProcessor::new(config.clone());
814 assert!(matches!(
815 processor.config.hessian_preconditioning.approximation_type,
816 HessianApproximationType::GaussNewton
817 ));
818
819 config.hessian_preconditioning.approximation_type =
820 HessianApproximationType::FisherInformation;
821 let processor = GradientProcessor::new(config.clone());
822 assert!(matches!(
823 processor.config.hessian_preconditioning.approximation_type,
824 HessianApproximationType::FisherInformation
825 ));
826
827 config.hessian_preconditioning.approximation_type = HessianApproximationType::QuasiNewton;
828 let processor = GradientProcessor::new(config.clone());
829 assert!(matches!(
830 processor.config.hessian_preconditioning.approximation_type,
831 HessianApproximationType::QuasiNewton
832 ));
833 }
834}