Skip to main content

trustformers_training/
gradient.rs

1use trustformers_core::errors::{tensor_op_error, Result};
2use trustformers_core::Tensor;
3
4/// Utility functions for gradient operations
5pub struct GradientUtils;
6
7impl GradientUtils {
8    /// Clip gradients by global norm
9    ///
10    /// This function clips the gradients of a list of tensors by their global norm.
11    /// If the global norm exceeds max_norm, all gradients are scaled down proportionally.
12    ///
13    /// # Arguments
14    /// * `gradients` - Mutable reference to a vector of gradients
15    /// * `max_norm` - Maximum allowed norm for gradient clipping
16    ///
17    /// # Returns
18    /// The actual norm of the gradients before clipping
19    pub fn clip_grad_norm(gradients: &mut Vec<Tensor>, max_norm: f32) -> Result<f32> {
20        if gradients.is_empty() {
21            return Ok(0.0);
22        }
23
24        // Compute global norm
25        let mut total_norm_sq = 0.0;
26        for grad in gradients.iter() {
27            let norm = grad.norm()?;
28            total_norm_sq += norm * norm;
29        }
30        let total_norm = total_norm_sq.sqrt();
31
32        // Clip if necessary
33        if total_norm > max_norm {
34            let clip_coef = max_norm / total_norm;
35            for grad in gradients.iter_mut() {
36                *grad = grad.scale(clip_coef)?;
37            }
38        }
39
40        Ok(total_norm)
41    }
42
43    /// Clip gradients by value
44    ///
45    /// This function clips each gradient tensor element-wise to be within [-clip_value, clip_value]
46    ///
47    /// # Arguments
48    /// * `gradients` - Mutable reference to a vector of gradients
49    /// * `clip_value` - Maximum absolute value for each gradient element
50    pub fn clip_grad_value(gradients: &mut Vec<Tensor>, clip_value: f32) -> Result<()> {
51        for grad in gradients.iter_mut() {
52            match grad {
53                Tensor::F32(arr) => {
54                    arr.mapv_inplace(|x| x.clamp(-clip_value, clip_value));
55                },
56                Tensor::F64(arr) => {
57                    let clip_value_f64 = clip_value as f64;
58                    arr.mapv_inplace(|x| x.clamp(-clip_value_f64, clip_value_f64));
59                },
60                _ => {
61                    return Err(tensor_op_error(
62                        "gradient_clipping",
63                        "Unsupported tensor type for gradient value clipping",
64                    ))
65                },
66            }
67        }
68        Ok(())
69    }
70
71    /// Accumulate gradients by averaging
72    ///
73    /// This function adds new gradients to accumulated gradients and scales them
74    /// by 1/accumulation_steps to maintain proper scaling
75    ///
76    /// # Arguments
77    /// * `accumulated_grads` - Mutable reference to accumulated gradients
78    /// * `new_grads` - New gradients to add
79    /// * `accumulation_steps` - Number of accumulation steps (for scaling)
80    pub fn accumulate_gradients(
81        accumulated_grads: &mut Vec<Tensor>,
82        new_grads: &[Tensor],
83        accumulation_steps: usize,
84    ) -> Result<()> {
85        if accumulated_grads.len() != new_grads.len() {
86            return Err(tensor_op_error(
87                "gradient_accumulation",
88                "mismatched number of gradients",
89            ));
90        }
91
92        let scale = 1.0 / accumulation_steps as f32;
93
94        for (acc_grad, new_grad) in accumulated_grads.iter_mut().zip(new_grads.iter()) {
95            let scaled_new_grad = new_grad.scale(scale)?;
96            *acc_grad = acc_grad.add(&scaled_new_grad)?;
97        }
98
99        Ok(())
100    }
101
102    /// Zero out accumulated gradients
103    pub fn zero_accumulated_gradients(accumulated_grads: &mut Vec<Tensor>) -> Result<()> {
104        for grad in accumulated_grads.iter_mut() {
105            match grad {
106                Tensor::F32(arr) => {
107                    arr.fill(0.0);
108                },
109                Tensor::F64(arr) => {
110                    arr.fill(0.0);
111                },
112                _ => {
113                    return Err(tensor_op_error(
114                        "gradient_zeroing",
115                        "Unsupported tensor type for zeroing gradients",
116                    ))
117                },
118            }
119        }
120        Ok(())
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use trustformers_core::Tensor;
128
129    #[test]
130    fn test_clip_grad_norm() {
131        let mut gradients = vec![
132            Tensor::zeros(&[2, 2]).expect("tensor operation failed"),
133            Tensor::zeros(&[3, 3]).expect("tensor operation failed"),
134        ];
135
136        // Set some values
137        if let Tensor::F32(ref mut arr) = gradients[0] {
138            arr[[0, 0]] = 3.0;
139            arr[[0, 1]] = 4.0; // norm = 5.0
140        }
141
142        if let Tensor::F32(ref mut arr) = gradients[1] {
143            arr[[0, 0]] = 6.0;
144            arr[[0, 1]] = 8.0; // norm = 10.0
145        }
146
147        // Total norm should be sqrt(5^2 + 10^2) = sqrt(125) ≈ 11.18
148        let norm =
149            GradientUtils::clip_grad_norm(&mut gradients, 5.0).expect("operation failed in test");
150        assert!(norm > 11.0 && norm < 12.0);
151
152        // After clipping, gradients should be scaled down
153        if let Tensor::F32(ref arr) = gradients[0] {
154            assert!(arr[[0, 0]] < 3.0);
155            assert!(arr[[0, 1]] < 4.0);
156        }
157    }
158
159    #[test]
160    fn test_clip_grad_value() {
161        let mut gradients = vec![Tensor::zeros(&[2, 2]).expect("tensor operation failed")];
162
163        // Set some values that exceed clip_value
164        if let Tensor::F32(ref mut arr) = gradients[0] {
165            arr[[0, 0]] = 10.0;
166            arr[[0, 1]] = -15.0;
167            arr[[1, 0]] = 2.0;
168            arr[[1, 1]] = -3.0;
169        }
170
171        GradientUtils::clip_grad_value(&mut gradients, 5.0).expect("operation failed in test");
172
173        // Check that values are clipped
174        if let Tensor::F32(ref arr) = gradients[0] {
175            assert_eq!(arr[[0, 0]], 5.0); // 10.0 clipped to 5.0
176            assert_eq!(arr[[0, 1]], -5.0); // -15.0 clipped to -5.0
177            assert_eq!(arr[[1, 0]], 2.0); // 2.0 unchanged
178            assert_eq!(arr[[1, 1]], -3.0); // -3.0 unchanged
179        }
180    }
181}