tensorlogic_train/optimizers/
common.rs1use crate::TrainResult;
4use scirs2_core::ndarray::{Array, Ix2};
5use std::collections::HashMap;
6
7pub fn compute_gradient_norm(gradients: &HashMap<String, Array<f64, Ix2>>) -> f64 {
15 let mut total_norm_sq = 0.0;
16
17 for grad in gradients.values() {
18 for &g in grad.iter() {
19 total_norm_sq += g * g;
20 }
21 }
22
23 total_norm_sq.sqrt()
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
28pub enum GradClipMode {
29 Value,
31 Norm,
33}
34
35#[derive(Debug, Clone)]
37pub struct OptimizerConfig {
38 pub learning_rate: f64,
40 pub momentum: f64,
42 pub beta1: f64,
44 pub beta2: f64,
46 pub epsilon: f64,
48 pub weight_decay: f64,
50 pub grad_clip: Option<f64>,
52 pub grad_clip_mode: GradClipMode,
54}
55
56impl Default for OptimizerConfig {
57 fn default() -> Self {
58 Self {
59 learning_rate: 0.001,
60 momentum: 0.9,
61 beta1: 0.9,
62 beta2: 0.999,
63 epsilon: 1e-8,
64 weight_decay: 0.01,
65 grad_clip: None,
66 grad_clip_mode: GradClipMode::Value,
67 }
68 }
69}
70
71pub trait Optimizer {
73 fn step(
75 &mut self,
76 parameters: &mut HashMap<String, Array<f64, Ix2>>,
77 gradients: &HashMap<String, Array<f64, Ix2>>,
78 ) -> TrainResult<()>;
79
80 fn zero_grad(&mut self);
82
83 fn get_lr(&self) -> f64;
85
86 fn set_lr(&mut self, lr: f64);
88
89 fn state_dict(&self) -> HashMap<String, Vec<f64>>;
91
92 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>);
94}