tensorlogic_train/optimizers/
lion.rs1use crate::error::TrainResult;
13use scirs2_core::ndarray::Array1;
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
18pub struct LionConfig {
19 pub learning_rate: f64,
21 pub beta1: f64,
23 pub beta2: f64,
25 pub weight_decay: f64,
27}
28
29impl Default for LionConfig {
30 fn default() -> Self {
31 Self {
32 learning_rate: 1e-4,
33 beta1: 0.9,
34 beta2: 0.99,
35 weight_decay: 0.0,
36 }
37 }
38}
39
40pub struct LionOptimizer {
53 config: LionConfig,
54 momentum: HashMap<String, Array1<f64>>,
56}
57
58impl LionOptimizer {
59 pub fn new(config: LionConfig) -> TrainResult<Self> {
61 if config.learning_rate <= 0.0 {
62 return Err(crate::error::TrainError::ConfigError(
63 "Learning rate must be positive".to_string(),
64 ));
65 }
66 if !(0.0..1.0).contains(&config.beta1) {
67 return Err(crate::error::TrainError::ConfigError(
68 "beta1 must be in [0, 1)".to_string(),
69 ));
70 }
71 if !(0.0..1.0).contains(&config.beta2) {
72 return Err(crate::error::TrainError::ConfigError(
73 "beta2 must be in [0, 1)".to_string(),
74 ));
75 }
76 if config.weight_decay < 0.0 {
77 return Err(crate::error::TrainError::ConfigError(
78 "weight_decay must be non-negative".to_string(),
79 ));
80 }
81
82 Ok(Self {
83 config,
84 momentum: HashMap::new(),
85 })
86 }
87
88 pub fn step(
90 &mut self,
91 params: &mut HashMap<String, Array1<f64>>,
92 gradients: &HashMap<String, Array1<f64>>,
93 ) -> TrainResult<()> {
94 for (name, param) in params.iter_mut() {
95 if let Some(grad) = gradients.get(name) {
96 let momentum = self
98 .momentum
99 .entry(name.clone())
100 .or_insert_with(|| Array1::zeros(param.len()));
101
102 let update_direction = momentum.mapv(|m| m * self.config.beta1)
105 + grad.mapv(|g| g * (1.0 - self.config.beta1));
106
107 for i in 0..param.len() {
110 let sign_update = if update_direction[i] > 0.0 {
111 1.0
112 } else if update_direction[i] < 0.0 {
113 -1.0
114 } else {
115 0.0
116 };
117
118 let update = sign_update + self.config.weight_decay * param[i];
119 param[i] -= self.config.learning_rate * update;
120 }
121
122 *momentum = momentum.mapv(|m| m * self.config.beta2)
125 + grad.mapv(|g| g * (1.0 - self.config.beta2));
126 }
127 }
128
129 Ok(())
130 }
131
132 pub fn get_lr(&self) -> f64 {
134 self.config.learning_rate
135 }
136
137 pub fn set_lr(&mut self, lr: f64) {
139 self.config.learning_rate = lr;
140 }
141
142 pub fn state_dict(&self) -> HashMap<String, Vec<f64>> {
144 self.momentum
145 .iter()
146 .map(|(k, v)| (format!("momentum.{}", k), v.to_vec()))
147 .collect()
148 }
149
150 pub fn load_state_dict(&mut self, state: &HashMap<String, Vec<f64>>) -> TrainResult<()> {
152 for (key, value) in state {
153 if let Some(param_name) = key.strip_prefix("momentum.") {
154 self.momentum
155 .insert(param_name.to_string(), Array1::from_vec(value.clone()));
156 }
157 }
158 Ok(())
159 }
160
161 pub fn reset(&mut self) {
163 self.momentum.clear();
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use scirs2_core::ndarray::Array1;
171 use std::collections::HashMap;
172
173 #[test]
174 fn test_lion_optimizer() {
175 let config = LionConfig::default();
176 let mut optimizer = LionOptimizer::new(config).unwrap();
177
178 let mut params = HashMap::new();
179 params.insert("w".to_string(), Array1::from_vec(vec![1.0, 2.0, 3.0]));
180
181 let mut gradients = HashMap::new();
182 gradients.insert("w".to_string(), Array1::from_vec(vec![0.1, 0.2, 0.3]));
183
184 optimizer.step(&mut params, &gradients).unwrap();
186
187 let w = params.get("w").unwrap();
189 assert!(w[0] < 1.0);
190 assert!(w[1] < 2.0);
191 assert!(w[2] < 3.0);
192 }
193
194 #[test]
195 fn test_lion_with_weight_decay() {
196 let config = LionConfig {
197 learning_rate: 1e-3,
198 beta1: 0.9,
199 beta2: 0.99,
200 weight_decay: 0.01,
201 };
202 let mut optimizer = LionOptimizer::new(config).unwrap();
203
204 let mut params = HashMap::new();
205 params.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0]));
206
207 let mut gradients = HashMap::new();
208 gradients.insert("w".to_string(), Array1::from_vec(vec![0.1, 0.1]));
209
210 let initial_w = params.get("w").unwrap()[0];
211
212 optimizer.step(&mut params, &gradients).unwrap();
213
214 let updated_w = params.get("w").unwrap()[0];
215 assert!(updated_w < initial_w);
217 }
218
219 #[test]
220 fn test_lion_sign_based_update() {
221 let config = LionConfig {
222 learning_rate: 1e-2,
223 beta1: 0.0, beta2: 0.0,
225 weight_decay: 0.0,
226 };
227 let mut optimizer = LionOptimizer::new(config).unwrap();
228
229 let mut params = HashMap::new();
230 params.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0, 1.0]));
231
232 let mut gradients = HashMap::new();
233 gradients.insert(
234 "w".to_string(),
235 Array1::from_vec(vec![0.1, 1.0, 100.0]), );
237
238 optimizer.step(&mut params, &gradients).unwrap();
239
240 let w = params.get("w").unwrap();
241 let delta0 = 1.0 - w[0];
243 let delta1 = 1.0 - w[1];
244 let delta2 = 1.0 - w[2];
245
246 assert!((delta0 - delta1).abs() < 1e-10);
247 assert!((delta1 - delta2).abs() < 1e-10);
248 }
249
250 #[test]
251 fn test_lion_state_dict() {
252 let config = LionConfig::default();
253 let mut optimizer = LionOptimizer::new(config).unwrap();
254
255 let mut params = HashMap::new();
256 params.insert("w".to_string(), Array1::from_vec(vec![1.0, 2.0]));
257
258 let mut gradients = HashMap::new();
259 gradients.insert("w".to_string(), Array1::from_vec(vec![0.1, 0.2]));
260
261 optimizer.step(&mut params, &gradients).unwrap();
262
263 let state = optimizer.state_dict();
265 assert!(state.contains_key("momentum.w"));
266
267 let mut optimizer2 = LionOptimizer::new(LionConfig::default()).unwrap();
269 optimizer2.load_state_dict(&state).unwrap();
270
271 assert_eq!(
273 optimizer.momentum.get("w").unwrap().to_vec(),
274 optimizer2.momentum.get("w").unwrap().to_vec()
275 );
276 }
277
278 #[test]
279 fn test_lion_lr_schedule() {
280 let config = LionConfig::default();
281 let mut optimizer = LionOptimizer::new(config).unwrap();
282
283 assert!((optimizer.get_lr() - 1e-4).abs() < 1e-10);
284
285 optimizer.set_lr(1e-3);
286 assert!((optimizer.get_lr() - 1e-3).abs() < 1e-10);
287 }
288
289 #[test]
290 fn test_lion_invalid_config() {
291 let config = LionConfig {
292 learning_rate: -1.0,
293 ..Default::default()
294 };
295 assert!(LionOptimizer::new(config).is_err());
296
297 let config = LionConfig {
298 beta1: 1.5,
299 ..Default::default()
300 };
301 assert!(LionOptimizer::new(config).is_err());
302
303 let config = LionConfig {
304 beta2: -0.1,
305 ..Default::default()
306 };
307 assert!(LionOptimizer::new(config).is_err());
308
309 let config = LionConfig {
310 weight_decay: -0.1,
311 ..Default::default()
312 };
313 assert!(LionOptimizer::new(config).is_err());
314 }
315
316 #[test]
317 fn test_lion_reset() {
318 let config = LionConfig::default();
319 let mut optimizer = LionOptimizer::new(config).unwrap();
320
321 let mut params = HashMap::new();
322 params.insert("w".to_string(), Array1::from_vec(vec![1.0]));
323
324 let mut gradients = HashMap::new();
325 gradients.insert("w".to_string(), Array1::from_vec(vec![0.1]));
326
327 optimizer.step(&mut params, &gradients).unwrap();
328 assert!(!optimizer.momentum.is_empty());
329
330 optimizer.reset();
331 assert!(optimizer.momentum.is_empty());
332 }
333}