Skip to main content

rustorch_core/
optimizer.rs

1use crate::Tensor;
2use rayon::prelude::*;
3use std::collections::HashMap;
4use std::sync::Arc;
5use wide::f32x8;
6
7pub struct Adam {
8    params: Vec<Tensor>,
9    lr: f32,
10    beta1: f32,
11    beta2: f32,
12    epsilon: f32,
13    step: u32,
14    state: HashMap<usize, (Tensor, Tensor)>, // (m, v) keyed by param unique ID
15}
16
17impl Adam {
18    pub fn new(params: Vec<Tensor>, lr: f32) -> Self {
19        Self {
20            params,
21            lr,
22            beta1: 0.9,
23            beta2: 0.999,
24            epsilon: 1e-8,
25            step: 0,
26            state: HashMap::new(),
27        }
28    }
29
30    pub fn step(&mut self) {
31        self.step += 1;
32
33        for param in self.params.iter() {
34            let grad = if let Some(g) = param.grad() {
35                g
36            } else {
37                continue;
38            };
39
40            let is_wgpu = param.storage().wgpu_buffer().is_some();
41
42            let param_id = Arc::as_ptr(&param.inner) as usize;
43
44            self.state.entry(param_id).or_insert_with(|| {
45                let m = if is_wgpu {
46                    Tensor::zeros(param.shape()).to_wgpu()
47                } else {
48                    Tensor::zeros(param.shape())
49                };
50                let v = if is_wgpu {
51                    Tensor::zeros(param.shape()).to_wgpu()
52                } else {
53                    Tensor::zeros(param.shape())
54                };
55                (m, v)
56            });
57
58            let (m, v) = self.state.get(&param_id).unwrap();
59
60            let beta1 = self.beta1;
61            let beta2 = self.beta2;
62            let eps = self.epsilon;
63            let lr = self.lr;
64            let step = self.step as f32;
65            let bias_correction1 = 1.0 - beta1.powf(step);
66            let bias_correction2 = 1.0 - beta2.powf(step);
67
68            if is_wgpu {
69                #[cfg(feature = "wgpu_backend")]
70                {
71                    let grad_wgpu = if grad.storage().wgpu_buffer().is_none() {
72                        grad.to_wgpu()
73                    } else {
74                        grad
75                    };
76
77                    let param_buf = param.storage().wgpu_buffer().unwrap();
78                    let grad_buf = grad_wgpu.storage().wgpu_buffer().unwrap();
79                    let m_buf = m.storage().wgpu_buffer().unwrap();
80                    let v_buf = v.storage().wgpu_buffer().unwrap();
81
82                    crate::backend::wgpu::adam_step_wgpu(
83                        param_buf,
84                        grad_buf,
85                        m_buf,
86                        v_buf,
87                        param.shape().iter().product(),
88                        self.lr,
89                        self.beta1,
90                        self.beta2,
91                        self.epsilon,
92                        self.step,
93                    );
94                    crate::backend::wgpu::flush_queue();
95                }
96            } else {
97                let mut p_data = param.data_mut();
98                let g_data = grad.data();
99                let mut m_data = m.data_mut();
100                let mut v_data = v.data_mut();
101                let lanes = 8usize;
102                let head = (p_data.len() / lanes) * lanes;
103
104                p_data[..head]
105                    .par_chunks_mut(lanes)
106                    .zip(m_data[..head].par_chunks_mut(lanes))
107                    .zip(v_data[..head].par_chunks_mut(lanes))
108                    .zip(g_data[..head].par_chunks(lanes))
109                    .for_each(|(((p_chunk, m_chunk), v_chunk), g_chunk)| {
110                        let p_vec = f32x8::from([
111                            p_chunk[0], p_chunk[1], p_chunk[2], p_chunk[3], p_chunk[4], p_chunk[5],
112                            p_chunk[6], p_chunk[7],
113                        ]);
114                        let m_vec = f32x8::from([
115                            m_chunk[0], m_chunk[1], m_chunk[2], m_chunk[3], m_chunk[4], m_chunk[5],
116                            m_chunk[6], m_chunk[7],
117                        ]);
118                        let v_vec = f32x8::from([
119                            v_chunk[0], v_chunk[1], v_chunk[2], v_chunk[3], v_chunk[4], v_chunk[5],
120                            v_chunk[6], v_chunk[7],
121                        ]);
122                        let g_vec = f32x8::from([
123                            g_chunk[0], g_chunk[1], g_chunk[2], g_chunk[3], g_chunk[4], g_chunk[5],
124                            g_chunk[6], g_chunk[7],
125                        ]);
126
127                        let m_t = f32x8::splat(beta1) * m_vec + f32x8::splat(1.0 - beta1) * g_vec;
128                        let v_t =
129                            f32x8::splat(beta2) * v_vec + f32x8::splat(1.0 - beta2) * g_vec * g_vec;
130                        let m_hat = m_t / f32x8::splat(bias_correction1);
131                        let v_hat = v_t / f32x8::splat(bias_correction2);
132                        let p_new =
133                            p_vec - f32x8::splat(lr) * m_hat / (v_hat.sqrt() + f32x8::splat(eps));
134
135                        let m_arr = m_t.to_array();
136                        let v_arr = v_t.to_array();
137                        let p_arr = p_new.to_array();
138
139                        m_chunk.copy_from_slice(&m_arr);
140                        v_chunk.copy_from_slice(&v_arr);
141                        p_chunk.copy_from_slice(&p_arr);
142                    });
143
144                for i in head..p_data.len() {
145                    let g = g_data[i];
146                    let m_t = beta1 * m_data[i] + (1.0 - beta1) * g;
147                    let v_t = beta2 * v_data[i] + (1.0 - beta2) * g * g;
148                    m_data[i] = m_t;
149                    v_data[i] = v_t;
150                    let m_hat = m_t / bias_correction1;
151                    let v_hat = v_t / bias_correction2;
152                    p_data[i] -= lr * m_hat / (v_hat.sqrt() + eps);
153                }
154            }
155        }
156    }
157
158    pub fn zero_grad(&self) {
159        for param in &self.params {
160            param.zero_grad();
161        }
162    }
163}