tensorlogic_train/optimizers/
sgd.rs1use super::common::{compute_gradient_norm, GradClipMode, Optimizer, OptimizerConfig};
4use crate::{TrainError, TrainResult};
5use scirs2_core::ndarray::{Array, Ix2};
6use std::collections::HashMap;
7
8#[derive(Debug)]
10pub struct SgdOptimizer {
11 config: OptimizerConfig,
12 velocity: HashMap<String, Array<f64, Ix2>>,
14}
15
16impl SgdOptimizer {
17 pub fn new(config: OptimizerConfig) -> Self {
19 Self {
20 config,
21 velocity: HashMap::new(),
22 }
23 }
24
25 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
27 if let Some(clip_value) = self.config.grad_clip {
28 match self.config.grad_clip_mode {
29 GradClipMode::Value => {
30 for grad in gradients.values_mut() {
32 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
33 }
34 }
35 GradClipMode::Norm => {
36 let total_norm = compute_gradient_norm(gradients);
38
39 if total_norm > clip_value {
40 let scale = clip_value / total_norm;
41 for grad in gradients.values_mut() {
42 grad.mapv_inplace(|g| g * scale);
43 }
44 }
45 }
46 }
47 }
48 }
49}
50
51impl Optimizer for SgdOptimizer {
52 fn step(
53 &mut self,
54 parameters: &mut HashMap<String, Array<f64, Ix2>>,
55 gradients: &HashMap<String, Array<f64, Ix2>>,
56 ) -> TrainResult<()> {
57 let mut clipped_gradients = gradients.clone();
58 self.clip_gradients(&mut clipped_gradients);
59
60 for (name, param) in parameters.iter_mut() {
61 let grad = clipped_gradients.get(name).ok_or_else(|| {
62 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
63 })?;
64
65 if !self.velocity.contains_key(name) {
67 self.velocity
68 .insert(name.clone(), Array::zeros(param.raw_dim()));
69 }
70
71 let velocity = self.velocity.get_mut(name).unwrap();
72
73 velocity.mapv_inplace(|v| self.config.momentum * v);
75 *velocity = &*velocity + &(grad * self.config.learning_rate);
76
77 *param = &*param - &*velocity;
79 }
80
81 Ok(())
82 }
83
84 fn zero_grad(&mut self) {
85 }
87
88 fn get_lr(&self) -> f64 {
89 self.config.learning_rate
90 }
91
92 fn set_lr(&mut self, lr: f64) {
93 self.config.learning_rate = lr;
94 }
95
96 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
97 let mut state = HashMap::new();
98 for (name, velocity) in &self.velocity {
99 state.insert(
100 format!("velocity_{}", name),
101 velocity.iter().copied().collect(),
102 );
103 }
104 state
105 }
106
107 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
108 for (key, values) in state {
109 if let Some(name) = key.strip_prefix("velocity_") {
110 if let Some(velocity) = self.velocity.get(name) {
112 let shape = velocity.raw_dim();
113 if let Ok(new_velocity) = Array::from_shape_vec(shape, values) {
114 self.velocity.insert(name.to_string(), new_velocity);
115 }
116 }
117 }
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use scirs2_core::array;
126
127 #[test]
128 fn test_sgd_optimizer() {
129 let config = OptimizerConfig {
130 learning_rate: 0.1,
131 momentum: 0.9,
132 ..Default::default()
133 };
134 let mut optimizer = SgdOptimizer::new(config);
135
136 let mut params = HashMap::new();
137 params.insert("w".to_string(), array![[1.0, 2.0]]);
138
139 let mut grads = HashMap::new();
140 grads.insert("w".to_string(), array![[0.1, 0.1]]);
141
142 optimizer.step(&mut params, &grads).unwrap();
143
144 let w = params.get("w").unwrap();
145 assert!(w[[0, 0]] < 1.0); assert!(w[[0, 1]] < 2.0);
147
148 let state = optimizer.state_dict();
150 assert!(state.contains_key("velocity_w"));
151 }
152
153 #[test]
154 fn test_gradient_clipping() {
155 let config = OptimizerConfig {
156 learning_rate: 0.1,
157 grad_clip: Some(0.05),
158 grad_clip_mode: GradClipMode::Value,
159 ..Default::default()
160 };
161 let mut optimizer = SgdOptimizer::new(config);
162
163 let mut params = HashMap::new();
164 params.insert("w".to_string(), array![[1.0]]);
165
166 let mut grads = HashMap::new();
167 grads.insert("w".to_string(), array![[1.0]]); optimizer.step(&mut params, &grads).unwrap();
170
171 let w = params.get("w").unwrap();
173 assert!((w[[0, 0]] - 1.0).abs() < 0.1); }
175}