Skip to main content

yscv_model/
ema.rs

1use yscv_tensor::Tensor;
2
3/// Exponential Moving Average of model parameters.
4///
5/// Maintains shadow copies of parameters that are updated as a weighted
6/// running average: `shadow = decay * shadow + (1 - decay) * param`.
7/// This is commonly used to produce a smoothed version of model weights
8/// that often generalises better at inference time.
9pub struct ExponentialMovingAverage {
10    decay: f32,
11    shadow_params: Vec<Tensor>,
12    num_updates: usize,
13}
14
15impl ExponentialMovingAverage {
16    /// Creates a new EMA tracker with the given decay factor (e.g. 0.999).
17    pub fn new(decay: f32) -> Self {
18        Self {
19            decay,
20            shadow_params: Vec::new(),
21            num_updates: 0,
22        }
23    }
24
25    /// Registers initial parameter values as shadow copies.
26    pub fn register(&mut self, params: &[Tensor]) {
27        self.shadow_params = params.to_vec();
28    }
29
30    /// Updates shadow parameters: `shadow = decay * shadow + (1 - decay) * param`.
31    ///
32    /// Panics if the number of tensors does not match the registered count or
33    /// if any tensor length differs from its shadow counterpart.
34    pub fn update(&mut self, params: &[Tensor]) {
35        assert_eq!(
36            params.len(),
37            self.shadow_params.len(),
38            "param count mismatch: expected {} but got {}",
39            self.shadow_params.len(),
40            params.len(),
41        );
42        let decay = self.decay;
43        let one_minus_decay = 1.0 - decay;
44        for (shadow, param) in self.shadow_params.iter_mut().zip(params.iter()) {
45            let s = shadow.data_mut();
46            let p = param.data();
47            assert_eq!(s.len(), p.len(), "tensor length mismatch in EMA update");
48            let len = s.len();
49            for i in 0..len {
50                s[i] = decay * s[i] + one_minus_decay * p[i];
51            }
52        }
53        self.num_updates += 1;
54    }
55
56    /// Returns a reference to the shadow parameters.
57    pub fn shadow_params(&self) -> &[Tensor] {
58        &self.shadow_params
59    }
60
61    /// Copies shadow parameter values into the provided mutable slice.
62    ///
63    /// Panics if the slice length does not match the shadow parameter count or
64    /// if any tensor length differs.
65    pub fn apply_shadow(&self, params: &mut [Tensor]) {
66        assert_eq!(
67            params.len(),
68            self.shadow_params.len(),
69            "param count mismatch in apply_shadow",
70        );
71        for (dst, src) in params.iter_mut().zip(self.shadow_params.iter()) {
72            let d = dst.data_mut();
73            let s = src.data();
74            assert_eq!(d.len(), s.len(), "tensor length mismatch in apply_shadow");
75            d.copy_from_slice(s);
76        }
77    }
78
79    /// Returns the number of update steps performed so far.
80    pub fn num_updates(&self) -> usize {
81        self.num_updates
82    }
83}