1use std::collections::HashMap;
2
3use yscv_tensor::Tensor;
4
5use super::error::OptimError;
6use super::{Adagrad, Adam, AdamW, Lamb, Lars, RAdam, RmsProp, Sgd};
7
8mod sealed {
9 pub trait Sealed {}
10}
11
12pub trait StepOptimizer: sealed::Sealed {
16 fn step(
17 &mut self,
18 parameter_id: u64,
19 weights: &mut Tensor,
20 grad: &Tensor,
21 ) -> Result<(), OptimError>;
22}
23
24macro_rules! impl_step_optimizer {
25 ($($ty:ty),*) => {
26 $(
27 impl sealed::Sealed for $ty {}
28 impl StepOptimizer for $ty {
29 fn step(
30 &mut self,
31 parameter_id: u64,
32 weights: &mut Tensor,
33 grad: &Tensor,
34 ) -> Result<(), OptimError> {
35 <$ty>::step(self, parameter_id, weights, grad)
36 }
37 }
38 )*
39 };
40}
41
42impl_step_optimizer!(Sgd, Adam, AdamW, RmsProp, Adagrad, RAdam, Lamb, Lars);
43
44#[derive(Debug, Clone)]
55pub struct Lookahead<O> {
56 inner: O,
57 alpha: f32,
58 k: usize,
59 step_count: usize,
60 slow_weights: HashMap<u64, Vec<f32>>,
61}
62
63impl<O: StepOptimizer> Lookahead<O> {
64 pub fn new(inner: O, alpha: f32, k: usize) -> Self {
69 Self {
70 inner,
71 alpha,
72 k,
73 step_count: 0,
74 slow_weights: HashMap::new(),
75 }
76 }
77
78 pub fn step(
84 &mut self,
85 parameter_id: u64,
86 weights: &mut Tensor,
87 grad: &Tensor,
88 ) -> Result<(), OptimError> {
89 self.inner.step(parameter_id, weights, grad)?;
91
92 self.slow_weights
94 .entry(parameter_id)
95 .or_insert_with(|| weights.data().to_vec());
96
97 self.step_count += 1;
98
99 if self.step_count.is_multiple_of(self.k) {
101 let slow = self
102 .slow_weights
103 .get_mut(¶meter_id)
104 .expect("slow weights must exist");
105 let fast = weights.data_mut();
106 for (s, f) in slow.iter_mut().zip(fast.iter_mut()) {
107 *s += self.alpha * (*f - *s);
108 *f = *s;
109 }
110 }
111
112 Ok(())
113 }
114}