Skip to main content

tensorlogic_train/optimizers/
lookahead.rs

1//! Lookahead optimizer (wrapper that uses slow and fast weights).
2//!
3//! Lookahead maintains two sets of weights: fast weights updated by an inner optimizer,
4//! and slow weights that are periodically updated as an exponential moving average.
5//!
6//! Reference: Zhang et al., "Lookahead Optimizer: k steps forward, 1 step back", NeurIPS 2019
7
8use super::common::Optimizer;
9use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{Array, Ix2};
11use std::collections::HashMap;
12
13/// Lookahead optimizer (wrapper that uses slow and fast weights).
14///
15/// Maintains two sets of weights: fast weights updated by an inner optimizer,
16/// and slow weights that are periodically updated as an exponential moving average.
17///
18/// Reference: Zhang et al., "Lookahead Optimizer: k steps forward, 1 step back", NeurIPS 2019
19#[derive(Debug)]
20pub struct LookaheadOptimizer<O: Optimizer> {
21    /// Inner optimizer for fast weights.
22    inner_optimizer: O,
23    /// Slow weights (maintained separately).
24    slow_weights: HashMap<String, Array<f64, Ix2>>,
25    /// Interpolation coefficient (typically 0.5).
26    alpha: f64,
27    /// Number of inner optimizer steps before synchronization.
28    k: usize,
29    /// Current step counter.
30    step_counter: usize,
31}
32
33impl<O: Optimizer> LookaheadOptimizer<O> {
34    /// Create a new Lookahead optimizer.
35    ///
36    /// # Arguments
37    /// * `inner_optimizer` - The inner optimizer (e.g., Adam, SGD)
38    /// * `alpha` - Interpolation coefficient for slow weight update (typically 0.5)
39    /// * `k` - Number of fast updates before slow weight synchronization (typically 5-10)
40    pub fn new(inner_optimizer: O, alpha: f64, k: usize) -> TrainResult<Self> {
41        if !(0.0..=1.0).contains(&alpha) {
42            return Err(TrainError::InvalidParameter(
43                "alpha must be in [0, 1]".to_string(),
44            ));
45        }
46        if k == 0 {
47            return Err(TrainError::InvalidParameter(
48                "k must be at least 1".to_string(),
49            ));
50        }
51        Ok(Self {
52            inner_optimizer,
53            slow_weights: HashMap::new(),
54            alpha,
55            k,
56            step_counter: 0,
57        })
58    }
59
60    /// Initialize slow weights from current parameters.
61    fn initialize_slow_weights(&mut self, parameters: &HashMap<String, Array<f64, Ix2>>) {
62        if self.slow_weights.is_empty() {
63            for (name, param) in parameters {
64                self.slow_weights.insert(name.clone(), param.clone());
65            }
66        }
67    }
68
69    /// Synchronize slow weights with fast weights.
70    fn synchronize_weights(&mut self, parameters: &mut HashMap<String, Array<f64, Ix2>>) {
71        for (name, param) in parameters.iter_mut() {
72            if let Some(slow_weight) = self.slow_weights.get_mut(name) {
73                *slow_weight = &*slow_weight + &((&*param - &*slow_weight) * self.alpha);
74                *param = slow_weight.clone();
75            }
76        }
77    }
78}
79
80impl<O: Optimizer> Optimizer for LookaheadOptimizer<O> {
81    fn step(
82        &mut self,
83        parameters: &mut HashMap<String, Array<f64, Ix2>>,
84        gradients: &HashMap<String, Array<f64, Ix2>>,
85    ) -> TrainResult<()> {
86        self.initialize_slow_weights(parameters);
87        self.inner_optimizer.step(parameters, gradients)?;
88        self.step_counter += 1;
89        if self.step_counter.is_multiple_of(self.k) {
90            self.synchronize_weights(parameters);
91        }
92        Ok(())
93    }
94
95    fn zero_grad(&mut self) {
96        self.inner_optimizer.zero_grad();
97    }
98
99    fn get_lr(&self) -> f64 {
100        self.inner_optimizer.get_lr()
101    }
102
103    fn set_lr(&mut self, lr: f64) {
104        self.inner_optimizer.set_lr(lr);
105    }
106
107    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
108        let mut state = self.inner_optimizer.state_dict();
109        state.insert("step_counter".to_string(), vec![self.step_counter as f64]);
110        state.insert("alpha".to_string(), vec![self.alpha]);
111        state.insert("k".to_string(), vec![self.k as f64]);
112        for (name, slow_weight) in &self.slow_weights {
113            state.insert(
114                format!("slow_{}", name),
115                slow_weight.iter().copied().collect(),
116            );
117        }
118        state
119    }
120
121    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
122        self.inner_optimizer.load_state_dict(state.clone());
123        if let Some(counter) = state.get("step_counter") {
124            self.step_counter = counter[0] as usize;
125        }
126        if let Some(alpha_val) = state.get("alpha") {
127            self.alpha = alpha_val[0];
128        }
129        if let Some(k_val) = state.get("k") {
130            self.k = k_val[0] as usize;
131        }
132        for (key, values) in state {
133            if let Some(name) = key.strip_prefix("slow_") {
134                if let Some(slow_weight) = self.slow_weights.get(name) {
135                    let shape = slow_weight.raw_dim();
136                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
137                        self.slow_weights.insert(name.to_string(), arr);
138                    }
139                }
140            }
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::super::adam::AdamOptimizer;
148    use super::super::common::OptimizerConfig;
149    use super::super::sgd::SgdOptimizer;
150    use super::*;
151    use scirs2_core::ndarray::array;
152
153    #[test]
154    fn test_lookahead_optimizer() {
155        let inner_config = OptimizerConfig {
156            learning_rate: 0.01,
157            ..Default::default()
158        };
159        let inner_optimizer = AdamOptimizer::new(inner_config);
160        let mut optimizer = LookaheadOptimizer::new(inner_optimizer, 0.5, 5).unwrap();
161        let mut params = HashMap::new();
162        params.insert("w".to_string(), array![[1.0, 2.0]]);
163        let mut grads = HashMap::new();
164        grads.insert("w".to_string(), array![[0.1, 0.1]]);
165        for _ in 0..10 {
166            optimizer.step(&mut params, &grads).unwrap();
167        }
168        let w = params.get("w").unwrap();
169        assert!(w[[0, 0]] < 1.0);
170        assert!(w[[0, 1]] < 2.0);
171        assert_eq!(optimizer.get_lr(), 0.01);
172        optimizer.set_lr(0.02);
173        assert_eq!(optimizer.get_lr(), 0.02);
174    }
175
176    #[test]
177    fn test_lookahead_invalid_alpha() {
178        let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
179        let result = LookaheadOptimizer::new(inner_optimizer, 1.5, 5);
180        assert!(result.is_err());
181        let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
182        let result = LookaheadOptimizer::new(inner_optimizer, -0.1, 5);
183        assert!(result.is_err());
184    }
185
186    #[test]
187    fn test_lookahead_invalid_k() {
188        let inner_optimizer = AdamOptimizer::new(OptimizerConfig::default());
189        let result = LookaheadOptimizer::new(inner_optimizer, 0.5, 0);
190        assert!(result.is_err());
191    }
192
193    #[test]
194    fn test_lookahead_synchronization() {
195        let inner_config = OptimizerConfig {
196            learning_rate: 0.1,
197            ..Default::default()
198        };
199        let inner_optimizer = SgdOptimizer::new(inner_config);
200        let mut optimizer = LookaheadOptimizer::new(inner_optimizer, 0.5, 3).unwrap();
201        let mut params = HashMap::new();
202        params.insert("w".to_string(), array![[1.0]]);
203        let mut grads = HashMap::new();
204        grads.insert("w".to_string(), array![[0.1]]);
205        let initial_w = params.get("w").unwrap()[[0, 0]];
206        for _ in 0..3 {
207            optimizer.step(&mut params, &grads).unwrap();
208        }
209        let w_after_sync = params.get("w").unwrap()[[0, 0]];
210        assert_ne!(w_after_sync, initial_w);
211        assert!(w_after_sync < initial_w);
212    }
213}