1use crate::{Result, Tensor};
18use scirs2_core::numeric::{Float, FromPrimitive};
19use std::collections::HashMap;
20use std::marker::PhantomData;
21
22#[derive(Debug, Clone)]
24pub struct GradientClippingConfig {
25 pub max_norm: f64,
27 pub norm_type: NormType,
29 pub adaptive_scaling: bool,
31 pub adaptive_momentum: f64,
33 pub min_threshold: f64,
35 pub max_threshold: f64,
37 pub warmup_steps: usize,
39 pub per_parameter_clipping: bool,
41}
42
43impl Default for GradientClippingConfig {
44 fn default() -> Self {
45 Self {
46 max_norm: 1.0,
47 norm_type: NormType::L2,
48 adaptive_scaling: false,
49 adaptive_momentum: 0.95,
50 min_threshold: 0.1,
51 max_threshold: 10.0,
52 warmup_steps: 0,
53 per_parameter_clipping: false,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq)]
60pub enum NormType {
61 L1,
63 L2,
65 Infinity,
67}
68
69#[derive(Debug, Clone)]
71pub struct GradientStatistics {
72 pub current_norm: f64,
74 pub avg_norm: f64,
76 pub std_norm: f64,
78 pub clip_count: usize,
80 pub total_updates: usize,
82 pub adaptive_threshold: f64,
84 pub norm_history: Vec<f64>,
86}
87
88impl Default for GradientStatistics {
89 fn default() -> Self {
90 Self {
91 current_norm: 0.0,
92 avg_norm: 0.0,
93 std_norm: 0.0,
94 clip_count: 0,
95 total_updates: 0,
96 adaptive_threshold: 1.0,
97 norm_history: Vec::with_capacity(100), }
99 }
100}
101
102pub struct GradientClipper<T> {
104 config: GradientClippingConfig,
105 statistics: GradientStatistics,
106 parameter_groups: HashMap<String, f64>, step_count: usize,
108 _phantom: PhantomData<T>,
109}
110
111impl<T> GradientClipper<T>
112where
113 T: Float + FromPrimitive + Clone + Send + Sync + Default + 'static,
114{
115 pub fn new(config: GradientClippingConfig) -> Self {
117 Self {
118 config,
119 statistics: GradientStatistics::default(),
120 parameter_groups: HashMap::new(),
121 step_count: 0,
122 _phantom: PhantomData,
123 }
124 }
125
126 pub fn default_stable() -> Self {
128 Self::new(GradientClippingConfig {
129 max_norm: 1.0,
130 norm_type: NormType::L2,
131 adaptive_scaling: false,
132 ..Default::default()
133 })
134 }
135
136 pub fn default_adaptive() -> Self {
138 Self::new(GradientClippingConfig {
139 max_norm: 1.0,
140 norm_type: NormType::L2,
141 adaptive_scaling: true,
142 adaptive_momentum: 0.95,
143 min_threshold: 0.1,
144 max_threshold: 5.0,
145 ..Default::default()
146 })
147 }
148
149 pub fn add_parameter_group(&mut self, group_name: String, threshold: f64) {
151 self.parameter_groups.insert(group_name, threshold);
152 }
153
154 pub fn clip_gradients(&mut self, gradients: &mut [Tensor<T>]) -> Result<f64> {
159 if gradients.is_empty() {
160 return Ok(0.0);
161 }
162
163 self.step_count += 1;
164
165 let global_norm = self.compute_global_norm(gradients)?;
167
168 self.update_statistics(global_norm);
170
171 let effective_threshold = self.get_effective_threshold();
173
174 let warmed_threshold = if self.step_count <= self.config.warmup_steps {
176 let warmup_factor = self.step_count as f64 / self.config.warmup_steps as f64;
177 effective_threshold * warmup_factor + self.config.max_norm * (1.0 - warmup_factor)
178 } else {
179 effective_threshold
180 };
181
182 if global_norm > warmed_threshold {
184 let scale_factor = warmed_threshold / global_norm;
185 self.scale_gradients(
186 gradients,
187 T::from_f64(scale_factor).unwrap_or_else(|| T::one()),
188 )?;
189 self.statistics.clip_count += 1;
190 }
191
192 Ok(global_norm)
193 }
194
195 pub fn clip_parameter_group(
197 &mut self,
198 group_name: &str,
199 gradients: &mut [Tensor<T>],
200 ) -> Result<f64> {
201 let threshold = self
202 .parameter_groups
203 .get(group_name)
204 .copied()
205 .unwrap_or(self.config.max_norm);
206
207 let global_norm = self.compute_global_norm(gradients)?;
208
209 if global_norm > threshold {
210 let scale_factor = threshold / global_norm;
211 self.scale_gradients(
212 gradients,
213 T::from_f64(scale_factor).unwrap_or_else(|| T::one()),
214 )?;
215 }
216
217 Ok(global_norm)
218 }
219
220 fn compute_global_norm(&self, gradients: &[Tensor<T>]) -> Result<f64> {
222 match self.config.norm_type {
223 NormType::L1 => {
224 let mut total_norm = 0.0;
225 for grad in gradients {
226 total_norm += self.compute_tensor_l1_norm(grad)?;
227 }
228 Ok(total_norm)
229 }
230 NormType::L2 => {
231 let mut total_squared_norm = 0.0;
232 for grad in gradients {
233 let tensor_norm = self.compute_tensor_l2_norm(grad)?;
234 total_squared_norm += tensor_norm * tensor_norm;
235 }
236 Ok(total_squared_norm.sqrt())
237 }
238 NormType::Infinity => {
239 let mut max_norm = 0.0;
240 for grad in gradients {
241 let tensor_max = self.compute_tensor_inf_norm(grad)?;
242 max_norm = max_norm.max(tensor_max);
243 }
244 Ok(max_norm)
245 }
246 }
247 }
248
249 fn compute_tensor_l1_norm(&self, tensor: &Tensor<T>) -> Result<f64> {
251 match &tensor.storage {
254 crate::tensor::TensorStorage::Cpu(array) => {
255 let sum: f64 = array.iter().map(|&x| x.abs().to_f64().unwrap_or(0.0)).sum();
256 Ok(sum)
257 }
258 #[cfg(feature = "gpu")]
259 crate::tensor::TensorStorage::Gpu(_) => {
260 Err(crate::TensorError::unsupported_operation_simple(
263 "GPU L1 norm computation not yet implemented".to_string(),
264 ))
265 }
266 }
267 }
268
269 fn compute_tensor_l2_norm(&self, tensor: &Tensor<T>) -> Result<f64> {
271 match &tensor.storage {
272 crate::tensor::TensorStorage::Cpu(array) => {
273 let sum_squares: f64 = array
274 .iter()
275 .map(|&x| {
276 let val = x.to_f64().unwrap_or(0.0);
277 val * val
278 })
279 .sum();
280 Ok(sum_squares.sqrt())
281 }
282 #[cfg(feature = "gpu")]
283 crate::tensor::TensorStorage::Gpu(_) => {
284 Err(crate::TensorError::unsupported_operation_simple(
286 "GPU L2 norm computation not yet implemented".to_string(),
287 ))
288 }
289 }
290 }
291
292 fn compute_tensor_inf_norm(&self, tensor: &Tensor<T>) -> Result<f64> {
294 match &tensor.storage {
295 crate::tensor::TensorStorage::Cpu(array) => {
296 let max_val = array
297 .iter()
298 .map(|&x| x.abs().to_f64().unwrap_or(0.0))
299 .fold(0.0, f64::max);
300 Ok(max_val)
301 }
302 #[cfg(feature = "gpu")]
303 crate::tensor::TensorStorage::Gpu(_) => {
304 Err(crate::TensorError::unsupported_operation_simple(
306 "GPU infinity norm computation not yet implemented".to_string(),
307 ))
308 }
309 }
310 }
311
312 fn scale_gradients(&self, gradients: &mut [Tensor<T>], scale_factor: T) -> Result<()> {
314 for grad in gradients.iter_mut() {
315 *grad = grad.mul_scalar(scale_factor)?;
316 }
317 Ok(())
318 }
319
320 fn update_statistics(&mut self, global_norm: f64) {
322 self.statistics.current_norm = global_norm;
323 self.statistics.total_updates += 1;
324
325 if self.statistics.total_updates == 1 {
327 self.statistics.avg_norm = global_norm;
328 } else {
329 let momentum = self.config.adaptive_momentum;
330 self.statistics.avg_norm =
331 momentum * self.statistics.avg_norm + (1.0 - momentum) * global_norm;
332 }
333
334 self.statistics.norm_history.push(global_norm);
336 if self.statistics.norm_history.len() > 100 {
337 self.statistics.norm_history.remove(0);
338 }
339
340 if self.statistics.norm_history.len() > 1 {
342 let mean = self.statistics.avg_norm;
343 let variance: f64 = self
344 .statistics
345 .norm_history
346 .iter()
347 .map(|&x| (x - mean).powi(2))
348 .sum::<f64>()
349 / (self.statistics.norm_history.len() - 1) as f64;
350 self.statistics.std_norm = variance.sqrt();
351 }
352
353 if self.config.adaptive_scaling {
355 self.update_adaptive_threshold();
356 }
357 }
358
359 fn update_adaptive_threshold(&mut self) {
361 let base_threshold = self.config.max_norm;
362
363 let variance_factor = if self.statistics.std_norm > 0.0 {
366 (self.statistics.std_norm / self.statistics.avg_norm).min(2.0)
367 } else {
368 1.0
369 };
370
371 let recent_clip_rate = if self.statistics.total_updates > 0 {
373 self.statistics.clip_count as f64 / self.statistics.total_updates as f64
374 } else {
375 0.0
376 };
377
378 let frequency_adjustment = if recent_clip_rate > 0.5 {
380 0.9 } else if recent_clip_rate < 0.1 {
382 1.1 } else {
384 1.0 };
386
387 let new_threshold = base_threshold * variance_factor * frequency_adjustment;
388
389 self.statistics.adaptive_threshold = new_threshold
391 .max(self.config.min_threshold)
392 .min(self.config.max_threshold);
393 }
394
395 fn get_effective_threshold(&self) -> f64 {
397 if self.config.adaptive_scaling {
398 self.statistics.adaptive_threshold
399 } else {
400 self.config.max_norm
401 }
402 }
403
404 pub fn get_statistics(&self) -> &GradientStatistics {
406 &self.statistics
407 }
408
409 pub fn get_config(&self) -> &GradientClippingConfig {
411 &self.config
412 }
413
414 pub fn reset_statistics(&mut self) {
416 self.statistics = GradientStatistics::default();
417 self.step_count = 0;
418 }
419
420 pub fn get_clipping_rate(&self) -> f64 {
422 if self.statistics.total_updates > 0 {
423 self.statistics.clip_count as f64 / self.statistics.total_updates as f64
424 } else {
425 0.0
426 }
427 }
428
429 pub fn would_clip(&self, gradients: &[Tensor<T>]) -> Result<bool> {
431 let global_norm = self.compute_global_norm(gradients)?;
432 Ok(global_norm > self.get_effective_threshold())
433 }
434}
435
436impl<T> Tensor<T>
437where
438 T: Float + FromPrimitive + Clone + Send + Sync + Default + 'static,
439{
440 pub fn mul_scalar(&self, scalar: T) -> Result<Tensor<T>> {
442 match &self.storage {
443 crate::tensor::TensorStorage::Cpu(array) => {
444 let scaled_array = array.mapv(|x| x * scalar);
445 Ok(Tensor::from_array(scaled_array))
446 }
447 #[cfg(feature = "gpu")]
448 crate::tensor::TensorStorage::Gpu(_) => {
449 Err(crate::TensorError::unsupported_operation_simple(
451 "GPU scalar multiplication not yet implemented".to_string(),
452 ))
453 }
454 }
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use scirs2_core::ndarray::Array1;
462
463 #[test]
464 fn test_gradient_clipping_basic() {
465 let mut clipper = GradientClipper::<f32>::default_stable();
466
467 let large_grad = Tensor::from_array(Array1::from_vec(vec![5.0, 5.0, 5.0, 5.0]).into_dyn());
469 let mut gradients = vec![large_grad];
470
471 let norm = clipper
472 .clip_gradients(&mut gradients)
473 .expect("test: clip_gradients should succeed");
474
475 assert!(norm > 1.0);
477 assert_eq!(clipper.get_statistics().clip_count, 1);
478 }
479
480 #[test]
481 fn test_adaptive_clipping() {
482 let mut clipper = GradientClipper::<f32>::default_adaptive();
483
484 for i in 0..10 {
486 let scale = 1.0 + i as f32 * 0.5;
487 let grad = Tensor::from_array(Array1::from_vec(vec![scale, scale]).into_dyn());
488 let mut gradients = vec![grad];
489
490 let _norm = clipper
491 .clip_gradients(&mut gradients)
492 .expect("test: clip_gradients should succeed");
493 }
494
495 let stats = clipper.get_statistics();
497 assert!(stats.total_updates == 10);
498 assert!(stats.adaptive_threshold > 0.0);
499 }
500
501 #[test]
502 fn test_parameter_groups() {
503 let mut clipper = GradientClipper::<f32>::new(GradientClippingConfig {
504 per_parameter_clipping: true,
505 ..Default::default()
506 });
507
508 clipper.add_parameter_group("embeddings".to_string(), 0.5);
509 clipper.add_parameter_group("output".to_string(), 2.0);
510
511 let grad = Tensor::from_array(Array1::from_vec(vec![1.5, 1.5]).into_dyn());
512 let mut gradients = vec![grad];
513
514 let norm = clipper
516 .clip_parameter_group("embeddings", &mut gradients)
517 .expect("test: operation should succeed");
518 assert!(norm > 0.5);
519 }
520
521 #[test]
522 fn test_different_norm_types() {
523 let l1_config = GradientClippingConfig {
524 norm_type: NormType::L1,
525 max_norm: 4.0,
526 ..Default::default()
527 };
528 let mut l1_clipper = GradientClipper::<f32>::new(l1_config);
529
530 let grad = Tensor::from_array(Array1::from_vec(vec![2.0, 2.0]).into_dyn());
531 let mut gradients = vec![grad];
532
533 let norm = l1_clipper
534 .clip_gradients(&mut gradients)
535 .expect("test: clip_gradients should succeed");
536 assert_eq!(norm, 4.0); }
538}