1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_beta1, validate_beta2, validate_epsilon, validate_lr};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct AdamState {
12 first_moment: Tensor,
13 second_moment: Tensor,
14 step: u64,
15}
16
17impl AdamState {
18 fn new(shape: &[usize]) -> Result<Self, OptimError> {
19 Ok(Self {
20 first_moment: Tensor::zeros(shape.to_vec())?,
21 second_moment: Tensor::zeros(shape.to_vec())?,
22 step: 0,
23 })
24 }
25
26 fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
27 *self = Self::new(shape)?;
28 Ok(())
29 }
30}
31
32#[derive(Debug, Clone)]
34pub struct Adam {
35 lr: f32,
36 beta1: f32,
37 beta2: f32,
38 epsilon: f32,
39 weight_decay: f32,
40 state: HashMap<u64, AdamState>,
41}
42
43impl Adam {
44 pub fn new(lr: f32) -> Result<Self, OptimError> {
46 validate_lr(lr)?;
47 Ok(Self {
48 lr,
49 beta1: 0.9,
50 beta2: 0.999,
51 epsilon: 1e-8,
52 weight_decay: 0.0,
53 state: HashMap::new(),
54 })
55 }
56
57 pub fn with_beta1(mut self, beta1: f32) -> Result<Self, OptimError> {
59 validate_beta1(beta1)?;
60 self.beta1 = beta1;
61 Ok(self)
62 }
63
64 pub fn with_beta2(mut self, beta2: f32) -> Result<Self, OptimError> {
66 validate_beta2(beta2)?;
67 self.beta2 = beta2;
68 Ok(self)
69 }
70
71 pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
73 validate_epsilon(epsilon)?;
74 self.epsilon = epsilon;
75 Ok(self)
76 }
77
78 pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
80 if !weight_decay.is_finite() || weight_decay < 0.0 {
81 return Err(OptimError::InvalidWeightDecay { weight_decay });
82 }
83 self.weight_decay = weight_decay;
84 Ok(self)
85 }
86
87 pub fn clear_state(&mut self) {
89 self.state.clear();
90 }
91
92 pub fn learning_rate(&self) -> f32 {
94 self.lr
95 }
96
97 pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
99 validate_lr(lr)?;
100 self.lr = lr;
101 Ok(())
102 }
103
104 pub fn step(
106 &mut self,
107 parameter_id: u64,
108 weights: &mut Tensor,
109 grad: &Tensor,
110 ) -> Result<(), OptimError> {
111 if weights.shape() != grad.shape() {
112 return Err(OptimError::ShapeMismatch {
113 weights: weights.shape().to_vec(),
114 grad: grad.shape().to_vec(),
115 });
116 }
117
118 let state = match self.state.entry(parameter_id) {
119 Entry::Occupied(entry) => entry.into_mut(),
120 Entry::Vacant(entry) => entry.insert(AdamState::new(weights.shape())?),
121 };
122 if state.first_moment.shape() != weights.shape()
123 || state.second_moment.shape() != weights.shape()
124 {
125 state.reset(weights.shape())?;
126 }
127
128 state.step = state.step.saturating_add(1);
129 let step_f64 = state.step as f64;
130 let bias_correction1 =
131 (1.0 - (self.beta1 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
132 let bias_correction2 =
133 (1.0 - (self.beta2 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
134
135 let first_moment = state.first_moment.data_mut();
136 let second_moment = state.second_moment.data_mut();
137 let grad_values = grad.data();
138 let weights_data = weights.data_mut();
139
140 let beta1 = self.beta1;
141 let beta2 = self.beta2;
142 let one_minus_beta1 = 1.0 - beta1;
143 let one_minus_beta2 = 1.0 - beta2;
144 let bias_correction1_inv = 1.0 / bias_correction1;
145 let bias_correction2_inv = 1.0 / bias_correction2;
146 let lr = self.lr;
147 let epsilon = self.epsilon;
148 let weight_decay = self.weight_decay;
149
150 adam_update_inner(
151 weights_data,
152 grad_values,
153 first_moment,
154 second_moment,
155 beta1,
156 beta2,
157 one_minus_beta1,
158 one_minus_beta2,
159 bias_correction1_inv,
160 bias_correction2_inv,
161 lr,
162 epsilon,
163 weight_decay,
164 );
165
166 Ok(())
167 }
168
169 pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
171 if !graph.requires_grad(node)? {
172 return Ok(());
173 }
174
175 let grad = match graph.grad(node)? {
176 Some(grad) => grad.clone(),
177 None => return Err(OptimError::MissingGradient { node: node.0 }),
178 };
179 let weights = graph.value_mut(node)?;
180 self.step(node.0 as u64, weights, &grad)
181 }
182}
183
184impl LearningRate for Adam {
185 fn learning_rate(&self) -> f32 {
186 Adam::learning_rate(self)
187 }
188
189 fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
190 Adam::set_learning_rate(self, lr)
191 }
192}
193
194#[allow(clippy::too_many_arguments, unsafe_code)]
196fn adam_update_inner(
197 weights: &mut [f32],
198 grad: &[f32],
199 first_moment: &mut [f32],
200 second_moment: &mut [f32],
201 beta1: f32,
202 beta2: f32,
203 one_minus_beta1: f32,
204 one_minus_beta2: f32,
205 bc1_inv: f32,
206 bc2_inv: f32,
207 lr: f32,
208 epsilon: f32,
209 weight_decay: f32,
210) {
211 let len = weights.len();
212
213 #[cfg(target_arch = "aarch64")]
214 if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
215 unsafe {
216 adam_update_neon(
217 weights,
218 grad,
219 first_moment,
220 second_moment,
221 beta1,
222 beta2,
223 one_minus_beta1,
224 one_minus_beta2,
225 bc1_inv,
226 bc2_inv,
227 lr,
228 epsilon,
229 weight_decay,
230 );
231 }
232 return;
233 }
234
235 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
236 if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
237 unsafe {
238 adam_update_avx(
239 weights,
240 grad,
241 first_moment,
242 second_moment,
243 beta1,
244 beta2,
245 one_minus_beta1,
246 one_minus_beta2,
247 bc1_inv,
248 bc2_inv,
249 lr,
250 epsilon,
251 weight_decay,
252 );
253 }
254 return;
255 }
256
257 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
258 if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
259 unsafe {
260 adam_update_sse(
261 weights,
262 grad,
263 first_moment,
264 second_moment,
265 beta1,
266 beta2,
267 one_minus_beta1,
268 one_minus_beta2,
269 bc1_inv,
270 bc2_inv,
271 lr,
272 epsilon,
273 weight_decay,
274 );
275 }
276 return;
277 }
278
279 let wp = weights.as_mut_ptr();
280 let gp = grad.as_ptr();
281 let mp = first_moment.as_mut_ptr();
282 let vp = second_moment.as_mut_ptr();
283 for i in 0..len {
284 unsafe {
285 let g = *gp.add(i) + weight_decay * *wp.add(i);
286 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
287 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
288 *mp.add(i) = m;
289 *vp.add(i) = v;
290 let m_hat = m * bc1_inv;
291 let v_hat = v * bc2_inv;
292 *wp.add(i) -= lr * m_hat / (v_hat.sqrt() + epsilon);
293 }
294 }
295}
296
297#[cfg(target_arch = "aarch64")]
300#[target_feature(enable = "neon")]
301#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
302unsafe fn adam_update_neon(
303 weights: &mut [f32],
304 grad: &[f32],
305 first_moment: &mut [f32],
306 second_moment: &mut [f32],
307 beta1: f32,
308 beta2: f32,
309 one_minus_beta1: f32,
310 one_minus_beta2: f32,
311 bc1_inv: f32,
312 bc2_inv: f32,
313 lr: f32,
314 epsilon: f32,
315 weight_decay: f32,
316) {
317 use std::arch::aarch64::*;
318 let len = weights.len();
319 let wp = weights.as_mut_ptr();
320 let gp = grad.as_ptr();
321 let mp = first_moment.as_mut_ptr();
322 let vp = second_moment.as_mut_ptr();
323 let beta1_v = vdupq_n_f32(beta1);
324 let beta2_v = vdupq_n_f32(beta2);
325 let omb1_v = vdupq_n_f32(one_minus_beta1);
326 let omb2_v = vdupq_n_f32(one_minus_beta2);
327 let bc1_v = vdupq_n_f32(bc1_inv);
328 let bc2_v = vdupq_n_f32(bc2_inv);
329 let lr_v = vdupq_n_f32(lr);
330 let eps_v = vdupq_n_f32(epsilon);
331 let wd_v = vdupq_n_f32(weight_decay);
332 let mut i = 0usize;
333 while i + 4 <= len {
334 let w = vld1q_f32(wp.add(i));
335 let raw_g = vld1q_f32(gp.add(i));
336 let g = vfmaq_f32(raw_g, wd_v, w);
337 let m_old = vld1q_f32(mp.add(i));
338 let v_old = vld1q_f32(vp.add(i));
339 let m_new = vfmaq_f32(vmulq_f32(g, omb1_v), m_old, beta1_v);
340 let grad_sq = vmulq_f32(g, g);
341 let v_new = vfmaq_f32(vmulq_f32(grad_sq, omb2_v), v_old, beta2_v);
342 vst1q_f32(mp.add(i), m_new);
343 vst1q_f32(vp.add(i), v_new);
344 let m_hat = vmulq_f32(m_new, bc1_v);
345 let v_hat = vmulq_f32(v_new, bc2_v);
346 let update = vdivq_f32(vmulq_f32(m_hat, lr_v), vaddq_f32(vsqrtq_f32(v_hat), eps_v));
347 vst1q_f32(wp.add(i), vsubq_f32(w, update));
348 i += 4;
349 }
350 while i < len {
351 let g = *gp.add(i) + weight_decay * *wp.add(i);
352 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
353 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
354 *mp.add(i) = m;
355 *vp.add(i) = v;
356 *wp.add(i) -= lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
357 i += 1;
358 }
359}
360
361#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
364#[target_feature(enable = "avx")]
365#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
366unsafe fn adam_update_avx(
367 weights: &mut [f32],
368 grad: &[f32],
369 first_moment: &mut [f32],
370 second_moment: &mut [f32],
371 beta1: f32,
372 beta2: f32,
373 one_minus_beta1: f32,
374 one_minus_beta2: f32,
375 bc1_inv: f32,
376 bc2_inv: f32,
377 lr: f32,
378 epsilon: f32,
379 weight_decay: f32,
380) {
381 #[cfg(target_arch = "x86")]
382 use std::arch::x86::*;
383 #[cfg(target_arch = "x86_64")]
384 use std::arch::x86_64::*;
385 let len = weights.len();
386 let wp = weights.as_mut_ptr();
387 let gp = grad.as_ptr();
388 let mp = first_moment.as_mut_ptr();
389 let vp = second_moment.as_mut_ptr();
390 let beta1_v = _mm256_set1_ps(beta1);
391 let beta2_v = _mm256_set1_ps(beta2);
392 let omb1_v = _mm256_set1_ps(one_minus_beta1);
393 let omb2_v = _mm256_set1_ps(one_minus_beta2);
394 let bc1_v = _mm256_set1_ps(bc1_inv);
395 let bc2_v = _mm256_set1_ps(bc2_inv);
396 let lr_v = _mm256_set1_ps(lr);
397 let eps_v = _mm256_set1_ps(epsilon);
398 let wd_v = _mm256_set1_ps(weight_decay);
399 let mut i = 0usize;
400 while i + 8 <= len {
401 let w = _mm256_loadu_ps(wp.add(i));
402 let raw_g = _mm256_loadu_ps(gp.add(i));
403 let g = _mm256_add_ps(raw_g, _mm256_mul_ps(wd_v, w));
404 let m_old = _mm256_loadu_ps(mp.add(i));
405 let v_old = _mm256_loadu_ps(vp.add(i));
406 let m_new = _mm256_add_ps(_mm256_mul_ps(beta1_v, m_old), _mm256_mul_ps(omb1_v, g));
407 let grad_sq = _mm256_mul_ps(g, g);
408 let v_new = _mm256_add_ps(
409 _mm256_mul_ps(beta2_v, v_old),
410 _mm256_mul_ps(omb2_v, grad_sq),
411 );
412 _mm256_storeu_ps(mp.add(i), m_new);
413 _mm256_storeu_ps(vp.add(i), v_new);
414 let m_hat = _mm256_mul_ps(m_new, bc1_v);
415 let v_hat = _mm256_mul_ps(v_new, bc2_v);
416 let update = _mm256_div_ps(
417 _mm256_mul_ps(m_hat, lr_v),
418 _mm256_add_ps(_mm256_sqrt_ps(v_hat), eps_v),
419 );
420 _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w, update));
421 i += 8;
422 }
423 while i < len {
424 let g = *gp.add(i) + weight_decay * *wp.add(i);
425 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
426 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
427 *mp.add(i) = m;
428 *vp.add(i) = v;
429 *wp.add(i) -= lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
430 i += 1;
431 }
432}
433
434#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
437#[target_feature(enable = "sse")]
438#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
439unsafe fn adam_update_sse(
440 weights: &mut [f32],
441 grad: &[f32],
442 first_moment: &mut [f32],
443 second_moment: &mut [f32],
444 beta1: f32,
445 beta2: f32,
446 one_minus_beta1: f32,
447 one_minus_beta2: f32,
448 bc1_inv: f32,
449 bc2_inv: f32,
450 lr: f32,
451 epsilon: f32,
452 weight_decay: f32,
453) {
454 #[cfg(target_arch = "x86")]
455 use std::arch::x86::*;
456 #[cfg(target_arch = "x86_64")]
457 use std::arch::x86_64::*;
458 let len = weights.len();
459 let wp = weights.as_mut_ptr();
460 let gp = grad.as_ptr();
461 let mp = first_moment.as_mut_ptr();
462 let vp = second_moment.as_mut_ptr();
463 let beta1_v = _mm_set1_ps(beta1);
464 let beta2_v = _mm_set1_ps(beta2);
465 let omb1_v = _mm_set1_ps(one_minus_beta1);
466 let omb2_v = _mm_set1_ps(one_minus_beta2);
467 let bc1_v = _mm_set1_ps(bc1_inv);
468 let bc2_v = _mm_set1_ps(bc2_inv);
469 let lr_v = _mm_set1_ps(lr);
470 let eps_v = _mm_set1_ps(epsilon);
471 let wd_v = _mm_set1_ps(weight_decay);
472 let mut i = 0usize;
473 while i + 4 <= len {
474 let w = _mm_loadu_ps(wp.add(i));
475 let raw_g = _mm_loadu_ps(gp.add(i));
476 let g = _mm_add_ps(raw_g, _mm_mul_ps(wd_v, w));
477 let m_old = _mm_loadu_ps(mp.add(i));
478 let v_old = _mm_loadu_ps(vp.add(i));
479 let m_new = _mm_add_ps(_mm_mul_ps(beta1_v, m_old), _mm_mul_ps(omb1_v, g));
480 let grad_sq = _mm_mul_ps(g, g);
481 let v_new = _mm_add_ps(_mm_mul_ps(beta2_v, v_old), _mm_mul_ps(omb2_v, grad_sq));
482 _mm_storeu_ps(mp.add(i), m_new);
483 _mm_storeu_ps(vp.add(i), v_new);
484 let m_hat = _mm_mul_ps(m_new, bc1_v);
485 let v_hat = _mm_mul_ps(v_new, bc2_v);
486 let update = _mm_div_ps(
487 _mm_mul_ps(m_hat, lr_v),
488 _mm_add_ps(_mm_sqrt_ps(v_hat), eps_v),
489 );
490 _mm_storeu_ps(wp.add(i), _mm_sub_ps(w, update));
491 i += 4;
492 }
493 while i < len {
494 let g = *gp.add(i) + weight_decay * *wp.add(i);
495 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
496 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
497 *mp.add(i) = m;
498 *vp.add(i) = v;
499 *wp.add(i) -= lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
500 i += 1;
501 }
502}