tensorlogic_train/optimizers/
adabelief.rs1use super::common::{compute_gradient_norm, GradClipMode, Optimizer, OptimizerConfig};
11use crate::{TrainError, TrainResult};
12use scirs2_core::ndarray::{Array, Ix2};
13use std::collections::HashMap;
14
15#[derive(Debug)]
24pub struct AdaBeliefOptimizer {
25 config: OptimizerConfig,
26 m: HashMap<String, Array<f64, Ix2>>,
28 s: HashMap<String, Array<f64, Ix2>>,
30 t: usize,
32}
33
34impl AdaBeliefOptimizer {
35 pub fn new(config: OptimizerConfig) -> Self {
37 Self {
38 config,
39 m: HashMap::new(),
40 s: HashMap::new(),
41 t: 0,
42 }
43 }
44
45 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
47 if let Some(clip_value) = self.config.grad_clip {
48 match self.config.grad_clip_mode {
49 GradClipMode::Value => {
50 for grad in gradients.values_mut() {
51 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
52 }
53 }
54 GradClipMode::Norm => {
55 let total_norm = compute_gradient_norm(gradients);
56 if total_norm > clip_value {
57 let scale = clip_value / total_norm;
58 for grad in gradients.values_mut() {
59 grad.mapv_inplace(|g| g * scale);
60 }
61 }
62 }
63 }
64 }
65 }
66}
67
68impl Optimizer for AdaBeliefOptimizer {
69 fn step(
70 &mut self,
71 parameters: &mut HashMap<String, Array<f64, Ix2>>,
72 gradients: &HashMap<String, Array<f64, Ix2>>,
73 ) -> TrainResult<()> {
74 let mut clipped_gradients = gradients.clone();
75 self.clip_gradients(&mut clipped_gradients);
76 self.t += 1;
77 let lr = self.config.learning_rate;
78 let beta1 = self.config.beta1;
79 let beta2 = self.config.beta2;
80 let eps = self.config.epsilon;
81 let weight_decay = self.config.weight_decay;
82 let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
83 let bias_correction2 = 1.0 - beta2.powi(self.t as i32);
84 for (name, param) in parameters.iter_mut() {
85 let grad = clipped_gradients.get(name).ok_or_else(|| {
86 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
87 })?;
88 if !self.m.contains_key(name) {
89 self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
90 self.s.insert(name.clone(), Array::zeros(param.raw_dim()));
91 }
92 let m = self
93 .m
94 .get_mut(name)
95 .expect("m initialized for all parameters");
96 let s = self
97 .s
98 .get_mut(name)
99 .expect("s initialized for all parameters");
100 *m = &*m * beta1 + &(grad * (1.0 - beta1));
101 let grad_diff = grad - &*m;
102 let grad_diff_squared = grad_diff.mapv(|g| g * g);
103 *s = &*s * beta2 + &(grad_diff_squared * (1.0 - beta2));
104 let m_hat = &*m / bias_correction1;
105 let s_hat = &*s / bias_correction2;
106 if weight_decay > 0.0 {
107 param.mapv_inplace(|p| p * (1.0 - lr * weight_decay));
108 }
109 let update = m_hat / (s_hat.mapv(|v| v.sqrt()) + eps);
110 *param = &*param - &(update * lr);
111 }
112 Ok(())
113 }
114
115 fn zero_grad(&mut self) {}
116
117 fn get_lr(&self) -> f64 {
118 self.config.learning_rate
119 }
120
121 fn set_lr(&mut self, lr: f64) {
122 self.config.learning_rate = lr;
123 }
124
125 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
126 let mut state = HashMap::new();
127 state.insert("t".to_string(), vec![self.t as f64]);
128 for (name, m_val) in &self.m {
129 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
130 }
131 for (name, s_val) in &self.s {
132 state.insert(format!("s_{}", name), s_val.iter().copied().collect());
133 }
134 state
135 }
136
137 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
138 if let Some(t_val) = state.get("t") {
139 self.t = t_val[0] as usize;
140 }
141 for (key, values) in state {
142 if let Some(name) = key.strip_prefix("m_") {
143 if let Some(m_array) = self.m.get(name) {
144 let shape = m_array.raw_dim();
145 if let Ok(arr) = Array::from_shape_vec(shape, values) {
146 self.m.insert(name.to_string(), arr);
147 }
148 }
149 } else if let Some(name) = key.strip_prefix("s_") {
150 if let Some(s_array) = self.s.get(name) {
151 let shape = s_array.raw_dim();
152 if let Ok(arr) = Array::from_shape_vec(shape, values) {
153 self.s.insert(name.to_string(), arr);
154 }
155 }
156 }
157 }
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use scirs2_core::ndarray::array;
165
166 #[test]
167 fn test_adabelief_optimizer() {
168 let config = OptimizerConfig {
169 learning_rate: 0.001,
170 weight_decay: 0.01,
171 ..Default::default()
172 };
173 let mut optimizer = AdaBeliefOptimizer::new(config);
174 let mut params = HashMap::new();
175 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
176 let mut grads = HashMap::new();
177 grads.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
178 for _ in 0..5 {
179 optimizer.step(&mut params, &grads).expect("unwrap");
180 }
181 let w = params.get("w").expect("unwrap");
182 assert!(w[[0, 0]] < 1.0);
183 assert!(w[[1, 1]] < 4.0);
184 let state = optimizer.state_dict();
185 assert!(state.contains_key("t"));
186 assert!(state.contains_key("m_w"));
187 assert!(state.contains_key("s_w"));
188 }
189}