1use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use trustformers_core::errors::{Result, TrustformersError};
17use trustformers_core::tensor::Tensor;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct OptimizerState {
25 pub step: usize,
27
28 pub momentum: HashMap<String, Vec<f32>>,
30
31 pub variance: HashMap<String, Vec<f32>>,
33
34 pub third_moment: HashMap<String, Vec<f32>>,
36
37 pub param_steps: HashMap<String, usize>,
39
40 pub velocity: HashMap<String, Vec<f32>>,
42}
43
44impl OptimizerState {
45 pub fn new() -> Self {
47 Self {
48 step: 0,
49 momentum: HashMap::new(),
50 variance: HashMap::new(),
51 third_moment: HashMap::new(),
52 param_steps: HashMap::new(),
53 velocity: HashMap::new(),
54 }
55 }
56
57 pub fn get_or_create_momentum(&mut self, param_id: String, size: usize) -> &mut Vec<f32> {
59 self.momentum.entry(param_id).or_insert_with(|| vec![0.0; size])
60 }
61
62 pub fn get_or_create_variance(&mut self, param_id: String, size: usize) -> &mut Vec<f32> {
64 self.variance.entry(param_id).or_insert_with(|| vec![0.0; size])
65 }
66
67 pub fn get_or_create_third_moment(&mut self, param_id: String, size: usize) -> &mut Vec<f32> {
69 self.third_moment.entry(param_id).or_insert_with(|| vec![0.0; size])
70 }
71
72 pub fn step(&mut self) {
74 self.step += 1;
75 }
76
77 pub fn step_param(&mut self, param_id: String) {
79 *self.param_steps.entry(param_id).or_insert(0) += 1;
80 }
81
82 pub fn get_param_step(&self, param_id: &str) -> usize {
84 self.param_steps.get(param_id).copied().unwrap_or(0)
85 }
86
87 pub fn clear(&mut self) {
89 self.step = 0;
90 self.momentum.clear();
91 self.variance.clear();
92 self.third_moment.clear();
93 self.param_steps.clear();
94 }
95
96 pub fn memory_usage(&self) -> StateMemoryStats {
98 let momentum_size: usize = self.momentum.values().map(|v| v.len()).sum();
99 let variance_size: usize = self.variance.values().map(|v| v.len()).sum();
100 let third_moment_size: usize = self.third_moment.values().map(|v| v.len()).sum();
101
102 StateMemoryStats {
103 momentum_elements: momentum_size,
104 variance_elements: variance_size,
105 third_moment_elements: third_moment_size,
106 total_bytes: (momentum_size + variance_size + third_moment_size)
107 * std::mem::size_of::<f32>(),
108 num_parameters: self.momentum.len(),
109 }
110 }
111}
112
113impl Default for OptimizerState {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct StateMemoryStats {
122 pub momentum_elements: usize,
123 pub variance_elements: usize,
124 pub third_moment_elements: usize,
125 pub total_bytes: usize,
126 pub num_parameters: usize,
127}
128
129pub struct BiasCorrection;
131
132impl BiasCorrection {
133 pub fn compute_correction(beta: f32, step: usize) -> f32 {
142 1.0 - beta.powi(step as i32)
143 }
144
145 pub fn apply_correction(value: f32, beta: f32, step: usize) -> f32 {
153 value / Self::compute_correction(beta, step)
154 }
155
156 pub fn compute_adam_corrections(beta1: f32, beta2: f32, step: usize) -> (f32, f32) {
162 (
163 Self::compute_correction(beta1, step),
164 Self::compute_correction(beta2, step),
165 )
166 }
167}
168
169#[derive(Debug, Clone)]
171pub enum WeightDecayMode {
172 L2Regularization,
174 Decoupled,
176}
177
178pub struct ParameterUpdate;
180
181impl ParameterUpdate {
182 pub fn apply_l2_regularization(grad: f32, param: f32, weight_decay: f32) -> f32 {
190 grad + weight_decay * param
191 }
192
193 pub fn apply_decoupled_weight_decay(param: &mut f32, lr: f32, weight_decay: f32) {
201 *param *= 1.0 - lr * weight_decay;
202 }
203
204 pub fn adam_update(param: &mut f32, lr: f32, m_hat: f32, v_hat: f32, eps: f32) {
214 *param -= lr * m_hat / (v_hat.sqrt() + eps);
215 }
216
217 pub fn sgd_momentum_update(param: &mut f32, lr: f32, momentum: f32) {
225 *param -= lr * momentum;
226 }
227
228 pub fn update_sgd_momentum(
238 momentum: &mut f32,
239 grad: f32,
240 momentum_coeff: f32,
241 dampening: f32,
242 nesterov: bool,
243 ) -> f32 {
244 *momentum = momentum_coeff * *momentum + (1.0 - dampening) * grad;
245 if nesterov {
246 grad + momentum_coeff * *momentum
247 } else {
248 *momentum
249 }
250 }
251
252 pub fn update_ema(ema: &mut f32, value: f32, beta: f32) {
260 *ema = beta * *ema + (1.0 - beta) * value;
261 }
262}
263
264#[derive(Debug, Clone)]
266pub struct GradientProcessor;
267
268impl GradientProcessor {
269 pub fn clip_by_norm(grad: &mut [f32], max_norm: f32) {
276 let norm: f32 = grad.iter().map(|g| g * g).sum::<f32>().sqrt();
277 if norm > max_norm {
278 let scale = max_norm / norm;
279 for g in grad.iter_mut() {
280 *g *= scale;
281 }
282 }
283 }
284
285 pub fn clip_by_value(grad: &mut [f32], min_value: f32, max_value: f32) {
293 for g in grad.iter_mut() {
294 *g = g.clamp(min_value, max_value);
295 }
296 }
297
298 pub fn scale_gradient(grad: &mut [f32], scale: f32) {
305 for g in grad.iter_mut() {
306 *g *= scale;
307 }
308 }
309
310 pub fn is_finite(grad: &[f32]) -> bool {
320 grad.iter().all(|g| g.is_finite())
321 }
322}
323
324pub struct ParameterIds;
326
327impl ParameterIds {
328 pub fn from_tensor(tensor: &Tensor) -> Result<String> {
334 match tensor {
335 Tensor::F32(data) => Ok(format!("{:p}", data.as_ptr())),
336 _ => Err(TrustformersError::tensor_op_error(
337 "Unsupported tensor type for parameter ID",
338 "from_tensor",
339 )),
340 }
341 }
342
343 pub fn from_name(name: &str) -> String {
349 name.to_string()
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn test_optimizer_state_creation() {
359 let state = OptimizerState::new();
360 assert_eq!(state.step, 0);
361 assert!(state.momentum.is_empty());
362 assert!(state.variance.is_empty());
363 }
364
365 #[test]
366 fn test_bias_correction() {
367 let correction1 = BiasCorrection::compute_correction(0.9, 1);
368 assert!((correction1 - 0.1).abs() < 1e-6);
369
370 let correction2 = BiasCorrection::compute_correction(0.999, 1);
371 assert!((correction2 - 0.001).abs() < 1e-6);
372
373 let corrected = BiasCorrection::apply_correction(0.09, 0.9, 1);
374 assert!((corrected - 0.9).abs() < 1e-6);
375 }
376
377 #[test]
378 fn test_parameter_update() {
379 let mut param = 1.0;
380 ParameterUpdate::apply_decoupled_weight_decay(&mut param, 0.01, 0.1);
381 assert!((param - 0.999).abs() < 1e-6);
382
383 let mut param2 = 1.0;
384 ParameterUpdate::adam_update(&mut param2, 0.01, 0.1, 0.01, 1e-8);
385 assert!((param2 - 0.99).abs() < 1e-6);
386 }
387
388 #[test]
389 fn test_gradient_processing() {
390 let mut grad = vec![3.0, 4.0];
391 GradientProcessor::clip_by_norm(&mut grad, 1.0);
392 let norm: f32 = grad.iter().map(|g| g * g).sum::<f32>().sqrt();
393 assert!((norm - 1.0).abs() < 1e-6);
394
395 assert!(GradientProcessor::is_finite(&grad));
396
397 let bad_grad = vec![f32::NAN, 1.0];
398 assert!(!GradientProcessor::is_finite(&bad_grad));
399 }
400
401 #[test]
402 fn test_memory_stats() {
403 let mut state = OptimizerState::new();
404 state.get_or_create_momentum("param1".to_string(), 100);
405 state.get_or_create_variance("param1".to_string(), 100);
406
407 let stats = state.memory_usage();
408 assert_eq!(stats.momentum_elements, 100);
409 assert_eq!(stats.variance_elements, 100);
410 assert_eq!(stats.num_parameters, 1);
411 assert_eq!(stats.total_bytes, 200 * std::mem::size_of::<f32>());
412 }
413
414 #[test]
415 fn test_ema_update() {
416 let mut ema = 0.0;
417 ParameterUpdate::update_ema(&mut ema, 1.0, 0.9);
418 assert!((ema - 0.1).abs() < 1e-6);
419
420 ParameterUpdate::update_ema(&mut ema, 1.0, 0.9);
421 assert!((ema - 0.19).abs() < 1e-6);
422 }
423}