tensorlogic_train/optimizers/
sgd.rs1use super::common::{compute_gradient_norm, GradClipMode, Optimizer, OptimizerConfig};
4use crate::{TrainError, TrainResult};
5use scirs2_core::ndarray::{Array, Ix2};
6use std::collections::HashMap;
7
8#[derive(Debug)]
10pub struct SgdOptimizer {
11 config: OptimizerConfig,
12 velocity: HashMap<String, Array<f64, Ix2>>,
14}
15
16impl SgdOptimizer {
17 pub fn new(config: OptimizerConfig) -> Self {
19 Self {
20 config,
21 velocity: HashMap::new(),
22 }
23 }
24
25 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
27 if let Some(clip_value) = self.config.grad_clip {
28 match self.config.grad_clip_mode {
29 GradClipMode::Value => {
30 for grad in gradients.values_mut() {
32 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
33 }
34 }
35 GradClipMode::Norm => {
36 let total_norm = compute_gradient_norm(gradients);
38
39 if total_norm > clip_value {
40 let scale = clip_value / total_norm;
41 for grad in gradients.values_mut() {
42 grad.mapv_inplace(|g| g * scale);
43 }
44 }
45 }
46 }
47 }
48 }
49}
50
51impl Optimizer for SgdOptimizer {
52 fn step(
53 &mut self,
54 parameters: &mut HashMap<String, Array<f64, Ix2>>,
55 gradients: &HashMap<String, Array<f64, Ix2>>,
56 ) -> TrainResult<()> {
57 let mut clipped_gradients = gradients.clone();
58 self.clip_gradients(&mut clipped_gradients);
59
60 for (name, param) in parameters.iter_mut() {
61 let grad = clipped_gradients.get(name).ok_or_else(|| {
62 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
63 })?;
64
65 if !self.velocity.contains_key(name) {
67 self.velocity
68 .insert(name.clone(), Array::zeros(param.raw_dim()));
69 }
70
71 let velocity = self
72 .velocity
73 .get_mut(name)
74 .expect("velocity initialized for all parameters");
75
76 velocity.mapv_inplace(|v| self.config.momentum * v);
78 *velocity = &*velocity + &(grad * self.config.learning_rate);
79
80 *param = &*param - &*velocity;
82 }
83
84 Ok(())
85 }
86
87 fn zero_grad(&mut self) {
88 }
90
91 fn get_lr(&self) -> f64 {
92 self.config.learning_rate
93 }
94
95 fn set_lr(&mut self, lr: f64) {
96 self.config.learning_rate = lr;
97 }
98
99 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
100 let mut state = HashMap::new();
101 for (name, velocity) in &self.velocity {
102 state.insert(
103 format!("velocity_{}", name),
104 velocity.iter().copied().collect(),
105 );
106 }
107 state
108 }
109
110 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
111 for (key, values) in state {
112 if let Some(name) = key.strip_prefix("velocity_") {
113 if let Some(velocity) = self.velocity.get(name) {
115 let shape = velocity.raw_dim();
116 if let Ok(new_velocity) = Array::from_shape_vec(shape, values) {
117 self.velocity.insert(name.to_string(), new_velocity);
118 }
119 }
120 }
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use scirs2_core::array;
129
130 #[test]
131 fn test_sgd_optimizer() {
132 let config = OptimizerConfig {
133 learning_rate: 0.1,
134 momentum: 0.9,
135 ..Default::default()
136 };
137 let mut optimizer = SgdOptimizer::new(config);
138
139 let mut params = HashMap::new();
140 params.insert("w".to_string(), array![[1.0, 2.0]]);
141
142 let mut grads = HashMap::new();
143 grads.insert("w".to_string(), array![[0.1, 0.1]]);
144
145 optimizer.step(&mut params, &grads).expect("unwrap");
146
147 let w = params.get("w").expect("unwrap");
148 assert!(w[[0, 0]] < 1.0); assert!(w[[0, 1]] < 2.0);
150
151 let state = optimizer.state_dict();
153 assert!(state.contains_key("velocity_w"));
154 }
155
156 #[test]
157 fn test_gradient_clipping() {
158 let config = OptimizerConfig {
159 learning_rate: 0.1,
160 grad_clip: Some(0.05),
161 grad_clip_mode: GradClipMode::Value,
162 ..Default::default()
163 };
164 let mut optimizer = SgdOptimizer::new(config);
165
166 let mut params = HashMap::new();
167 params.insert("w".to_string(), array![[1.0]]);
168
169 let mut grads = HashMap::new();
170 grads.insert("w".to_string(), array![[1.0]]); optimizer.step(&mut params, &grads).expect("unwrap");
173
174 let w = params.get("w").expect("unwrap");
176 assert!((w[[0, 0]] - 1.0).abs() < 0.1); }
178}