1use anyhow::Result;
2use std::collections::HashMap;
3use trustformers_core::tensor::Tensor;
4
5#[derive(Debug, Clone)]
6pub struct QuantizationConfig {
7 pub scale: f32,
8 pub zero_point: i8,
9 pub min_val: f32,
10 pub max_val: f32,
11}
12
13impl QuantizationConfig {
14 pub fn new(min_val: f32, max_val: f32) -> Self {
15 let scale = (max_val - min_val) / 255.0;
16 let zero_point = (-min_val / scale).round().clamp(-128.0, 127.0) as i8;
17
18 Self {
19 scale,
20 zero_point,
21 min_val,
22 max_val,
23 }
24 }
25
26 pub fn quantize(&self, value: f32) -> i8 {
27 let quantized = ((value - self.min_val) / self.scale).round() - 128.0;
28 quantized.clamp(-128.0, 127.0) as i8
29 }
30
31 pub fn dequantize(&self, quantized: i8) -> f32 {
32 (quantized as f32 + 128.0) * self.scale + self.min_val
33 }
34}
35
36#[derive(Debug, Clone)]
37pub struct QuantizedState {
38 pub data: Vec<i8>,
39 pub config: QuantizationConfig,
40}
41
42impl QuantizedState {
43 pub fn new(values: &[f32]) -> Self {
44 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
45 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
46
47 let config = QuantizationConfig::new(min_val, max_val);
48 let data: Vec<i8> = values.iter().map(|&v| config.quantize(v)).collect();
49
50 Self { data, config }
51 }
52
53 pub fn to_f32(&self) -> Vec<f32> {
54 self.data.iter().map(|&q| self.config.dequantize(q)).collect()
55 }
56
57 pub fn update(&mut self, new_values: &[f32]) {
58 let min_val = new_values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
59 let max_val = new_values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
60
61 self.config = QuantizationConfig::new(min_val, max_val);
62 self.data = new_values.iter().map(|&v| self.config.quantize(v)).collect();
63 }
64}
65
66#[derive(Debug)]
67pub struct Adam8bit {
68 pub learning_rate: f32,
69 pub beta1: f32,
70 pub beta2: f32,
71 pub epsilon: f32,
72 pub weight_decay: f32,
73 pub step: usize,
74 pub momentum_states: HashMap<String, QuantizedState>,
75 pub variance_states: HashMap<String, QuantizedState>,
76}
77
78impl Default for Adam8bit {
79 fn default() -> Self {
80 Self {
81 learning_rate: 1e-3,
82 beta1: 0.9,
83 beta2: 0.999,
84 epsilon: 1e-8,
85 weight_decay: 0.0,
86 step: 0,
87 momentum_states: HashMap::new(),
88 variance_states: HashMap::new(),
89 }
90 }
91}
92
93impl Adam8bit {
94 pub fn new(learning_rate: f32) -> Self {
95 Self {
96 learning_rate,
97 ..Default::default()
98 }
99 }
100
101 pub fn with_config(
102 learning_rate: f32,
103 beta1: f32,
104 beta2: f32,
105 epsilon: f32,
106 weight_decay: f32,
107 ) -> Self {
108 Self {
109 learning_rate,
110 beta1,
111 beta2,
112 epsilon,
113 weight_decay,
114 step: 0,
115 momentum_states: HashMap::new(),
116 variance_states: HashMap::new(),
117 }
118 }
119
120 pub fn step(
121 &mut self,
122 parameters: &mut HashMap<String, Tensor>,
123 gradients: &HashMap<String, Tensor>,
124 ) -> Result<()> {
125 self.step += 1;
126
127 let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
128 let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
129
130 for (name, param) in parameters.iter_mut() {
131 let grad = gradients
132 .get(name)
133 .ok_or_else(|| anyhow::anyhow!("Missing gradient for parameter: {}", name))?;
134
135 let mut param_data = param.data()?;
136 let grad_data = grad.data()?;
137
138 if param_data.len() != grad_data.len() {
139 return Err(anyhow::anyhow!(
140 "Parameter and gradient size mismatch for: {}",
141 name
142 ));
143 }
144
145 if !self.momentum_states.contains_key(name) {
146 let zeros = vec![0.0; param_data.len()];
147 self.momentum_states.insert(name.clone(), QuantizedState::new(&zeros));
148 self.variance_states.insert(name.clone(), QuantizedState::new(&zeros));
149 }
150
151 let momentum_state = self
152 .momentum_states
153 .get_mut(name)
154 .expect("momentum_state should exist after initialization");
155 let variance_state = self
156 .variance_states
157 .get_mut(name)
158 .expect("variance_state should exist after initialization");
159
160 let mut momentum = momentum_state.to_f32();
161 let mut variance = variance_state.to_f32();
162
163 for i in 0..param_data.len() {
164 let mut grad_val = grad_data[i];
165
166 if self.weight_decay > 0.0 {
167 grad_val += self.weight_decay * param_data[i];
168 }
169
170 momentum[i] = self.beta1 * momentum[i] + (1.0 - self.beta1) * grad_val;
171 variance[i] = self.beta2 * variance[i] + (1.0 - self.beta2) * grad_val * grad_val;
172
173 let corrected_momentum = momentum[i] / bias_correction1;
174 let corrected_variance = variance[i] / bias_correction2;
175
176 param_data[i] -= self.learning_rate * corrected_momentum
177 / (corrected_variance.sqrt() + self.epsilon);
178 }
179
180 momentum_state.update(&momentum);
181 variance_state.update(&variance);
182
183 *param = Tensor::new(param_data)?;
185 }
186
187 Ok(())
188 }
189
190 pub fn memory_usage(&self) -> usize {
191 let mut total = 0;
192 for state in self.momentum_states.values() {
193 total += state.data.len();
194 }
195 for state in self.variance_states.values() {
196 total += state.data.len();
197 }
198 total
199 }
200
201 pub fn memory_savings_vs_fp32(&self) -> f32 {
202 let quantized_size = self.memory_usage();
203 let fp32_equivalent = quantized_size * 4;
204 1.0 - (quantized_size as f32 / fp32_equivalent as f32)
205 }
206}
207
208#[derive(Debug)]
209pub struct AdamW8bit {
210 pub learning_rate: f32,
211 pub beta1: f32,
212 pub beta2: f32,
213 pub epsilon: f32,
214 pub weight_decay: f32,
215 pub step: usize,
216 pub momentum_states: HashMap<String, QuantizedState>,
217 pub variance_states: HashMap<String, QuantizedState>,
218}
219
220impl Default for AdamW8bit {
221 fn default() -> Self {
222 Self {
223 learning_rate: 1e-3,
224 beta1: 0.9,
225 beta2: 0.999,
226 epsilon: 1e-8,
227 weight_decay: 1e-2,
228 step: 0,
229 momentum_states: HashMap::new(),
230 variance_states: HashMap::new(),
231 }
232 }
233}
234
235impl AdamW8bit {
236 pub fn new(learning_rate: f32) -> Self {
237 Self {
238 learning_rate,
239 ..Default::default()
240 }
241 }
242
243 pub fn with_config(
244 learning_rate: f32,
245 beta1: f32,
246 beta2: f32,
247 epsilon: f32,
248 weight_decay: f32,
249 ) -> Self {
250 Self {
251 learning_rate,
252 beta1,
253 beta2,
254 epsilon,
255 weight_decay,
256 step: 0,
257 momentum_states: HashMap::new(),
258 variance_states: HashMap::new(),
259 }
260 }
261
262 pub fn step(
263 &mut self,
264 parameters: &mut HashMap<String, Tensor>,
265 gradients: &HashMap<String, Tensor>,
266 ) -> Result<()> {
267 self.step += 1;
268
269 let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
270 let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
271
272 for (name, param) in parameters.iter_mut() {
273 let grad = gradients
274 .get(name)
275 .ok_or_else(|| anyhow::anyhow!("Missing gradient for parameter: {}", name))?;
276
277 let mut param_data = param.data()?;
278 let grad_data = grad.data()?;
279
280 if param_data.len() != grad_data.len() {
281 return Err(anyhow::anyhow!(
282 "Parameter and gradient size mismatch for: {}",
283 name
284 ));
285 }
286
287 if !self.momentum_states.contains_key(name) {
288 let zeros = vec![0.0; param_data.len()];
289 self.momentum_states.insert(name.clone(), QuantizedState::new(&zeros));
290 self.variance_states.insert(name.clone(), QuantizedState::new(&zeros));
291 }
292
293 let momentum_state = self
294 .momentum_states
295 .get_mut(name)
296 .expect("momentum_state should exist after initialization");
297 let variance_state = self
298 .variance_states
299 .get_mut(name)
300 .expect("variance_state should exist after initialization");
301
302 let mut momentum = momentum_state.to_f32();
303 let mut variance = variance_state.to_f32();
304
305 for i in 0..param_data.len() {
306 let grad_val = grad_data[i];
307
308 momentum[i] = self.beta1 * momentum[i] + (1.0 - self.beta1) * grad_val;
309 variance[i] = self.beta2 * variance[i] + (1.0 - self.beta2) * grad_val * grad_val;
310
311 let corrected_momentum = momentum[i] / bias_correction1;
312 let corrected_variance = variance[i] / bias_correction2;
313
314 let update = corrected_momentum / (corrected_variance.sqrt() + self.epsilon);
315
316 param_data[i] = param_data[i] * (1.0 - self.learning_rate * self.weight_decay)
317 - self.learning_rate * update;
318 }
319
320 momentum_state.update(&momentum);
321 variance_state.update(&variance);
322
323 *param = Tensor::new(param_data)?;
325 }
326
327 Ok(())
328 }
329
330 pub fn memory_usage(&self) -> usize {
331 let mut total = 0;
332 for state in self.momentum_states.values() {
333 total += state.data.len();
334 }
335 for state in self.variance_states.values() {
336 total += state.data.len();
337 }
338 total
339 }
340
341 pub fn memory_savings_vs_fp32(&self) -> f32 {
342 let quantized_size = self.memory_usage();
343 let fp32_equivalent = quantized_size * 4;
344 1.0 - (quantized_size as f32 / fp32_equivalent as f32)
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use approx::assert_abs_diff_eq;
352
353 #[test]
354 fn test_quantization_config() {
355 let config = QuantizationConfig::new(-1.0, 1.0);
356
357 assert_eq!(config.quantize(-1.0), -128);
359 assert_eq!(config.quantize(1.0), 127);
360
361 let mid_quantized = config.quantize(0.0);
363 assert!((-1..=1).contains(&mid_quantized));
364
365 let original = 0.5;
367 let quantized = config.quantize(original);
368 let reconstructed = config.dequantize(quantized);
369 assert_abs_diff_eq!(original, reconstructed, epsilon = 0.02);
370 }
371
372 #[test]
373 fn test_quantized_state() {
374 let values = vec![0.1, -0.5, 0.8, -0.2];
375 let state = QuantizedState::new(&values);
376
377 let reconstructed = state.to_f32();
378
379 assert_eq!(values.len(), reconstructed.len());
381
382 assert!(reconstructed[2] > reconstructed[0]); assert!(reconstructed[1] < reconstructed[0]); for (orig, recon) in values.iter().zip(reconstructed.iter()) {
388 assert_abs_diff_eq!(orig, recon, epsilon = 0.1);
389 }
390 }
391
392 #[test]
393 fn test_adam8bit_creation() {
394 let optimizer = Adam8bit::new(0.001);
395 assert_eq!(optimizer.learning_rate, 0.001);
396 assert_eq!(optimizer.beta1, 0.9);
397 assert_eq!(optimizer.beta2, 0.999);
398 assert_eq!(optimizer.step, 0);
399 }
400
401 #[test]
402 fn test_adam8bit_memory_usage() {
403 let mut optimizer = Adam8bit::new(0.001);
404
405 let mut parameters = HashMap::new();
406 let mut gradients = HashMap::new();
407
408 let param_data = vec![1.0, 2.0, 3.0, 4.0];
409 let grad_data = vec![0.1, 0.2, 0.3, 0.4];
410
411 parameters.insert(
412 "layer1".to_string(),
413 Tensor::new(param_data).expect("Failed to create tensor"),
414 );
415 gradients.insert(
416 "layer1".to_string(),
417 Tensor::new(grad_data).expect("Failed to create tensor"),
418 );
419
420 optimizer.step(&mut parameters, &gradients).expect("Step failed");
421
422 assert_eq!(optimizer.memory_usage(), 8);
423 assert!(optimizer.memory_savings_vs_fp32() > 0.7);
424 }
425
426 #[test]
427 fn test_adamw8bit_creation() {
428 let optimizer = AdamW8bit::new(0.001);
429 assert_eq!(optimizer.learning_rate, 0.001);
430 assert_eq!(optimizer.weight_decay, 0.01);
431 }
432}