Skip to main content

tensorlogic_train/optimizers/
sgd.rs

1//! SGD optimizer with momentum.
2
3use super::common::{compute_gradient_norm, GradClipMode, Optimizer, OptimizerConfig};
4use crate::{TrainError, TrainResult};
5use scirs2_core::ndarray::{Array, Ix2};
6use std::collections::HashMap;
7
8/// SGD optimizer with momentum.
9#[derive(Debug)]
10pub struct SgdOptimizer {
11    config: OptimizerConfig,
12    /// Momentum buffers for each parameter.
13    velocity: HashMap<String, Array<f64, Ix2>>,
14}
15
16impl SgdOptimizer {
17    /// Create a new SGD optimizer.
18    pub fn new(config: OptimizerConfig) -> Self {
19        Self {
20            config,
21            velocity: HashMap::new(),
22        }
23    }
24
25    /// Apply gradient clipping if configured.
26    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
27        if let Some(clip_value) = self.config.grad_clip {
28            match self.config.grad_clip_mode {
29                GradClipMode::Value => {
30                    // Clip by value (element-wise)
31                    for grad in gradients.values_mut() {
32                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
33                    }
34                }
35                GradClipMode::Norm => {
36                    // Clip by global L2 norm
37                    let total_norm = compute_gradient_norm(gradients);
38
39                    if total_norm > clip_value {
40                        let scale = clip_value / total_norm;
41                        for grad in gradients.values_mut() {
42                            grad.mapv_inplace(|g| g * scale);
43                        }
44                    }
45                }
46            }
47        }
48    }
49}
50
51impl Optimizer for SgdOptimizer {
52    fn step(
53        &mut self,
54        parameters: &mut HashMap<String, Array<f64, Ix2>>,
55        gradients: &HashMap<String, Array<f64, Ix2>>,
56    ) -> TrainResult<()> {
57        let mut clipped_gradients = gradients.clone();
58        self.clip_gradients(&mut clipped_gradients);
59
60        for (name, param) in parameters.iter_mut() {
61            let grad = clipped_gradients.get(name).ok_or_else(|| {
62                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
63            })?;
64
65            // Initialize velocity if not present
66            if !self.velocity.contains_key(name) {
67                self.velocity
68                    .insert(name.clone(), Array::zeros(param.raw_dim()));
69            }
70
71            let velocity = self.velocity.get_mut(name).unwrap();
72
73            // Update velocity: v = momentum * v + lr * grad
74            velocity.mapv_inplace(|v| self.config.momentum * v);
75            *velocity = &*velocity + &(grad * self.config.learning_rate);
76
77            // Update parameter: param = param - velocity
78            *param = &*param - &*velocity;
79        }
80
81        Ok(())
82    }
83
84    fn zero_grad(&mut self) {
85        // Gradients are managed externally, nothing to do here
86    }
87
88    fn get_lr(&self) -> f64 {
89        self.config.learning_rate
90    }
91
92    fn set_lr(&mut self, lr: f64) {
93        self.config.learning_rate = lr;
94    }
95
96    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
97        let mut state = HashMap::new();
98        for (name, velocity) in &self.velocity {
99            state.insert(
100                format!("velocity_{}", name),
101                velocity.iter().copied().collect(),
102            );
103        }
104        state
105    }
106
107    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
108        for (key, values) in state {
109            if let Some(name) = key.strip_prefix("velocity_") {
110                // Reconstruct array from values (assumes correct shape)
111                if let Some(velocity) = self.velocity.get(name) {
112                    let shape = velocity.raw_dim();
113                    if let Ok(new_velocity) = Array::from_shape_vec(shape, values) {
114                        self.velocity.insert(name.to_string(), new_velocity);
115                    }
116                }
117            }
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use scirs2_core::array;
126
127    #[test]
128    fn test_sgd_optimizer() {
129        let config = OptimizerConfig {
130            learning_rate: 0.1,
131            momentum: 0.9,
132            ..Default::default()
133        };
134        let mut optimizer = SgdOptimizer::new(config);
135
136        let mut params = HashMap::new();
137        params.insert("w".to_string(), array![[1.0, 2.0]]);
138
139        let mut grads = HashMap::new();
140        grads.insert("w".to_string(), array![[0.1, 0.1]]);
141
142        optimizer.step(&mut params, &grads).unwrap();
143
144        let w = params.get("w").unwrap();
145        assert!(w[[0, 0]] < 1.0); // Should decrease
146        assert!(w[[0, 1]] < 2.0);
147
148        // Test state dict
149        let state = optimizer.state_dict();
150        assert!(state.contains_key("velocity_w"));
151    }
152
153    #[test]
154    fn test_gradient_clipping() {
155        let config = OptimizerConfig {
156            learning_rate: 0.1,
157            grad_clip: Some(0.05),
158            grad_clip_mode: GradClipMode::Value,
159            ..Default::default()
160        };
161        let mut optimizer = SgdOptimizer::new(config);
162
163        let mut params = HashMap::new();
164        params.insert("w".to_string(), array![[1.0]]);
165
166        let mut grads = HashMap::new();
167        grads.insert("w".to_string(), array![[1.0]]); // Large gradient
168
169        optimizer.step(&mut params, &grads).unwrap();
170
171        // Gradient should be clipped to 0.05, so update should be small
172        let w = params.get("w").unwrap();
173        assert!((w[[0, 0]] - 1.0).abs() < 0.1); // Small change due to clipping
174    }
175}