1use crate::common::{OptimizerState, StateMemoryStats};
14use crate::traits::StatefulOptimizer;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use trustformers_core::errors::{Result, TrustformersError};
18use trustformers_core::tensor::Tensor;
19use trustformers_core::traits::Optimizer;
20
21const NF4_VALUES: [f32; 16] = [
23 -1.0,
24 -0.696_192_8,
25 -0.525_073_05,
26 -0.394_917_5,
27 -0.284_441_38,
28 -0.184_773_43,
29 -0.091_050_036,
30 0.0,
31 0.079_580_3,
32 0.160_930_2,
33 0.246_112_3,
34 0.337_915_24,
35 0.440_709_83,
36 0.562_617,
37 0.722_956_84,
38 1.0,
39];
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct AdvancedQuantizationConfig {
44 pub method: QuantizationMethod,
46 pub block_size: usize,
48 pub adaptation_rate: f32,
50 pub min_scale: f32,
52 pub max_scale: f32,
54 pub double_quantization: bool,
56}
57
58impl Default for AdvancedQuantizationConfig {
59 fn default() -> Self {
60 Self {
61 method: QuantizationMethod::NF4,
62 block_size: 64,
63 adaptation_rate: 0.01,
64 min_scale: 1e-8,
65 max_scale: 1e8,
66 double_quantization: true,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub enum QuantizationMethod {
74 Int4,
76 NF4,
78 Int8,
80 Dynamic,
82 BlockWise,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct QuantizedTensor {
89 pub data: Vec<f32>,
91 pub scales: Vec<f32>,
93 pub zero_points: Vec<f32>,
95 pub shape: Vec<usize>,
97 pub method: QuantizationMethod,
99 pub block_size: usize,
101}
102
103impl QuantizedTensor {
104 pub fn new(
106 data: Vec<f32>,
107 scales: Vec<f32>,
108 zero_points: Vec<f32>,
109 shape: Vec<usize>,
110 method: QuantizationMethod,
111 block_size: usize,
112 ) -> Self {
113 Self {
114 data,
115 scales,
116 zero_points,
117 shape,
118 method,
119 block_size,
120 }
121 }
122
123 pub fn memory_usage(&self) -> usize {
125 self.data.len() * 4 + self.scales.len() * 4 + self.zero_points.len() * 4
127 }
128
129 pub fn compression_ratio(&self) -> f32 {
131 let original_size = self.shape.iter().product::<usize>() * 4; match self.method {
135 QuantizationMethod::NF4 | QuantizationMethod::Int4 => 8.0, QuantizationMethod::Int8 => 4.0, _ => {
138 let compressed_size = self.memory_usage();
139 if compressed_size > 0 {
140 original_size as f32 / compressed_size as f32
141 } else {
142 1.0
143 }
144 },
145 }
146 }
147}
148
149pub struct QuantizationUtils;
151
152impl QuantizationUtils {
153 pub fn quantize_nf4(tensor: &Tensor, block_size: usize) -> Result<QuantizedTensor> {
155 let data = tensor.data()?;
156 let shape = tensor.shape();
157 let num_elements = data.len();
158 let num_blocks = num_elements.div_ceil(block_size);
159
160 let mut quantized_data = Vec::new();
161 let mut scales = Vec::with_capacity(num_blocks);
162 let mut zero_points = Vec::with_capacity(num_blocks);
163
164 for block_idx in 0..num_blocks {
165 let start = block_idx * block_size;
166 let end = (start + block_size).min(num_elements);
167 let block = &data[start..end];
168
169 let min_val = block.iter().fold(f32::INFINITY, |a, &b| a.min(b));
171 let max_val = block.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
172
173 let scale = (max_val - min_val) / 15.0; let zero_point = -min_val / scale;
175
176 scales.push(scale);
177 zero_points.push(zero_point);
178
179 for &value in block {
181 let normalized = (value - min_val) / scale;
182 let quantized = Self::find_closest_nf4(normalized / 15.0);
183 quantized_data.push(quantized);
184 }
185 }
186
187 Ok(QuantizedTensor::new(
188 quantized_data,
189 scales,
190 zero_points,
191 shape,
192 QuantizationMethod::NF4,
193 block_size,
194 ))
195 }
196
197 fn find_closest_nf4(value: f32) -> f32 {
199 let clamped = value.clamp(-1.0, 1.0);
200 let mut best_val = NF4_VALUES[0];
201 let mut best_diff = (NF4_VALUES[0] - clamped).abs();
202
203 for &nf4_val in NF4_VALUES.iter() {
204 let diff = (nf4_val - clamped).abs();
205 if diff < best_diff {
206 best_diff = diff;
207 best_val = nf4_val;
208 }
209 }
210
211 best_val
212 }
213
214 pub fn dequantize_nf4(quantized: &QuantizedTensor) -> Result<Tensor> {
216 let num_elements: usize = quantized.shape.iter().product();
217 let mut data = Vec::with_capacity(num_elements);
218 let block_size = quantized.block_size;
219 let num_blocks = num_elements.div_ceil(block_size);
220
221 let mut data_idx = 0;
222
223 for block_idx in 0..num_blocks {
224 let start = block_idx * block_size;
225 let end = (start + block_size).min(num_elements);
226 let block_len = end - start;
227
228 let scale = quantized.scales[block_idx];
229 let zero_point = quantized.zero_points[block_idx];
230
231 for _ in 0..block_len {
232 if data_idx < quantized.data.len() {
233 let nf4_val = quantized.data[data_idx];
234 let dequantized = (nf4_val * 15.0 + zero_point) * scale;
235 data.push(dequantized);
236 data_idx += 1;
237 }
238 }
239 }
240
241 Tensor::new(data)
242 }
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct GradientStatistics {
248 pub mean: f32,
249 pub variance: f32,
250 pub skewness: f32,
251 pub kurtosis: f32,
252 pub l2_norm: f32,
253}
254
255impl GradientStatistics {
256 pub fn compute(data: &[f32]) -> Self {
258 let n = data.len() as f32;
259 let mean = data.iter().sum::<f32>() / n;
260
261 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
262
263 let std_dev = variance.sqrt();
264
265 let skewness = if std_dev > 1e-8 {
266 data.iter().map(|x| ((x - mean) / std_dev).powi(3)).sum::<f32>() / n
267 } else {
268 0.0
269 };
270
271 let kurtosis = if std_dev > 1e-8 {
272 data.iter().map(|x| ((x - mean) / std_dev).powi(4)).sum::<f32>() / n - 3.0
273 } else {
275 0.0
276 };
277
278 let l2_norm = data.iter().map(|x| x * x).sum::<f32>().sqrt();
279
280 Self {
281 mean,
282 variance,
283 skewness,
284 kurtosis,
285 l2_norm,
286 }
287 }
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct Adam4bitOptimizerConfig {
293 pub learning_rate: f32,
294 pub beta1: f32,
295 pub beta2: f32,
296 pub epsilon: f32,
297 pub weight_decay: f32,
298}
299
300impl Default for Adam4bitOptimizerConfig {
301 fn default() -> Self {
302 Self {
303 learning_rate: 1e-3,
304 beta1: 0.9,
305 beta2: 0.999,
306 epsilon: 1e-8,
307 weight_decay: 0.0,
308 }
309 }
310}
311
312#[derive(Debug)]
314pub struct Adam4bit {
315 config: AdvancedQuantizationConfig,
316 optimizer_config: Adam4bitOptimizerConfig,
317 state: OptimizerState,
318 momentum_quantized: HashMap<String, QuantizedTensor>,
320 variance_quantized: HashMap<String, QuantizedTensor>,
322 gradient_stats: HashMap<String, GradientStatistics>,
323}
324
325impl Adam4bit {
326 pub fn new(
328 learning_rate: f32,
329 beta1: f32,
330 beta2: f32,
331 epsilon: f32,
332 weight_decay: f32,
333 ) -> Self {
334 let optimizer_config = Adam4bitOptimizerConfig {
335 learning_rate,
336 beta1,
337 beta2,
338 epsilon,
339 weight_decay,
340 };
341
342 Self {
343 config: AdvancedQuantizationConfig::default(),
344 optimizer_config,
345 state: OptimizerState::new(),
346 momentum_quantized: HashMap::new(),
347 variance_quantized: HashMap::new(),
348 gradient_stats: HashMap::new(),
349 }
350 }
351
352 pub fn with_quantization_config(
354 optimizer_config: Adam4bitOptimizerConfig,
355 quantization_config: AdvancedQuantizationConfig,
356 ) -> Self {
357 Self {
358 config: quantization_config,
359 optimizer_config,
360 state: OptimizerState::new(),
361 momentum_quantized: HashMap::new(),
362 variance_quantized: HashMap::new(),
363 gradient_stats: HashMap::new(),
364 }
365 }
366
367 pub fn memory_savings(&self) -> f32 {
369 0.75
371 }
372
373 fn update_gradient_stats(&mut self, param_id: &str, gradient_data: &[f32]) {
375 let stats = GradientStatistics::compute(gradient_data);
376
377 if let Some(existing_stats) = self.gradient_stats.get_mut(param_id) {
379 let alpha = self.config.adaptation_rate;
380 existing_stats.mean = (1.0 - alpha) * existing_stats.mean + alpha * stats.mean;
381 existing_stats.variance =
382 (1.0 - alpha) * existing_stats.variance + alpha * stats.variance;
383 } else {
384 self.gradient_stats.insert(param_id.to_string(), stats);
385 }
386 }
387}
388
389impl Optimizer for Adam4bit {
390 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
391 match (parameter, grad) {
392 (Tensor::F32(param), Tensor::F32(grad_arr)) => {
393 let param_id = format!("{:p}", param.as_ptr());
394 let size = grad_arr.len();
395
396 self.update_gradient_stats(
398 ¶m_id,
399 &grad_arr.iter().cloned().collect::<Vec<f32>>(),
400 );
401
402 if !self.momentum_quantized.contains_key(¶m_id) {
404 let zeros = vec![0.0; size];
405 let zero_tensor = Tensor::new(zeros)?;
406 let momentum_q =
407 QuantizationUtils::quantize_nf4(&zero_tensor, self.config.block_size)?;
408 let variance_q =
409 QuantizationUtils::quantize_nf4(&zero_tensor, self.config.block_size)?;
410
411 self.momentum_quantized.insert(param_id.clone(), momentum_q);
412 self.variance_quantized.insert(param_id.clone(), variance_q);
413 }
414
415 let momentum_q = self.momentum_quantized.get(¶m_id).unwrap();
417 let variance_q = self.variance_quantized.get(¶m_id).unwrap();
418
419 let momentum_tensor = QuantizationUtils::dequantize_nf4(momentum_q)?;
421 let variance_tensor = QuantizationUtils::dequantize_nf4(variance_q)?;
422
423 let momentum_data = momentum_tensor.data()?;
424 let variance_data = variance_tensor.data()?;
425
426 let mut new_momentum = Vec::with_capacity(size);
427 let mut new_variance = Vec::with_capacity(size);
428
429 let step = (self.state.step + 1) as f32;
430 let bias_correction1 = 1.0 - self.optimizer_config.beta1.powf(step);
431 let bias_correction2 = 1.0 - self.optimizer_config.beta2.powf(step);
432
433 for i in 0..size {
435 let mut g = grad_arr[i];
436
437 if self.optimizer_config.weight_decay > 0.0 {
439 g += self.optimizer_config.weight_decay * param[i];
440 }
441
442 let m = self.optimizer_config.beta1 * momentum_data[i]
444 + (1.0 - self.optimizer_config.beta1) * g;
445 let v = self.optimizer_config.beta2 * variance_data[i]
446 + (1.0 - self.optimizer_config.beta2) * g * g;
447
448 new_momentum.push(m);
449 new_variance.push(v);
450
451 let m_hat = m / bias_correction1;
453 let v_hat = v / bias_correction2;
454
455 param[i] -= self.optimizer_config.learning_rate * m_hat
457 / (v_hat.sqrt() + self.optimizer_config.epsilon);
458 }
459
460 let new_momentum_tensor = Tensor::new(new_momentum)?;
462 let new_variance_tensor = Tensor::new(new_variance)?;
463
464 let momentum_q_new =
465 QuantizationUtils::quantize_nf4(&new_momentum_tensor, self.config.block_size)?;
466 let variance_q_new =
467 QuantizationUtils::quantize_nf4(&new_variance_tensor, self.config.block_size)?;
468
469 self.momentum_quantized.insert(param_id.clone(), momentum_q_new);
470 self.variance_quantized.insert(param_id, variance_q_new);
471
472 Ok(())
473 },
474 _ => Err(TrustformersError::tensor_op_error(
475 "Unsupported tensor types for Adam4bit",
476 "Adam4bit::update",
477 )),
478 }
479 }
480
481 fn zero_grad(&mut self) {
482 }
484
485 fn step(&mut self) {
486 self.state.step();
487 }
488
489 fn get_lr(&self) -> f32 {
490 self.optimizer_config.learning_rate
491 }
492
493 fn set_lr(&mut self, lr: f32) {
494 self.optimizer_config.learning_rate = lr;
495 }
496}
497
498impl StatefulOptimizer for Adam4bit {
499 type Config = Adam4bitOptimizerConfig;
500 type State = OptimizerState;
501
502 fn config(&self) -> &Self::Config {
503 &self.optimizer_config
504 }
505
506 fn state(&self) -> &Self::State {
507 &self.state
508 }
509
510 fn state_mut(&mut self) -> &mut Self::State {
511 &mut self.state
512 }
513
514 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
515 let mut state_dict = HashMap::new();
516
517 state_dict.insert(
519 "learning_rate".to_string(),
520 Tensor::new(vec![self.optimizer_config.learning_rate])?,
521 );
522 state_dict.insert(
523 "beta1".to_string(),
524 Tensor::new(vec![self.optimizer_config.beta1])?,
525 );
526 state_dict.insert(
527 "beta2".to_string(),
528 Tensor::new(vec![self.optimizer_config.beta2])?,
529 );
530 state_dict.insert(
531 "epsilon".to_string(),
532 Tensor::new(vec![self.optimizer_config.epsilon])?,
533 );
534 state_dict.insert(
535 "weight_decay".to_string(),
536 Tensor::new(vec![self.optimizer_config.weight_decay])?,
537 );
538 state_dict.insert(
539 "step".to_string(),
540 Tensor::new(vec![self.state.step as f32])?,
541 );
542
543 for (param_id, momentum_q) in &self.momentum_quantized {
545 state_dict.insert(
546 format!("momentum_q_{}", param_id),
547 Tensor::new(momentum_q.data.clone())?,
548 );
549 }
550
551 for (param_id, variance_q) in &self.variance_quantized {
552 state_dict.insert(
553 format!("variance_q_{}", param_id),
554 Tensor::new(variance_q.data.clone())?,
555 );
556 }
557
558 Ok(state_dict)
559 }
560
561 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
562 if let Some(lr_tensor) = state.get("learning_rate") {
564 if let Ok(lr_vec) = lr_tensor.data() {
565 if !lr_vec.is_empty() {
566 self.optimizer_config.learning_rate = lr_vec[0];
567 }
568 }
569 }
570 Ok(())
574 }
575
576 fn memory_usage(&self) -> StateMemoryStats {
577 let total_memory =
578 self.momentum_quantized.values().map(|q| q.memory_usage()).sum::<usize>()
579 + self.variance_quantized.values().map(|q| q.memory_usage()).sum::<usize>();
580
581 StateMemoryStats {
582 momentum_elements: self.momentum_quantized.values().map(|q| q.data.len()).sum(),
583 variance_elements: self.variance_quantized.values().map(|q| q.data.len()).sum(),
584 third_moment_elements: 0,
585 total_bytes: total_memory,
586 num_parameters: self.momentum_quantized.len(),
587 }
588 }
589
590 fn reset_state(&mut self) {
591 self.state.clear();
592 self.momentum_quantized.clear();
593 self.variance_quantized.clear();
594 self.gradient_stats.clear();
595 }
596
597 fn num_parameters(&self) -> usize {
598 self.momentum_quantized.values().map(|q| q.data.len()).sum()
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn test_nf4_quantization() {
608 let data = vec![1.0, -0.5, 0.0, 0.8, -1.2];
609 let tensor = Tensor::new(data.clone()).unwrap();
610
611 let quantized = QuantizationUtils::quantize_nf4(&tensor, 64).unwrap();
612 assert_eq!(quantized.method, QuantizationMethod::NF4);
613 assert!(quantized.compression_ratio() >= 1.0);
614 }
615
616 #[test]
617 fn test_gradient_statistics() {
618 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
619 let stats = GradientStatistics::compute(&data);
620
621 assert!((stats.mean - 3.0).abs() < 1e-6);
622 assert!(stats.variance > 0.0);
623 assert!(stats.l2_norm > 0.0);
624 }
625
626 #[test]
627 fn test_adam4bit_creation() {
628 let optimizer = Adam4bit::new(0.001, 0.9, 0.999, 1e-8, 0.01);
629 assert_eq!(optimizer.get_lr(), 0.001);
630 assert!(optimizer.memory_savings() > 0.5); }
632
633 #[test]
634 fn test_quantized_tensor_memory() {
635 let quantized = QuantizedTensor::new(
636 vec![0.0, 1.0, 2.0, 3.0],
637 vec![1.0],
638 vec![0.0],
639 vec![4],
640 QuantizationMethod::NF4,
641 64,
642 );
643
644 assert!(quantized.memory_usage() > 0);
645 assert!(quantized.compression_ratio() >= 1.0);
646 }
647}