tensorlogic_train/optimizers/
lookahead.rs1use super::common::Optimizer;
9use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{Array, Ix2};
11use std::collections::HashMap;
12
13#[derive(Debug)]
20pub struct LookaheadOptimizer<O: Optimizer> {
21 inner_optimizer: O,
23 slow_weights: HashMap<String, Array<f64, Ix2>>,
25 alpha: f64,
27 k: usize,
29 step_counter: usize,
31}
32
33impl<O: Optimizer> LookaheadOptimizer<O> {
34 pub fn new(inner_optimizer: O, alpha: f64, k: usize) -> TrainResult<Self> {
41 if !(0.0..=1.0).contains(&alpha) {
42 return Err(TrainError::InvalidParameter(
43 "alpha must be in [0, 1]".to_string(),
44 ));
45 }
46 if k == 0 {
47 return Err(TrainError::InvalidParameter(
48 "k must be at least 1".to_string(),
49 ));
50 }
51 Ok(Self {
52 inner_optimizer,
53 slow_weights: HashMap::new(),
54 alpha,
55 k,
56 step_counter: 0,
57 })
58 }
59
60 fn initialize_slow_weights(&mut self, parameters: &HashMap<String, Array<f64, Ix2>>) {
62 if self.slow_weights.is_empty() {
63 for (name, param) in parameters {
64 self.slow_weights.insert(name.clone(), param.clone());
65 }
66 }
67 }
68
69 fn synchronize_weights(&mut self, parameters: &mut HashMap<String, Array<f64, Ix2>>) {
71 for (name, param) in parameters.iter_mut() {
72 if let Some(slow_weight) = self.slow_weights.get_mut(name) {
73 *slow_weight = &*slow_weight + &((&*param - &*slow_weight) * self.alpha);
74 *param = slow_weight.clone();
75 }
76 }
77 }
78}
79
80impl<O: Optimizer> Optimizer for LookaheadOptimizer<O> {
81 fn step(
82 &mut self,
83 parameters: &mut HashMap<String, Array<f64, Ix2>>,
84 gradients: &HashMap<String, Array<f64, Ix2>>,
85 ) -> TrainResult<()> {
86 self.initialize_slow_weights(parameters);
87 self.inner_optimizer.step(parameters, gradients)?;
88 self.step_counter += 1;
89 if self.step_counter.is_multiple_of(self.k) {
90 self.synchronize_weights(parameters);
91 }
92 Ok(())
93 }
94
95 fn zero_grad(&mut self) {
96 self.inner_optimizer.zero_grad();
97 }
98
99 fn get_lr(&self) -> f64 {
100 self.inner_optimizer.get_lr()
101 }
102
103 fn set_lr(&mut self, lr: f64) {
104 self.inner_optimizer.set_lr(lr);
105 }
106
107 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
108 let mut state = self.inner_optimizer.state_dict();
109 state.insert("step_counter".to_string(), vec![self.step_counter as f64]);
110 state.insert("alpha".to_string(), vec![self.alpha]);
111 state.insert("k".to_string(), vec![self.k as f64]);
112 for (name, slow_weight) in &self.slow_weights {
113 state.insert(
114 format!("slow_{}", name),
115 slow_weight.iter().copied().collect(),
116 );
117 }
118 state
119 }
120
121 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
122 self.inner_optimizer.load_state_dict(state.clone());
123 if let Some(counter) = state.get("step_counter") {
124 self.step_counter = counter[0] as usize;
125 }
126 if let Some(alpha_val) = state.get("alpha") {
127 self.alpha = alpha_val[0];
128 }
129 if let Some(k_val) = state.get("k") {
130 self.k = k_val[0] as usize;
131 }
132 for (key, values) in state {
133 if let Some(name) = key.strip_prefix("slow_") {
134 if let Some(slow_weight) = self.slow_weights.get(name) {
135 let shape = slow_weight.raw_dim();
136 if let Ok(arr) = Array::from_shape_vec(shape, values) {
137 self.slow_weights.insert(name.to_string(), arr);
138 }
139 }
140 }
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::super::adam::AdamOptimizer;
148 use super::super::common::OptimizerConfig;
149 use super::super::sgd::SgdOptimizer;
150 use super::*;
151 use scirs2_core::ndarray::array;
152
153 #[test]
154 fn test_lookahead_optimizer() {
155 let inner_config = OptimizerConfig {
156 learning_rate: 0.01,
157 ..Default::default()
158 };
159 let inner_optimizer = AdamOptimizer::new(inner_config);
160 let mut optimizer = LookaheadOptimizer::new(inner_optimizer, 0.5, 5).unwrap();
161 let mut params = HashMap::new();
162 params.insert("w".to_string(), array![[1.0, 2.0]]);
163 let mut grads = HashMap::new();
164 grads.insert("w".to_string(), array![[0.1, 0.1]]);
165 for _ in 0..10 {
166 optimizer.step(&mut params, &grads).unwrap();
167 }
168 let w = params.get("w").unwrap();
169 assert!(w[[0, 0]] < 1.0);
170 assert!(w[[0, 1]] < 2.0);
171 assert_eq!(optimizer.get_lr(), 0.01);
172 optimizer.set_lr(0.02);
173 assert_eq!(optimizer.get_lr(), 0.02);
174 }
175
176 #[test]
177 fn test_lookahead_invalid_alpha() {
178 let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
179 let result = LookaheadOptimizer::new(inner_optimizer, 1.5, 5);
180 assert!(result.is_err());
181 let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
182 let result = LookaheadOptimizer::new(inner_optimizer, -0.1, 5);
183 assert!(result.is_err());
184 }
185
186 #[test]
187 fn test_lookahead_invalid_k() {
188 let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
189 let result = LookaheadOptimizer::new(inner_optimizer, 0.5, 0);
190 assert!(result.is_err());
191 }
192
193 #[test]
194 fn test_lookahead_synchronization() {
195 let inner_config = OptimizerConfig {
196 learning_rate: 0.1,
197 ..Default::default()
198 };
199 let inner_optimizer = SgdOptimizer::new(inner_config);
200 let mut optimizer = LookaheadOptimizer::new(inner_optimizer, 0.5, 3).unwrap();
201 let mut params = HashMap::new();
202 params.insert("w".to_string(), array![[1.0]]);
203 let mut grads = HashMap::new();
204 grads.insert("w".to_string(), array![[0.1]]);
205 let initial_w = params.get("w").unwrap()[[0, 0]];
206 for _ in 0..3 {
207 optimizer.step(&mut params, &grads).unwrap();
208 }
209 let w_after_sync = params.get("w").unwrap()[[0, 0]];
210 assert_ne!(w_after_sync, initial_w);
211 assert!(w_after_sync < initial_w);
212 }
213}