rustorch_core/
optimizer.rs1use crate::Tensor;
2use rayon::prelude::*;
3use std::collections::HashMap;
4use std::sync::Arc;
5use wide::f32x8;
6
7pub struct Adam {
8 params: Vec<Tensor>,
9 lr: f32,
10 beta1: f32,
11 beta2: f32,
12 epsilon: f32,
13 step: u32,
14 state: HashMap<usize, (Tensor, Tensor)>, }
16
17impl Adam {
18 pub fn new(params: Vec<Tensor>, lr: f32) -> Self {
19 Self {
20 params,
21 lr,
22 beta1: 0.9,
23 beta2: 0.999,
24 epsilon: 1e-8,
25 step: 0,
26 state: HashMap::new(),
27 }
28 }
29
30 pub fn step(&mut self) {
31 self.step += 1;
32
33 for param in self.params.iter() {
34 let grad = if let Some(g) = param.grad() {
35 g
36 } else {
37 continue;
38 };
39
40 let is_wgpu = param.storage().wgpu_buffer().is_some();
41
42 let param_id = Arc::as_ptr(¶m.inner) as usize;
43
44 self.state.entry(param_id).or_insert_with(|| {
45 let m = if is_wgpu {
46 Tensor::zeros(param.shape()).to_wgpu()
47 } else {
48 Tensor::zeros(param.shape())
49 };
50 let v = if is_wgpu {
51 Tensor::zeros(param.shape()).to_wgpu()
52 } else {
53 Tensor::zeros(param.shape())
54 };
55 (m, v)
56 });
57
58 let (m, v) = self.state.get(¶m_id).unwrap();
59
60 let beta1 = self.beta1;
61 let beta2 = self.beta2;
62 let eps = self.epsilon;
63 let lr = self.lr;
64 let step = self.step as f32;
65 let bias_correction1 = 1.0 - beta1.powf(step);
66 let bias_correction2 = 1.0 - beta2.powf(step);
67
68 if is_wgpu {
69 #[cfg(feature = "wgpu_backend")]
70 {
71 let grad_wgpu = if grad.storage().wgpu_buffer().is_none() {
72 grad.to_wgpu()
73 } else {
74 grad
75 };
76
77 let param_buf = param.storage().wgpu_buffer().unwrap();
78 let grad_buf = grad_wgpu.storage().wgpu_buffer().unwrap();
79 let m_buf = m.storage().wgpu_buffer().unwrap();
80 let v_buf = v.storage().wgpu_buffer().unwrap();
81
82 crate::backend::wgpu::adam_step_wgpu(
83 param_buf,
84 grad_buf,
85 m_buf,
86 v_buf,
87 param.shape().iter().product(),
88 self.lr,
89 self.beta1,
90 self.beta2,
91 self.epsilon,
92 self.step,
93 );
94 crate::backend::wgpu::flush_queue();
95 }
96 } else {
97 let mut p_data = param.data_mut();
98 let g_data = grad.data();
99 let mut m_data = m.data_mut();
100 let mut v_data = v.data_mut();
101 let lanes = 8usize;
102 let head = (p_data.len() / lanes) * lanes;
103
104 p_data[..head]
105 .par_chunks_mut(lanes)
106 .zip(m_data[..head].par_chunks_mut(lanes))
107 .zip(v_data[..head].par_chunks_mut(lanes))
108 .zip(g_data[..head].par_chunks(lanes))
109 .for_each(|(((p_chunk, m_chunk), v_chunk), g_chunk)| {
110 let p_vec = f32x8::from([
111 p_chunk[0], p_chunk[1], p_chunk[2], p_chunk[3], p_chunk[4], p_chunk[5],
112 p_chunk[6], p_chunk[7],
113 ]);
114 let m_vec = f32x8::from([
115 m_chunk[0], m_chunk[1], m_chunk[2], m_chunk[3], m_chunk[4], m_chunk[5],
116 m_chunk[6], m_chunk[7],
117 ]);
118 let v_vec = f32x8::from([
119 v_chunk[0], v_chunk[1], v_chunk[2], v_chunk[3], v_chunk[4], v_chunk[5],
120 v_chunk[6], v_chunk[7],
121 ]);
122 let g_vec = f32x8::from([
123 g_chunk[0], g_chunk[1], g_chunk[2], g_chunk[3], g_chunk[4], g_chunk[5],
124 g_chunk[6], g_chunk[7],
125 ]);
126
127 let m_t = f32x8::splat(beta1) * m_vec + f32x8::splat(1.0 - beta1) * g_vec;
128 let v_t =
129 f32x8::splat(beta2) * v_vec + f32x8::splat(1.0 - beta2) * g_vec * g_vec;
130 let m_hat = m_t / f32x8::splat(bias_correction1);
131 let v_hat = v_t / f32x8::splat(bias_correction2);
132 let p_new =
133 p_vec - f32x8::splat(lr) * m_hat / (v_hat.sqrt() + f32x8::splat(eps));
134
135 let m_arr = m_t.to_array();
136 let v_arr = v_t.to_array();
137 let p_arr = p_new.to_array();
138
139 m_chunk.copy_from_slice(&m_arr);
140 v_chunk.copy_from_slice(&v_arr);
141 p_chunk.copy_from_slice(&p_arr);
142 });
143
144 for i in head..p_data.len() {
145 let g = g_data[i];
146 let m_t = beta1 * m_data[i] + (1.0 - beta1) * g;
147 let v_t = beta2 * v_data[i] + (1.0 - beta2) * g * g;
148 m_data[i] = m_t;
149 v_data[i] = v_t;
150 let m_hat = m_t / bias_correction1;
151 let v_hat = v_t / bias_correction2;
152 p_data[i] -= lr * m_hat / (v_hat.sqrt() + eps);
153 }
154 }
155 }
156 }
157
158 pub fn zero_grad(&self) {
159 for param in &self.params {
160 param.zero_grad();
161 }
162 }
163}