Skip to main content

tensorlogic_train/optimizers/
lion.rs

1//! Lion (EvoLved Sign Momentum) optimizer.
2//!
3//! Paper: "Symbolic Discovery of Optimization Algorithms" (Chen et al., 2023)
4//! <https://arxiv.org/abs/2302.06675>
5//!
6//! Lion is a simple yet effective optimizer that:
7//! - Uses only the sign of momentum for updates
8//! - Has only 2 hyperparameters (learning rate and betas)
9//! - Requires less memory than Adam (no second moment)
10//! - Often achieves better performance with larger batch sizes
11
12use crate::error::TrainResult;
13use scirs2_core::ndarray::Array1;
14use std::collections::HashMap;
15
16/// Lion optimizer configuration.
17#[derive(Debug, Clone)]
18pub struct LionConfig {
19    /// Learning rate (default: 1e-4, typically 3-10x smaller than Adam)
20    pub learning_rate: f64,
21    /// Momentum coefficient for update direction (default: 0.9)
22    pub beta1: f64,
23    /// Momentum coefficient for state update (default: 0.99)
24    pub beta2: f64,
25    /// Weight decay coefficient (default: 0.0)
26    pub weight_decay: f64,
27}
28
29impl Default for LionConfig {
30    fn default() -> Self {
31        Self {
32            learning_rate: 1e-4,
33            beta1: 0.9,
34            beta2: 0.99,
35            weight_decay: 0.0,
36        }
37    }
38}
39
40/// Lion optimizer.
41///
42/// Update rule:
43/// 1. c_t = β1 * m_{t-1} + (1 - β1) * g_t  (interpolation)
44/// 2. θ_t = θ_{t-1} - lr * (sign(c_t) + λ * θ_{t-1})  (parameter update with weight decay)
45/// 3. m_t = β2 * m_{t-1} + (1 - β2) * g_t  (momentum update)
46///
47/// Key differences from Adam:
48/// - Uses sign(momentum) instead of normalized gradients
49/// - Only tracks first moment (momentum), no second moment
50/// - Typically requires smaller learning rates than Adam
51/// - More memory efficient
52pub struct LionOptimizer {
53    config: LionConfig,
54    /// Momentum buffers (first moment)
55    momentum: HashMap<String, Array1<f64>>,
56}
57
58impl LionOptimizer {
59    /// Create a new Lion optimizer.
60    pub fn new(config: LionConfig) -> TrainResult<Self> {
61        if config.learning_rate <= 0.0 {
62            return Err(crate::error::TrainError::ConfigError(
63                "Learning rate must be positive".to_string(),
64            ));
65        }
66        if !(0.0..1.0).contains(&config.beta1) {
67            return Err(crate::error::TrainError::ConfigError(
68                "beta1 must be in [0, 1)".to_string(),
69            ));
70        }
71        if !(0.0..1.0).contains(&config.beta2) {
72            return Err(crate::error::TrainError::ConfigError(
73                "beta2 must be in [0, 1)".to_string(),
74            ));
75        }
76        if config.weight_decay < 0.0 {
77            return Err(crate::error::TrainError::ConfigError(
78                "weight_decay must be non-negative".to_string(),
79            ));
80        }
81
82        Ok(Self {
83            config,
84            momentum: HashMap::new(),
85        })
86    }
87
88    /// Perform a single optimization step.
89    pub fn step(
90        &mut self,
91        params: &mut HashMap<String, Array1<f64>>,
92        gradients: &HashMap<String, Array1<f64>>,
93    ) -> TrainResult<()> {
94        for (name, param) in params.iter_mut() {
95            if let Some(grad) = gradients.get(name) {
96                // Initialize momentum if needed
97                let momentum = self
98                    .momentum
99                    .entry(name.clone())
100                    .or_insert_with(|| Array1::zeros(param.len()));
101
102                // Step 1: Interpolate for update direction
103                // c_t = β1 * m_{t-1} + (1 - β1) * g_t
104                let update_direction = momentum.mapv(|m| m * self.config.beta1)
105                    + grad.mapv(|g| g * (1.0 - self.config.beta1));
106
107                // Step 2: Parameter update using sign of update direction
108                // θ_t = θ_{t-1} - lr * (sign(c_t) + λ * θ_{t-1})
109                for i in 0..param.len() {
110                    let sign_update = if update_direction[i] > 0.0 {
111                        1.0
112                    } else if update_direction[i] < 0.0 {
113                        -1.0
114                    } else {
115                        0.0
116                    };
117
118                    let update = sign_update + self.config.weight_decay * param[i];
119                    param[i] -= self.config.learning_rate * update;
120                }
121
122                // Step 3: Update momentum
123                // m_t = β2 * m_{t-1} + (1 - β2) * g_t
124                *momentum = momentum.mapv(|m| m * self.config.beta2)
125                    + grad.mapv(|g| g * (1.0 - self.config.beta2));
126            }
127        }
128
129        Ok(())
130    }
131
132    /// Get the current learning rate.
133    pub fn get_lr(&self) -> f64 {
134        self.config.learning_rate
135    }
136
137    /// Set the learning rate.
138    pub fn set_lr(&mut self, lr: f64) {
139        self.config.learning_rate = lr;
140    }
141
142    /// Get optimizer state for checkpointing.
143    pub fn state_dict(&self) -> HashMap<String, Vec<f64>> {
144        self.momentum
145            .iter()
146            .map(|(k, v)| (format!("momentum.{}", k), v.to_vec()))
147            .collect()
148    }
149
150    /// Load optimizer state from checkpoint.
151    pub fn load_state_dict(&mut self, state: &HashMap<String, Vec<f64>>) -> TrainResult<()> {
152        for (key, value) in state {
153            if let Some(param_name) = key.strip_prefix("momentum.") {
154                self.momentum
155                    .insert(param_name.to_string(), Array1::from_vec(value.clone()));
156            }
157        }
158        Ok(())
159    }
160
161    /// Reset optimizer state.
162    pub fn reset(&mut self) {
163        self.momentum.clear();
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use scirs2_core::ndarray::Array1;
171    use std::collections::HashMap;
172
173    #[test]
174    fn test_lion_optimizer() {
175        let config = LionConfig::default();
176        let mut optimizer = LionOptimizer::new(config).unwrap();
177
178        let mut params = HashMap::new();
179        params.insert("w".to_string(), Array1::from_vec(vec![1.0, 2.0, 3.0]));
180
181        let mut gradients = HashMap::new();
182        gradients.insert("w".to_string(), Array1::from_vec(vec![0.1, 0.2, 0.3]));
183
184        // Perform optimization step
185        optimizer.step(&mut params, &gradients).unwrap();
186
187        // Parameters should have changed
188        let w = params.get("w").unwrap();
189        assert!(w[0] < 1.0);
190        assert!(w[1] < 2.0);
191        assert!(w[2] < 3.0);
192    }
193
194    #[test]
195    fn test_lion_with_weight_decay() {
196        let config = LionConfig {
197            learning_rate: 1e-3,
198            beta1: 0.9,
199            beta2: 0.99,
200            weight_decay: 0.01,
201        };
202        let mut optimizer = LionOptimizer::new(config).unwrap();
203
204        let mut params = HashMap::new();
205        params.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0]));
206
207        let mut gradients = HashMap::new();
208        gradients.insert("w".to_string(), Array1::from_vec(vec![0.1, 0.1]));
209
210        let initial_w = params.get("w").unwrap()[0];
211
212        optimizer.step(&mut params, &gradients).unwrap();
213
214        let updated_w = params.get("w").unwrap()[0];
215        // With weight decay, the update should be larger
216        assert!(updated_w < initial_w);
217    }
218
219    #[test]
220    fn test_lion_sign_based_update() {
221        let config = LionConfig {
222            learning_rate: 1e-2,
223            beta1: 0.0, // No momentum for clearer test
224            beta2: 0.0,
225            weight_decay: 0.0,
226        };
227        let mut optimizer = LionOptimizer::new(config).unwrap();
228
229        let mut params = HashMap::new();
230        params.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0, 1.0]));
231
232        let mut gradients = HashMap::new();
233        gradients.insert(
234            "w".to_string(),
235            Array1::from_vec(vec![0.1, 1.0, 100.0]), // Different magnitudes
236        );
237
238        optimizer.step(&mut params, &gradients).unwrap();
239
240        let w = params.get("w").unwrap();
241        // All updates should be the same magnitude (sign-based)
242        let delta0 = 1.0 - w[0];
243        let delta1 = 1.0 - w[1];
244        let delta2 = 1.0 - w[2];
245
246        assert!((delta0 - delta1).abs() < 1e-10);
247        assert!((delta1 - delta2).abs() < 1e-10);
248    }
249
250    #[test]
251    fn test_lion_state_dict() {
252        let config = LionConfig::default();
253        let mut optimizer = LionOptimizer::new(config).unwrap();
254
255        let mut params = HashMap::new();
256        params.insert("w".to_string(), Array1::from_vec(vec![1.0, 2.0]));
257
258        let mut gradients = HashMap::new();
259        gradients.insert("w".to_string(), Array1::from_vec(vec![0.1, 0.2]));
260
261        optimizer.step(&mut params, &gradients).unwrap();
262
263        // Save state
264        let state = optimizer.state_dict();
265        assert!(state.contains_key("momentum.w"));
266
267        // Create new optimizer and load state
268        let mut optimizer2 = LionOptimizer::new(LionConfig::default()).unwrap();
269        optimizer2.load_state_dict(&state).unwrap();
270
271        // States should match
272        assert_eq!(
273            optimizer.momentum.get("w").unwrap().to_vec(),
274            optimizer2.momentum.get("w").unwrap().to_vec()
275        );
276    }
277
278    #[test]
279    fn test_lion_lr_schedule() {
280        let config = LionConfig::default();
281        let mut optimizer = LionOptimizer::new(config).unwrap();
282
283        assert!((optimizer.get_lr() - 1e-4).abs() < 1e-10);
284
285        optimizer.set_lr(1e-3);
286        assert!((optimizer.get_lr() - 1e-3).abs() < 1e-10);
287    }
288
289    #[test]
290    fn test_lion_invalid_config() {
291        let config = LionConfig {
292            learning_rate: -1.0,
293            ..Default::default()
294        };
295        assert!(LionOptimizer::new(config).is_err());
296
297        let config = LionConfig {
298            beta1: 1.5,
299            ..Default::default()
300        };
301        assert!(LionOptimizer::new(config).is_err());
302
303        let config = LionConfig {
304            beta2: -0.1,
305            ..Default::default()
306        };
307        assert!(LionOptimizer::new(config).is_err());
308
309        let config = LionConfig {
310            weight_decay: -0.1,
311            ..Default::default()
312        };
313        assert!(LionOptimizer::new(config).is_err());
314    }
315
316    #[test]
317    fn test_lion_reset() {
318        let config = LionConfig::default();
319        let mut optimizer = LionOptimizer::new(config).unwrap();
320
321        let mut params = HashMap::new();
322        params.insert("w".to_string(), Array1::from_vec(vec![1.0]));
323
324        let mut gradients = HashMap::new();
325        gradients.insert("w".to_string(), Array1::from_vec(vec![0.1]));
326
327        optimizer.step(&mut params, &gradients).unwrap();
328        assert!(!optimizer.momentum.is_empty());
329
330        optimizer.reset();
331        assert!(optimizer.momentum.is_empty());
332    }
333}