tiny_recursive_rs/training/
ema.rs1use candle_core::{Result, Tensor};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone)]
10pub struct EMAConfig {
11 pub decay: f64,
14}
15
16impl Default for EMAConfig {
17 fn default() -> Self {
18 Self {
19 decay: 0.9999, }
21 }
22}
23
24pub struct EMA {
29 config: EMAConfig,
30 shadow_params: HashMap<usize, Tensor>,
31}
32
33impl EMA {
34 pub fn new(config: EMAConfig) -> Self {
36 Self {
37 config,
38 shadow_params: HashMap::new(),
39 }
40 }
41
42 pub fn update(&mut self, params: &[Tensor]) -> Result<()> {
50 for (i, param) in params.iter().enumerate() {
51 let shadow = self.shadow_params.entry(i).or_insert_with(|| {
53 param.clone()
55 });
56
57 *shadow = ((shadow.clone() * self.config.decay)?
59 + (param * (1.0 - self.config.decay))?)?;
60 }
61
62 Ok(())
63 }
64
65 pub fn get_params(&self) -> Vec<Tensor> {
70 let mut params = Vec::new();
71 for i in 0..self.shadow_params.len() {
72 if let Some(shadow) = self.shadow_params.get(&i) {
73 params.push(shadow.clone());
74 }
75 }
76 params
77 }
78
79 pub fn copy_to(&self, params: &mut [Tensor]) -> Result<()> {
84 for (i, param) in params.iter_mut().enumerate() {
85 if let Some(shadow) = self.shadow_params.get(&i) {
86 *param = shadow.clone();
87 }
88 }
89 Ok(())
90 }
91
92 pub fn copy_from(&mut self, params: &[Tensor]) {
97 for (i, param) in params.iter().enumerate() {
98 self.shadow_params.insert(i, param.clone());
99 }
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use candle_core::Device;
107
108 #[test]
109 fn test_ema_creation() {
110 let config = EMAConfig::default();
111 let ema = EMA::new(config);
112
113 assert_eq!(ema.shadow_params.len(), 0);
114 }
115
116 #[test]
117 fn test_ema_update() -> Result<()> {
118 let device = Device::Cpu;
119 let param = Tensor::ones((10, 10), candle_core::DType::F32, &device)?;
120
121 let config = EMAConfig { decay: 0.9 };
122 let mut ema = EMA::new(config);
123
124 ema.update(&[param.clone()])?;
126
127 let shadow = &ema.shadow_params[&0];
129 let diff = (shadow.clone() - param.clone())?.abs()?.sum_all()?.to_scalar::<f32>()?;
130 assert!(diff < 1e-6);
131
132 Ok(())
133 }
134
135 #[test]
136 fn test_ema_smoothing() -> Result<()> {
137 let device = Device::Cpu;
138
139 let config = EMAConfig { decay: 0.9 };
140 let mut ema = EMA::new(config);
141
142 let param1 = Tensor::ones((5, 5), candle_core::DType::F32, &device)?;
144 ema.update(&[param1.clone()])?;
145
146 let param2 = Tensor::zeros((5, 5), candle_core::DType::F32, &device)?;
148 ema.update(&[param2.clone()])?;
149
150 let shadow = &ema.shadow_params[&0];
151 let mean_val = shadow.mean_all()?.to_scalar::<f32>()?;
152
153 assert!((mean_val - 0.9).abs() < 1e-6);
155
156 Ok(())
157 }
158
159 #[test]
160 fn test_copy_to() -> Result<()> {
161 let device = Device::Cpu;
162
163 let config = EMAConfig { decay: 0.95 };
164 let mut ema = EMA::new(config);
165
166 let param = Tensor::ones((5, 5), candle_core::DType::F32, &device)?;
168 ema.update(&[param.clone()])?;
169
170 let param2 = Tensor::zeros((5, 5), candle_core::DType::F32, &device)?;
172 ema.update(&[param2.clone()])?;
173
174 let mut params = vec![Tensor::ones((5, 5), candle_core::DType::F32, &device)?];
176 ema.copy_to(&mut params)?;
177
178 let expected = 0.95; let actual = params[0].mean_all()?.to_scalar::<f32>()?;
181 assert!((actual - expected).abs() < 1e-6);
182
183 Ok(())
184 }
185
186 #[test]
187 fn test_copy_from() -> Result<()> {
188 let device = Device::Cpu;
189
190 let config = EMAConfig::default();
191 let mut ema = EMA::new(config);
192
193 let param = Tensor::full(2.0f32, (5, 5), &device)?;
195
196 ema.copy_from(&[param.clone()]);
198
199 let shadow = &ema.shadow_params[&0];
201 let diff = (shadow.clone() - param)?.abs()?.sum_all()?.to_scalar::<f32>()?;
202 assert!(diff < 1e-6);
203
204 Ok(())
205 }
206}