yscv_optim/clip.rs
1use yscv_autograd::{Graph, NodeId};
2
3/// Clips the total norm of gradients for the given nodes in-place.
4///
5/// Computes the combined norm (controlled by `norm_type`, typically 2.0 for L2)
6/// across all gradient tensors for `node_ids`. If the total norm exceeds
7/// `max_norm`, every gradient is scaled by `max_norm / total_norm`.
8///
9/// Returns the computed total norm before clipping (useful for monitoring).
10///
11/// Nodes without gradients are silently skipped.
12/// If `max_norm` is not positive or `node_ids` is empty, no clipping is
13/// performed and the function returns 0.0.
14pub fn clip_grad_norm_(
15 graph: &mut Graph,
16 node_ids: &[NodeId],
17 max_norm: f32,
18 norm_type: f32,
19) -> f32 {
20 if node_ids.is_empty() || !max_norm.is_finite() || max_norm <= 0.0 {
21 return 0.0;
22 }
23
24 // Accumulate the total norm across all gradient tensors (read-only pass).
25 let mut total_norm: f32 = if norm_type == f32::INFINITY {
26 // Inf-norm: max absolute value across all gradients.
27 let mut max_val: f32 = 0.0;
28 for &id in node_ids {
29 if let Ok(Some(grad)) = graph.grad(id) {
30 for &v in grad.data() {
31 let abs = v.abs();
32 if abs > max_val {
33 max_val = abs;
34 }
35 }
36 }
37 }
38 max_val
39 } else {
40 // General p-norm: (sum |g_i|^p)^(1/p).
41 let mut acc: f32 = 0.0;
42 for &id in node_ids {
43 if let Ok(Some(grad)) = graph.grad(id) {
44 for &v in grad.data() {
45 acc += v.abs().powf(norm_type);
46 }
47 }
48 }
49 acc.powf(1.0 / norm_type)
50 };
51
52 if !total_norm.is_finite() {
53 total_norm = 0.0;
54 }
55
56 // Scale gradients in-place if total norm exceeds max_norm.
57 if total_norm > max_norm {
58 let scale = max_norm / total_norm;
59 for &id in node_ids {
60 if let Ok(Some(grad)) = graph.grad_mut(id) {
61 for v in grad.data_mut() {
62 *v *= scale;
63 }
64 }
65 }
66 }
67
68 total_norm
69}
70
71/// Clamps every gradient element to the range `[-max_val, max_val]` in-place.
72///
73/// Nodes without gradients are silently skipped.
74/// If `max_val` is not positive or `node_ids` is empty, no clamping is performed.
75pub fn clip_grad_value_(graph: &mut Graph, node_ids: &[NodeId], max_val: f32) {
76 if node_ids.is_empty() || !max_val.is_finite() || max_val <= 0.0 {
77 return;
78 }
79
80 for &id in node_ids {
81 if let Ok(Some(grad)) = graph.grad_mut(id) {
82 for v in grad.data_mut() {
83 if *v > max_val {
84 *v = max_val;
85 } else if *v < -max_val {
86 *v = -max_val;
87 }
88 }
89 }
90 }
91}