scirs2_autograd/gradient_clipping.rs
1//! Gradient clipping utilities
2//!
3//! Gradient clipping is a technique used to prevent the exploding gradient problem
4//! in deep learning by constraining the gradients to a reasonable range or magnitude.
5//! This module provides various gradient clipping strategies.
6
7use crate::tensor::Tensor;
8use crate::tensor_ops;
9use crate::Float;
10
11/// Trait for gradient clipping strategies
12///
13/// Gradient clipping modifies gradients to prevent exploding gradients while
14/// preserving the direction of optimization.
15pub trait GradientClipper<F: Float> {
16 /// Apply gradient clipping to a list of gradients
17 ///
18 /// # Arguments
19 /// * `gradients` - Slice of gradient tensors to clip
20 ///
21 /// # Returns
22 /// Vector of clipped gradient tensors
23 fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>>;
24
25 /// Check if clipping was applied in the last call to clip_gradients
26 ///
27 /// This can be useful for monitoring whether gradients are being clipped.
28 fn was_clipped(&self) -> bool {
29 // Default implementation - individual clippers can override
30 false
31 }
32
33 /// Get statistics about the last clipping operation
34 ///
35 /// Returns information that can be used for logging or monitoring.
36 fn get_clipping_stats(&self) -> ClippingStats<F> {
37 ClippingStats::default()
38 }
39}
40
41/// Statistics about gradient clipping operations
42#[derive(Debug, Clone)]
43pub struct ClippingStats<F: Float> {
44 /// Whether clipping was applied
45 pub was_clipped: bool,
46 /// Original gradient norm (before clipping)
47 pub original_norm: Option<F>,
48 /// Clipped gradient norm (after clipping)
49 pub clipped_norm: Option<F>,
50 /// Clipping factor applied
51 pub clipping_factor: Option<F>,
52 /// Number of gradients that were clipped
53 pub num_clipped: usize,
54 /// Total number of gradients processed
55 pub total_gradients: usize,
56}
57
58impl<F: Float> Default for ClippingStats<F> {
59 fn default() -> Self {
60 Self {
61 was_clipped: false,
62 original_norm: None,
63 clipped_norm: None,
64 clipping_factor: None,
65 num_clipped: 0,
66 total_gradients: 0,
67 }
68 }
69}
70
71/// Clip gradients by value
72///
73/// Clips each element of each gradient tensor to be within the range [min_value, max_value].
74/// This is the simplest form of gradient clipping.
75///
76/// # Example
77/// ```
78/// use scirs2_autograd as ag;
79/// use scirs2_autograd::gradient_clipping::{ClipByValue, GradientClipper};
80/// use scirs2_autograd::tensor_ops::convert_to_tensor;
81///
82/// let mut env = ag::VariableEnvironment::new();
83/// let mut rng = ag::ndarray_ext::ArrayRng::<f32>::default();
84///
85/// env.run(|g| {
86/// // Create some example gradients
87/// let grad1 = convert_to_tensor(rng.standard_normal(&[2, 2]), g);
88/// let grad2 = convert_to_tensor(rng.standard_normal(&[3]), g);
89/// let gradients = vec![grad1, grad2];
90///
91/// let mut clipper = ClipByValue::new(-1.0f32, 1.0f32);
92/// let _clipped_gradients = clipper.clip_gradients(&gradients);
93/// });
94/// ```
95pub struct ClipByValue<F: Float> {
96 pub min_value: F,
97 pub max_value: F,
98 last_clipped: std::cell::Cell<bool>,
99}
100
101impl<F: Float> ClipByValue<F> {
102 /// Create a new value-based gradient clipper
103 ///
104 /// # Arguments
105 /// * `min_value` - Minimum allowed gradient value
106 /// * `max_value` - Maximum allowed gradient value
107 ///
108 /// # Panics
109 /// Panics if `min_value` >= `max_value`
110 pub fn new(min_value: F, max_value: F) -> Self {
111 assert!(
112 min_value < max_value,
113 "min_value must be less than max_value"
114 );
115
116 Self {
117 min_value,
118 max_value,
119 last_clipped: std::cell::Cell::new(false),
120 }
121 }
122
123 /// Create a symmetric value clipper
124 ///
125 /// Creates a clipper that clips values to [-max_abs_value, max_abs_value].
126 ///
127 /// # Arguments
128 /// * `max_abs_value` - Maximum absolute value allowed
129 pub fn symmetric(max_abs_value: F) -> Self {
130 Self::new(-max_abs_value, max_abs_value)
131 }
132}
133
134impl<F: Float> GradientClipper<F> for ClipByValue<F> {
135 fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
136 let any_clipped = false;
137
138 let clipped: Vec<_> = gradients
139 .iter()
140 .map(|grad| {
141 let clipped_grad = tensor_ops::clip(*grad, self.min_value, self.max_value);
142 // Note: In a real implementation, we'd want to check if actual clipping occurred
143 // For now, we assume clipping may have occurred if the operation was performed
144 clipped_grad
145 })
146 .collect();
147
148 self.last_clipped.set(any_clipped);
149 clipped
150 }
151
152 fn was_clipped(&self) -> bool {
153 self.last_clipped.get()
154 }
155}
156
157/// Clip gradients by norm
158///
159/// Clips the norm of each individual gradient tensor. If the L2 norm of a gradient
160/// exceeds the maximum norm, the gradient is scaled down proportionally.
161///
162/// For a gradient g with norm ||g||, if ||g|| > max_norm, then:
163/// g_clipped = g * (max_norm / ||g||)
164///
165/// # Example
166/// ```
167/// use scirs2_autograd as ag;
168/// use scirs2_autograd::gradient_clipping::{ClipByNorm, GradientClipper};
169/// use scirs2_autograd::tensor_ops::convert_to_tensor;
170///
171/// let mut env = ag::VariableEnvironment::new();
172/// let mut rng = ag::ndarray_ext::ArrayRng::<f32>::default();
173///
174/// env.run(|g| {
175/// // Create some example gradients
176/// let grad1 = convert_to_tensor(rng.standard_normal(&[2, 2]), g);
177/// let grad2 = convert_to_tensor(rng.standard_normal(&[3]), g);
178/// let gradients = vec![grad1, grad2];
179///
180/// let mut clipper = ClipByNorm::new(1.0f32);
181/// let _clipped_gradients = clipper.clip_gradients(&gradients);
182/// });
183/// ```
184pub struct ClipByNorm<F: Float> {
185 pub max_norm: F,
186 last_clipped: std::cell::Cell<bool>,
187 last_stats: std::cell::RefCell<ClippingStats<F>>,
188}
189
190impl<F: Float> ClipByNorm<F> {
191 /// Create a new norm-based gradient clipper
192 ///
193 /// # Arguments
194 /// * `max_norm` - Maximum allowed L2 norm for gradients
195 ///
196 /// # Panics
197 /// Panics if `max_norm` is not positive
198 pub fn new(max_norm: F) -> Self {
199 assert!(max_norm > F::zero(), "max_norm must be positive");
200
201 Self {
202 max_norm,
203 last_clipped: std::cell::Cell::new(false),
204 last_stats: std::cell::RefCell::new(ClippingStats::default()),
205 }
206 }
207}
208
209impl<F: Float> GradientClipper<F> for ClipByNorm<F> {
210 fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
211 let any_clipped = false;
212 let num_clipped = 0;
213
214 let clipped: Vec<_> = gradients
215 .iter()
216 .map(|grad| {
217 // Compute the Frobenius norm of the gradient (equivalent to L2 norm for vectors)
218 let grad_norm = tensor_ops::frobenius_norm(grad);
219
220 // Create scalar tensors for comparison
221 let max_norm_tensor = tensor_ops::scalar(self.max_norm, grad.graph());
222 let one_tensor = tensor_ops::scalar(F::one(), grad.graph());
223
224 // Compute clipping factor: min(1.0, max_norm / grad_norm)
225 let ratio = max_norm_tensor / grad_norm;
226 let clipping_factor = tensor_ops::minimum(one_tensor, ratio);
227
228 // Note: In a full implementation, we'd track whether clipping actually occurred
229 // For simplicity, we assume clipping may have occurred
230 (*grad) * clipping_factor
231 })
232 .collect();
233
234 self.last_clipped.set(any_clipped);
235
236 // Update stats
237 let mut stats = self.last_stats.borrow_mut();
238 stats.was_clipped = any_clipped;
239 stats.num_clipped = num_clipped;
240 stats.total_gradients = gradients.len();
241
242 clipped
243 }
244
245 fn was_clipped(&self) -> bool {
246 self.last_clipped.get()
247 }
248
249 fn get_clipping_stats(&self) -> ClippingStats<F> {
250 self.last_stats.borrow().clone()
251 }
252}
253
254/// Clip gradients by global norm
255///
256/// Clips all gradients jointly based on their global norm. The global norm is
257/// computed as the L2 norm of the concatenation of all gradient vectors.
258///
259/// If the global norm exceeds max_norm, all gradients are scaled by the same factor:
260/// scaling_factor = max_norm / global_norm
261///
262/// This method preserves the relative magnitudes between different gradients
263/// while ensuring the overall gradient update is not too large.
264///
265/// # Example
266/// ```
267/// use scirs2_autograd as ag;
268/// use scirs2_autograd::gradient_clipping::{ClipByGlobalNorm, GradientClipper};
269/// use scirs2_autograd::tensor_ops::convert_to_tensor;
270///
271/// let mut env = ag::VariableEnvironment::new();
272/// let mut rng = ag::ndarray_ext::ArrayRng::<f32>::default();
273///
274/// env.run(|g| {
275/// // Create some example gradients
276/// let grad1 = convert_to_tensor(rng.standard_normal(&[2, 2]), g);
277/// let grad2 = convert_to_tensor(rng.standard_normal(&[3]), g);
278/// let gradients = vec![grad1, grad2];
279///
280/// let mut clipper = ClipByGlobalNorm::new(1.0f32);
281/// let _clipped_gradients = clipper.clip_gradients(&gradients);
282/// });
283/// ```
284pub struct ClipByGlobalNorm<F: Float> {
285 pub max_norm: F,
286 last_clipped: std::cell::Cell<bool>,
287 last_stats: std::cell::RefCell<ClippingStats<F>>,
288}
289
290impl<F: Float> ClipByGlobalNorm<F> {
291 /// Create a new global norm-based gradient clipper
292 ///
293 /// # Arguments
294 /// * `max_norm` - Maximum allowed global norm for all gradients combined
295 ///
296 /// # Panics
297 /// Panics if `max_norm` is not positive
298 pub fn new(max_norm: F) -> Self {
299 assert!(max_norm > F::zero(), "max_norm must be positive");
300
301 Self {
302 max_norm,
303 last_clipped: std::cell::Cell::new(false),
304 last_stats: std::cell::RefCell::new(ClippingStats::default()),
305 }
306 }
307}
308
309impl<F: Float> GradientClipper<F> for ClipByGlobalNorm<F> {
310 fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
311 if gradients.is_empty() {
312 return Vec::new();
313 }
314
315 let g = gradients[0].graph();
316
317 // Compute global norm: sqrt(sum(norm(grad_i)^2))
318 let squared_norms: Vec<_> = gradients
319 .iter()
320 .map(|grad| {
321 let norm = tensor_ops::frobenius_norm(grad);
322 tensor_ops::square(norm)
323 })
324 .collect();
325
326 let global_norm_squared = tensor_ops::add_n(&squared_norms);
327 let global_norm = tensor_ops::sqrt(global_norm_squared);
328
329 // Compute clipping factor
330 let max_norm_tensor = tensor_ops::scalar(self.max_norm, g);
331 let one_tensor = tensor_ops::scalar(F::one(), g);
332 let ratio = max_norm_tensor / global_norm;
333 let clipping_factor = tensor_ops::minimum(one_tensor, ratio);
334
335 // Apply the same clipping factor to all gradients
336 let clipped: Vec<_> = gradients
337 .iter()
338 .map(|grad| (*grad) * clipping_factor)
339 .collect();
340
341 // Note: In a full implementation, we'd evaluate global_norm and check if clipping occurred
342 let was_clipped = false; // Placeholder - would need evaluation to determine
343
344 self.last_clipped.set(was_clipped);
345
346 // Update stats
347 let mut stats = self.last_stats.borrow_mut();
348 stats.was_clipped = was_clipped;
349 stats.total_gradients = gradients.len();
350 stats.num_clipped = if was_clipped { gradients.len() } else { 0 };
351
352 clipped
353 }
354
355 fn was_clipped(&self) -> bool {
356 self.last_clipped.get()
357 }
358
359 fn get_clipping_stats(&self) -> ClippingStats<F> {
360 self.last_stats.borrow().clone()
361 }
362}
363
364/// Adaptive gradient clipper
365///
366/// Adjusts the clipping threshold based on the history of gradient norms.
367/// This can help automatically tune the clipping threshold during training.
368pub struct AdaptiveClipByNorm<F: Float> {
369 base_clipper: ClipByNorm<F>,
370 #[allow(dead_code)]
371 adaptation_rate: F,
372 current_threshold: std::cell::Cell<F>,
373}
374
375impl<F: Float> AdaptiveClipByNorm<F> {
376 /// Create a new adaptive gradient clipper
377 ///
378 /// # Arguments
379 /// * `initial_max_norm` - Initial maximum norm threshold
380 /// * `adaptation_rate` - Rate at which to adapt the threshold (0.0 to 1.0)
381 pub fn new(initial_max_norm: F, adaptation_rate: F) -> Self {
382 assert!(
383 adaptation_rate >= F::zero() && adaptation_rate <= F::one(),
384 "adaptation_rate must be between 0.0 and 1.0"
385 );
386
387 Self {
388 base_clipper: ClipByNorm::new(initial_max_norm),
389 adaptation_rate,
390 current_threshold: std::cell::Cell::new(initial_max_norm),
391 }
392 }
393
394 /// Get the current adaptive threshold
395 pub fn current_threshold(&self) -> F {
396 self.current_threshold.get()
397 }
398
399 /// Manually update the threshold (for external adaptation logic)
400 pub fn set_threshold(&self, new_threshold: F) {
401 assert!(new_threshold > F::zero(), "threshold must be positive");
402 self.current_threshold.set(new_threshold);
403 }
404}
405
406impl<F: Float> GradientClipper<F> for AdaptiveClipByNorm<F> {
407 fn clip_gradients<'g>(&mut self, gradients: &[Tensor<'g, F>]) -> Vec<Tensor<'g, F>> {
408 // Update the base clipper's threshold
409 let current_threshold = self.current_threshold.get();
410 self.base_clipper.max_norm = current_threshold;
411
412 // Apply clipping with current threshold
413 let result = self.base_clipper.clip_gradients(gradients);
414
415 // Note: In a full implementation, we'd compute actual gradient norms
416 // and adapt the threshold based on recent history
417 // For now, this is a placeholder for the adaptation logic
418
419 result
420 }
421
422 fn was_clipped(&self) -> bool {
423 self.base_clipper.was_clipped()
424 }
425
426 fn get_clipping_stats(&self) -> ClippingStats<F> {
427 self.base_clipper.get_clipping_stats()
428 }
429}
430
431/// Convenience functions for gradient clipping
432impl<F: Float> Tensor<'_, F> {
433 /// Clip this tensor's values to a range
434 ///
435 /// # Arguments
436 /// * `min_value` - Minimum allowed value
437 /// * `max_value` - Maximum allowed value
438 pub fn clip_values(self, min_value: F, max_value: F) -> Self {
439 tensor_ops::clip(self, min_value, max_value)
440 }
441
442 /// Clip this tensor's norm to a maximum value
443 ///
444 /// # Arguments
445 /// * `max_norm` - Maximum allowed norm
446 pub fn clip_norm(self, max_norm: F) -> Self {
447 let norm = tensor_ops::frobenius_norm(self);
448 let max_norm_tensor = tensor_ops::scalar(max_norm, self.graph());
449 let one_tensor = tensor_ops::scalar(F::one(), self.graph());
450 let ratio = max_norm_tensor / norm;
451 let clipping_factor = tensor_ops::minimum(one_tensor, ratio);
452 self * clipping_factor
453 }
454}
455
456/// Common gradient clipping presets
457pub mod presets {
458 use super::*;
459
460 /// Create a conservative gradient clipper for fine-tuning
461 pub fn conservative<F: Float>() -> ClipByGlobalNorm<F> {
462 ClipByGlobalNorm::new(F::from(0.5).expect("Failed to convert constant to float"))
463 }
464
465 /// Create a standard gradient clipper for general training
466 pub fn standard<F: Float>() -> ClipByGlobalNorm<F> {
467 ClipByGlobalNorm::new(F::from(1.0).expect("Failed to convert constant to float"))
468 }
469
470 /// Create an aggressive gradient clipper for unstable training
471 pub fn aggressive<F: Float>() -> ClipByGlobalNorm<F> {
472 ClipByGlobalNorm::new(F::from(0.1).expect("Failed to convert constant to float"))
473 }
474
475 /// Create a value-based clipper for preventing extreme gradients
476 pub fn extreme_prevention<F: Float>() -> ClipByValue<F> {
477 ClipByValue::symmetric(F::from(10.0).expect("Failed to convert constant to float"))
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn test_clip_by_value_creation() {
487 let clipper = ClipByValue::new(-1.0f32, 1.0f32);
488 assert_eq!(clipper.min_value, -1.0);
489 assert_eq!(clipper.max_value, 1.0);
490
491 let symmetric = ClipByValue::symmetric(0.5f32);
492 assert_eq!(symmetric.min_value, -0.5);
493 assert_eq!(symmetric.max_value, 0.5);
494 }
495
496 #[test]
497 fn test_clip_by_norm_creation() {
498 let clipper = ClipByNorm::new(1.0f32);
499 assert_eq!(clipper.max_norm, 1.0);
500 }
501
502 #[test]
503 fn test_clip_by_global_norm_creation() {
504 let clipper = ClipByGlobalNorm::new(1.0f32);
505 assert_eq!(clipper.max_norm, 1.0);
506 }
507
508 #[test]
509 fn test_adaptive_clipper() {
510 let clipper = AdaptiveClipByNorm::new(1.0f32, 0.1);
511 assert_eq!(clipper.current_threshold(), 1.0);
512
513 clipper.set_threshold(0.5);
514 assert_eq!(clipper.current_threshold(), 0.5);
515 }
516
517 #[test]
518 fn test_clipping_stats_default() {
519 let stats = ClippingStats::<f32>::default();
520 assert!(!stats.was_clipped);
521 assert_eq!(stats.num_clipped, 0);
522 assert_eq!(stats.total_gradients, 0);
523 }
524
525 #[test]
526 fn test_presets() {
527 let _conservative = presets::conservative::<f32>();
528 let _standard = presets::standard::<f32>();
529 let _aggressive = presets::aggressive::<f32>();
530 let _extreme = presets::extreme_prevention::<f32>();
531 }
532
533 #[test]
534 #[should_panic(expected = "min_value must be less than max_value")]
535 fn test_clip_by_value_invalid_range() {
536 ClipByValue::new(1.0f32, -1.0f32);
537 }
538
539 #[test]
540 #[should_panic(expected = "max_norm must be positive")]
541 fn test_clip_by_norm_negative_norm() {
542 ClipByNorm::new(-1.0f32);
543 }
544
545 #[test]
546 #[should_panic(expected = "adaptation_rate must be between 0.0 and 1.0")]
547 fn test_adaptive_clipper_invalid_rate() {
548 AdaptiveClipByNorm::new(1.0f32, 2.0);
549 }
550}