1use anyhow::Result;
2use std::collections::HashMap;
3use trustformers_core::tensor::Tensor;
4
5#[derive(Debug, Clone)]
6pub struct QuantizationConfig {
7 pub scale: f32,
8 pub zero_point: i8,
9 pub min_val: f32,
10 pub max_val: f32,
11}
12
13impl QuantizationConfig {
14 pub fn new(min_val: f32, max_val: f32) -> Self {
15 let scale = (max_val - min_val) / 255.0;
16 let zero_point = (-min_val / scale).round().clamp(-128.0, 127.0) as i8;
17
18 Self {
19 scale,
20 zero_point,
21 min_val,
22 max_val,
23 }
24 }
25
26 pub fn quantize(&self, value: f32) -> i8 {
27 let quantized = ((value - self.min_val) / self.scale).round() - 128.0;
28 quantized.clamp(-128.0, 127.0) as i8
29 }
30
31 pub fn dequantize(&self, quantized: i8) -> f32 {
32 (quantized as f32 + 128.0) * self.scale + self.min_val
33 }
34}
35
36#[derive(Debug, Clone)]
37pub struct QuantizedState {
38 pub data: Vec<i8>,
39 pub config: QuantizationConfig,
40}
41
42impl QuantizedState {
43 pub fn new(values: &[f32]) -> Self {
44 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
45 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
46
47 let config = QuantizationConfig::new(min_val, max_val);
48 let data: Vec<i8> = values.iter().map(|&v| config.quantize(v)).collect();
49
50 Self { data, config }
51 }
52
53 pub fn to_f32(&self) -> Vec<f32> {
54 self.data.iter().map(|&q| self.config.dequantize(q)).collect()
55 }
56
57 pub fn update(&mut self, new_values: &[f32]) {
58 let min_val = new_values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
59 let max_val = new_values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
60
61 self.config = QuantizationConfig::new(min_val, max_val);
62 self.data = new_values.iter().map(|&v| self.config.quantize(v)).collect();
63 }
64}
65
66#[derive(Debug)]
67pub struct Adam8bit {
68 pub learning_rate: f32,
69 pub beta1: f32,
70 pub beta2: f32,
71 pub epsilon: f32,
72 pub weight_decay: f32,
73 pub step: usize,
74 pub momentum_states: HashMap<String, QuantizedState>,
75 pub variance_states: HashMap<String, QuantizedState>,
76}
77
78impl Default for Adam8bit {
79 fn default() -> Self {
80 Self {
81 learning_rate: 1e-3,
82 beta1: 0.9,
83 beta2: 0.999,
84 epsilon: 1e-8,
85 weight_decay: 0.0,
86 step: 0,
87 momentum_states: HashMap::new(),
88 variance_states: HashMap::new(),
89 }
90 }
91}
92
93impl Adam8bit {
94 pub fn new(learning_rate: f32) -> Self {
95 Self {
96 learning_rate,
97 ..Default::default()
98 }
99 }
100
101 pub fn with_config(
102 learning_rate: f32,
103 beta1: f32,
104 beta2: f32,
105 epsilon: f32,
106 weight_decay: f32,
107 ) -> Self {
108 Self {
109 learning_rate,
110 beta1,
111 beta2,
112 epsilon,
113 weight_decay,
114 step: 0,
115 momentum_states: HashMap::new(),
116 variance_states: HashMap::new(),
117 }
118 }
119
120 pub fn step(
121 &mut self,
122 parameters: &mut HashMap<String, Tensor>,
123 gradients: &HashMap<String, Tensor>,
124 ) -> Result<()> {
125 self.step += 1;
126
127 let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
128 let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
129
130 for (name, param) in parameters.iter_mut() {
131 let grad = gradients
132 .get(name)
133 .ok_or_else(|| anyhow::anyhow!("Missing gradient for parameter: {}", name))?;
134
135 let mut param_data = param.data()?;
136 let grad_data = grad.data()?;
137
138 if param_data.len() != grad_data.len() {
139 return Err(anyhow::anyhow!(
140 "Parameter and gradient size mismatch for: {}",
141 name
142 ));
143 }
144
145 if !self.momentum_states.contains_key(name) {
146 let zeros = vec![0.0; param_data.len()];
147 self.momentum_states.insert(name.clone(), QuantizedState::new(&zeros));
148 self.variance_states.insert(name.clone(), QuantizedState::new(&zeros));
149 }
150
151 let momentum_state = self.momentum_states.get_mut(name).unwrap();
152 let variance_state = self.variance_states.get_mut(name).unwrap();
153
154 let mut momentum = momentum_state.to_f32();
155 let mut variance = variance_state.to_f32();
156
157 for i in 0..param_data.len() {
158 let mut grad_val = grad_data[i];
159
160 if self.weight_decay > 0.0 {
161 grad_val += self.weight_decay * param_data[i];
162 }
163
164 momentum[i] = self.beta1 * momentum[i] + (1.0 - self.beta1) * grad_val;
165 variance[i] = self.beta2 * variance[i] + (1.0 - self.beta2) * grad_val * grad_val;
166
167 let corrected_momentum = momentum[i] / bias_correction1;
168 let corrected_variance = variance[i] / bias_correction2;
169
170 param_data[i] -= self.learning_rate * corrected_momentum
171 / (corrected_variance.sqrt() + self.epsilon);
172 }
173
174 momentum_state.update(&momentum);
175 variance_state.update(&variance);
176
177 *param = Tensor::new(param_data)?;
179 }
180
181 Ok(())
182 }
183
184 pub fn memory_usage(&self) -> usize {
185 let mut total = 0;
186 for state in self.momentum_states.values() {
187 total += state.data.len();
188 }
189 for state in self.variance_states.values() {
190 total += state.data.len();
191 }
192 total
193 }
194
195 pub fn memory_savings_vs_fp32(&self) -> f32 {
196 let quantized_size = self.memory_usage();
197 let fp32_equivalent = quantized_size * 4;
198 1.0 - (quantized_size as f32 / fp32_equivalent as f32)
199 }
200}
201
202#[derive(Debug)]
203pub struct AdamW8bit {
204 pub learning_rate: f32,
205 pub beta1: f32,
206 pub beta2: f32,
207 pub epsilon: f32,
208 pub weight_decay: f32,
209 pub step: usize,
210 pub momentum_states: HashMap<String, QuantizedState>,
211 pub variance_states: HashMap<String, QuantizedState>,
212}
213
214impl Default for AdamW8bit {
215 fn default() -> Self {
216 Self {
217 learning_rate: 1e-3,
218 beta1: 0.9,
219 beta2: 0.999,
220 epsilon: 1e-8,
221 weight_decay: 1e-2,
222 step: 0,
223 momentum_states: HashMap::new(),
224 variance_states: HashMap::new(),
225 }
226 }
227}
228
229impl AdamW8bit {
230 pub fn new(learning_rate: f32) -> Self {
231 Self {
232 learning_rate,
233 ..Default::default()
234 }
235 }
236
237 pub fn with_config(
238 learning_rate: f32,
239 beta1: f32,
240 beta2: f32,
241 epsilon: f32,
242 weight_decay: f32,
243 ) -> Self {
244 Self {
245 learning_rate,
246 beta1,
247 beta2,
248 epsilon,
249 weight_decay,
250 step: 0,
251 momentum_states: HashMap::new(),
252 variance_states: HashMap::new(),
253 }
254 }
255
256 pub fn step(
257 &mut self,
258 parameters: &mut HashMap<String, Tensor>,
259 gradients: &HashMap<String, Tensor>,
260 ) -> Result<()> {
261 self.step += 1;
262
263 let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
264 let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
265
266 for (name, param) in parameters.iter_mut() {
267 let grad = gradients
268 .get(name)
269 .ok_or_else(|| anyhow::anyhow!("Missing gradient for parameter: {}", name))?;
270
271 let mut param_data = param.data()?;
272 let grad_data = grad.data()?;
273
274 if param_data.len() != grad_data.len() {
275 return Err(anyhow::anyhow!(
276 "Parameter and gradient size mismatch for: {}",
277 name
278 ));
279 }
280
281 if !self.momentum_states.contains_key(name) {
282 let zeros = vec![0.0; param_data.len()];
283 self.momentum_states.insert(name.clone(), QuantizedState::new(&zeros));
284 self.variance_states.insert(name.clone(), QuantizedState::new(&zeros));
285 }
286
287 let momentum_state = self.momentum_states.get_mut(name).unwrap();
288 let variance_state = self.variance_states.get_mut(name).unwrap();
289
290 let mut momentum = momentum_state.to_f32();
291 let mut variance = variance_state.to_f32();
292
293 for i in 0..param_data.len() {
294 let grad_val = grad_data[i];
295
296 momentum[i] = self.beta1 * momentum[i] + (1.0 - self.beta1) * grad_val;
297 variance[i] = self.beta2 * variance[i] + (1.0 - self.beta2) * grad_val * grad_val;
298
299 let corrected_momentum = momentum[i] / bias_correction1;
300 let corrected_variance = variance[i] / bias_correction2;
301
302 let update = corrected_momentum / (corrected_variance.sqrt() + self.epsilon);
303
304 param_data[i] = param_data[i] * (1.0 - self.learning_rate * self.weight_decay)
305 - self.learning_rate * update;
306 }
307
308 momentum_state.update(&momentum);
309 variance_state.update(&variance);
310
311 *param = Tensor::new(param_data)?;
313 }
314
315 Ok(())
316 }
317
318 pub fn memory_usage(&self) -> usize {
319 let mut total = 0;
320 for state in self.momentum_states.values() {
321 total += state.data.len();
322 }
323 for state in self.variance_states.values() {
324 total += state.data.len();
325 }
326 total
327 }
328
329 pub fn memory_savings_vs_fp32(&self) -> f32 {
330 let quantized_size = self.memory_usage();
331 let fp32_equivalent = quantized_size * 4;
332 1.0 - (quantized_size as f32 / fp32_equivalent as f32)
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use approx::assert_abs_diff_eq;
340
341 #[test]
342 fn test_quantization_config() {
343 let config = QuantizationConfig::new(-1.0, 1.0);
344
345 assert_eq!(config.quantize(-1.0), -128);
347 assert_eq!(config.quantize(1.0), 127);
348
349 let mid_quantized = config.quantize(0.0);
351 assert!(mid_quantized >= -1 && mid_quantized <= 1);
352
353 let original = 0.5;
355 let quantized = config.quantize(original);
356 let reconstructed = config.dequantize(quantized);
357 assert_abs_diff_eq!(original, reconstructed, epsilon = 0.02);
358 }
359
360 #[test]
361 fn test_quantized_state() {
362 let values = vec![0.1, -0.5, 0.8, -0.2];
363 let state = QuantizedState::new(&values);
364
365 let reconstructed = state.to_f32();
366
367 assert_eq!(values.len(), reconstructed.len());
369
370 assert!(reconstructed[2] > reconstructed[0]); assert!(reconstructed[1] < reconstructed[0]); for (orig, recon) in values.iter().zip(reconstructed.iter()) {
376 assert_abs_diff_eq!(orig, recon, epsilon = 0.1);
377 }
378 }
379
380 #[test]
381 fn test_adam8bit_creation() {
382 let optimizer = Adam8bit::new(0.001);
383 assert_eq!(optimizer.learning_rate, 0.001);
384 assert_eq!(optimizer.beta1, 0.9);
385 assert_eq!(optimizer.beta2, 0.999);
386 assert_eq!(optimizer.step, 0);
387 }
388
389 #[test]
390 fn test_adam8bit_memory_usage() {
391 let mut optimizer = Adam8bit::new(0.001);
392
393 let mut parameters = HashMap::new();
394 let mut gradients = HashMap::new();
395
396 let param_data = vec![1.0, 2.0, 3.0, 4.0];
397 let grad_data = vec![0.1, 0.2, 0.3, 0.4];
398
399 parameters.insert("layer1".to_string(), Tensor::new(param_data).unwrap());
400 gradients.insert("layer1".to_string(), Tensor::new(grad_data).unwrap());
401
402 optimizer.step(&mut parameters, &gradients).unwrap();
403
404 assert_eq!(optimizer.memory_usage(), 8);
405 assert!(optimizer.memory_savings_vs_fp32() > 0.7);
406 }
407
408 #[test]
409 fn test_adamw8bit_creation() {
410 let optimizer = AdamW8bit::new(0.001);
411 assert_eq!(optimizer.learning_rate, 0.001);
412 assert_eq!(optimizer.weight_decay, 0.01);
413 }
414}