scirs2_optim/gradient_processing/
mod.rs

1//! Gradient processing utilities for machine learning optimization
2//!
3//! This module provides comprehensive gradient manipulation utilities including
4//! various clipping strategies, normalization, and other processing techniques.
5
6use ndarray::{Array, Dimension, ScalarOperand};
7use num_traits::Float;
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11
12/// Gradient clipping configuration
13#[derive(Debug, Clone)]
14pub struct GradientClipConfig<A: Float> {
15    /// Maximum allowed value for individual gradient elements
16    pub max_value: Option<A>,
17    /// Minimum allowed value for individual gradient elements  
18    pub min_value: Option<A>,
19    /// Maximum allowed L2 norm for the entire gradient vector
20    pub max_norm: Option<A>,
21    /// Maximum allowed L1 norm
22    pub max_l1_norm: Option<A>,
23    /// Whether to apply gradient centralization
24    pub centralization: bool,
25    /// Threshold for zeroing small gradients
26    pub zero_threshold: Option<A>,
27}
28
29impl<A: Float> Default for GradientClipConfig<A> {
30    fn default() -> Self {
31        Self {
32            max_value: None,
33            min_value: None,
34            max_norm: None,
35            max_l1_norm: None,
36            centralization: false,
37            zero_threshold: None,
38        }
39    }
40}
41
42/// Gradient clipping processor
43pub struct GradientProcessor<A: Float> {
44    config: GradientClipConfig<A>,
45}
46
47impl<A: Float + ScalarOperand + Debug> Default for GradientProcessor<A> {
48    fn default() -> Self {
49        Self {
50            config: GradientClipConfig::default(),
51        }
52    }
53}
54
55impl<A: Float + ScalarOperand + Debug> GradientProcessor<A> {
56    /// Create a new gradient processor with default configuration
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Create a new gradient processor with a specific configuration
62    pub fn with_config(config: GradientClipConfig<A>) -> Self {
63        Self { config }
64    }
65
66    /// Set max value clipping
67    pub fn set_max_value(&mut self, value: A) -> &mut Self {
68        self.config.max_value = Some(value);
69        self
70    }
71
72    /// Set min value clipping
73    pub fn set_min_value(&mut self, value: A) -> &mut Self {
74        self.config.min_value = Some(value);
75        self
76    }
77
78    /// Set max L2 norm clipping
79    pub fn set_max_norm(&mut self, value: A) -> &mut Self {
80        self.config.max_norm = Some(value);
81        self
82    }
83
84    /// Set max L1 norm clipping
85    pub fn set_max_l1_norm(&mut self, value: A) -> &mut Self {
86        self.config.max_l1_norm = Some(value);
87        self
88    }
89
90    /// Enable gradient centralization
91    pub fn set_centralization(&mut self, enabled: bool) -> &mut Self {
92        self.config.centralization = enabled;
93        self
94    }
95
96    /// Set threshold for zeroing small gradients
97    pub fn set_zero_threshold(&mut self, value: A) -> &mut Self {
98        self.config.zero_threshold = Some(value);
99        self
100    }
101
102    /// Set value clipping range
103    pub fn set_value_clip(&mut self, min: A, max: A) -> &mut Self {
104        self.config.min_value = Some(min);
105        self.config.max_value = Some(max);
106        self
107    }
108
109    /// Set norm clipping
110    pub fn set_norm_clip(&mut self, max_norm: A) -> &mut Self {
111        self.config.max_norm = Some(max_norm);
112        self
113    }
114
115    /// Set L1 norm clipping
116    pub fn set_l1_norm_clip(&mut self, max_l1_norm: A) -> &mut Self {
117        self.config.max_l1_norm = Some(max_l1_norm);
118        self
119    }
120
121    /// Enable gradient centralization
122    pub fn enable_centralization(&mut self) -> &mut Self {
123        self.config.centralization = true;
124        self
125    }
126
127    /// Process gradients according to configuration
128    pub fn process<D: Dimension>(&self, gradients: &mut Array<A, D>) -> Result<()> {
129        // Apply value clipping if configured
130        if let (Some(min), Some(max)) = (self.config.min_value, self.config.max_value) {
131            clip_gradients_by_value(gradients, min, max);
132        }
133
134        // Apply L2 norm clipping if configured
135        if let Some(max_norm) = self.config.max_norm {
136            clip_gradient_norm(gradients, max_norm)?;
137        }
138
139        // Apply L1 norm clipping if configured
140        if let Some(max_l1_norm) = self.config.max_l1_norm {
141            clip_gradient_l1_norm(gradients, max_l1_norm)?;
142        }
143
144        // Apply gradient centralization if enabled
145        if self.config.centralization {
146            gradient_centralization(gradients);
147        }
148
149        // Zero small gradients if threshold is set
150        if let Some(threshold) = self.config.zero_threshold {
151            zero_small_gradients(gradients, threshold);
152        }
153
154        Ok(())
155    }
156}
157
158/// Clip gradient values to a specified range
159pub fn clip_gradients_by_value<A, D>(
160    gradients: &mut Array<A, D>,
161    min_value: A,
162    max_value: A,
163) -> &mut Array<A, D>
164where
165    A: Float + ScalarOperand,
166    D: Dimension,
167{
168    gradients.mapv_inplace(|x| {
169        if x < min_value {
170            min_value
171        } else if x > max_value {
172            max_value
173        } else {
174            x
175        }
176    });
177    gradients
178}
179
180/// Clip gradient L2 norm (global gradient clipping)
181pub fn clip_gradient_norm<A, D>(
182    gradients: &mut Array<A, D>,
183    max_norm: A,
184) -> Result<&mut Array<A, D>>
185where
186    A: Float + ScalarOperand,
187    D: Dimension,
188{
189    if max_norm <= A::zero() {
190        return Err(OptimError::InvalidConfig(
191            "max_norm must be positive".to_string(),
192        ));
193    }
194
195    // Calculate current L2 norm
196    let norm = gradients
197        .iter()
198        .fold(A::zero(), |acc, &x| acc + x * x)
199        .sqrt();
200
201    // If norm exceeds max_norm, scale gradients
202    if norm > max_norm {
203        let scale = max_norm / norm;
204        gradients.mapv_inplace(|x| x * scale);
205    }
206
207    Ok(gradients)
208}
209
210/// Clip gradient L1 norm
211pub fn clip_gradient_l1_norm<A, D>(
212    gradients: &mut Array<A, D>,
213    max_l1_norm: A,
214) -> Result<&mut Array<A, D>>
215where
216    A: Float + ScalarOperand,
217    D: Dimension,
218{
219    if max_l1_norm <= A::zero() {
220        return Err(OptimError::InvalidConfig(
221            "max_l1_norm must be positive".to_string(),
222        ));
223    }
224
225    // Calculate current L1 norm
226    let l1_norm = gradients.iter().fold(A::zero(), |acc, &x| acc + x.abs());
227
228    // If norm exceeds max_l1_norm, scale gradients
229    if l1_norm > max_l1_norm {
230        let scale = max_l1_norm / l1_norm;
231        gradients.mapv_inplace(|x| x * scale);
232    }
233
234    Ok(gradients)
235}
236
237/// Compute gradient centralization
238pub fn gradient_centralization<A, D>(gradients: &mut Array<A, D>) -> &mut Array<A, D>
239where
240    A: Float + ScalarOperand,
241    D: Dimension,
242{
243    // Calculate mean
244    let sum = gradients.iter().fold(A::zero(), |acc, &x| acc + x);
245    let mean = sum / A::from(gradients.len()).unwrap_or(A::one());
246
247    // Subtract mean from each element
248    gradients.mapv_inplace(|x| x - mean);
249
250    gradients
251}
252
253/// Zero out small gradient values
254pub fn zero_small_gradients<A, D>(gradients: &mut Array<A, D>, threshold: A) -> &mut Array<A, D>
255where
256    A: Float + ScalarOperand,
257    D: Dimension,
258{
259    let abs_threshold = threshold.abs();
260
261    gradients.mapv_inplace(|x| {
262        if x.abs() < abs_threshold {
263            A::zero()
264        } else {
265            x
266        }
267    });
268
269    gradients
270}
271
272/// Gradient accumulation utility
273#[derive(Debug, Clone)]
274pub struct GradientAccumulator<A: Float, D: Dimension> {
275    /// Accumulated gradients
276    accumulated_gradients: Option<Array<A, D>>,
277    /// Number of accumulated micro-batches
278    num_accumulated: usize,
279    /// Target number of micro-batches before step
280    accumulation_steps: usize,
281    /// Whether to average gradients (vs sum)
282    average_gradients: bool,
283}
284
285impl<A: Float + ScalarOperand + Debug, D: Dimension> GradientAccumulator<A, D> {
286    /// Create a new gradient accumulator
287    ///
288    /// # Arguments
289    ///
290    /// * `accumulation_steps` - Number of micro-batches to accumulate before stepping
291    /// * `average_gradients` - Whether to average gradients (true) or sum them (false)
292    pub fn new(accumulation_steps: usize, average_gradients: bool) -> Self {
293        Self {
294            accumulated_gradients: None,
295            num_accumulated: 0,
296            accumulation_steps,
297            average_gradients,
298        }
299    }
300
301    /// Add gradients from a micro-batch
302    ///
303    /// # Arguments
304    ///
305    /// * `gradients` - Gradients from the current micro-batch
306    ///
307    /// # Returns
308    ///
309    /// `true` if enough gradients have been accumulated and it's time to step
310    pub fn accumulate(&mut self, gradients: &Array<A, D>) -> bool {
311        if self.accumulated_gradients.is_none() {
312            self.accumulated_gradients = Some(gradients.clone());
313        } else {
314            let acc = self.accumulated_gradients.as_mut().unwrap();
315            for (acc_val, &grad_val) in acc.iter_mut().zip(gradients.iter()) {
316                *acc_val = *acc_val + grad_val;
317            }
318        }
319
320        self.num_accumulated += 1;
321        self.num_accumulated >= self.accumulation_steps
322    }
323
324    /// Get the accumulated gradients and reset the accumulator
325    ///
326    /// # Returns
327    ///
328    /// The accumulated gradients, ready for optimization step
329    pub fn get_and_reset(&mut self) -> Option<Array<A, D>> {
330        if let Some(mut gradients) = self.accumulated_gradients.take() {
331            if self.average_gradients && self.num_accumulated > 0 {
332                let scale = A::one() / A::from(self.num_accumulated).unwrap_or(A::one());
333                gradients.mapv_inplace(|x| x * scale);
334            }
335            self.num_accumulated = 0;
336            Some(gradients)
337        } else {
338            None
339        }
340    }
341
342    /// Get current accumulation progress
343    pub fn progress(&self) -> (usize, usize) {
344        (self.num_accumulated, self.accumulation_steps)
345    }
346
347    /// Check if ready for optimization step
348    pub fn is_ready(&self) -> bool {
349        self.num_accumulated >= self.accumulation_steps
350    }
351
352    /// Reset the accumulator
353    pub fn reset(&mut self) {
354        self.accumulated_gradients = None;
355        self.num_accumulated = 0;
356    }
357
358    /// Change accumulation steps
359    pub fn set_accumulation_steps(&mut self, steps: usize) {
360        self.accumulation_steps = steps;
361    }
362}
363
364/// Adaptive gradient clipping
365///
366/// Clips gradients based on the ratio of gradient norm to parameter norm.
367/// This is particularly useful for transformer models.
368pub fn adaptive_gradient_clipping<'a, A, D>(
369    gradients: &'a mut Array<A, D>,
370    parameters: &Array<A, D>,
371    max_ratio: A,
372) -> Result<&'a mut Array<A, D>>
373where
374    A: Float + ScalarOperand,
375    D: Dimension,
376{
377    if max_ratio <= A::zero() {
378        return Err(OptimError::InvalidConfig(
379            "max_ratio must be positive".to_string(),
380        ));
381    }
382
383    let grad_norm = gradients
384        .iter()
385        .fold(A::zero(), |acc, &x| acc + x * x)
386        .sqrt();
387
388    let param_norm = parameters
389        .iter()
390        .fold(A::zero(), |acc, &x| acc + x * x)
391        .sqrt();
392
393    if param_norm > A::zero() && grad_norm > A::zero() {
394        let ratio = grad_norm / param_norm;
395        if ratio > max_ratio {
396            let scale = max_ratio / ratio;
397            gradients.mapv_inplace(|x| x * scale);
398        }
399    }
400
401    Ok(gradients)
402}
403
404/// Add noise to gradients for regularization
405///
406/// # Arguments
407///
408/// * `gradients` - Gradients to add noise to
409/// * `noise_std` - Standard deviation of Gaussian noise to add
410/// * `seed` - Optional seed for reproducible results
411pub fn add_gradient_noise<A, D>(
412    gradients: &mut Array<A, D>,
413    noise_std: A,
414    seed: Option<u64>,
415) -> &mut Array<A, D>
416where
417    A: Float + ScalarOperand,
418    D: Dimension,
419{
420    use ndarray_rand::rand::SeedableRng;
421    use ndarray_rand::rand_distr::Normal;
422    use ndarray_rand::RandomExt;
423
424    if noise_std <= A::zero() {
425        return gradients;
426    }
427
428    let mut rng = if let Some(s) = seed {
429        ndarray_rand::rand::rngs::StdRng::seed_from_u64(s)
430    } else {
431        ndarray_rand::rand::rngs::StdRng::from_entropy()
432    };
433
434    let normal = Normal::new(0.0, noise_std.to_f64().unwrap_or(0.01)).unwrap();
435    let noise = Array::random_using(gradients.raw_dim(), normal, &mut rng);
436
437    gradients.zip_mut_with(&noise, |g, &n| {
438        *g = *g + A::from(n).unwrap_or(A::zero());
439    });
440
441    gradients
442}
443
444/// Gradient masking and freezing utilities
445///
446/// Allows selective gradient updates by masking certain parameters
447#[derive(Debug, Clone)]
448pub struct GradientMask<A: Float, D: Dimension> {
449    /// Mask indicating which parameters to update (true = update, false = freeze)
450    mask: Array<bool, D>,
451    /// Optional learning rate multipliers for each parameter
452    lr_multipliers: Option<Array<A, D>>,
453}
454
455impl<A: Float + ScalarOperand + Debug, D: Dimension> GradientMask<A, D> {
456    /// Create a new gradient mask
457    ///
458    /// # Arguments
459    ///
460    /// * `mask` - Boolean mask indicating which parameters to update
461    pub fn new(mask: Array<bool, D>) -> Self {
462        Self {
463            mask,
464            lr_multipliers: None,
465        }
466    }
467
468    /// Create a mask that freezes all parameters
469    pub fn freeze_all(shape: D) -> Self {
470        Self {
471            mask: Array::from_elem(shape, false),
472            lr_multipliers: None,
473        }
474    }
475
476    /// Create a mask that updates all parameters
477    pub fn update_all(shape: D) -> Self {
478        Self {
479            mask: Array::from_elem(shape, true),
480            lr_multipliers: None,
481        }
482    }
483
484    /// Set learning rate multipliers for different parameters
485    pub fn with_lr_multipliers(mut self, multipliers: Array<A, D>) -> Self {
486        self.lr_multipliers = Some(multipliers);
487        self
488    }
489
490    /// Apply the mask to gradients
491    ///
492    /// # Arguments
493    ///
494    /// * `gradients` - Gradients to mask
495    ///
496    /// # Returns
497    ///
498    /// Masked gradients where frozen parameters have zero gradients
499    pub fn apply_mask<'a>(&self, gradients: &'a mut Array<A, D>) -> &'a mut Array<A, D> {
500        gradients.zip_mut_with(&self.mask, |grad, &should_update| {
501            if !should_update {
502                *grad = A::zero();
503            }
504        });
505
506        // Apply learning rate multipliers if present
507        if let Some(multipliers) = &self.lr_multipliers {
508            gradients.zip_mut_with(multipliers, |grad, &mult| {
509                *grad = *grad * mult;
510            });
511        }
512
513        gradients
514    }
515
516    /// Freeze specific parameters by indices
517    pub fn freeze_indices(&mut self, indices: &[usize]) -> Result<()> {
518        let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
519            OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
520        })?;
521
522        for &idx in indices {
523            if idx < flat_mask.len() {
524                flat_mask[idx] = false;
525            } else {
526                return Err(OptimError::InvalidConfig(format!(
527                    "Index {} out of bounds for mask of size {}",
528                    idx,
529                    flat_mask.len()
530                )));
531            }
532        }
533        Ok(())
534    }
535
536    /// Unfreeze specific parameters by indices
537    pub fn unfreeze_indices(&mut self, indices: &[usize]) -> Result<()> {
538        let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
539            OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
540        })?;
541
542        for &idx in indices {
543            if idx < flat_mask.len() {
544                flat_mask[idx] = true;
545            } else {
546                return Err(OptimError::InvalidConfig(format!(
547                    "Index {} out of bounds for mask of size {}",
548                    idx,
549                    flat_mask.len()
550                )));
551            }
552        }
553        Ok(())
554    }
555
556    /// Get the number of frozen parameters
557    pub fn num_frozen(&self) -> usize {
558        self.mask.iter().filter(|&&x| !x).count()
559    }
560
561    /// Get the number of active (unfrozen) parameters
562    pub fn num_active(&self) -> usize {
563        self.mask.iter().filter(|&&x| x).count()
564    }
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570    use approx::assert_relative_eq;
571    use ndarray::Array1;
572
573    #[test]
574    fn test_gradient_processor() {
575        let config = GradientClipConfig::<f64> {
576            max_value: Some(5.0),
577            min_value: Some(-5.0),
578            max_norm: Some(10.0),
579            ..Default::default()
580        };
581
582        let processor = GradientProcessor::with_config(config);
583
584        let mut gradients = Array1::from_vec(vec![-8.0, 3.0, 7.0, -2.0, 6.0]);
585        processor.process(&mut gradients).unwrap();
586
587        // Check value clipping
588        assert_eq!(gradients[0], -5.0);
589        assert_eq!(gradients[2], 5.0);
590        assert_eq!(gradients[4], 5.0);
591    }
592
593    #[test]
594    fn test_adaptive_clipping() {
595        let mut gradients = Array1::from_vec(vec![3.0, 4.0]); // norm = 5
596        let parameters = Array1::from_vec(vec![1.0, 0.0]); // norm = 1
597
598        // Gradient/parameter ratio = 5/1 = 5, max_ratio = 2
599        adaptive_gradient_clipping(&mut gradients, &parameters, 2.0).unwrap();
600
601        // After clipping, ratio should be 2
602        let new_grad_norm = gradients.iter().fold(0.0, |acc, &x| acc + x * x).sqrt();
603        assert!((new_grad_norm - 2.0).abs() < 1e-6);
604    }
605
606    #[test]
607    fn test_gradient_accumulator() {
608        let mut accumulator = GradientAccumulator::new(3, true);
609
610        // First micro-batch
611        let grad1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
612        assert!(!accumulator.accumulate(&grad1));
613        assert_eq!(accumulator.progress(), (1, 3));
614
615        // Second micro-batch
616        let grad2 = Array1::from_vec(vec![2.0, 3.0, 4.0]);
617        assert!(!accumulator.accumulate(&grad2));
618        assert_eq!(accumulator.progress(), (2, 3));
619
620        // Third micro-batch - should trigger ready
621        let grad3 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
622        assert!(accumulator.accumulate(&grad3));
623        assert!(accumulator.is_ready());
624
625        // Get accumulated gradients (should be averaged)
626        let final_grads = accumulator.get_and_reset().unwrap();
627        assert_relative_eq!(final_grads[0], 2.0, epsilon = 1e-6); // (1+2+3)/3
628        assert_relative_eq!(final_grads[1], 3.0, epsilon = 1e-6); // (2+3+4)/3
629        assert_relative_eq!(final_grads[2], 4.0, epsilon = 1e-6); // (3+4+5)/3
630
631        // Should be reset now
632        assert_eq!(accumulator.progress(), (0, 3));
633        assert!(!accumulator.is_ready());
634    }
635
636    #[test]
637    fn test_gradient_accumulator_sum_mode() {
638        let mut accumulator = GradientAccumulator::new(2, false); // sum mode
639
640        let grad1 = Array1::from_vec(vec![1.0, 2.0]);
641        let grad2 = Array1::from_vec(vec![3.0, 4.0]);
642
643        accumulator.accumulate(&grad1);
644        accumulator.accumulate(&grad2);
645
646        let final_grads = accumulator.get_and_reset().unwrap();
647        assert_relative_eq!(final_grads[0], 4.0, epsilon = 1e-6); // 1+3
648        assert_relative_eq!(final_grads[1], 6.0, epsilon = 1e-6); // 2+4
649    }
650
651    #[test]
652    fn test_gradient_noise() {
653        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
654        let original = gradients.clone();
655
656        // Add noise with fixed seed for reproducibility
657        add_gradient_noise(&mut gradients, 0.1, Some(42));
658
659        // Gradients should be different but close to original
660        for (i, (&orig, &noisy)) in original.iter().zip(gradients.iter()).enumerate() {
661            assert!(
662                (orig - noisy).abs() < 1.0,
663                "Index {}: {} vs {}",
664                i,
665                orig,
666                noisy
667            );
668        }
669    }
670
671    #[test]
672    fn test_gradient_noise_zero_std() {
673        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
674        let original = gradients.clone();
675
676        // Zero noise should leave gradients unchanged
677        add_gradient_noise(&mut gradients, 0.0, Some(42));
678
679        for (orig, noisy) in original.iter().zip(gradients.iter()) {
680            assert_relative_eq!(*orig, *noisy, epsilon = 1e-10);
681        }
682    }
683
684    #[test]
685    fn test_gradient_mask_creation() {
686        let mask = Array1::from_vec(vec![true, false, true]);
687        let grad_mask: GradientMask<f64, ndarray::Ix1> = GradientMask::new(mask);
688
689        assert_eq!(grad_mask.num_active(), 2);
690        assert_eq!(grad_mask.num_frozen(), 1);
691    }
692
693    #[test]
694    fn test_gradient_mask_apply() {
695        let mask = Array1::from_vec(vec![true, false, true]);
696        let grad_mask: GradientMask<f64, ndarray::Ix1> = GradientMask::new(mask);
697        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
698
699        grad_mask.apply_mask(&mut gradients);
700
701        assert_eq!(gradients.as_slice().unwrap(), &[1.0, 0.0, 3.0]);
702    }
703
704    #[test]
705    fn test_gradient_mask_freeze_unfreeze() {
706        let mask = Array1::from_vec(vec![true, true, true]);
707        let mut grad_mask: GradientMask<f64, ndarray::Ix1> = GradientMask::new(mask);
708
709        // Freeze some indices
710        grad_mask.freeze_indices(&[0, 2]).unwrap();
711        assert_eq!(grad_mask.num_frozen(), 2);
712        assert_eq!(grad_mask.num_active(), 1);
713
714        // Unfreeze one index
715        grad_mask.unfreeze_indices(&[0]).unwrap();
716        assert_eq!(grad_mask.num_frozen(), 1);
717        assert_eq!(grad_mask.num_active(), 2);
718    }
719
720    #[test]
721    fn test_gradient_mask_with_lr_multipliers() {
722        let mask = Array1::from_vec(vec![true, true, true]);
723        let multipliers = Array1::from_vec(vec![1.0, 0.5, 2.0]);
724        let grad_mask: GradientMask<f64, ndarray::Ix1> =
725            GradientMask::new(mask).with_lr_multipliers(multipliers);
726        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
727
728        grad_mask.apply_mask(&mut gradients);
729
730        assert_relative_eq!(gradients[0], 1.0, epsilon = 1e-6);
731        assert_relative_eq!(gradients[1], 1.0, epsilon = 1e-6); // 2.0 * 0.5
732        assert_relative_eq!(gradients[2], 6.0, epsilon = 1e-6); // 3.0 * 2.0
733    }
734
735    #[test]
736    fn test_gradient_mask_freeze_all() {
737        let grad_mask = GradientMask::<f64, ndarray::Ix1>::freeze_all(ndarray::Ix1(3));
738        assert_eq!(grad_mask.num_frozen(), 3);
739        assert_eq!(grad_mask.num_active(), 0);
740    }
741
742    #[test]
743    fn test_gradient_mask_update_all() {
744        let grad_mask = GradientMask::<f64, ndarray::Ix1>::update_all(ndarray::Ix1(3));
745        assert_eq!(grad_mask.num_frozen(), 0);
746        assert_eq!(grad_mask.num_active(), 3);
747    }
748}