Skip to main content

tenflowers_core/
gradient_clipping.rs

1//! Advanced Gradient Clipping with Adaptive Scaling
2//!
3//! This module provides sophisticated gradient clipping techniques essential for training
4//! stability in modern deep learning, particularly with large language models, RNNs, and
5//! transformers. It implements both global norm clipping and adaptive scaling strategies
6//! that automatically adjust based on training dynamics.
7//!
8//! # Features
9//! - **Global Gradient Norm Clipping**: Scales all gradients by the same factor to maintain
10//!   relative magnitudes while preventing exploding gradients
11//! - **Adaptive Clipping**: Dynamically adjusts clipping thresholds based on gradient
12//!   statistics and training progress
13//! - **Per-Parameter Clipping**: Fine-grained control for different parameter groups
14//! - **Gradient Statistics Tracking**: Monitors gradient norms for training diagnostics
15//! - **Warmup and Decay**: Gradually adjusts clipping behavior during training phases
16
17use crate::{Result, Tensor};
18use scirs2_core::numeric::{Float, FromPrimitive};
19use std::collections::HashMap;
20use std::marker::PhantomData;
21
22/// Configuration for gradient clipping behavior
23#[derive(Debug, Clone)]
24pub struct GradientClippingConfig {
25    /// Maximum allowed gradient norm (global clipping threshold)
26    pub max_norm: f64,
27    /// Type of norm to use for clipping (1, 2, or infinity)
28    pub norm_type: NormType,
29    /// Enable adaptive threshold adjustment based on gradient statistics
30    pub adaptive_scaling: bool,
31    /// Momentum factor for adaptive threshold updates (0.0 - 1.0)
32    pub adaptive_momentum: f64,
33    /// Minimum allowed clipping threshold (prevents overly aggressive clipping)
34    pub min_threshold: f64,
35    /// Maximum allowed clipping threshold (prevents disabling clipping)
36    pub max_threshold: f64,
37    /// Warmup steps for gradually enabling clipping
38    pub warmup_steps: usize,
39    /// Enable per-parameter group clipping with different thresholds
40    pub per_parameter_clipping: bool,
41}
42
43impl Default for GradientClippingConfig {
44    fn default() -> Self {
45        Self {
46            max_norm: 1.0,
47            norm_type: NormType::L2,
48            adaptive_scaling: false,
49            adaptive_momentum: 0.95,
50            min_threshold: 0.1,
51            max_threshold: 10.0,
52            warmup_steps: 0,
53            per_parameter_clipping: false,
54        }
55    }
56}
57
58/// Type of norm to use for gradient clipping
59#[derive(Debug, Clone, Copy, PartialEq)]
60pub enum NormType {
61    /// L1 norm (sum of absolute values)
62    L1,
63    /// L2 norm (Euclidean norm, most common)
64    L2,
65    /// Infinity norm (maximum absolute value)
66    Infinity,
67}
68
69/// Statistics tracking for gradient clipping
70#[derive(Debug, Clone)]
71pub struct GradientStatistics {
72    /// Current global gradient norm
73    pub current_norm: f64,
74    /// Exponential moving average of gradient norms
75    pub avg_norm: f64,
76    /// Standard deviation of recent gradient norms
77    pub std_norm: f64,
78    /// Number of times gradients were clipped
79    pub clip_count: usize,
80    /// Total number of gradient updates processed
81    pub total_updates: usize,
82    /// Current adaptive threshold (if adaptive scaling is enabled)
83    pub adaptive_threshold: f64,
84    /// History of recent gradient norms (for statistics calculation)
85    pub norm_history: Vec<f64>,
86}
87
88impl Default for GradientStatistics {
89    fn default() -> Self {
90        Self {
91            current_norm: 0.0,
92            avg_norm: 0.0,
93            std_norm: 0.0,
94            clip_count: 0,
95            total_updates: 0,
96            adaptive_threshold: 1.0,
97            norm_history: Vec::with_capacity(100), // Keep last 100 norms
98        }
99    }
100}
101
102/// Advanced gradient clipping system with adaptive scaling
103pub struct GradientClipper<T> {
104    config: GradientClippingConfig,
105    statistics: GradientStatistics,
106    parameter_groups: HashMap<String, f64>, // Group name -> threshold
107    step_count: usize,
108    _phantom: PhantomData<T>,
109}
110
111impl<T> GradientClipper<T>
112where
113    T: Float + FromPrimitive + Clone + Send + Sync + Default + 'static,
114{
115    /// Create a new gradient clipper with the specified configuration
116    pub fn new(config: GradientClippingConfig) -> Self {
117        Self {
118            config,
119            statistics: GradientStatistics::default(),
120            parameter_groups: HashMap::new(),
121            step_count: 0,
122            _phantom: PhantomData,
123        }
124    }
125
126    /// Create a gradient clipper with default settings for stable training
127    pub fn default_stable() -> Self {
128        Self::new(GradientClippingConfig {
129            max_norm: 1.0,
130            norm_type: NormType::L2,
131            adaptive_scaling: false,
132            ..Default::default()
133        })
134    }
135
136    /// Create a gradient clipper with adaptive scaling for dynamic adjustment
137    pub fn default_adaptive() -> Self {
138        Self::new(GradientClippingConfig {
139            max_norm: 1.0,
140            norm_type: NormType::L2,
141            adaptive_scaling: true,
142            adaptive_momentum: 0.95,
143            min_threshold: 0.1,
144            max_threshold: 5.0,
145            ..Default::default()
146        })
147    }
148
149    /// Add a parameter group with its own clipping threshold
150    pub fn add_parameter_group(&mut self, group_name: String, threshold: f64) {
151        self.parameter_groups.insert(group_name, threshold);
152    }
153
154    /// Clip gradients using global norm clipping
155    ///
156    /// This is the main method for applying gradient clipping. It computes the global
157    /// gradient norm across all tensors and scales them proportionally if needed.
158    pub fn clip_gradients(&mut self, gradients: &mut [Tensor<T>]) -> Result<f64> {
159        if gradients.is_empty() {
160            return Ok(0.0);
161        }
162
163        self.step_count += 1;
164
165        // Compute global gradient norm
166        let global_norm = self.compute_global_norm(gradients)?;
167
168        // Update statistics
169        self.update_statistics(global_norm);
170
171        // Determine effective clipping threshold
172        let effective_threshold = self.get_effective_threshold();
173
174        // Apply warmup if configured
175        let warmed_threshold = if self.step_count <= self.config.warmup_steps {
176            let warmup_factor = self.step_count as f64 / self.config.warmup_steps as f64;
177            effective_threshold * warmup_factor + self.config.max_norm * (1.0 - warmup_factor)
178        } else {
179            effective_threshold
180        };
181
182        // Apply clipping if necessary
183        if global_norm > warmed_threshold {
184            let scale_factor = warmed_threshold / global_norm;
185            self.scale_gradients(
186                gradients,
187                T::from_f64(scale_factor).unwrap_or_else(|| T::one()),
188            )?;
189            self.statistics.clip_count += 1;
190        }
191
192        Ok(global_norm)
193    }
194
195    /// Clip gradients for a specific parameter group
196    pub fn clip_parameter_group(
197        &mut self,
198        group_name: &str,
199        gradients: &mut [Tensor<T>],
200    ) -> Result<f64> {
201        let threshold = self
202            .parameter_groups
203            .get(group_name)
204            .copied()
205            .unwrap_or(self.config.max_norm);
206
207        let global_norm = self.compute_global_norm(gradients)?;
208
209        if global_norm > threshold {
210            let scale_factor = threshold / global_norm;
211            self.scale_gradients(
212                gradients,
213                T::from_f64(scale_factor).unwrap_or_else(|| T::one()),
214            )?;
215        }
216
217        Ok(global_norm)
218    }
219
220    /// Compute the global norm of all gradients
221    fn compute_global_norm(&self, gradients: &[Tensor<T>]) -> Result<f64> {
222        match self.config.norm_type {
223            NormType::L1 => {
224                let mut total_norm = 0.0;
225                for grad in gradients {
226                    total_norm += self.compute_tensor_l1_norm(grad)?;
227                }
228                Ok(total_norm)
229            }
230            NormType::L2 => {
231                let mut total_squared_norm = 0.0;
232                for grad in gradients {
233                    let tensor_norm = self.compute_tensor_l2_norm(grad)?;
234                    total_squared_norm += tensor_norm * tensor_norm;
235                }
236                Ok(total_squared_norm.sqrt())
237            }
238            NormType::Infinity => {
239                let mut max_norm = 0.0;
240                for grad in gradients {
241                    let tensor_max = self.compute_tensor_inf_norm(grad)?;
242                    max_norm = max_norm.max(tensor_max);
243                }
244                Ok(max_norm)
245            }
246        }
247    }
248
249    /// Compute L1 norm of a tensor
250    fn compute_tensor_l1_norm(&self, tensor: &Tensor<T>) -> Result<f64> {
251        // This is a simplified implementation - in practice, you'd use the tensor's
252        // actual data to compute the norm
253        match &tensor.storage {
254            crate::tensor::TensorStorage::Cpu(array) => {
255                let sum: f64 = array.iter().map(|&x| x.abs().to_f64().unwrap_or(0.0)).sum();
256                Ok(sum)
257            }
258            #[cfg(feature = "gpu")]
259            crate::tensor::TensorStorage::Gpu(_) => {
260                // For GPU tensors, we'd need to implement GPU norm computation
261                // For now, return a placeholder
262                Err(crate::TensorError::unsupported_operation_simple(
263                    "GPU L1 norm computation not yet implemented".to_string(),
264                ))
265            }
266        }
267    }
268
269    /// Compute L2 norm of a tensor
270    fn compute_tensor_l2_norm(&self, tensor: &Tensor<T>) -> Result<f64> {
271        match &tensor.storage {
272            crate::tensor::TensorStorage::Cpu(array) => {
273                let sum_squares: f64 = array
274                    .iter()
275                    .map(|&x| {
276                        let val = x.to_f64().unwrap_or(0.0);
277                        val * val
278                    })
279                    .sum();
280                Ok(sum_squares.sqrt())
281            }
282            #[cfg(feature = "gpu")]
283            crate::tensor::TensorStorage::Gpu(_) => {
284                // For GPU tensors, we'd need to implement GPU norm computation
285                Err(crate::TensorError::unsupported_operation_simple(
286                    "GPU L2 norm computation not yet implemented".to_string(),
287                ))
288            }
289        }
290    }
291
292    /// Compute infinity norm of a tensor
293    fn compute_tensor_inf_norm(&self, tensor: &Tensor<T>) -> Result<f64> {
294        match &tensor.storage {
295            crate::tensor::TensorStorage::Cpu(array) => {
296                let max_val = array
297                    .iter()
298                    .map(|&x| x.abs().to_f64().unwrap_or(0.0))
299                    .fold(0.0, f64::max);
300                Ok(max_val)
301            }
302            #[cfg(feature = "gpu")]
303            crate::tensor::TensorStorage::Gpu(_) => {
304                // For GPU tensors, we'd need to implement GPU norm computation
305                Err(crate::TensorError::unsupported_operation_simple(
306                    "GPU infinity norm computation not yet implemented".to_string(),
307                ))
308            }
309        }
310    }
311
312    /// Scale all gradients by a constant factor
313    fn scale_gradients(&self, gradients: &mut [Tensor<T>], scale_factor: T) -> Result<()> {
314        for grad in gradients.iter_mut() {
315            *grad = grad.mul_scalar(scale_factor)?;
316        }
317        Ok(())
318    }
319
320    /// Update gradient statistics for adaptive scaling
321    fn update_statistics(&mut self, global_norm: f64) {
322        self.statistics.current_norm = global_norm;
323        self.statistics.total_updates += 1;
324
325        // Update exponential moving average
326        if self.statistics.total_updates == 1 {
327            self.statistics.avg_norm = global_norm;
328        } else {
329            let momentum = self.config.adaptive_momentum;
330            self.statistics.avg_norm =
331                momentum * self.statistics.avg_norm + (1.0 - momentum) * global_norm;
332        }
333
334        // Update norm history for standard deviation calculation
335        self.statistics.norm_history.push(global_norm);
336        if self.statistics.norm_history.len() > 100 {
337            self.statistics.norm_history.remove(0);
338        }
339
340        // Calculate standard deviation
341        if self.statistics.norm_history.len() > 1 {
342            let mean = self.statistics.avg_norm;
343            let variance: f64 = self
344                .statistics
345                .norm_history
346                .iter()
347                .map(|&x| (x - mean).powi(2))
348                .sum::<f64>()
349                / (self.statistics.norm_history.len() - 1) as f64;
350            self.statistics.std_norm = variance.sqrt();
351        }
352
353        // Update adaptive threshold if enabled
354        if self.config.adaptive_scaling {
355            self.update_adaptive_threshold();
356        }
357    }
358
359    /// Update the adaptive clipping threshold based on gradient statistics
360    fn update_adaptive_threshold(&mut self) {
361        let base_threshold = self.config.max_norm;
362
363        // Adjust threshold based on gradient variance and average
364        // Higher variance suggests need for more aggressive clipping
365        let variance_factor = if self.statistics.std_norm > 0.0 {
366            (self.statistics.std_norm / self.statistics.avg_norm).min(2.0)
367        } else {
368            1.0
369        };
370
371        // Adaptive adjustment based on recent clipping frequency
372        let recent_clip_rate = if self.statistics.total_updates > 0 {
373            self.statistics.clip_count as f64 / self.statistics.total_updates as f64
374        } else {
375            0.0
376        };
377
378        // If clipping too frequently, reduce threshold; if rarely clipping, increase threshold
379        let frequency_adjustment = if recent_clip_rate > 0.5 {
380            0.9 // Reduce threshold by 10%
381        } else if recent_clip_rate < 0.1 {
382            1.1 // Increase threshold by 10%
383        } else {
384            1.0 // Keep current threshold
385        };
386
387        let new_threshold = base_threshold * variance_factor * frequency_adjustment;
388
389        // Clamp to configured bounds
390        self.statistics.adaptive_threshold = new_threshold
391            .max(self.config.min_threshold)
392            .min(self.config.max_threshold);
393    }
394
395    /// Get the effective clipping threshold (adaptive or fixed)
396    fn get_effective_threshold(&self) -> f64 {
397        if self.config.adaptive_scaling {
398            self.statistics.adaptive_threshold
399        } else {
400            self.config.max_norm
401        }
402    }
403
404    /// Get current gradient statistics
405    pub fn get_statistics(&self) -> &GradientStatistics {
406        &self.statistics
407    }
408
409    /// Get the current configuration
410    pub fn get_config(&self) -> &GradientClippingConfig {
411        &self.config
412    }
413
414    /// Reset statistics (useful for training phase transitions)
415    pub fn reset_statistics(&mut self) {
416        self.statistics = GradientStatistics::default();
417        self.step_count = 0;
418    }
419
420    /// Get clipping rate (percentage of updates where clipping was applied)
421    pub fn get_clipping_rate(&self) -> f64 {
422        if self.statistics.total_updates > 0 {
423            self.statistics.clip_count as f64 / self.statistics.total_updates as f64
424        } else {
425            0.0
426        }
427    }
428
429    /// Check if gradients would be clipped with current threshold
430    pub fn would_clip(&self, gradients: &[Tensor<T>]) -> Result<bool> {
431        let global_norm = self.compute_global_norm(gradients)?;
432        Ok(global_norm > self.get_effective_threshold())
433    }
434}
435
436impl<T> Tensor<T>
437where
438    T: Float + FromPrimitive + Clone + Send + Sync + Default + 'static,
439{
440    /// Convenience method to multiply tensor by a scalar
441    pub fn mul_scalar(&self, scalar: T) -> Result<Tensor<T>> {
442        match &self.storage {
443            crate::tensor::TensorStorage::Cpu(array) => {
444                let scaled_array = array.mapv(|x| x * scalar);
445                Ok(Tensor::from_array(scaled_array))
446            }
447            #[cfg(feature = "gpu")]
448            crate::tensor::TensorStorage::Gpu(_) => {
449                // For GPU tensors, we'd implement GPU scalar multiplication
450                Err(crate::TensorError::unsupported_operation_simple(
451                    "GPU scalar multiplication not yet implemented".to_string(),
452                ))
453            }
454        }
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461    use scirs2_core::ndarray::Array1;
462
463    #[test]
464    fn test_gradient_clipping_basic() {
465        let mut clipper = GradientClipper::<f32>::default_stable();
466
467        // Create test gradients with large norms
468        let large_grad = Tensor::from_array(Array1::from_vec(vec![5.0, 5.0, 5.0, 5.0]).into_dyn());
469        let mut gradients = vec![large_grad];
470
471        let norm = clipper
472            .clip_gradients(&mut gradients)
473            .expect("test: clip_gradients should succeed");
474
475        // Should have clipped since norm > 1.0
476        assert!(norm > 1.0);
477        assert_eq!(clipper.get_statistics().clip_count, 1);
478    }
479
480    #[test]
481    fn test_adaptive_clipping() {
482        let mut clipper = GradientClipper::<f32>::default_adaptive();
483
484        // Simulate multiple gradient updates to test adaptive behavior
485        for i in 0..10 {
486            let scale = 1.0 + i as f32 * 0.5;
487            let grad = Tensor::from_array(Array1::from_vec(vec![scale, scale]).into_dyn());
488            let mut gradients = vec![grad];
489
490            let _norm = clipper
491                .clip_gradients(&mut gradients)
492                .expect("test: clip_gradients should succeed");
493        }
494
495        // Adaptive threshold should have adjusted
496        let stats = clipper.get_statistics();
497        assert!(stats.total_updates == 10);
498        assert!(stats.adaptive_threshold > 0.0);
499    }
500
501    #[test]
502    fn test_parameter_groups() {
503        let mut clipper = GradientClipper::<f32>::new(GradientClippingConfig {
504            per_parameter_clipping: true,
505            ..Default::default()
506        });
507
508        clipper.add_parameter_group("embeddings".to_string(), 0.5);
509        clipper.add_parameter_group("output".to_string(), 2.0);
510
511        let grad = Tensor::from_array(Array1::from_vec(vec![1.5, 1.5]).into_dyn());
512        let mut gradients = vec![grad];
513
514        // Should clip with embedding threshold (0.5)
515        let norm = clipper
516            .clip_parameter_group("embeddings", &mut gradients)
517            .expect("test: operation should succeed");
518        assert!(norm > 0.5);
519    }
520
521    #[test]
522    fn test_different_norm_types() {
523        let l1_config = GradientClippingConfig {
524            norm_type: NormType::L1,
525            max_norm: 4.0,
526            ..Default::default()
527        };
528        let mut l1_clipper = GradientClipper::<f32>::new(l1_config);
529
530        let grad = Tensor::from_array(Array1::from_vec(vec![2.0, 2.0]).into_dyn());
531        let mut gradients = vec![grad];
532
533        let norm = l1_clipper
534            .clip_gradients(&mut gradients)
535            .expect("test: clip_gradients should succeed");
536        assert_eq!(norm, 4.0); // L1 norm should be 4.0
537    }
538}