1use ndarray::{Array, Dimension, ScalarOperand};
7use num_traits::Float;
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11
12#[derive(Debug, Clone)]
14pub struct GradientClipConfig<A: Float> {
15 pub max_value: Option<A>,
17 pub min_value: Option<A>,
19 pub max_norm: Option<A>,
21 pub max_l1_norm: Option<A>,
23 pub centralization: bool,
25 pub zero_threshold: Option<A>,
27}
28
29impl<A: Float> Default for GradientClipConfig<A> {
30 fn default() -> Self {
31 Self {
32 max_value: None,
33 min_value: None,
34 max_norm: None,
35 max_l1_norm: None,
36 centralization: false,
37 zero_threshold: None,
38 }
39 }
40}
41
42pub struct GradientProcessor<A: Float> {
44 config: GradientClipConfig<A>,
45}
46
47impl<A: Float + ScalarOperand + Debug> Default for GradientProcessor<A> {
48 fn default() -> Self {
49 Self {
50 config: GradientClipConfig::default(),
51 }
52 }
53}
54
55impl<A: Float + ScalarOperand + Debug> GradientProcessor<A> {
56 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn with_config(config: GradientClipConfig<A>) -> Self {
63 Self { config }
64 }
65
66 pub fn set_max_value(&mut self, value: A) -> &mut Self {
68 self.config.max_value = Some(value);
69 self
70 }
71
72 pub fn set_min_value(&mut self, value: A) -> &mut Self {
74 self.config.min_value = Some(value);
75 self
76 }
77
78 pub fn set_max_norm(&mut self, value: A) -> &mut Self {
80 self.config.max_norm = Some(value);
81 self
82 }
83
84 pub fn set_max_l1_norm(&mut self, value: A) -> &mut Self {
86 self.config.max_l1_norm = Some(value);
87 self
88 }
89
90 pub fn set_centralization(&mut self, enabled: bool) -> &mut Self {
92 self.config.centralization = enabled;
93 self
94 }
95
96 pub fn set_zero_threshold(&mut self, value: A) -> &mut Self {
98 self.config.zero_threshold = Some(value);
99 self
100 }
101
102 pub fn set_value_clip(&mut self, min: A, max: A) -> &mut Self {
104 self.config.min_value = Some(min);
105 self.config.max_value = Some(max);
106 self
107 }
108
109 pub fn set_norm_clip(&mut self, max_norm: A) -> &mut Self {
111 self.config.max_norm = Some(max_norm);
112 self
113 }
114
115 pub fn set_l1_norm_clip(&mut self, max_l1_norm: A) -> &mut Self {
117 self.config.max_l1_norm = Some(max_l1_norm);
118 self
119 }
120
121 pub fn enable_centralization(&mut self) -> &mut Self {
123 self.config.centralization = true;
124 self
125 }
126
127 pub fn process<D: Dimension>(&self, gradients: &mut Array<A, D>) -> Result<()> {
129 if let (Some(min), Some(max)) = (self.config.min_value, self.config.max_value) {
131 clip_gradients_by_value(gradients, min, max);
132 }
133
134 if let Some(max_norm) = self.config.max_norm {
136 clip_gradient_norm(gradients, max_norm)?;
137 }
138
139 if let Some(max_l1_norm) = self.config.max_l1_norm {
141 clip_gradient_l1_norm(gradients, max_l1_norm)?;
142 }
143
144 if self.config.centralization {
146 gradient_centralization(gradients);
147 }
148
149 if let Some(threshold) = self.config.zero_threshold {
151 zero_small_gradients(gradients, threshold);
152 }
153
154 Ok(())
155 }
156}
157
158pub fn clip_gradients_by_value<A, D>(
160 gradients: &mut Array<A, D>,
161 min_value: A,
162 max_value: A,
163) -> &mut Array<A, D>
164where
165 A: Float + ScalarOperand,
166 D: Dimension,
167{
168 gradients.mapv_inplace(|x| {
169 if x < min_value {
170 min_value
171 } else if x > max_value {
172 max_value
173 } else {
174 x
175 }
176 });
177 gradients
178}
179
180pub fn clip_gradient_norm<A, D>(
182 gradients: &mut Array<A, D>,
183 max_norm: A,
184) -> Result<&mut Array<A, D>>
185where
186 A: Float + ScalarOperand,
187 D: Dimension,
188{
189 if max_norm <= A::zero() {
190 return Err(OptimError::InvalidConfig(
191 "max_norm must be positive".to_string(),
192 ));
193 }
194
195 let norm = gradients
197 .iter()
198 .fold(A::zero(), |acc, &x| acc + x * x)
199 .sqrt();
200
201 if norm > max_norm {
203 let scale = max_norm / norm;
204 gradients.mapv_inplace(|x| x * scale);
205 }
206
207 Ok(gradients)
208}
209
210pub fn clip_gradient_l1_norm<A, D>(
212 gradients: &mut Array<A, D>,
213 max_l1_norm: A,
214) -> Result<&mut Array<A, D>>
215where
216 A: Float + ScalarOperand,
217 D: Dimension,
218{
219 if max_l1_norm <= A::zero() {
220 return Err(OptimError::InvalidConfig(
221 "max_l1_norm must be positive".to_string(),
222 ));
223 }
224
225 let l1_norm = gradients.iter().fold(A::zero(), |acc, &x| acc + x.abs());
227
228 if l1_norm > max_l1_norm {
230 let scale = max_l1_norm / l1_norm;
231 gradients.mapv_inplace(|x| x * scale);
232 }
233
234 Ok(gradients)
235}
236
237pub fn gradient_centralization<A, D>(gradients: &mut Array<A, D>) -> &mut Array<A, D>
239where
240 A: Float + ScalarOperand,
241 D: Dimension,
242{
243 let sum = gradients.iter().fold(A::zero(), |acc, &x| acc + x);
245 let mean = sum / A::from(gradients.len()).unwrap_or(A::one());
246
247 gradients.mapv_inplace(|x| x - mean);
249
250 gradients
251}
252
253pub fn zero_small_gradients<A, D>(gradients: &mut Array<A, D>, threshold: A) -> &mut Array<A, D>
255where
256 A: Float + ScalarOperand,
257 D: Dimension,
258{
259 let abs_threshold = threshold.abs();
260
261 gradients.mapv_inplace(|x| {
262 if x.abs() < abs_threshold {
263 A::zero()
264 } else {
265 x
266 }
267 });
268
269 gradients
270}
271
272#[derive(Debug, Clone)]
274pub struct GradientAccumulator<A: Float, D: Dimension> {
275 accumulated_gradients: Option<Array<A, D>>,
277 num_accumulated: usize,
279 accumulation_steps: usize,
281 average_gradients: bool,
283}
284
285impl<A: Float + ScalarOperand + Debug, D: Dimension> GradientAccumulator<A, D> {
286 pub fn new(accumulation_steps: usize, average_gradients: bool) -> Self {
293 Self {
294 accumulated_gradients: None,
295 num_accumulated: 0,
296 accumulation_steps,
297 average_gradients,
298 }
299 }
300
301 pub fn accumulate(&mut self, gradients: &Array<A, D>) -> bool {
311 if self.accumulated_gradients.is_none() {
312 self.accumulated_gradients = Some(gradients.clone());
313 } else {
314 let acc = self.accumulated_gradients.as_mut().unwrap();
315 for (acc_val, &grad_val) in acc.iter_mut().zip(gradients.iter()) {
316 *acc_val = *acc_val + grad_val;
317 }
318 }
319
320 self.num_accumulated += 1;
321 self.num_accumulated >= self.accumulation_steps
322 }
323
324 pub fn get_and_reset(&mut self) -> Option<Array<A, D>> {
330 if let Some(mut gradients) = self.accumulated_gradients.take() {
331 if self.average_gradients && self.num_accumulated > 0 {
332 let scale = A::one() / A::from(self.num_accumulated).unwrap_or(A::one());
333 gradients.mapv_inplace(|x| x * scale);
334 }
335 self.num_accumulated = 0;
336 Some(gradients)
337 } else {
338 None
339 }
340 }
341
342 pub fn progress(&self) -> (usize, usize) {
344 (self.num_accumulated, self.accumulation_steps)
345 }
346
347 pub fn is_ready(&self) -> bool {
349 self.num_accumulated >= self.accumulation_steps
350 }
351
352 pub fn reset(&mut self) {
354 self.accumulated_gradients = None;
355 self.num_accumulated = 0;
356 }
357
358 pub fn set_accumulation_steps(&mut self, steps: usize) {
360 self.accumulation_steps = steps;
361 }
362}
363
364pub fn adaptive_gradient_clipping<'a, A, D>(
369 gradients: &'a mut Array<A, D>,
370 parameters: &Array<A, D>,
371 max_ratio: A,
372) -> Result<&'a mut Array<A, D>>
373where
374 A: Float + ScalarOperand,
375 D: Dimension,
376{
377 if max_ratio <= A::zero() {
378 return Err(OptimError::InvalidConfig(
379 "max_ratio must be positive".to_string(),
380 ));
381 }
382
383 let grad_norm = gradients
384 .iter()
385 .fold(A::zero(), |acc, &x| acc + x * x)
386 .sqrt();
387
388 let param_norm = parameters
389 .iter()
390 .fold(A::zero(), |acc, &x| acc + x * x)
391 .sqrt();
392
393 if param_norm > A::zero() && grad_norm > A::zero() {
394 let ratio = grad_norm / param_norm;
395 if ratio > max_ratio {
396 let scale = max_ratio / ratio;
397 gradients.mapv_inplace(|x| x * scale);
398 }
399 }
400
401 Ok(gradients)
402}
403
404pub fn add_gradient_noise<A, D>(
412 gradients: &mut Array<A, D>,
413 noise_std: A,
414 seed: Option<u64>,
415) -> &mut Array<A, D>
416where
417 A: Float + ScalarOperand,
418 D: Dimension,
419{
420 use ndarray_rand::rand::SeedableRng;
421 use ndarray_rand::rand_distr::Normal;
422 use ndarray_rand::RandomExt;
423
424 if noise_std <= A::zero() {
425 return gradients;
426 }
427
428 let mut rng = if let Some(s) = seed {
429 ndarray_rand::rand::rngs::StdRng::seed_from_u64(s)
430 } else {
431 ndarray_rand::rand::rngs::StdRng::from_entropy()
432 };
433
434 let normal = Normal::new(0.0, noise_std.to_f64().unwrap_or(0.01)).unwrap();
435 let noise = Array::random_using(gradients.raw_dim(), normal, &mut rng);
436
437 gradients.zip_mut_with(&noise, |g, &n| {
438 *g = *g + A::from(n).unwrap_or(A::zero());
439 });
440
441 gradients
442}
443
444#[derive(Debug, Clone)]
448pub struct GradientMask<A: Float, D: Dimension> {
449 mask: Array<bool, D>,
451 lr_multipliers: Option<Array<A, D>>,
453}
454
455impl<A: Float + ScalarOperand + Debug, D: Dimension> GradientMask<A, D> {
456 pub fn new(mask: Array<bool, D>) -> Self {
462 Self {
463 mask,
464 lr_multipliers: None,
465 }
466 }
467
468 pub fn freeze_all(shape: D) -> Self {
470 Self {
471 mask: Array::from_elem(shape, false),
472 lr_multipliers: None,
473 }
474 }
475
476 pub fn update_all(shape: D) -> Self {
478 Self {
479 mask: Array::from_elem(shape, true),
480 lr_multipliers: None,
481 }
482 }
483
484 pub fn with_lr_multipliers(mut self, multipliers: Array<A, D>) -> Self {
486 self.lr_multipliers = Some(multipliers);
487 self
488 }
489
490 pub fn apply_mask<'a>(&self, gradients: &'a mut Array<A, D>) -> &'a mut Array<A, D> {
500 gradients.zip_mut_with(&self.mask, |grad, &should_update| {
501 if !should_update {
502 *grad = A::zero();
503 }
504 });
505
506 if let Some(multipliers) = &self.lr_multipliers {
508 gradients.zip_mut_with(multipliers, |grad, &mult| {
509 *grad = *grad * mult;
510 });
511 }
512
513 gradients
514 }
515
516 pub fn freeze_indices(&mut self, indices: &[usize]) -> Result<()> {
518 let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
519 OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
520 })?;
521
522 for &idx in indices {
523 if idx < flat_mask.len() {
524 flat_mask[idx] = false;
525 } else {
526 return Err(OptimError::InvalidConfig(format!(
527 "Index {} out of bounds for mask of size {}",
528 idx,
529 flat_mask.len()
530 )));
531 }
532 }
533 Ok(())
534 }
535
536 pub fn unfreeze_indices(&mut self, indices: &[usize]) -> Result<()> {
538 let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
539 OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
540 })?;
541
542 for &idx in indices {
543 if idx < flat_mask.len() {
544 flat_mask[idx] = true;
545 } else {
546 return Err(OptimError::InvalidConfig(format!(
547 "Index {} out of bounds for mask of size {}",
548 idx,
549 flat_mask.len()
550 )));
551 }
552 }
553 Ok(())
554 }
555
556 pub fn num_frozen(&self) -> usize {
558 self.mask.iter().filter(|&&x| !x).count()
559 }
560
561 pub fn num_active(&self) -> usize {
563 self.mask.iter().filter(|&&x| x).count()
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use approx::assert_relative_eq;
571 use ndarray::Array1;
572
573 #[test]
574 fn test_gradient_processor() {
575 let config = GradientClipConfig::<f64> {
576 max_value: Some(5.0),
577 min_value: Some(-5.0),
578 max_norm: Some(10.0),
579 ..Default::default()
580 };
581
582 let processor = GradientProcessor::with_config(config);
583
584 let mut gradients = Array1::from_vec(vec![-8.0, 3.0, 7.0, -2.0, 6.0]);
585 processor.process(&mut gradients).unwrap();
586
587 assert_eq!(gradients[0], -5.0);
589 assert_eq!(gradients[2], 5.0);
590 assert_eq!(gradients[4], 5.0);
591 }
592
593 #[test]
594 fn test_adaptive_clipping() {
595 let mut gradients = Array1::from_vec(vec![3.0, 4.0]); let parameters = Array1::from_vec(vec![1.0, 0.0]); adaptive_gradient_clipping(&mut gradients, ¶meters, 2.0).unwrap();
600
601 let new_grad_norm = gradients.iter().fold(0.0, |acc, &x| acc + x * x).sqrt();
603 assert!((new_grad_norm - 2.0).abs() < 1e-6);
604 }
605
606 #[test]
607 fn test_gradient_accumulator() {
608 let mut accumulator = GradientAccumulator::new(3, true);
609
610 let grad1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
612 assert!(!accumulator.accumulate(&grad1));
613 assert_eq!(accumulator.progress(), (1, 3));
614
615 let grad2 = Array1::from_vec(vec![2.0, 3.0, 4.0]);
617 assert!(!accumulator.accumulate(&grad2));
618 assert_eq!(accumulator.progress(), (2, 3));
619
620 let grad3 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
622 assert!(accumulator.accumulate(&grad3));
623 assert!(accumulator.is_ready());
624
625 let final_grads = accumulator.get_and_reset().unwrap();
627 assert_relative_eq!(final_grads[0], 2.0, epsilon = 1e-6); assert_relative_eq!(final_grads[1], 3.0, epsilon = 1e-6); assert_relative_eq!(final_grads[2], 4.0, epsilon = 1e-6); assert_eq!(accumulator.progress(), (0, 3));
633 assert!(!accumulator.is_ready());
634 }
635
636 #[test]
637 fn test_gradient_accumulator_sum_mode() {
638 let mut accumulator = GradientAccumulator::new(2, false); let grad1 = Array1::from_vec(vec![1.0, 2.0]);
641 let grad2 = Array1::from_vec(vec![3.0, 4.0]);
642
643 accumulator.accumulate(&grad1);
644 accumulator.accumulate(&grad2);
645
646 let final_grads = accumulator.get_and_reset().unwrap();
647 assert_relative_eq!(final_grads[0], 4.0, epsilon = 1e-6); assert_relative_eq!(final_grads[1], 6.0, epsilon = 1e-6); }
650
651 #[test]
652 fn test_gradient_noise() {
653 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
654 let original = gradients.clone();
655
656 add_gradient_noise(&mut gradients, 0.1, Some(42));
658
659 for (i, (&orig, &noisy)) in original.iter().zip(gradients.iter()).enumerate() {
661 assert!(
662 (orig - noisy).abs() < 1.0,
663 "Index {}: {} vs {}",
664 i,
665 orig,
666 noisy
667 );
668 }
669 }
670
671 #[test]
672 fn test_gradient_noise_zero_std() {
673 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
674 let original = gradients.clone();
675
676 add_gradient_noise(&mut gradients, 0.0, Some(42));
678
679 for (orig, noisy) in original.iter().zip(gradients.iter()) {
680 assert_relative_eq!(*orig, *noisy, epsilon = 1e-10);
681 }
682 }
683
684 #[test]
685 fn test_gradient_mask_creation() {
686 let mask = Array1::from_vec(vec![true, false, true]);
687 let grad_mask: GradientMask<f64, ndarray::Ix1> = GradientMask::new(mask);
688
689 assert_eq!(grad_mask.num_active(), 2);
690 assert_eq!(grad_mask.num_frozen(), 1);
691 }
692
693 #[test]
694 fn test_gradient_mask_apply() {
695 let mask = Array1::from_vec(vec![true, false, true]);
696 let grad_mask: GradientMask<f64, ndarray::Ix1> = GradientMask::new(mask);
697 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
698
699 grad_mask.apply_mask(&mut gradients);
700
701 assert_eq!(gradients.as_slice().unwrap(), &[1.0, 0.0, 3.0]);
702 }
703
704 #[test]
705 fn test_gradient_mask_freeze_unfreeze() {
706 let mask = Array1::from_vec(vec![true, true, true]);
707 let mut grad_mask: GradientMask<f64, ndarray::Ix1> = GradientMask::new(mask);
708
709 grad_mask.freeze_indices(&[0, 2]).unwrap();
711 assert_eq!(grad_mask.num_frozen(), 2);
712 assert_eq!(grad_mask.num_active(), 1);
713
714 grad_mask.unfreeze_indices(&[0]).unwrap();
716 assert_eq!(grad_mask.num_frozen(), 1);
717 assert_eq!(grad_mask.num_active(), 2);
718 }
719
720 #[test]
721 fn test_gradient_mask_with_lr_multipliers() {
722 let mask = Array1::from_vec(vec![true, true, true]);
723 let multipliers = Array1::from_vec(vec![1.0, 0.5, 2.0]);
724 let grad_mask: GradientMask<f64, ndarray::Ix1> =
725 GradientMask::new(mask).with_lr_multipliers(multipliers);
726 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
727
728 grad_mask.apply_mask(&mut gradients);
729
730 assert_relative_eq!(gradients[0], 1.0, epsilon = 1e-6);
731 assert_relative_eq!(gradients[1], 1.0, epsilon = 1e-6); assert_relative_eq!(gradients[2], 6.0, epsilon = 1e-6); }
734
735 #[test]
736 fn test_gradient_mask_freeze_all() {
737 let grad_mask = GradientMask::<f64, ndarray::Ix1>::freeze_all(ndarray::Ix1(3));
738 assert_eq!(grad_mask.num_frozen(), 3);
739 assert_eq!(grad_mask.num_active(), 0);
740 }
741
742 #[test]
743 fn test_gradient_mask_update_all() {
744 let grad_mask = GradientMask::<f64, ndarray::Ix1>::update_all(ndarray::Ix1(3));
745 assert_eq!(grad_mask.num_frozen(), 0);
746 assert_eq!(grad_mask.num_active(), 3);
747 }
748}