Skip to main content

scirs2_autograd/
gradient_clipping.rs

1//! Gradient clipping utilities
2//!
3//! Gradient clipping is a technique used to prevent the exploding gradient problem
4//! in deep learning by constraining the gradients to a reasonable range or magnitude.
5//! This module provides various gradient clipping strategies.
6
7use crate::tensor::Tensor;
8use crate::tensor_ops;
9use crate::Float;
10
11/// Trait for gradient clipping strategies
12///
13/// Gradient clipping modifies gradients to prevent exploding gradients while
14/// preserving the direction of optimization.
15pub trait GradientClipper<F: Float> {
16    /// Apply gradient clipping to a list of gradients
17    ///
18    /// # Arguments
19    /// * `gradients` - Slice of gradient tensors to clip
20    ///
21    /// # Returns
22    /// Vector of clipped gradient tensors
23    fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>>;
24
25    /// Check if clipping was applied in the last call to clip_gradients
26    ///
27    /// This can be useful for monitoring whether gradients are being clipped.
28    fn was_clipped(&self) -> bool {
29        // Default implementation - individual clippers can override
30        false
31    }
32
33    /// Get statistics about the last clipping operation
34    ///
35    /// Returns information that can be used for logging or monitoring.
36    fn get_clipping_stats(&self) -> ClippingStats<F> {
37        ClippingStats::default()
38    }
39}
40
41/// Statistics about gradient clipping operations
42#[derive(Debug, Clone)]
43pub struct ClippingStats<F: Float> {
44    /// Whether clipping was applied
45    pub was_clipped: bool,
46    /// Original gradient norm (before clipping)
47    pub original_norm: Option<F>,
48    /// Clipped gradient norm (after clipping)
49    pub clipped_norm: Option<F>,
50    /// Clipping factor applied
51    pub clipping_factor: Option<F>,
52    /// Number of gradients that were clipped
53    pub num_clipped: usize,
54    /// Total number of gradients processed
55    pub total_gradients: usize,
56}
57
58impl<F: Float> Default for ClippingStats<F> {
59    fn default() -> Self {
60        Self {
61            was_clipped: false,
62            original_norm: None,
63            clipped_norm: None,
64            clipping_factor: None,
65            num_clipped: 0,
66            total_gradients: 0,
67        }
68    }
69}
70
71/// Clip gradients by value
72///
73/// Clips each element of each gradient tensor to be within the range [min_value, max_value].
74/// This is the simplest form of gradient clipping.
75///
76/// # Example
77/// ```
78/// use scirs2_autograd as ag;
79/// use scirs2_autograd::gradient_clipping::{ClipByValue, GradientClipper};
80/// use scirs2_autograd::tensor_ops::convert_to_tensor;
81///
82/// let mut env = ag::VariableEnvironment::new();
83/// let mut rng = ag::ndarray_ext::ArrayRng::<f32>::default();
84///
85/// env.run(|g| {
86///     // Create some example gradients
87///     let grad1 = convert_to_tensor(rng.standard_normal(&[2, 2]), g);
88///     let grad2 = convert_to_tensor(rng.standard_normal(&[3]), g);
89///     let gradients = vec![grad1, grad2];
90///
91///     let mut clipper = ClipByValue::new(-1.0f32, 1.0f32);
92///     let _clipped_gradients = clipper.clip_gradients(&gradients);
93/// });
94/// ```
95pub struct ClipByValue<F: Float> {
96    pub min_value: F,
97    pub max_value: F,
98    last_clipped: std::cell::Cell<bool>,
99}
100
101impl<F: Float> ClipByValue<F> {
102    /// Create a new value-based gradient clipper
103    ///
104    /// # Arguments
105    /// * `min_value` - Minimum allowed gradient value
106    /// * `max_value` - Maximum allowed gradient value
107    ///
108    /// # Panics
109    /// Panics if `min_value` >= `max_value`
110    pub fn new(min_value: F, max_value: F) -> Self {
111        assert!(
112            min_value < max_value,
113            "min_value must be less than max_value"
114        );
115
116        Self {
117            min_value,
118            max_value,
119            last_clipped: std::cell::Cell::new(false),
120        }
121    }
122
123    /// Create a symmetric value clipper
124    ///
125    /// Creates a clipper that clips values to [-max_abs_value, max_abs_value].
126    ///
127    /// # Arguments
128    /// * `max_abs_value` - Maximum absolute value allowed
129    pub fn symmetric(max_abs_value: F) -> Self {
130        Self::new(-max_abs_value, max_abs_value)
131    }
132}
133
134impl<F: Float> GradientClipper<F> for ClipByValue<F> {
135    fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
136        let any_clipped = false;
137
138        let clipped: Vec<_> = gradients
139            .iter()
140            .map(|grad| {
141                let clipped_grad = tensor_ops::clip(*grad, self.min_value, self.max_value);
142                // Note: In a real implementation, we'd want to check if actual clipping occurred
143                // For now, we assume clipping may have occurred if the operation was performed
144                clipped_grad
145            })
146            .collect();
147
148        self.last_clipped.set(any_clipped);
149        clipped
150    }
151
152    fn was_clipped(&self) -> bool {
153        self.last_clipped.get()
154    }
155}
156
157/// Clip gradients by norm
158///
159/// Clips the norm of each individual gradient tensor. If the L2 norm of a gradient
160/// exceeds the maximum norm, the gradient is scaled down proportionally.
161///
162/// For a gradient g with norm ||g||, if ||g|| > max_norm, then:
163/// g_clipped = g * (max_norm / ||g||)
164///
165/// # Example
166/// ```
167/// use scirs2_autograd as ag;
168/// use scirs2_autograd::gradient_clipping::{ClipByNorm, GradientClipper};
169/// use scirs2_autograd::tensor_ops::convert_to_tensor;
170///
171/// let mut env = ag::VariableEnvironment::new();
172/// let mut rng = ag::ndarray_ext::ArrayRng::<f32>::default();
173///
174/// env.run(|g| {
175///     // Create some example gradients
176///     let grad1 = convert_to_tensor(rng.standard_normal(&[2, 2]), g);
177///     let grad2 = convert_to_tensor(rng.standard_normal(&[3]), g);
178///     let gradients = vec![grad1, grad2];
179///
180///     let mut clipper = ClipByNorm::new(1.0f32);
181///     let _clipped_gradients = clipper.clip_gradients(&gradients);
182/// });
183/// ```
184pub struct ClipByNorm<F: Float> {
185    pub max_norm: F,
186    last_clipped: std::cell::Cell<bool>,
187    last_stats: std::cell::RefCell<ClippingStats<F>>,
188}
189
190impl<F: Float> ClipByNorm<F> {
191    /// Create a new norm-based gradient clipper
192    ///
193    /// # Arguments
194    /// * `max_norm` - Maximum allowed L2 norm for gradients
195    ///
196    /// # Panics
197    /// Panics if `max_norm` is not positive
198    pub fn new(max_norm: F) -> Self {
199        assert!(max_norm > F::zero(), "max_norm must be positive");
200
201        Self {
202            max_norm,
203            last_clipped: std::cell::Cell::new(false),
204            last_stats: std::cell::RefCell::new(ClippingStats::default()),
205        }
206    }
207}
208
209impl<F: Float> GradientClipper<F> for ClipByNorm<F> {
210    fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
211        let any_clipped = false;
212        let num_clipped = 0;
213
214        let clipped: Vec<_> = gradients
215            .iter()
216            .map(|grad| {
217                // Compute the Frobenius norm of the gradient (equivalent to L2 norm for vectors)
218                let grad_norm = tensor_ops::frobenius_norm(grad);
219
220                // Create scalar tensors for comparison
221                let max_norm_tensor = tensor_ops::scalar(self.max_norm, grad.graph());
222                let one_tensor = tensor_ops::scalar(F::one(), grad.graph());
223
224                // Compute clipping factor: min(1.0, max_norm / grad_norm)
225                let ratio = max_norm_tensor / grad_norm;
226                let clipping_factor = tensor_ops::minimum(one_tensor, ratio);
227
228                // Note: In a full implementation, we'd track whether clipping actually occurred
229                // For simplicity, we assume clipping may have occurred
230                (*grad) * clipping_factor
231            })
232            .collect();
233
234        self.last_clipped.set(any_clipped);
235
236        // Update stats
237        let mut stats = self.last_stats.borrow_mut();
238        stats.was_clipped = any_clipped;
239        stats.num_clipped = num_clipped;
240        stats.total_gradients = gradients.len();
241
242        clipped
243    }
244
245    fn was_clipped(&self) -> bool {
246        self.last_clipped.get()
247    }
248
249    fn get_clipping_stats(&self) -> ClippingStats<F> {
250        self.last_stats.borrow().clone()
251    }
252}
253
254/// Clip gradients by global norm
255///
256/// Clips all gradients jointly based on their global norm. The global norm is
257/// computed as the L2 norm of the concatenation of all gradient vectors.
258///
259/// If the global norm exceeds max_norm, all gradients are scaled by the same factor:
260/// scaling_factor = max_norm / global_norm
261///
262/// This method preserves the relative magnitudes between different gradients
263/// while ensuring the overall gradient update is not too large.
264///
265/// # Example
266/// ```
267/// use scirs2_autograd as ag;
268/// use scirs2_autograd::gradient_clipping::{ClipByGlobalNorm, GradientClipper};
269/// use scirs2_autograd::tensor_ops::convert_to_tensor;
270///
271/// let mut env = ag::VariableEnvironment::new();
272/// let mut rng = ag::ndarray_ext::ArrayRng::<f32>::default();
273///
274/// env.run(|g| {
275///     // Create some example gradients
276///     let grad1 = convert_to_tensor(rng.standard_normal(&[2, 2]), g);
277///     let grad2 = convert_to_tensor(rng.standard_normal(&[3]), g);
278///     let gradients = vec![grad1, grad2];
279///
280///     let mut clipper = ClipByGlobalNorm::new(1.0f32);
281///     let _clipped_gradients = clipper.clip_gradients(&gradients);
282/// });
283/// ```
284pub struct ClipByGlobalNorm<F: Float> {
285    pub max_norm: F,
286    last_clipped: std::cell::Cell<bool>,
287    last_stats: std::cell::RefCell<ClippingStats<F>>,
288}
289
290impl<F: Float> ClipByGlobalNorm<F> {
291    /// Create a new global norm-based gradient clipper
292    ///
293    /// # Arguments
294    /// * `max_norm` - Maximum allowed global norm for all gradients combined
295    ///
296    /// # Panics
297    /// Panics if `max_norm` is not positive
298    pub fn new(max_norm: F) -> Self {
299        assert!(max_norm > F::zero(), "max_norm must be positive");
300
301        Self {
302            max_norm,
303            last_clipped: std::cell::Cell::new(false),
304            last_stats: std::cell::RefCell::new(ClippingStats::default()),
305        }
306    }
307}
308
309impl<F: Float> GradientClipper<F> for ClipByGlobalNorm<F> {
310    fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
311        if gradients.is_empty() {
312            return Vec::new();
313        }
314
315        let g = gradients[0].graph();
316
317        // Compute global norm: sqrt(sum(norm(grad_i)^2))
318        let squared_norms: Vec<_> = gradients
319            .iter()
320            .map(|grad| {
321                let norm = tensor_ops::frobenius_norm(grad);
322                tensor_ops::square(norm)
323            })
324            .collect();
325
326        let global_norm_squared = tensor_ops::add_n(&squared_norms);
327        let global_norm = tensor_ops::sqrt(global_norm_squared);
328
329        // Compute clipping factor
330        let max_norm_tensor = tensor_ops::scalar(self.max_norm, g);
331        let one_tensor = tensor_ops::scalar(F::one(), g);
332        let ratio = max_norm_tensor / global_norm;
333        let clipping_factor = tensor_ops::minimum(one_tensor, ratio);
334
335        // Apply the same clipping factor to all gradients
336        let clipped: Vec<_> = gradients
337            .iter()
338            .map(|grad| (*grad) * clipping_factor)
339            .collect();
340
341        // Note: In a full implementation, we'd evaluate global_norm and check if clipping occurred
342        let was_clipped = false; // Placeholder - would need evaluation to determine
343
344        self.last_clipped.set(was_clipped);
345
346        // Update stats
347        let mut stats = self.last_stats.borrow_mut();
348        stats.was_clipped = was_clipped;
349        stats.total_gradients = gradients.len();
350        stats.num_clipped = if was_clipped { gradients.len() } else { 0 };
351
352        clipped
353    }
354
355    fn was_clipped(&self) -> bool {
356        self.last_clipped.get()
357    }
358
359    fn get_clipping_stats(&self) -> ClippingStats<F> {
360        self.last_stats.borrow().clone()
361    }
362}
363
364/// Adaptive gradient clipper
365///
366/// Adjusts the clipping threshold based on the history of gradient norms.
367/// This can help automatically tune the clipping threshold during training.
368pub struct AdaptiveClipByNorm<F: Float> {
369    base_clipper: ClipByNorm<F>,
370    #[allow(dead_code)]
371    adaptation_rate: F,
372    current_threshold: std::cell::Cell<F>,
373}
374
375impl<F: Float> AdaptiveClipByNorm<F> {
376    /// Create a new adaptive gradient clipper
377    ///
378    /// # Arguments
379    /// * `initial_max_norm` - Initial maximum norm threshold
380    /// * `adaptation_rate` - Rate at which to adapt the threshold (0.0 to 1.0)
381    pub fn new(initial_max_norm: F, adaptation_rate: F) -> Self {
382        assert!(
383            adaptation_rate >= F::zero() && adaptation_rate <= F::one(),
384            "adaptation_rate must be between 0.0 and 1.0"
385        );
386
387        Self {
388            base_clipper: ClipByNorm::new(initial_max_norm),
389            adaptation_rate,
390            current_threshold: std::cell::Cell::new(initial_max_norm),
391        }
392    }
393
394    /// Get the current adaptive threshold
395    pub fn current_threshold(&self) -> F {
396        self.current_threshold.get()
397    }
398
399    /// Manually update the threshold (for external adaptation logic)
400    pub fn set_threshold(&self, new_threshold: F) {
401        assert!(new_threshold > F::zero(), "threshold must be positive");
402        self.current_threshold.set(new_threshold);
403    }
404}
405
406impl<F: Float> GradientClipper<F> for AdaptiveClipByNorm<F> {
407    fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
408        // Update the base clipper's threshold
409        let current_threshold = self.current_threshold.get();
410        self.base_clipper.max_norm = current_threshold;
411
412        // Apply clipping with current threshold
413        let result = self.base_clipper.clip_gradients(gradients);
414
415        // Note: In a full implementation, we'd compute actual gradient norms
416        // and adapt the threshold based on recent history
417        // For now, this is a placeholder for the adaptation logic
418
419        result
420    }
421
422    fn was_clipped(&self) -> bool {
423        self.base_clipper.was_clipped()
424    }
425
426    fn get_clipping_stats(&self) -> ClippingStats<F> {
427        self.base_clipper.get_clipping_stats()
428    }
429}
430
431/// Convenience functions for gradient clipping
432impl<F: Float> Tensor<'_, F> {
433    /// Clip this tensor's values to a range
434    ///
435    /// # Arguments
436    /// * `min_value` - Minimum allowed value
437    /// * `max_value` - Maximum allowed value
438    pub fn clip_values(self, min_value: F, max_value: F) -> Self {
439        tensor_ops::clip(self, min_value, max_value)
440    }
441
442    /// Clip this tensor's norm to a maximum value
443    ///
444    /// # Arguments
445    /// * `max_norm` - Maximum allowed norm
446    pub fn clip_norm(self, max_norm: F) -> Self {
447        let norm = tensor_ops::frobenius_norm(self);
448        let max_norm_tensor = tensor_ops::scalar(max_norm, self.graph());
449        let one_tensor = tensor_ops::scalar(F::one(), self.graph());
450        let ratio = max_norm_tensor / norm;
451        let clipping_factor = tensor_ops::minimum(one_tensor, ratio);
452        self * clipping_factor
453    }
454}
455
456/// Common gradient clipping presets
457pub mod presets {
458    use super::*;
459
460    /// Create a conservative gradient clipper for fine-tuning
461    pub fn conservative<F: Float>() -> ClipByGlobalNorm<F> {
462        ClipByGlobalNorm::new(F::from(0.5).expect("Failed to convert constant to float"))
463    }
464
465    /// Create a standard gradient clipper for general training
466    pub fn standard<F: Float>() -> ClipByGlobalNorm<F> {
467        ClipByGlobalNorm::new(F::from(1.0).expect("Failed to convert constant to float"))
468    }
469
470    /// Create an aggressive gradient clipper for unstable training
471    pub fn aggressive<F: Float>() -> ClipByGlobalNorm<F> {
472        ClipByGlobalNorm::new(F::from(0.1).expect("Failed to convert constant to float"))
473    }
474
475    /// Create a value-based clipper for preventing extreme gradients
476    pub fn extreme_prevention<F: Float>() -> ClipByValue<F> {
477        ClipByValue::symmetric(F::from(10.0).expect("Failed to convert constant to float"))
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484
485    #[test]
486    fn test_clip_by_value_creation() {
487        let clipper = ClipByValue::new(-1.0f32, 1.0f32);
488        assert_eq!(clipper.min_value, -1.0);
489        assert_eq!(clipper.max_value, 1.0);
490
491        let symmetric = ClipByValue::symmetric(0.5f32);
492        assert_eq!(symmetric.min_value, -0.5);
493        assert_eq!(symmetric.max_value, 0.5);
494    }
495
496    #[test]
497    fn test_clip_by_norm_creation() {
498        let clipper = ClipByNorm::new(1.0f32);
499        assert_eq!(clipper.max_norm, 1.0);
500    }
501
502    #[test]
503    fn test_clip_by_global_norm_creation() {
504        let clipper = ClipByGlobalNorm::new(1.0f32);
505        assert_eq!(clipper.max_norm, 1.0);
506    }
507
508    #[test]
509    fn test_adaptive_clipper() {
510        let clipper = AdaptiveClipByNorm::new(1.0f32, 0.1);
511        assert_eq!(clipper.current_threshold(), 1.0);
512
513        clipper.set_threshold(0.5);
514        assert_eq!(clipper.current_threshold(), 0.5);
515    }
516
517    #[test]
518    fn test_clipping_stats_default() {
519        let stats = ClippingStats::<f32>::default();
520        assert!(!stats.was_clipped);
521        assert_eq!(stats.num_clipped, 0);
522        assert_eq!(stats.total_gradients, 0);
523    }
524
525    #[test]
526    fn test_presets() {
527        let _conservative = presets::conservative::<f32>();
528        let _standard = presets::standard::<f32>();
529        let _aggressive = presets::aggressive::<f32>();
530        let _extreme = presets::extreme_prevention::<f32>();
531    }
532
533    #[test]
534    #[should_panic(expected = "min_value must be less than max_value")]
535    fn test_clip_by_value_invalid_range() {
536        ClipByValue::new(1.0f32, -1.0f32);
537    }
538
539    #[test]
540    #[should_panic(expected = "max_norm must be positive")]
541    fn test_clip_by_norm_negative_norm() {
542        ClipByNorm::new(-1.0f32);
543    }
544
545    #[test]
546    #[should_panic(expected = "adaptation_rate must be between 0.0 and 1.0")]
547    fn test_adaptive_clipper_invalid_rate() {
548        AdaptiveClipByNorm::new(1.0f32, 2.0);
549    }
550}