Skip to main content

yscv_optim/
lookahead.rs

1use std::collections::HashMap;
2
3use yscv_tensor::Tensor;
4
5use super::error::OptimError;
6use super::{Adagrad, Adam, AdamW, Lamb, Lars, RAdam, RmsProp, Sgd};
7
8mod sealed {
9    pub trait Sealed {}
10}
11
12/// Trait for optimizers that support a per-parameter `step` update.
13///
14/// This trait is sealed and cannot be implemented outside this crate.
15pub trait StepOptimizer: sealed::Sealed {
16    fn step(
17        &mut self,
18        parameter_id: u64,
19        weights: &mut Tensor,
20        grad: &Tensor,
21    ) -> Result<(), OptimError>;
22}
23
24macro_rules! impl_step_optimizer {
25    ($($ty:ty),*) => {
26        $(
27            impl sealed::Sealed for $ty {}
28            impl StepOptimizer for $ty {
29                fn step(
30                    &mut self,
31                    parameter_id: u64,
32                    weights: &mut Tensor,
33                    grad: &Tensor,
34                ) -> Result<(), OptimError> {
35                    <$ty>::step(self, parameter_id, weights, grad)
36                }
37            }
38        )*
39    };
40}
41
42impl_step_optimizer!(Sgd, Adam, AdamW, RmsProp, Adagrad, RAdam, Lamb, Lars);
43
44/// Lookahead optimizer wrapper.
45///
46/// Maintains "slow weights" that are periodically interpolated toward the fast
47/// weights produced by the inner optimizer.  Every `k` calls to `step`, the
48/// slow weights are updated via:
49///
50/// ```text
51/// slow_w = slow_w + alpha * (fast_w - slow_w)
52/// fast_w = slow_w
53/// ```
54#[derive(Debug, Clone)]
55pub struct Lookahead<O> {
56    inner: O,
57    alpha: f32,
58    k: usize,
59    step_count: usize,
60    slow_weights: HashMap<u64, Vec<f32>>,
61}
62
63impl<O: StepOptimizer> Lookahead<O> {
64    /// Creates a new `Lookahead` wrapper around the given optimizer.
65    ///
66    /// * `alpha` — interpolation coefficient (typically 0.5).
67    /// * `k` — synchronisation period (typically 5).
68    pub fn new(inner: O, alpha: f32, k: usize) -> Self {
69        Self {
70            inner,
71            alpha,
72            k,
73            step_count: 0,
74            slow_weights: HashMap::new(),
75        }
76    }
77
78    /// Performs one optimisation step.
79    ///
80    /// 1. Delegates to the inner optimizer to update the fast weights.
81    /// 2. Increments the internal step counter.
82    /// 3. Every `k` steps, synchronises slow and fast weights.
83    pub fn step(
84        &mut self,
85        parameter_id: u64,
86        weights: &mut Tensor,
87        grad: &Tensor,
88    ) -> Result<(), OptimError> {
89        // Inner (fast) update.
90        self.inner.step(parameter_id, weights, grad)?;
91
92        // Initialise slow weights on first encounter.
93        self.slow_weights
94            .entry(parameter_id)
95            .or_insert_with(|| weights.data().to_vec());
96
97        self.step_count += 1;
98
99        // Synchronise every k steps.
100        if self.step_count.is_multiple_of(self.k) {
101            let slow = self
102                .slow_weights
103                .get_mut(&parameter_id)
104                .expect("slow weights must exist");
105            let fast = weights.data_mut();
106            for (s, f) in slow.iter_mut().zip(fast.iter_mut()) {
107                *s += self.alpha * (*f - *s);
108                *f = *s;
109            }
110        }
111
112        Ok(())
113    }
114}