1use std::collections::HashMap;
2
3use yscv_tensor::Tensor;
4
5use super::error::OptimError;
6use super::{Adagrad, Adam, AdamW, Lamb, Lars, RAdam, RmsProp, Sgd};
7
8pub trait StepOptimizer {
10 fn step(
11 &mut self,
12 parameter_id: u64,
13 weights: &mut Tensor,
14 grad: &Tensor,
15 ) -> Result<(), OptimError>;
16}
17
18macro_rules! impl_step_optimizer {
19 ($($ty:ty),*) => {
20 $(
21 impl StepOptimizer for $ty {
22 fn step(
23 &mut self,
24 parameter_id: u64,
25 weights: &mut Tensor,
26 grad: &Tensor,
27 ) -> Result<(), OptimError> {
28 <$ty>::step(self, parameter_id, weights, grad)
29 }
30 }
31 )*
32 };
33}
34
35impl_step_optimizer!(Sgd, Adam, AdamW, RmsProp, Adagrad, RAdam, Lamb, Lars);
36
37#[derive(Debug, Clone)]
48pub struct Lookahead<O> {
49 inner: O,
50 alpha: f32,
51 k: usize,
52 step_count: usize,
53 slow_weights: HashMap<u64, Vec<f32>>,
54}
55
56impl<O: StepOptimizer> Lookahead<O> {
57 pub fn new(inner: O, alpha: f32, k: usize) -> Self {
62 Self {
63 inner,
64 alpha,
65 k,
66 step_count: 0,
67 slow_weights: HashMap::new(),
68 }
69 }
70
71 pub fn step(
77 &mut self,
78 parameter_id: u64,
79 weights: &mut Tensor,
80 grad: &Tensor,
81 ) -> Result<(), OptimError> {
82 self.inner.step(parameter_id, weights, grad)?;
84
85 self.slow_weights
87 .entry(parameter_id)
88 .or_insert_with(|| weights.data().to_vec());
89
90 self.step_count += 1;
91
92 if self.step_count.is_multiple_of(self.k) {
94 let slow = self
95 .slow_weights
96 .get_mut(¶meter_id)
97 .expect("slow weights must exist");
98 let fast = weights.data_mut();
99 for (s, f) in slow.iter_mut().zip(fast.iter_mut()) {
100 *s += self.alpha * (*f - *s);
101 *f = *s;
102 }
103 }
104
105 Ok(())
106 }
107}