tensorlogic_train/optimizers/
adabelief.rs1use super::common::{compute_gradient_norm, GradClipMode, Optimizer, OptimizerConfig};
11use crate::{TrainError, TrainResult};
12use scirs2_core::ndarray::{Array, Ix2};
13use std::collections::HashMap;
14
15#[derive(Debug)]
24pub struct AdaBeliefOptimizer {
25 config: OptimizerConfig,
26 m: HashMap<String, Array<f64, Ix2>>,
28 s: HashMap<String, Array<f64, Ix2>>,
30 t: usize,
32}
33
34impl AdaBeliefOptimizer {
35 pub fn new(config: OptimizerConfig) -> Self {
37 Self {
38 config,
39 m: HashMap::new(),
40 s: HashMap::new(),
41 t: 0,
42 }
43 }
44
45 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
47 if let Some(clip_value) = self.config.grad_clip {
48 match self.config.grad_clip_mode {
49 GradClipMode::Value => {
50 for grad in gradients.values_mut() {
51 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
52 }
53 }
54 GradClipMode::Norm => {
55 let total_norm = compute_gradient_norm(gradients);
56 if total_norm > clip_value {
57 let scale = clip_value / total_norm;
58 for grad in gradients.values_mut() {
59 grad.mapv_inplace(|g| g * scale);
60 }
61 }
62 }
63 }
64 }
65 }
66}
67
68impl Optimizer for AdaBeliefOptimizer {
69 fn step(
70 &mut self,
71 parameters: &mut HashMap<String, Array<f64, Ix2>>,
72 gradients: &HashMap<String, Array<f64, Ix2>>,
73 ) -> TrainResult<()> {
74 let mut clipped_gradients = gradients.clone();
75 self.clip_gradients(&mut clipped_gradients);
76 self.t += 1;
77 let lr = self.config.learning_rate;
78 let beta1 = self.config.beta1;
79 let beta2 = self.config.beta2;
80 let eps = self.config.epsilon;
81 let weight_decay = self.config.weight_decay;
82 let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
83 let bias_correction2 = 1.0 - beta2.powi(self.t as i32);
84 for (name, param) in parameters.iter_mut() {
85 let grad = clipped_gradients.get(name).ok_or_else(|| {
86 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
87 })?;
88 if !self.m.contains_key(name) {
89 self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
90 self.s.insert(name.clone(), Array::zeros(param.raw_dim()));
91 }
92 let m = self.m.get_mut(name).unwrap();
93 let s = self.s.get_mut(name).unwrap();
94 *m = &*m * beta1 + &(grad * (1.0 - beta1));
95 let grad_diff = grad - &*m;
96 let grad_diff_squared = grad_diff.mapv(|g| g * g);
97 *s = &*s * beta2 + &(grad_diff_squared * (1.0 - beta2));
98 let m_hat = &*m / bias_correction1;
99 let s_hat = &*s / bias_correction2;
100 if weight_decay > 0.0 {
101 param.mapv_inplace(|p| p * (1.0 - lr * weight_decay));
102 }
103 let update = m_hat / (s_hat.mapv(|v| v.sqrt()) + eps);
104 *param = &*param - &(update * lr);
105 }
106 Ok(())
107 }
108
109 fn zero_grad(&mut self) {}
110
111 fn get_lr(&self) -> f64 {
112 self.config.learning_rate
113 }
114
115 fn set_lr(&mut self, lr: f64) {
116 self.config.learning_rate = lr;
117 }
118
119 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
120 let mut state = HashMap::new();
121 state.insert("t".to_string(), vec![self.t as f64]);
122 for (name, m_val) in &self.m {
123 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
124 }
125 for (name, s_val) in &self.s {
126 state.insert(format!("s_{}", name), s_val.iter().copied().collect());
127 }
128 state
129 }
130
131 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
132 if let Some(t_val) = state.get("t") {
133 self.t = t_val[0] as usize;
134 }
135 for (key, values) in state {
136 if let Some(name) = key.strip_prefix("m_") {
137 if let Some(m_array) = self.m.get(name) {
138 let shape = m_array.raw_dim();
139 if let Ok(arr) = Array::from_shape_vec(shape, values) {
140 self.m.insert(name.to_string(), arr);
141 }
142 }
143 } else if let Some(name) = key.strip_prefix("s_") {
144 if let Some(s_array) = self.s.get(name) {
145 let shape = s_array.raw_dim();
146 if let Ok(arr) = Array::from_shape_vec(shape, values) {
147 self.s.insert(name.to_string(), arr);
148 }
149 }
150 }
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use scirs2_core::ndarray::array;
159
160 #[test]
161 fn test_adabelief_optimizer() {
162 let config = OptimizerConfig {
163 learning_rate: 0.001,
164 weight_decay: 0.01,
165 ..Default::default()
166 };
167 let mut optimizer = AdaBeliefOptimizer::new(config);
168 let mut params = HashMap::new();
169 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
170 let mut grads = HashMap::new();
171 grads.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
172 for _ in 0..5 {
173 optimizer.step(&mut params, &grads).unwrap();
174 }
175 let w = params.get("w").unwrap();
176 assert!(w[[0, 0]] < 1.0);
177 assert!(w[[1, 1]] < 4.0);
178 let state = optimizer.state_dict();
179 assert!(state.contains_key("t"));
180 assert!(state.contains_key("m_w"));
181 assert!(state.contains_key("s_w"));
182 }
183}