Skip to main content

tensorlogic_train/optimizers/
common.rs

1//! Common optimizer utilities and traits.
2
3use crate::TrainResult;
4use scirs2_core::ndarray::{Array, Ix2};
5use std::collections::HashMap;
6
7/// Compute the global L2 norm of all gradients.
8///
9/// # Arguments
10/// * `gradients` - Gradients for all parameters
11///
12/// # Returns
13/// The L2 norm of all gradients combined
14pub fn compute_gradient_norm(gradients: &HashMap<String, Array<f64, Ix2>>) -> f64 {
15    let mut total_norm_sq = 0.0;
16
17    for grad in gradients.values() {
18        for &g in grad.iter() {
19            total_norm_sq += g * g;
20        }
21    }
22
23    total_norm_sq.sqrt()
24}
25
26/// Gradient clipping mode.
27#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
28pub enum GradClipMode {
29    /// Clip by value (element-wise).
30    Value,
31    /// Clip by global L2 norm.
32    Norm,
33}
34
35/// Configuration for optimizers.
36#[derive(Debug, Clone)]
37pub struct OptimizerConfig {
38    /// Learning rate.
39    pub learning_rate: f64,
40    /// Momentum (for SGD).
41    pub momentum: f64,
42    /// Beta1 (for Adam/AdamW).
43    pub beta1: f64,
44    /// Beta2 (for Adam/AdamW).
45    pub beta2: f64,
46    /// Epsilon for numerical stability.
47    pub epsilon: f64,
48    /// Weight decay (for AdamW).
49    pub weight_decay: f64,
50    /// Gradient clipping threshold (None = no clipping).
51    pub grad_clip: Option<f64>,
52    /// Gradient clipping mode.
53    pub grad_clip_mode: GradClipMode,
54}
55
56impl Default for OptimizerConfig {
57    fn default() -> Self {
58        Self {
59            learning_rate: 0.001,
60            momentum: 0.9,
61            beta1: 0.9,
62            beta2: 0.999,
63            epsilon: 1e-8,
64            weight_decay: 0.01,
65            grad_clip: None,
66            grad_clip_mode: GradClipMode::Value,
67        }
68    }
69}
70
71/// Trait for optimizers.
72pub trait Optimizer {
73    /// Update parameters with computed gradients.
74    fn step(
75        &mut self,
76        parameters: &mut HashMap<String, Array<f64, Ix2>>,
77        gradients: &HashMap<String, Array<f64, Ix2>>,
78    ) -> TrainResult<()>;
79
80    /// Zero all gradients.
81    fn zero_grad(&mut self);
82
83    /// Get current learning rate.
84    fn get_lr(&self) -> f64;
85
86    /// Set learning rate.
87    fn set_lr(&mut self, lr: f64);
88
89    /// Get optimizer state for checkpointing.
90    fn state_dict(&self) -> HashMap<String, Vec<f64>>;
91
92    /// Load optimizer state from checkpoint.
93    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>);
94}