Skip to main content

tensorlogic_train/optimizers/
sgd.rs

1//! SGD optimizer with momentum.
2
3use super::common::{compute_gradient_norm, GradClipMode, Optimizer, OptimizerConfig};
4use crate::{TrainError, TrainResult};
5use scirs2_core::ndarray::{Array, Ix2};
6use std::collections::HashMap;
7
8/// SGD optimizer with momentum.
9#[derive(Debug)]
10pub struct SgdOptimizer {
11    config: OptimizerConfig,
12    /// Momentum buffers for each parameter.
13    velocity: HashMap<String, Array<f64, Ix2>>,
14}
15
16impl SgdOptimizer {
17    /// Create a new SGD optimizer.
18    pub fn new(config: OptimizerConfig) -> Self {
19        Self {
20            config,
21            velocity: HashMap::new(),
22        }
23    }
24
25    /// Apply gradient clipping if configured.
26    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
27        if let Some(clip_value) = self.config.grad_clip {
28            match self.config.grad_clip_mode {
29                GradClipMode::Value => {
30                    // Clip by value (element-wise)
31                    for grad in gradients.values_mut() {
32                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
33                    }
34                }
35                GradClipMode::Norm => {
36                    // Clip by global L2 norm
37                    let total_norm = compute_gradient_norm(gradients);
38
39                    if total_norm > clip_value {
40                        let scale = clip_value / total_norm;
41                        for grad in gradients.values_mut() {
42                            grad.mapv_inplace(|g| g * scale);
43                        }
44                    }
45                }
46            }
47        }
48    }
49}
50
51impl Optimizer for SgdOptimizer {
52    fn step(
53        &mut self,
54        parameters: &mut HashMap<String, Array<f64, Ix2>>,
55        gradients: &HashMap<String, Array<f64, Ix2>>,
56    ) -> TrainResult<()> {
57        let mut clipped_gradients = gradients.clone();
58        self.clip_gradients(&mut clipped_gradients);
59
60        for (name, param) in parameters.iter_mut() {
61            let grad = clipped_gradients.get(name).ok_or_else(|| {
62                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
63            })?;
64
65            // Initialize velocity if not present
66            if !self.velocity.contains_key(name) {
67                self.velocity
68                    .insert(name.clone(), Array::zeros(param.raw_dim()));
69            }
70
71            let velocity = self
72                .velocity
73                .get_mut(name)
74                .expect("velocity initialized for all parameters");
75
76            // Update velocity: v = momentum * v + lr * grad
77            velocity.mapv_inplace(|v| self.config.momentum * v);
78            *velocity = &*velocity + &(grad * self.config.learning_rate);
79
80            // Update parameter: param = param - velocity
81            *param = &*param - &*velocity;
82        }
83
84        Ok(())
85    }
86
87    fn zero_grad(&mut self) {
88        // Gradients are managed externally, nothing to do here
89    }
90
91    fn get_lr(&self) -> f64 {
92        self.config.learning_rate
93    }
94
95    fn set_lr(&mut self, lr: f64) {
96        self.config.learning_rate = lr;
97    }
98
99    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
100        let mut state = HashMap::new();
101        for (name, velocity) in &self.velocity {
102            state.insert(
103                format!("velocity_{}", name),
104                velocity.iter().copied().collect(),
105            );
106        }
107        state
108    }
109
110    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
111        for (key, values) in state {
112            if let Some(name) = key.strip_prefix("velocity_") {
113                // Reconstruct array from values (assumes correct shape)
114                if let Some(velocity) = self.velocity.get(name) {
115                    let shape = velocity.raw_dim();
116                    if let Ok(new_velocity) = Array::from_shape_vec(shape, values) {
117                        self.velocity.insert(name.to_string(), new_velocity);
118                    }
119                }
120            }
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use scirs2_core::array;
129
130    #[test]
131    fn test_sgd_optimizer() {
132        let config = OptimizerConfig {
133            learning_rate: 0.1,
134            momentum: 0.9,
135            ..Default::default()
136        };
137        let mut optimizer = SgdOptimizer::new(config);
138
139        let mut params = HashMap::new();
140        params.insert("w".to_string(), array![[1.0, 2.0]]);
141
142        let mut grads = HashMap::new();
143        grads.insert("w".to_string(), array![[0.1, 0.1]]);
144
145        optimizer.step(&mut params, &grads).expect("unwrap");
146
147        let w = params.get("w").expect("unwrap");
148        assert!(w[[0, 0]] < 1.0); // Should decrease
149        assert!(w[[0, 1]] < 2.0);
150
151        // Test state dict
152        let state = optimizer.state_dict();
153        assert!(state.contains_key("velocity_w"));
154    }
155
156    #[test]
157    fn test_gradient_clipping() {
158        let config = OptimizerConfig {
159            learning_rate: 0.1,
160            grad_clip: Some(0.05),
161            grad_clip_mode: GradClipMode::Value,
162            ..Default::default()
163        };
164        let mut optimizer = SgdOptimizer::new(config);
165
166        let mut params = HashMap::new();
167        params.insert("w".to_string(), array![[1.0]]);
168
169        let mut grads = HashMap::new();
170        grads.insert("w".to_string(), array![[1.0]]); // Large gradient
171
172        optimizer.step(&mut params, &grads).expect("unwrap");
173
174        // Gradient should be clipped to 0.05, so update should be small
175        let w = params.get("w").expect("unwrap");
176        assert!((w[[0, 0]] - 1.0).abs() < 0.1); // Small change due to clipping
177    }
178}