1use crate::{Optimizer, TrainResult};
20use scirs2_core::ndarray::Array2;
21use std::collections::HashMap;
22
23#[derive(Debug, Clone, Copy, PartialEq, Default)]
25pub enum GcStrategy {
26 #[default]
29 LayerWise,
30
31 Global,
34
35 PerRow,
38
39 PerColumn,
42}
43
44#[derive(Debug, Clone)]
46pub struct GcConfig {
47 pub strategy: GcStrategy,
49
50 pub enabled: bool,
52
53 pub min_dims: usize,
56
57 pub eps: f64,
59}
60
61impl Default for GcConfig {
62 fn default() -> Self {
63 Self {
64 strategy: GcStrategy::LayerWise,
65 enabled: true,
66 min_dims: 2, eps: 1e-8,
68 }
69 }
70}
71
72impl GcConfig {
73 pub fn new(strategy: GcStrategy) -> Self {
75 Self {
76 strategy,
77 ..Default::default()
78 }
79 }
80
81 pub fn enable(&mut self) {
83 self.enabled = true;
84 }
85
86 pub fn disable(&mut self) {
88 self.enabled = false;
89 }
90
91 pub fn with_min_dims(mut self, min_dims: usize) -> Self {
93 self.min_dims = min_dims;
94 self
95 }
96
97 pub fn with_eps(mut self, eps: f64) -> Self {
99 self.eps = eps;
100 self
101 }
102}
103
104pub struct GradientCentralization {
135 inner_optimizer: Box<dyn Optimizer>,
137
138 config: GcConfig,
140
141 stats: GcStats,
143}
144
145#[derive(Debug, Clone, Default)]
147pub struct GcStats {
148 pub num_centralized: usize,
150
151 pub num_skipped: usize,
153
154 pub avg_grad_norm_before: f64,
156
157 pub avg_grad_norm_after: f64,
159
160 pub total_operations: usize,
162}
163
164impl GradientCentralization {
165 pub fn new(inner_optimizer: Box<dyn Optimizer>, config: GcConfig) -> Self {
167 Self {
168 inner_optimizer,
169 config,
170 stats: GcStats::default(),
171 }
172 }
173
174 pub fn with_default(inner_optimizer: Box<dyn Optimizer>) -> Self {
176 Self::new(inner_optimizer, GcConfig::default())
177 }
178
179 pub fn config(&self) -> &GcConfig {
181 &self.config
182 }
183
184 pub fn config_mut(&mut self) -> &mut GcConfig {
186 &mut self.config
187 }
188
189 pub fn stats(&self) -> &GcStats {
191 &self.stats
192 }
193
194 pub fn reset_stats(&mut self) {
196 self.stats = GcStats::default();
197 }
198
199 fn centralize_gradients(
201 &mut self,
202 grads: &HashMap<String, Array2<f64>>,
203 ) -> HashMap<String, Array2<f64>> {
204 if !self.config.enabled {
205 return grads.clone();
206 }
207
208 let mut centralized_grads = HashMap::new();
209 let mut total_norm_before = 0.0;
210 let mut total_norm_after = 0.0;
211
212 for (name, grad) in grads {
213 let shape = grad.shape();
214
215 if shape.len() < self.config.min_dims {
217 centralized_grads.insert(name.clone(), grad.clone());
218 self.stats.num_skipped += 1;
219 continue;
220 }
221
222 let norm_before = grad.iter().map(|&x| x * x).sum::<f64>().sqrt();
224 total_norm_before += norm_before;
225
226 let centered_grad = match self.config.strategy {
228 GcStrategy::LayerWise => self.centralize_layerwise(grad),
229 GcStrategy::Global => grad.clone(), GcStrategy::PerRow => self.centralize_per_row(grad),
231 GcStrategy::PerColumn => self.centralize_per_column(grad),
232 };
233
234 let norm_after = centered_grad.iter().map(|&x| x * x).sum::<f64>().sqrt();
236 total_norm_after += norm_after;
237
238 centralized_grads.insert(name.clone(), centered_grad);
239 self.stats.num_centralized += 1;
240 }
241
242 if self.config.strategy == GcStrategy::Global && !centralized_grads.is_empty() {
244 centralized_grads = self.centralize_global(¢ralized_grads);
245 }
246
247 let n = (self.stats.num_centralized + self.stats.num_skipped).max(1) as f64;
249 self.stats.avg_grad_norm_before = total_norm_before / n;
250 self.stats.avg_grad_norm_after = total_norm_after / n;
251 self.stats.total_operations += 1;
252
253 centralized_grads
254 }
255
256 fn centralize_layerwise(&self, grad: &Array2<f64>) -> Array2<f64> {
258 let mean = grad.mean().unwrap_or(0.0);
259 grad - mean
260 }
261
262 fn centralize_per_row(&self, grad: &Array2<f64>) -> Array2<f64> {
264 let mut centered = grad.clone();
265
266 for i in 0..grad.nrows() {
267 let row_mean = grad.row(i).mean().unwrap_or(0.0);
268 for j in 0..grad.ncols() {
269 centered[[i, j]] -= row_mean;
270 }
271 }
272
273 centered
274 }
275
276 fn centralize_per_column(&self, grad: &Array2<f64>) -> Array2<f64> {
278 let mut centered = grad.clone();
279
280 for j in 0..grad.ncols() {
281 let col_mean = grad.column(j).mean().unwrap_or(0.0);
282 for i in 0..grad.nrows() {
283 centered[[i, j]] -= col_mean;
284 }
285 }
286
287 centered
288 }
289
290 fn centralize_global(
292 &self,
293 grads: &HashMap<String, Array2<f64>>,
294 ) -> HashMap<String, Array2<f64>> {
295 let mut total_sum = 0.0;
297 let mut total_count = 0;
298
299 for grad in grads.values() {
300 total_sum += grad.sum();
301 total_count += grad.len();
302 }
303
304 let global_mean = if total_count > 0 {
305 total_sum / total_count as f64
306 } else {
307 0.0
308 };
309
310 let mut centralized = HashMap::new();
312 for (name, grad) in grads {
313 centralized.insert(name.clone(), grad - global_mean);
314 }
315
316 centralized
317 }
318}
319
320impl Optimizer for GradientCentralization {
321 fn step(
322 &mut self,
323 params: &mut HashMap<String, Array2<f64>>,
324 grads: &HashMap<String, Array2<f64>>,
325 ) -> TrainResult<()> {
326 let centralized_grads = self.centralize_gradients(grads);
328
329 self.inner_optimizer.step(params, ¢ralized_grads)
331 }
332
333 fn zero_grad(&mut self) {
334 self.inner_optimizer.zero_grad();
335 }
336
337 fn get_lr(&self) -> f64 {
338 self.inner_optimizer.get_lr()
339 }
340
341 fn set_lr(&mut self, lr: f64) {
342 self.inner_optimizer.set_lr(lr);
343 }
344
345 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
346 let mut state = self.inner_optimizer.state_dict();
348
349 let gc_state = if self.config.enabled {
351 vec![1.0]
352 } else {
353 vec![0.0]
354 };
355 state.insert("gc_enabled".to_string(), gc_state);
356
357 state
358 }
359
360 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
361 if let Some(gc_state) = state.get("gc_enabled") {
363 self.config.enabled = !gc_state.is_empty() && gc_state[0] > 0.5;
364 }
365
366 self.inner_optimizer.load_state_dict(state);
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::{AdamOptimizer, OptimizerConfig};
375 use scirs2_core::ndarray::Array2;
376
377 #[test]
378 fn test_gc_config_default() {
379 let config = GcConfig::default();
380 assert!(config.enabled);
381 assert_eq!(config.min_dims, 2);
382 assert_eq!(config.strategy, GcStrategy::LayerWise);
383 }
384
385 #[test]
386 fn test_gc_config_builder() {
387 let config = GcConfig::new(GcStrategy::PerRow)
388 .with_min_dims(1)
389 .with_eps(1e-10);
390
391 assert_eq!(config.strategy, GcStrategy::PerRow);
392 assert_eq!(config.min_dims, 1);
393 assert_eq!(config.eps, 1e-10);
394 }
395
396 #[test]
397 fn test_gc_layerwise_centralization() {
398 let config = OptimizerConfig {
399 learning_rate: 0.001,
400 ..Default::default()
401 };
402 let adam = AdamOptimizer::new(config);
403 let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
404
405 let grad = Array2::from_shape_fn((3, 3), |(i, j)| (i * 3 + j) as f64);
407 let mean = grad.mean().unwrap();
408
409 let mut grads = HashMap::new();
410 grads.insert("w1".to_string(), grad.clone());
411
412 let centered = gc.centralize_gradients(&grads);
413 let centered_grad = ¢ered["w1"];
414
415 let new_mean = centered_grad.mean().unwrap();
417 assert!(new_mean.abs() < 1e-10);
418
419 for i in 0..3 {
421 for j in 0..3 {
422 assert!((centered_grad[[i, j]] - (grad[[i, j]] - mean)).abs() < 1e-10);
423 }
424 }
425 }
426
427 #[test]
428 fn test_gc_per_row_centralization() {
429 let config = OptimizerConfig {
430 learning_rate: 0.001,
431 ..Default::default()
432 };
433 let adam = AdamOptimizer::new(config);
434 let config = GcConfig::new(GcStrategy::PerRow);
435 let mut gc = GradientCentralization::new(Box::new(adam), config);
436
437 let grad = Array2::from_shape_fn((2, 3), |(i, j)| (i * 10 + j) as f64);
438
439 let mut grads = HashMap::new();
440 grads.insert("w1".to_string(), grad.clone());
441
442 let centered = gc.centralize_gradients(&grads);
443 let centered_grad = ¢ered["w1"];
444
445 for i in 0..2 {
447 let row_mean = centered_grad.row(i).mean().unwrap();
448 assert!(row_mean.abs() < 1e-10);
449 }
450 }
451
452 #[test]
453 fn test_gc_per_column_centralization() {
454 let config = OptimizerConfig {
455 learning_rate: 0.001,
456 ..Default::default()
457 };
458 let adam = AdamOptimizer::new(config);
459 let config = GcConfig::new(GcStrategy::PerColumn);
460 let mut gc = GradientCentralization::new(Box::new(adam), config);
461
462 let grad = Array2::from_shape_fn((3, 2), |(i, j)| (i + j * 10) as f64);
463
464 let mut grads = HashMap::new();
465 grads.insert("w1".to_string(), grad.clone());
466
467 let centered = gc.centralize_gradients(&grads);
468 let centered_grad = ¢ered["w1"];
469
470 for j in 0..2 {
472 let col_mean = centered_grad.column(j).mean().unwrap();
473 assert!(col_mean.abs() < 1e-10);
474 }
475 }
476
477 #[test]
478 fn test_gc_global_centralization() {
479 let config = OptimizerConfig {
480 learning_rate: 0.001,
481 ..Default::default()
482 };
483 let adam = AdamOptimizer::new(config);
484 let config = GcConfig::new(GcStrategy::Global);
485 let mut gc = GradientCentralization::new(Box::new(adam), config);
486
487 let mut grads = HashMap::new();
488 grads.insert("w1".to_string(), Array2::from_elem((2, 2), 5.0));
489 grads.insert("w2".to_string(), Array2::from_elem((2, 2), 15.0));
490
491 let centered = gc.centralize_gradients(&grads);
492
493 let w1_centered = ¢ered["w1"];
496 let w2_centered = ¢ered["w2"];
497
498 assert!((w1_centered[[0, 0]] + 5.0).abs() < 1e-10);
499 assert!((w2_centered[[0, 0]] - 5.0).abs() < 1e-10);
500 }
501
502 #[test]
503 fn test_gc_skip_small_tensors() {
504 let config = OptimizerConfig {
505 learning_rate: 0.001,
506 ..Default::default()
507 };
508 let adam = AdamOptimizer::new(config);
509 let config = GcConfig::default().with_min_dims(2);
510 let gc = GradientCentralization::new(Box::new(adam), config);
511
512 assert_eq!(gc.config().min_dims, 2);
515 }
516
517 #[test]
518 fn test_gc_enable_disable() {
519 let mut config = GcConfig::default();
520 assert!(config.enabled);
521
522 config.disable();
523 assert!(!config.enabled);
524
525 config.enable();
526 assert!(config.enabled);
527 }
528
529 #[test]
530 fn test_gc_with_optimizer_step() {
531 let config = OptimizerConfig {
532 learning_rate: 0.01,
533 ..Default::default()
534 };
535 let adam = AdamOptimizer::new(config);
536 let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
537
538 let mut params = HashMap::new();
539 params.insert("w1".to_string(), Array2::ones((3, 3)));
540
541 let mut grads = HashMap::new();
543 grads.insert(
544 "w1".to_string(),
545 Array2::from_shape_fn((3, 3), |(i, j)| 0.1 + (i + j) as f64 * 0.05),
546 );
547
548 assert!(gc.step(&mut params, &grads).is_ok());
550
551 let updated = ¶ms["w1"];
553 let has_changed = updated.iter().any(|&x| (x - 1.0).abs() > 1e-6);
556 assert!(has_changed);
557 }
558
559 #[test]
560 fn test_gc_statistics() {
561 let config = OptimizerConfig {
562 learning_rate: 0.001,
563 ..Default::default()
564 };
565 let adam = AdamOptimizer::new(config);
566 let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
567
568 let mut grads = HashMap::new();
569 grads.insert("w1".to_string(), Array2::ones((3, 3)));
570 grads.insert("w2".to_string(), Array2::ones((3, 3)));
571
572 gc.centralize_gradients(&grads);
573
574 let stats = gc.stats();
575 assert_eq!(stats.num_centralized, 2);
576 assert_eq!(stats.total_operations, 1);
577 assert!(stats.avg_grad_norm_before > 0.0);
578 }
579
580 #[test]
581 fn test_gc_reset_stats() {
582 let config = OptimizerConfig {
583 learning_rate: 0.001,
584 ..Default::default()
585 };
586 let adam = AdamOptimizer::new(config);
587 let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
588
589 let mut grads = HashMap::new();
590 grads.insert("w1".to_string(), Array2::ones((3, 3)));
591
592 gc.centralize_gradients(&grads);
593 assert_eq!(gc.stats().total_operations, 1);
594
595 gc.reset_stats();
596 assert_eq!(gc.stats().total_operations, 0);
597 }
598
599 #[test]
600 fn test_gc_learning_rate() {
601 let config = OptimizerConfig {
602 learning_rate: 0.001,
603 ..Default::default()
604 };
605 let adam = AdamOptimizer::new(config);
606 let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
607
608 assert_eq!(gc.get_lr(), 0.001);
609
610 gc.set_lr(0.01);
611 assert_eq!(gc.get_lr(), 0.01);
612 }
613
614 #[test]
615 fn test_gc_state_dict() {
616 let config = OptimizerConfig {
617 learning_rate: 0.001,
618 ..Default::default()
619 };
620 let adam = AdamOptimizer::new(config);
621 let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
622
623 let state = gc.state_dict();
625 assert!(state.contains_key("gc_enabled"));
626
627 gc.config_mut().disable();
629 assert!(!gc.config().enabled);
630
631 gc.load_state_dict(state);
632 assert!(gc.config().enabled); }
634
635 #[test]
636 fn test_gc_disabled() {
637 let config = OptimizerConfig {
638 learning_rate: 0.001,
639 ..Default::default()
640 };
641 let adam = AdamOptimizer::new(config);
642 let mut config = GcConfig::default();
643 config.disable();
644
645 let mut gc = GradientCentralization::new(Box::new(adam), config);
646
647 let grad = Array2::from_elem((3, 3), 5.0);
648 let mut grads = HashMap::new();
649 grads.insert("w1".to_string(), grad.clone());
650
651 let centered = gc.centralize_gradients(&grads);
652
653 let centered_grad = ¢ered["w1"];
655 assert_eq!(centered_grad[[0, 0]], 5.0);
656 }
657}