tiny_recursive_rs/training/
ema.rs

1/// Exponential Moving Average for model weights
2///
3/// Maintains a moving average of model parameters for improved stability
4/// and generalization during training.
5use candle_core::{Result, Tensor};
6use std::collections::HashMap;
7
8/// EMA configuration
9#[derive(Debug, Clone)]
10pub struct EMAConfig {
11    /// Decay rate for exponential moving average
12    /// EMA_weight = decay * EMA_weight + (1 - decay) * weight
13    pub decay: f64,
14}
15
16impl Default for EMAConfig {
17    fn default() -> Self {
18        Self {
19            decay: 0.9999, // Common value for model EMA
20        }
21    }
22}
23
24/// Exponential Moving Average
25///
26/// Maintains shadow copies of model parameters that are updated
27/// with exponential moving average.
28pub struct EMA {
29    config: EMAConfig,
30    shadow_params: HashMap<usize, Tensor>,
31}
32
33impl EMA {
34    /// Create new EMA
35    pub fn new(config: EMAConfig) -> Self {
36        Self {
37            config,
38            shadow_params: HashMap::new(),
39        }
40    }
41
42    /// Update EMA parameters
43    ///
44    /// # Arguments
45    /// * `params` - Current model parameters
46    ///
47    /// # Returns
48    /// Result indicating success or error
49    pub fn update(&mut self, params: &[Tensor]) -> Result<()> {
50        for (i, param) in params.iter().enumerate() {
51            // Get or create shadow parameter
52            let shadow = self.shadow_params.entry(i).or_insert_with(|| {
53                // Initialize shadow to current parameter value
54                param.clone()
55            });
56
57            // Update: shadow = decay * shadow + (1 - decay) * param
58            *shadow = ((shadow.clone() * self.config.decay)?
59                + (param * (1.0 - self.config.decay))?)?;
60        }
61
62        Ok(())
63    }
64
65    /// Get EMA parameters
66    ///
67    /// # Returns
68    /// Vector of EMA'd parameters
69    pub fn get_params(&self) -> Vec<Tensor> {
70        let mut params = Vec::new();
71        for i in 0..self.shadow_params.len() {
72            if let Some(shadow) = self.shadow_params.get(&i) {
73                params.push(shadow.clone());
74            }
75        }
76        params
77    }
78
79    /// Copy EMA parameters to model
80    ///
81    /// # Arguments
82    /// * `params` - Model parameters to update
83    pub fn copy_to(&self, params: &mut [Tensor]) -> Result<()> {
84        for (i, param) in params.iter_mut().enumerate() {
85            if let Some(shadow) = self.shadow_params.get(&i) {
86                *param = shadow.clone();
87            }
88        }
89        Ok(())
90    }
91
92    /// Copy model parameters to EMA
93    ///
94    /// # Arguments
95    /// * `params` - Model parameters to copy from
96    pub fn copy_from(&mut self, params: &[Tensor]) {
97        for (i, param) in params.iter().enumerate() {
98            self.shadow_params.insert(i, param.clone());
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use candle_core::Device;
107
108    #[test]
109    fn test_ema_creation() {
110        let config = EMAConfig::default();
111        let ema = EMA::new(config);
112
113        assert_eq!(ema.shadow_params.len(), 0);
114    }
115
116    #[test]
117    fn test_ema_update() -> Result<()> {
118        let device = Device::Cpu;
119        let param = Tensor::ones((10, 10), candle_core::DType::F32, &device)?;
120
121        let config = EMAConfig { decay: 0.9 };
122        let mut ema = EMA::new(config);
123
124        // First update initializes shadow
125        ema.update(&[param.clone()])?;
126
127        // Shadow should be initialized to param value
128        let shadow = &ema.shadow_params[&0];
129        let diff = (shadow.clone() - param.clone())?.abs()?.sum_all()?.to_scalar::<f32>()?;
130        assert!(diff < 1e-6);
131
132        Ok(())
133    }
134
135    #[test]
136    fn test_ema_smoothing() -> Result<()> {
137        let device = Device::Cpu;
138
139        let config = EMAConfig { decay: 0.9 };
140        let mut ema = EMA::new(config);
141
142        // Start with ones
143        let param1 = Tensor::ones((5, 5), candle_core::DType::F32, &device)?;
144        ema.update(&[param1.clone()])?;
145
146        // Update with zeros - EMA should be between 0 and 1
147        let param2 = Tensor::zeros((5, 5), candle_core::DType::F32, &device)?;
148        ema.update(&[param2.clone()])?;
149
150        let shadow = &ema.shadow_params[&0];
151        let mean_val = shadow.mean_all()?.to_scalar::<f32>()?;
152
153        // Should be decay * 1 + (1 - decay) * 0 = 0.9
154        assert!((mean_val - 0.9).abs() < 1e-6);
155
156        Ok(())
157    }
158
159    #[test]
160    fn test_copy_to() -> Result<()> {
161        let device = Device::Cpu;
162
163        let config = EMAConfig { decay: 0.95 };
164        let mut ema = EMA::new(config);
165
166        // Initialize EMA
167        let param = Tensor::ones((5, 5), candle_core::DType::F32, &device)?;
168        ema.update(&[param.clone()])?;
169
170        // Update EMA
171        let param2 = Tensor::zeros((5, 5), candle_core::DType::F32, &device)?;
172        ema.update(&[param2.clone()])?;
173
174        // Copy EMA back to params
175        let mut params = vec![Tensor::ones((5, 5), candle_core::DType::F32, &device)?];
176        ema.copy_to(&mut params)?;
177
178        // Params should now match EMA shadow
179        let expected = 0.95; // decay * 1 + (1 - decay) * 0
180        let actual = params[0].mean_all()?.to_scalar::<f32>()?;
181        assert!((actual - expected).abs() < 1e-6);
182
183        Ok(())
184    }
185
186    #[test]
187    fn test_copy_from() -> Result<()> {
188        let device = Device::Cpu;
189
190        let config = EMAConfig::default();
191        let mut ema = EMA::new(config);
192
193        // Create params
194        let param = Tensor::full(2.0f32, (5, 5), &device)?;
195
196        // Copy params to EMA
197        ema.copy_from(&[param.clone()]);
198
199        // Shadow should match param
200        let shadow = &ema.shadow_params[&0];
201        let diff = (shadow.clone() - param)?.abs()?.sum_all()?.to_scalar::<f32>()?;
202        assert!(diff < 1e-6);
203
204        Ok(())
205    }
206}