trustformers_training/
gradient.rs1use trustformers_core::errors::{tensor_op_error, Result};
2use trustformers_core::Tensor;
3
4pub struct GradientUtils;
6
7impl GradientUtils {
8 pub fn clip_grad_norm(gradients: &mut Vec<Tensor>, max_norm: f32) -> Result<f32> {
20 if gradients.is_empty() {
21 return Ok(0.0);
22 }
23
24 let mut total_norm_sq = 0.0;
26 for grad in gradients.iter() {
27 let norm = grad.norm()?;
28 total_norm_sq += norm * norm;
29 }
30 let total_norm = total_norm_sq.sqrt();
31
32 if total_norm > max_norm {
34 let clip_coef = max_norm / total_norm;
35 for grad in gradients.iter_mut() {
36 *grad = grad.scale(clip_coef)?;
37 }
38 }
39
40 Ok(total_norm)
41 }
42
43 pub fn clip_grad_value(gradients: &mut Vec<Tensor>, clip_value: f32) -> Result<()> {
51 for grad in gradients.iter_mut() {
52 match grad {
53 Tensor::F32(arr) => {
54 arr.mapv_inplace(|x| x.clamp(-clip_value, clip_value));
55 },
56 Tensor::F64(arr) => {
57 let clip_value_f64 = clip_value as f64;
58 arr.mapv_inplace(|x| x.clamp(-clip_value_f64, clip_value_f64));
59 },
60 _ => {
61 return Err(tensor_op_error(
62 "gradient_clipping",
63 "Unsupported tensor type for gradient value clipping",
64 ))
65 },
66 }
67 }
68 Ok(())
69 }
70
71 pub fn accumulate_gradients(
81 accumulated_grads: &mut Vec<Tensor>,
82 new_grads: &[Tensor],
83 accumulation_steps: usize,
84 ) -> Result<()> {
85 if accumulated_grads.len() != new_grads.len() {
86 return Err(tensor_op_error(
87 "gradient_accumulation",
88 "mismatched number of gradients",
89 ));
90 }
91
92 let scale = 1.0 / accumulation_steps as f32;
93
94 for (acc_grad, new_grad) in accumulated_grads.iter_mut().zip(new_grads.iter()) {
95 let scaled_new_grad = new_grad.scale(scale)?;
96 *acc_grad = acc_grad.add(&scaled_new_grad)?;
97 }
98
99 Ok(())
100 }
101
102 pub fn zero_accumulated_gradients(accumulated_grads: &mut Vec<Tensor>) -> Result<()> {
104 for grad in accumulated_grads.iter_mut() {
105 match grad {
106 Tensor::F32(arr) => {
107 arr.fill(0.0);
108 },
109 Tensor::F64(arr) => {
110 arr.fill(0.0);
111 },
112 _ => {
113 return Err(tensor_op_error(
114 "gradient_zeroing",
115 "Unsupported tensor type for zeroing gradients",
116 ))
117 },
118 }
119 }
120 Ok(())
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use trustformers_core::Tensor;
128
129 #[test]
130 fn test_clip_grad_norm() {
131 let mut gradients = vec![
132 Tensor::zeros(&[2, 2]).expect("tensor operation failed"),
133 Tensor::zeros(&[3, 3]).expect("tensor operation failed"),
134 ];
135
136 if let Tensor::F32(ref mut arr) = gradients[0] {
138 arr[[0, 0]] = 3.0;
139 arr[[0, 1]] = 4.0; }
141
142 if let Tensor::F32(ref mut arr) = gradients[1] {
143 arr[[0, 0]] = 6.0;
144 arr[[0, 1]] = 8.0; }
146
147 let norm =
149 GradientUtils::clip_grad_norm(&mut gradients, 5.0).expect("operation failed in test");
150 assert!(norm > 11.0 && norm < 12.0);
151
152 if let Tensor::F32(ref arr) = gradients[0] {
154 assert!(arr[[0, 0]] < 3.0);
155 assert!(arr[[0, 1]] < 4.0);
156 }
157 }
158
159 #[test]
160 fn test_clip_grad_value() {
161 let mut gradients = vec![Tensor::zeros(&[2, 2]).expect("tensor operation failed")];
162
163 if let Tensor::F32(ref mut arr) = gradients[0] {
165 arr[[0, 0]] = 10.0;
166 arr[[0, 1]] = -15.0;
167 arr[[1, 0]] = 2.0;
168 arr[[1, 1]] = -3.0;
169 }
170
171 GradientUtils::clip_grad_value(&mut gradients, 5.0).expect("operation failed in test");
172
173 if let Tensor::F32(ref arr) = gradients[0] {
175 assert_eq!(arr[[0, 0]], 5.0); assert_eq!(arr[[0, 1]], -5.0); assert_eq!(arr[[1, 0]], 2.0); assert_eq!(arr[[1, 1]], -3.0); }
180 }
181}