Skip to main content

tensorlogic_train/
gradient_centralization.rs

1//! Gradient Centralization (GC) - Advanced gradient preprocessing technique.
2//!
3//! Gradient Centralization is a simple yet effective optimization technique that
4//! normalizes gradients by subtracting their mean before applying the optimizer update.
5//! This has been shown to improve generalization, accelerate training, and stabilize
6//! gradient flow, especially for deep networks.
7//!
8//! # Benefits
9//! - Improved generalization and test accuracy
10//! - Faster convergence
11//! - Better gradient flow (reduces gradient explosion/vanishing)
12//! - Works with any optimizer
13//! - Minimal computational overhead
14//!
15//! # Reference
16//! Yong et al., "Gradient Centralization: A New Optimization Technique for Deep Neural Networks"
17//! ECCV 2020 - <https://arxiv.org/abs/2004.01461>
18
19use crate::{Optimizer, TrainResult};
20use scirs2_core::ndarray::Array2;
21use std::collections::HashMap;
22
23/// Gradient centralization strategy.
24#[derive(Debug, Clone, Copy, PartialEq, Default)]
25pub enum GcStrategy {
26    /// Centralize each layer's gradients independently (most common).
27    /// For each parameter matrix: g = g - mean(g)
28    #[default]
29    LayerWise,
30
31    /// Centralize all gradients globally (experimental).
32    /// Compute global mean across all parameters: g_all = g_all - mean(g_all)
33    Global,
34
35    /// Centralize per row (for weight matrices).
36    /// For weight matrix: g\[i,:\] = g\[i,:\] - mean(g\[i,:\])
37    PerRow,
38
39    /// Centralize per column (for weight matrices).
40    /// For weight matrix: g\[:,j\] = g\[:,j\] - mean(g\[:,j\])
41    PerColumn,
42}
43
44/// Configuration for gradient centralization.
45#[derive(Debug, Clone)]
46pub struct GcConfig {
47    /// Centralization strategy.
48    pub strategy: GcStrategy,
49
50    /// Whether to apply GC (can be toggled dynamically).
51    pub enabled: bool,
52
53    /// Minimum parameter dimensions to apply GC (skip small parameters).
54    /// For example, bias vectors (1D) are typically not centralized.
55    pub min_dims: usize,
56
57    /// Epsilon for numerical stability.
58    pub eps: f64,
59}
60
61impl Default for GcConfig {
62    fn default() -> Self {
63        Self {
64            strategy: GcStrategy::LayerWise,
65            enabled: true,
66            min_dims: 2, // Only centralize 2D+ tensors (weight matrices)
67            eps: 1e-8,
68        }
69    }
70}
71
72impl GcConfig {
73    /// Create a new GC configuration.
74    pub fn new(strategy: GcStrategy) -> Self {
75        Self {
76            strategy,
77            ..Default::default()
78        }
79    }
80
81    /// Enable gradient centralization.
82    pub fn enable(&mut self) {
83        self.enabled = true;
84    }
85
86    /// Disable gradient centralization.
87    pub fn disable(&mut self) {
88        self.enabled = false;
89    }
90
91    /// Set minimum dimensions for applying GC.
92    pub fn with_min_dims(mut self, min_dims: usize) -> Self {
93        self.min_dims = min_dims;
94        self
95    }
96
97    /// Set epsilon for numerical stability.
98    pub fn with_eps(mut self, eps: f64) -> Self {
99        self.eps = eps;
100        self
101    }
102}
103
104/// Gradient Centralization optimizer wrapper.
105///
106/// Wraps any optimizer and applies gradient centralization before the optimizer step.
107/// GC normalizes gradients by subtracting their mean, which improves training dynamics.
108///
109/// # Example
110/// ```no_run
111/// use tensorlogic_train::*;
112/// use scirs2_core::ndarray::Array2;
113/// use std::collections::HashMap;
114///
115/// // Create base optimizer
116/// let config = OptimizerConfig { learning_rate: 0.001, ..Default::default() };
117/// let adam = AdamOptimizer::new(config);
118///
119/// // Wrap with gradient centralization
120/// let mut gc_adam = GradientCentralization::new(
121///     Box::new(adam),
122///     GcConfig::default(),
123/// );
124///
125/// // Use as normal optimizer - GC is applied automatically
126/// let mut params = HashMap::new();
127/// params.insert("w1".to_string(), Array2::zeros((10, 5)));
128///
129/// let mut grads = HashMap::new();
130/// grads.insert("w1".to_string(), Array2::ones((10, 5)));
131///
132/// gc_adam.step(&mut params, &grads).unwrap();
133/// ```
134pub struct GradientCentralization {
135    /// Wrapped optimizer.
136    inner_optimizer: Box<dyn Optimizer>,
137
138    /// GC configuration.
139    config: GcConfig,
140
141    /// Statistics for monitoring.
142    stats: GcStats,
143}
144
145/// Statistics for gradient centralization.
146#[derive(Debug, Clone, Default)]
147pub struct GcStats {
148    /// Number of parameters centralized.
149    pub num_centralized: usize,
150
151    /// Number of parameters skipped (too small).
152    pub num_skipped: usize,
153
154    /// Average gradient magnitude before centralization.
155    pub avg_grad_norm_before: f64,
156
157    /// Average gradient magnitude after centralization.
158    pub avg_grad_norm_after: f64,
159
160    /// Total number of centralization operations.
161    pub total_operations: usize,
162}
163
164impl GradientCentralization {
165    /// Create a new gradient centralization optimizer.
166    pub fn new(inner_optimizer: Box<dyn Optimizer>, config: GcConfig) -> Self {
167        Self {
168            inner_optimizer,
169            config,
170            stats: GcStats::default(),
171        }
172    }
173
174    /// Create with default configuration.
175    pub fn with_default(inner_optimizer: Box<dyn Optimizer>) -> Self {
176        Self::new(inner_optimizer, GcConfig::default())
177    }
178
179    /// Get GC configuration.
180    pub fn config(&self) -> &GcConfig {
181        &self.config
182    }
183
184    /// Get mutable GC configuration.
185    pub fn config_mut(&mut self) -> &mut GcConfig {
186        &mut self.config
187    }
188
189    /// Get statistics.
190    pub fn stats(&self) -> &GcStats {
191        &self.stats
192    }
193
194    /// Reset statistics.
195    pub fn reset_stats(&mut self) {
196        self.stats = GcStats::default();
197    }
198
199    /// Apply gradient centralization to gradients.
200    fn centralize_gradients(
201        &mut self,
202        grads: &HashMap<String, Array2<f64>>,
203    ) -> HashMap<String, Array2<f64>> {
204        if !self.config.enabled {
205            return grads.clone();
206        }
207
208        let mut centralized_grads = HashMap::new();
209        let mut total_norm_before = 0.0;
210        let mut total_norm_after = 0.0;
211
212        for (name, grad) in grads {
213            let shape = grad.shape();
214
215            // Check if parameter meets minimum dimension requirement
216            if shape.len() < self.config.min_dims {
217                centralized_grads.insert(name.clone(), grad.clone());
218                self.stats.num_skipped += 1;
219                continue;
220            }
221
222            // Compute norm before centralization
223            let norm_before = grad.iter().map(|&x| x * x).sum::<f64>().sqrt();
224            total_norm_before += norm_before;
225
226            // Apply centralization based on strategy
227            let centered_grad = match self.config.strategy {
228                GcStrategy::LayerWise => self.centralize_layerwise(grad),
229                GcStrategy::Global => grad.clone(), // Global handled separately
230                GcStrategy::PerRow => self.centralize_per_row(grad),
231                GcStrategy::PerColumn => self.centralize_per_column(grad),
232            };
233
234            // Compute norm after centralization
235            let norm_after = centered_grad.iter().map(|&x| x * x).sum::<f64>().sqrt();
236            total_norm_after += norm_after;
237
238            centralized_grads.insert(name.clone(), centered_grad);
239            self.stats.num_centralized += 1;
240        }
241
242        // Handle global strategy
243        if self.config.strategy == GcStrategy::Global && !centralized_grads.is_empty() {
244            centralized_grads = self.centralize_global(&centralized_grads);
245        }
246
247        // Update statistics
248        let n = (self.stats.num_centralized + self.stats.num_skipped).max(1) as f64;
249        self.stats.avg_grad_norm_before = total_norm_before / n;
250        self.stats.avg_grad_norm_after = total_norm_after / n;
251        self.stats.total_operations += 1;
252
253        centralized_grads
254    }
255
256    /// Centralize gradients layer-wise (subtract mean from each layer).
257    fn centralize_layerwise(&self, grad: &Array2<f64>) -> Array2<f64> {
258        let mean = grad.mean().unwrap_or(0.0);
259        grad - mean
260    }
261
262    /// Centralize gradients per row.
263    fn centralize_per_row(&self, grad: &Array2<f64>) -> Array2<f64> {
264        let mut centered = grad.clone();
265
266        for i in 0..grad.nrows() {
267            let row_mean = grad.row(i).mean().unwrap_or(0.0);
268            for j in 0..grad.ncols() {
269                centered[[i, j]] -= row_mean;
270            }
271        }
272
273        centered
274    }
275
276    /// Centralize gradients per column.
277    fn centralize_per_column(&self, grad: &Array2<f64>) -> Array2<f64> {
278        let mut centered = grad.clone();
279
280        for j in 0..grad.ncols() {
281            let col_mean = grad.column(j).mean().unwrap_or(0.0);
282            for i in 0..grad.nrows() {
283                centered[[i, j]] -= col_mean;
284            }
285        }
286
287        centered
288    }
289
290    /// Centralize all gradients globally.
291    fn centralize_global(
292        &self,
293        grads: &HashMap<String, Array2<f64>>,
294    ) -> HashMap<String, Array2<f64>> {
295        // Compute global mean across all parameters
296        let mut total_sum = 0.0;
297        let mut total_count = 0;
298
299        for grad in grads.values() {
300            total_sum += grad.sum();
301            total_count += grad.len();
302        }
303
304        let global_mean = if total_count > 0 {
305            total_sum / total_count as f64
306        } else {
307            0.0
308        };
309
310        // Subtract global mean from all gradients
311        let mut centralized = HashMap::new();
312        for (name, grad) in grads {
313            centralized.insert(name.clone(), grad - global_mean);
314        }
315
316        centralized
317    }
318}
319
320impl Optimizer for GradientCentralization {
321    fn step(
322        &mut self,
323        params: &mut HashMap<String, Array2<f64>>,
324        grads: &HashMap<String, Array2<f64>>,
325    ) -> TrainResult<()> {
326        // Apply gradient centralization
327        let centralized_grads = self.centralize_gradients(grads);
328
329        // Forward to wrapped optimizer
330        self.inner_optimizer.step(params, &centralized_grads)
331    }
332
333    fn zero_grad(&mut self) {
334        self.inner_optimizer.zero_grad();
335    }
336
337    fn get_lr(&self) -> f64 {
338        self.inner_optimizer.get_lr()
339    }
340
341    fn set_lr(&mut self, lr: f64) {
342        self.inner_optimizer.set_lr(lr);
343    }
344
345    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
346        // Include both inner optimizer state and GC config
347        let mut state = self.inner_optimizer.state_dict();
348
349        // Serialize GC config (simplified - just store enabled flag)
350        let gc_state = if self.config.enabled {
351            vec![1.0]
352        } else {
353            vec![0.0]
354        };
355        state.insert("gc_enabled".to_string(), gc_state);
356
357        state
358    }
359
360    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
361        // Load GC config
362        if let Some(gc_state) = state.get("gc_enabled") {
363            self.config.enabled = !gc_state.is_empty() && gc_state[0] > 0.5;
364        }
365
366        // Load inner optimizer state
367        self.inner_optimizer.load_state_dict(state);
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::{AdamOptimizer, OptimizerConfig};
375    use scirs2_core::ndarray::Array2;
376
377    #[test]
378    fn test_gc_config_default() {
379        let config = GcConfig::default();
380        assert!(config.enabled);
381        assert_eq!(config.min_dims, 2);
382        assert_eq!(config.strategy, GcStrategy::LayerWise);
383    }
384
385    #[test]
386    fn test_gc_config_builder() {
387        let config = GcConfig::new(GcStrategy::PerRow)
388            .with_min_dims(1)
389            .with_eps(1e-10);
390
391        assert_eq!(config.strategy, GcStrategy::PerRow);
392        assert_eq!(config.min_dims, 1);
393        assert_eq!(config.eps, 1e-10);
394    }
395
396    #[test]
397    fn test_gc_layerwise_centralization() {
398        let config = OptimizerConfig {
399            learning_rate: 0.001,
400            ..Default::default()
401        };
402        let adam = AdamOptimizer::new(config);
403        let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
404
405        // Create gradient with known mean
406        let grad = Array2::from_shape_fn((3, 3), |(i, j)| (i * 3 + j) as f64);
407        let mean = grad.mean().unwrap();
408
409        let mut grads = HashMap::new();
410        grads.insert("w1".to_string(), grad.clone());
411
412        let centered = gc.centralize_gradients(&grads);
413        let centered_grad = &centered["w1"];
414
415        // Mean should be close to zero after centralization
416        let new_mean = centered_grad.mean().unwrap();
417        assert!(new_mean.abs() < 1e-10);
418
419        // Each element should be shifted by original mean
420        for i in 0..3 {
421            for j in 0..3 {
422                assert!((centered_grad[[i, j]] - (grad[[i, j]] - mean)).abs() < 1e-10);
423            }
424        }
425    }
426
427    #[test]
428    fn test_gc_per_row_centralization() {
429        let config = OptimizerConfig {
430            learning_rate: 0.001,
431            ..Default::default()
432        };
433        let adam = AdamOptimizer::new(config);
434        let config = GcConfig::new(GcStrategy::PerRow);
435        let mut gc = GradientCentralization::new(Box::new(adam), config);
436
437        let grad = Array2::from_shape_fn((2, 3), |(i, j)| (i * 10 + j) as f64);
438
439        let mut grads = HashMap::new();
440        grads.insert("w1".to_string(), grad.clone());
441
442        let centered = gc.centralize_gradients(&grads);
443        let centered_grad = &centered["w1"];
444
445        // Each row should have mean close to zero
446        for i in 0..2 {
447            let row_mean = centered_grad.row(i).mean().unwrap();
448            assert!(row_mean.abs() < 1e-10);
449        }
450    }
451
452    #[test]
453    fn test_gc_per_column_centralization() {
454        let config = OptimizerConfig {
455            learning_rate: 0.001,
456            ..Default::default()
457        };
458        let adam = AdamOptimizer::new(config);
459        let config = GcConfig::new(GcStrategy::PerColumn);
460        let mut gc = GradientCentralization::new(Box::new(adam), config);
461
462        let grad = Array2::from_shape_fn((3, 2), |(i, j)| (i + j * 10) as f64);
463
464        let mut grads = HashMap::new();
465        grads.insert("w1".to_string(), grad.clone());
466
467        let centered = gc.centralize_gradients(&grads);
468        let centered_grad = &centered["w1"];
469
470        // Each column should have mean close to zero
471        for j in 0..2 {
472            let col_mean = centered_grad.column(j).mean().unwrap();
473            assert!(col_mean.abs() < 1e-10);
474        }
475    }
476
477    #[test]
478    fn test_gc_global_centralization() {
479        let config = OptimizerConfig {
480            learning_rate: 0.001,
481            ..Default::default()
482        };
483        let adam = AdamOptimizer::new(config);
484        let config = GcConfig::new(GcStrategy::Global);
485        let mut gc = GradientCentralization::new(Box::new(adam), config);
486
487        let mut grads = HashMap::new();
488        grads.insert("w1".to_string(), Array2::from_elem((2, 2), 5.0));
489        grads.insert("w2".to_string(), Array2::from_elem((2, 2), 15.0));
490
491        let centered = gc.centralize_gradients(&grads);
492
493        // Global mean should be 10.0
494        // After centralization: w1 = -5, w2 = 5
495        let w1_centered = &centered["w1"];
496        let w2_centered = &centered["w2"];
497
498        assert!((w1_centered[[0, 0]] + 5.0).abs() < 1e-10);
499        assert!((w2_centered[[0, 0]] - 5.0).abs() < 1e-10);
500    }
501
502    #[test]
503    fn test_gc_skip_small_tensors() {
504        let config = OptimizerConfig {
505            learning_rate: 0.001,
506            ..Default::default()
507        };
508        let adam = AdamOptimizer::new(config);
509        let config = GcConfig::default().with_min_dims(2);
510        let gc = GradientCentralization::new(Box::new(adam), config);
511
512        // This test would require 1D tensors, but our implementation uses Array2
513        // So we verify that the min_dims check is there
514        assert_eq!(gc.config().min_dims, 2);
515    }
516
517    #[test]
518    fn test_gc_enable_disable() {
519        let mut config = GcConfig::default();
520        assert!(config.enabled);
521
522        config.disable();
523        assert!(!config.enabled);
524
525        config.enable();
526        assert!(config.enabled);
527    }
528
529    #[test]
530    fn test_gc_with_optimizer_step() {
531        let config = OptimizerConfig {
532            learning_rate: 0.01,
533            ..Default::default()
534        };
535        let adam = AdamOptimizer::new(config);
536        let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
537
538        let mut params = HashMap::new();
539        params.insert("w1".to_string(), Array2::ones((3, 3)));
540
541        // Use varying gradients (not uniform) so after centralization there's still signal
542        let mut grads = HashMap::new();
543        grads.insert(
544            "w1".to_string(),
545            Array2::from_shape_fn((3, 3), |(i, j)| 0.1 + (i + j) as f64 * 0.05),
546        );
547
548        // Step should succeed
549        assert!(gc.step(&mut params, &grads).is_ok());
550
551        // Parameters should be updated (at least some of them should decrease)
552        let updated = &params["w1"];
553        // After GC, we still have non-zero centered gradients
554        // At least one parameter should have changed
555        let has_changed = updated.iter().any(|&x| (x - 1.0).abs() > 1e-6);
556        assert!(has_changed);
557    }
558
559    #[test]
560    fn test_gc_statistics() {
561        let config = OptimizerConfig {
562            learning_rate: 0.001,
563            ..Default::default()
564        };
565        let adam = AdamOptimizer::new(config);
566        let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
567
568        let mut grads = HashMap::new();
569        grads.insert("w1".to_string(), Array2::ones((3, 3)));
570        grads.insert("w2".to_string(), Array2::ones((3, 3)));
571
572        gc.centralize_gradients(&grads);
573
574        let stats = gc.stats();
575        assert_eq!(stats.num_centralized, 2);
576        assert_eq!(stats.total_operations, 1);
577        assert!(stats.avg_grad_norm_before > 0.0);
578    }
579
580    #[test]
581    fn test_gc_reset_stats() {
582        let config = OptimizerConfig {
583            learning_rate: 0.001,
584            ..Default::default()
585        };
586        let adam = AdamOptimizer::new(config);
587        let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
588
589        let mut grads = HashMap::new();
590        grads.insert("w1".to_string(), Array2::ones((3, 3)));
591
592        gc.centralize_gradients(&grads);
593        assert_eq!(gc.stats().total_operations, 1);
594
595        gc.reset_stats();
596        assert_eq!(gc.stats().total_operations, 0);
597    }
598
599    #[test]
600    fn test_gc_learning_rate() {
601        let config = OptimizerConfig {
602            learning_rate: 0.001,
603            ..Default::default()
604        };
605        let adam = AdamOptimizer::new(config);
606        let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
607
608        assert_eq!(gc.get_lr(), 0.001);
609
610        gc.set_lr(0.01);
611        assert_eq!(gc.get_lr(), 0.01);
612    }
613
614    #[test]
615    fn test_gc_state_dict() {
616        let config = OptimizerConfig {
617            learning_rate: 0.001,
618            ..Default::default()
619        };
620        let adam = AdamOptimizer::new(config);
621        let mut gc = GradientCentralization::new(Box::new(adam), GcConfig::default());
622
623        // Get state
624        let state = gc.state_dict();
625        assert!(state.contains_key("gc_enabled"));
626
627        // Modify and load
628        gc.config_mut().disable();
629        assert!(!gc.config().enabled);
630
631        gc.load_state_dict(state);
632        assert!(gc.config().enabled); // Should be restored
633    }
634
635    #[test]
636    fn test_gc_disabled() {
637        let config = OptimizerConfig {
638            learning_rate: 0.001,
639            ..Default::default()
640        };
641        let adam = AdamOptimizer::new(config);
642        let mut config = GcConfig::default();
643        config.disable();
644
645        let mut gc = GradientCentralization::new(Box::new(adam), config);
646
647        let grad = Array2::from_elem((3, 3), 5.0);
648        let mut grads = HashMap::new();
649        grads.insert("w1".to_string(), grad.clone());
650
651        let centered = gc.centralize_gradients(&grads);
652
653        // Should return unchanged gradients when disabled
654        let centered_grad = &centered["w1"];
655        assert_eq!(centered_grad[[0, 0]], 5.0);
656    }
657}