Skip to main content

trustformers_optim/
common.rs

1//! Common optimization operations and utilities.
2//!
3//! This module provides shared functionality that is used across multiple optimizers,
4//! reducing code duplication and ensuring consistent behavior.
5//!
6//! # Features
7//!
8//! - **State Management**: Unified parameter state tracking
9//! - **Bias Correction**: Standard bias correction calculations for momentum methods
10//! - **Parameter Updates**: Common update patterns with weight decay variants
11//! - **Gradient Processing**: Shared gradient manipulation utilities
12//! - **Memory Management**: Efficient buffer allocation and reuse
13
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use trustformers_core::errors::{Result, TrustformersError};
17use trustformers_core::tensor::Tensor;
18
19/// Unified state management for optimizer parameters.
20///
21/// This struct provides a consistent interface for tracking optimizer state
22/// across different algorithms, reducing code duplication and memory overhead.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct OptimizerState {
25    /// Current step counter for bias correction and scheduling
26    pub step: usize,
27
28    /// First moment estimates (momentum buffers)
29    pub momentum: HashMap<String, Vec<f32>>,
30
31    /// Second moment estimates (squared gradient buffers)
32    pub variance: HashMap<String, Vec<f32>>,
33
34    /// Optional third moment estimates (for higher-order methods)
35    pub third_moment: HashMap<String, Vec<f32>>,
36
37    /// Per-parameter step counts (for adaptive methods)
38    pub param_steps: HashMap<String, usize>,
39
40    /// Velocity buffers for optimization methods like SGD with momentum
41    pub velocity: HashMap<String, Vec<f32>>,
42}
43
44impl OptimizerState {
45    /// Creates a new optimizer state with empty buffers.
46    pub fn new() -> Self {
47        Self {
48            step: 0,
49            momentum: HashMap::new(),
50            variance: HashMap::new(),
51            third_moment: HashMap::new(),
52            param_steps: HashMap::new(),
53            velocity: HashMap::new(),
54        }
55    }
56
57    /// Gets or creates momentum buffer for a parameter.
58    pub fn get_or_create_momentum(&mut self, param_id: String, size: usize) -> &mut Vec<f32> {
59        self.momentum.entry(param_id).or_insert_with(|| vec![0.0; size])
60    }
61
62    /// Gets or creates variance buffer for a parameter.
63    pub fn get_or_create_variance(&mut self, param_id: String, size: usize) -> &mut Vec<f32> {
64        self.variance.entry(param_id).or_insert_with(|| vec![0.0; size])
65    }
66
67    /// Gets or creates third moment buffer for a parameter.
68    pub fn get_or_create_third_moment(&mut self, param_id: String, size: usize) -> &mut Vec<f32> {
69        self.third_moment.entry(param_id).or_insert_with(|| vec![0.0; size])
70    }
71
72    /// Increments the global step counter.
73    pub fn step(&mut self) {
74        self.step += 1;
75    }
76
77    /// Increments the step counter for a specific parameter.
78    pub fn step_param(&mut self, param_id: String) {
79        *self.param_steps.entry(param_id).or_insert(0) += 1;
80    }
81
82    /// Gets the step count for a specific parameter.
83    pub fn get_param_step(&self, param_id: &str) -> usize {
84        self.param_steps.get(param_id).copied().unwrap_or(0)
85    }
86
87    /// Clears all state buffers to free memory.
88    pub fn clear(&mut self) {
89        self.step = 0;
90        self.momentum.clear();
91        self.variance.clear();
92        self.third_moment.clear();
93        self.param_steps.clear();
94    }
95
96    /// Gets memory usage statistics.
97    pub fn memory_usage(&self) -> StateMemoryStats {
98        let momentum_size: usize = self.momentum.values().map(|v| v.len()).sum();
99        let variance_size: usize = self.variance.values().map(|v| v.len()).sum();
100        let third_moment_size: usize = self.third_moment.values().map(|v| v.len()).sum();
101
102        StateMemoryStats {
103            momentum_elements: momentum_size,
104            variance_elements: variance_size,
105            third_moment_elements: third_moment_size,
106            total_bytes: (momentum_size + variance_size + third_moment_size)
107                * std::mem::size_of::<f32>(),
108            num_parameters: self.momentum.len(),
109        }
110    }
111}
112
113impl Default for OptimizerState {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119/// Memory usage statistics for optimizer state.
120#[derive(Debug, Clone)]
121pub struct StateMemoryStats {
122    pub momentum_elements: usize,
123    pub variance_elements: usize,
124    pub third_moment_elements: usize,
125    pub total_bytes: usize,
126    pub num_parameters: usize,
127}
128
129/// Common bias correction utilities for momentum-based optimizers.
130pub struct BiasCorrection;
131
132impl BiasCorrection {
133    /// Computes bias correction factor for exponential moving averages.
134    ///
135    /// Formula: 1 - beta^step
136    ///
137    /// # Arguments
138    ///
139    /// * `beta` - The exponential decay rate (e.g., 0.9 for momentum, 0.999 for variance)
140    /// * `step` - The current step number (1-indexed)
141    pub fn compute_correction(beta: f32, step: usize) -> f32 {
142        1.0 - beta.powi(step as i32)
143    }
144
145    /// Applies bias correction to a value.
146    ///
147    /// # Arguments
148    ///
149    /// * `value` - The biased estimate
150    /// * `beta` - The exponential decay rate
151    /// * `step` - The current step number (1-indexed)
152    pub fn apply_correction(value: f32, beta: f32, step: usize) -> f32 {
153        value / Self::compute_correction(beta, step)
154    }
155
156    /// Computes both first and second moment bias corrections.
157    ///
158    /// # Returns
159    ///
160    /// Tuple of (bias_correction1, bias_correction2) for Adam-style optimizers.
161    pub fn compute_adam_corrections(beta1: f32, beta2: f32, step: usize) -> (f32, f32) {
162        (
163            Self::compute_correction(beta1, step),
164            Self::compute_correction(beta2, step),
165        )
166    }
167}
168
169/// Weight decay application strategies.
170#[derive(Debug, Clone)]
171pub enum WeightDecayMode {
172    /// L2 regularization applied to gradients (traditional Adam)
173    L2Regularization,
174    /// Decoupled weight decay applied directly to parameters (AdamW style)
175    Decoupled,
176}
177
178/// Common parameter update operations.
179pub struct ParameterUpdate;
180
181impl ParameterUpdate {
182    /// Applies weight decay to gradients (L2 regularization).
183    ///
184    /// # Arguments
185    ///
186    /// * `grad` - The gradient value
187    /// * `param` - The parameter value
188    /// * `weight_decay` - The weight decay coefficient
189    pub fn apply_l2_regularization(grad: f32, param: f32, weight_decay: f32) -> f32 {
190        grad + weight_decay * param
191    }
192
193    /// Applies decoupled weight decay directly to parameter.
194    ///
195    /// # Arguments
196    ///
197    /// * `param` - The parameter value to update
198    /// * `lr` - The learning rate
199    /// * `weight_decay` - The weight decay coefficient
200    pub fn apply_decoupled_weight_decay(param: &mut f32, lr: f32, weight_decay: f32) {
201        *param *= 1.0 - lr * weight_decay;
202    }
203
204    /// Updates parameter using Adam-style formula.
205    ///
206    /// # Arguments
207    ///
208    /// * `param` - The parameter to update
209    /// * `lr` - Learning rate
210    /// * `m_hat` - Bias-corrected first moment
211    /// * `v_hat` - Bias-corrected second moment
212    /// * `eps` - Epsilon for numerical stability
213    pub fn adam_update(param: &mut f32, lr: f32, m_hat: f32, v_hat: f32, eps: f32) {
214        *param -= lr * m_hat / (v_hat.sqrt() + eps);
215    }
216
217    /// Updates parameter using SGD with momentum.
218    ///
219    /// # Arguments
220    ///
221    /// * `param` - The parameter to update
222    /// * `lr` - Learning rate
223    /// * `momentum` - Momentum buffer value
224    pub fn sgd_momentum_update(param: &mut f32, lr: f32, momentum: f32) {
225        *param -= lr * momentum;
226    }
227
228    /// Updates momentum buffer for SGD.
229    ///
230    /// # Arguments
231    ///
232    /// * `momentum` - The momentum buffer to update
233    /// * `grad` - The gradient
234    /// * `momentum_coeff` - Momentum coefficient (typically 0.9)
235    /// * `dampening` - Dampening factor (typically 0.0)
236    /// * `nesterov` - Whether to use Nesterov momentum
237    pub fn update_sgd_momentum(
238        momentum: &mut f32,
239        grad: f32,
240        momentum_coeff: f32,
241        dampening: f32,
242        nesterov: bool,
243    ) -> f32 {
244        *momentum = momentum_coeff * *momentum + (1.0 - dampening) * grad;
245        if nesterov {
246            grad + momentum_coeff * *momentum
247        } else {
248            *momentum
249        }
250    }
251
252    /// Updates exponential moving average (for Adam-style methods).
253    ///
254    /// # Arguments
255    ///
256    /// * `ema` - The exponential moving average to update
257    /// * `value` - The new value
258    /// * `beta` - The decay coefficient
259    pub fn update_ema(ema: &mut f32, value: f32, beta: f32) {
260        *ema = beta * *ema + (1.0 - beta) * value;
261    }
262}
263
264/// Gradient processing utilities.
265#[derive(Debug, Clone)]
266pub struct GradientProcessor;
267
268impl GradientProcessor {
269    /// Clips gradient by norm.
270    ///
271    /// # Arguments
272    ///
273    /// * `grad` - The gradient to clip
274    /// * `max_norm` - Maximum allowed norm
275    pub fn clip_by_norm(grad: &mut [f32], max_norm: f32) {
276        let norm: f32 = grad.iter().map(|g| g * g).sum::<f32>().sqrt();
277        if norm > max_norm {
278            let scale = max_norm / norm;
279            for g in grad.iter_mut() {
280                *g *= scale;
281            }
282        }
283    }
284
285    /// Clips gradient by value.
286    ///
287    /// # Arguments
288    ///
289    /// * `grad` - The gradient to clip
290    /// * `min_value` - Minimum allowed value
291    /// * `max_value` - Maximum allowed value
292    pub fn clip_by_value(grad: &mut [f32], min_value: f32, max_value: f32) {
293        for g in grad.iter_mut() {
294            *g = g.clamp(min_value, max_value);
295        }
296    }
297
298    /// Applies gradient scaling for mixed precision training.
299    ///
300    /// # Arguments
301    ///
302    /// * `grad` - The gradient to scale
303    /// * `scale` - The scaling factor
304    pub fn scale_gradient(grad: &mut [f32], scale: f32) {
305        for g in grad.iter_mut() {
306            *g *= scale;
307        }
308    }
309
310    /// Checks for non-finite gradients (NaN or Inf).
311    ///
312    /// # Arguments
313    ///
314    /// * `grad` - The gradient to check
315    ///
316    /// # Returns
317    ///
318    /// True if all gradients are finite.
319    pub fn is_finite(grad: &[f32]) -> bool {
320        grad.iter().all(|g| g.is_finite())
321    }
322}
323
324/// Utility functions for creating parameter IDs.
325pub struct ParameterIds;
326
327impl ParameterIds {
328    /// Creates a unique parameter ID from tensor pointer.
329    ///
330    /// # Arguments
331    ///
332    /// * `tensor` - The tensor to create ID for
333    pub fn from_tensor(tensor: &Tensor) -> Result<String> {
334        match tensor {
335            Tensor::F32(data) => Ok(format!("{:p}", data.as_ptr())),
336            _ => Err(TrustformersError::tensor_op_error(
337                "Unsupported tensor type for parameter ID",
338                "from_tensor",
339            )),
340        }
341    }
342
343    /// Creates a parameter ID from name.
344    ///
345    /// # Arguments
346    ///
347    /// * `name` - The parameter name
348    pub fn from_name(name: &str) -> String {
349        name.to_string()
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[test]
358    fn test_optimizer_state_creation() {
359        let state = OptimizerState::new();
360        assert_eq!(state.step, 0);
361        assert!(state.momentum.is_empty());
362        assert!(state.variance.is_empty());
363    }
364
365    #[test]
366    fn test_bias_correction() {
367        let correction1 = BiasCorrection::compute_correction(0.9, 1);
368        assert!((correction1 - 0.1).abs() < 1e-6);
369
370        let correction2 = BiasCorrection::compute_correction(0.999, 1);
371        assert!((correction2 - 0.001).abs() < 1e-6);
372
373        let corrected = BiasCorrection::apply_correction(0.09, 0.9, 1);
374        assert!((corrected - 0.9).abs() < 1e-6);
375    }
376
377    #[test]
378    fn test_parameter_update() {
379        let mut param = 1.0;
380        ParameterUpdate::apply_decoupled_weight_decay(&mut param, 0.01, 0.1);
381        assert!((param - 0.999).abs() < 1e-6);
382
383        let mut param2 = 1.0;
384        ParameterUpdate::adam_update(&mut param2, 0.01, 0.1, 0.01, 1e-8);
385        assert!((param2 - 0.99).abs() < 1e-6);
386    }
387
388    #[test]
389    fn test_gradient_processing() {
390        let mut grad = vec![3.0, 4.0];
391        GradientProcessor::clip_by_norm(&mut grad, 1.0);
392        let norm: f32 = grad.iter().map(|g| g * g).sum::<f32>().sqrt();
393        assert!((norm - 1.0).abs() < 1e-6);
394
395        assert!(GradientProcessor::is_finite(&grad));
396
397        let bad_grad = vec![f32::NAN, 1.0];
398        assert!(!GradientProcessor::is_finite(&bad_grad));
399    }
400
401    #[test]
402    fn test_memory_stats() {
403        let mut state = OptimizerState::new();
404        state.get_or_create_momentum("param1".to_string(), 100);
405        state.get_or_create_variance("param1".to_string(), 100);
406
407        let stats = state.memory_usage();
408        assert_eq!(stats.momentum_elements, 100);
409        assert_eq!(stats.variance_elements, 100);
410        assert_eq!(stats.num_parameters, 1);
411        assert_eq!(stats.total_bytes, 200 * std::mem::size_of::<f32>());
412    }
413
414    #[test]
415    fn test_ema_update() {
416        let mut ema = 0.0;
417        ParameterUpdate::update_ema(&mut ema, 1.0, 0.9);
418        assert!((ema - 0.1).abs() < 1e-6);
419
420        ParameterUpdate::update_ema(&mut ema, 1.0, 0.9);
421        assert!((ema - 0.19).abs() < 1e-6);
422    }
423}