1#![allow(unused_variables)] use crate::errors::{Result, TrustformersError};
4use crate::tensor::Tensor;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum QuantizationScheme {
10 Int8,
12 Int4,
14 Dynamic,
16 DynamicINT8,
18 GPTQ,
20 AWQ,
22 BnB8bit,
24 BnB4bit,
26 BnB4bitFP4,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct QuantizationConfig {
33 pub scheme: QuantizationScheme,
34 pub symmetric: bool,
35 pub per_channel: bool,
36 pub calibration_samples: Option<usize>,
37 pub group_size: Option<usize>, pub bnb_config: Option<BnBConfig>, }
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct BnBConfig {
44 pub use_double_quant: bool,
45 pub quant_type: BnBQuantType,
46 pub compute_dtype: BnBComputeType,
47 pub bnb_4bit_quant_storage: Option<BnBStorageType>,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
51pub enum BnBQuantType {
52 NF4, FP4, Int8, }
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum BnBComputeType {
59 Float16,
60 BFloat16,
61 Float32,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
65pub enum BnBStorageType {
66 UInt8,
67 Int8,
68 Float16,
69}
70
71impl Default for QuantizationConfig {
72 fn default() -> Self {
73 Self {
74 scheme: QuantizationScheme::Int8,
75 symmetric: true,
76 per_channel: false,
77 calibration_samples: Some(128),
78 group_size: Some(128),
79 bnb_config: None,
80 }
81 }
82}
83
84impl Default for BnBConfig {
85 fn default() -> Self {
86 Self {
87 use_double_quant: false,
88 quant_type: BnBQuantType::NF4,
89 compute_dtype: BnBComputeType::Float16,
90 bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
91 }
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct QuantizedTensor {
98 pub data: Vec<u8>,
99 pub scale: Vec<f32>,
100 pub zero_point: Vec<i32>,
101 pub shape: Vec<usize>,
102 pub scheme: QuantizationScheme,
103 pub per_channel: bool,
104}
105
106impl QuantizedTensor {
107 pub fn new(
109 data: Vec<u8>,
110 scale: Vec<f32>,
111 zero_point: Vec<i32>,
112 shape: Vec<usize>,
113 scheme: QuantizationScheme,
114 per_channel: bool,
115 ) -> Self {
116 Self {
117 data,
118 scale,
119 zero_point,
120 shape,
121 scheme,
122 per_channel,
123 }
124 }
125
126 pub fn dequantize(&self) -> Result<Tensor> {
128 let total_elements: usize = self.shape.iter().product();
129 let mut result = Vec::with_capacity(total_elements);
130
131 match self.scheme {
132 QuantizationScheme::Int8 | QuantizationScheme::BnB8bit => {
133 if self.per_channel {
134 self.dequantize_per_channel_int8(&mut result)?;
135 } else {
136 self.dequantize_per_tensor_int8(&mut result)?;
137 }
138 },
139 QuantizationScheme::Int4 => {
140 if self.per_channel {
141 self.dequantize_per_channel_int4(&mut result)?;
142 } else {
143 self.dequantize_per_tensor_int4(&mut result)?;
144 }
145 },
146 QuantizationScheme::Dynamic | QuantizationScheme::DynamicINT8 => {
147 if self.per_channel {
149 self.dequantize_per_channel_int8(&mut result)?;
150 } else {
151 self.dequantize_per_tensor_int8(&mut result)?;
152 }
153 },
154 QuantizationScheme::GPTQ => {
155 if self.per_channel {
158 self.dequantize_gptq_per_channel(&mut result)?;
159 } else {
160 self.dequantize_gptq_per_tensor(&mut result)?;
161 }
162 },
163 QuantizationScheme::AWQ => {
164 if self.per_channel {
167 self.dequantize_awq_per_channel(&mut result)?;
168 } else {
169 self.dequantize_awq_per_tensor(&mut result)?;
170 }
171 },
172 QuantizationScheme::BnB4bit => {
173 if self.per_channel {
175 self.dequantize_bnb_4bit_per_channel(&mut result)?;
176 } else {
177 self.dequantize_bnb_4bit_per_tensor(&mut result)?;
178 }
179 },
180 QuantizationScheme::BnB4bitFP4 => {
181 if self.per_channel {
183 self.dequantize_bnb_fp4_per_channel(&mut result)?;
184 } else {
185 self.dequantize_bnb_fp4_per_tensor(&mut result)?;
186 }
187 },
188 }
189
190 Tensor::from_vec(result, &self.shape)
191 }
192
193 fn dequantize_per_tensor_int8(&self, result: &mut Vec<f32>) -> Result<()> {
194 if self.scale.len() != 1 || self.zero_point.len() != 1 {
195 return Err(TrustformersError::quantization_error(
196 "Per-tensor quantization requires single scale and zero point".into(),
197 ));
198 }
199
200 let scale = self.scale[0];
201 let zero_point = self.zero_point[0];
202
203 for &quantized_val in &self.data {
204 let int_val = quantized_val as i32 - zero_point;
205 let float_val = int_val as f32 * scale;
206 result.push(float_val);
207 }
208
209 Ok(())
210 }
211
212 fn dequantize_per_channel_int8(&self, result: &mut Vec<f32>) -> Result<()> {
213 let channels = self.scale.len();
214 let elements_per_channel = self.data.len() / channels;
215
216 for (channel_idx, (&scale, &zero_point)) in
217 self.scale.iter().zip(&self.zero_point).enumerate()
218 {
219 let start_idx = channel_idx * elements_per_channel;
220 let end_idx = start_idx + elements_per_channel;
221
222 for &quantized_val in &self.data[start_idx..end_idx] {
223 let int_val = quantized_val as i32 - zero_point;
224 let float_val = int_val as f32 * scale;
225 result.push(float_val);
226 }
227 }
228
229 Ok(())
230 }
231
232 fn dequantize_per_tensor_int4(&self, result: &mut Vec<f32>) -> Result<()> {
233 if self.scale.len() != 1 || self.zero_point.len() != 1 {
234 return Err(TrustformersError::quantization_error(
235 "Per-tensor quantization requires single scale and zero point".into(),
236 ));
237 }
238
239 let scale = self.scale[0];
240 let zero_point = self.zero_point[0];
241
242 for &byte in &self.data {
244 let high_nibble = (byte >> 4) as i32 - zero_point;
246 let high_val = high_nibble as f32 * scale;
247 result.push(high_val);
248
249 let low_nibble = (byte & 0x0F) as i32 - zero_point;
251 let low_val = low_nibble as f32 * scale;
252 result.push(low_val);
253 }
254
255 Ok(())
256 }
257
258 fn dequantize_per_channel_int4(&self, result: &mut Vec<f32>) -> Result<()> {
259 let channels = self.scale.len();
260 let bytes_per_channel = self.data.len() / channels;
261
262 for (channel_idx, (&scale, &zero_point)) in
263 self.scale.iter().zip(&self.zero_point).enumerate()
264 {
265 let start_idx = channel_idx * bytes_per_channel;
266 let end_idx = start_idx + bytes_per_channel;
267
268 for &byte in &self.data[start_idx..end_idx] {
269 let high_nibble = (byte >> 4) as i32 - zero_point;
271 let high_val = high_nibble as f32 * scale;
272 result.push(high_val);
273
274 let low_nibble = (byte & 0x0F) as i32 - zero_point;
276 let low_val = low_nibble as f32 * scale;
277 result.push(low_val);
278 }
279 }
280
281 Ok(())
282 }
283
284 fn dequantize_gptq_per_tensor(&self, result: &mut Vec<f32>) -> Result<()> {
286 if self.scale.len() != 1 || self.zero_point.len() != 1 {
287 return Err(TrustformersError::quantization_error(
288 "GPTQ per-tensor quantization requires single scale and zero point".into(),
289 ));
290 }
291
292 let scale = self.scale[0];
293 let zero_point = self.zero_point[0];
294
295 for &quantized_val in &self.data {
298 let int_val = quantized_val as i32 - zero_point;
299 let float_val = int_val as f32 * scale;
300 result.push(float_val);
301 }
302
303 Ok(())
304 }
305
306 fn dequantize_gptq_per_channel(&self, result: &mut Vec<f32>) -> Result<()> {
308 let channels = self.scale.len();
309 let elements_per_channel = self.data.len() / channels;
310
311 for (channel_idx, (&scale, &zero_point)) in
312 self.scale.iter().zip(&self.zero_point).enumerate()
313 {
314 let start_idx = channel_idx * elements_per_channel;
315 let end_idx = start_idx + elements_per_channel;
316
317 for &quantized_val in &self.data[start_idx..end_idx] {
318 let int_val = quantized_val as i32 - zero_point;
319 let float_val = int_val as f32 * scale;
320 result.push(float_val);
321 }
322 }
323
324 Ok(())
325 }
326
327 fn dequantize_awq_per_tensor(&self, result: &mut Vec<f32>) -> Result<()> {
329 if self.scale.len() != 1 || self.zero_point.len() != 1 {
330 return Err(TrustformersError::quantization_error(
331 "AWQ per-tensor quantization requires single scale and zero point".into(),
332 ));
333 }
334
335 let scale = self.scale[0];
336 let zero_point = self.zero_point[0];
337
338 for &quantized_val in &self.data {
341 let int_val = quantized_val as i32 - zero_point;
342 let float_val = int_val as f32 * scale;
343 result.push(float_val);
344 }
345
346 Ok(())
347 }
348
349 fn dequantize_awq_per_channel(&self, result: &mut Vec<f32>) -> Result<()> {
351 let channels = self.scale.len();
352 let elements_per_channel = self.data.len() / channels;
353
354 for (channel_idx, (&scale, &zero_point)) in
355 self.scale.iter().zip(&self.zero_point).enumerate()
356 {
357 let start_idx = channel_idx * elements_per_channel;
358 let end_idx = start_idx + elements_per_channel;
359
360 for &quantized_val in &self.data[start_idx..end_idx] {
361 let int_val = quantized_val as i32 - zero_point;
362 let float_val = int_val as f32 * scale;
363 result.push(float_val);
364 }
365 }
366
367 Ok(())
368 }
369
370 fn dequantize_bnb_4bit_per_tensor(&self, result: &mut Vec<f32>) -> Result<()> {
372 const NF4_LEVELS: [f32; 16] = [
375 -1.0,
376 -0.6961928009986877,
377 -0.5250730514526367,
378 -0.39491748809814453,
379 -0.28444138169288635,
380 -0.18477343022823334,
381 -0.09105003625154495,
382 0.0,
383 0.07958029955625534,
384 0.16093020141124725,
385 0.24611230194568634,
386 0.33791524171829224,
387 0.44070982933044434,
388 0.5626170039176941,
389 0.7229568362236023,
390 1.0,
391 ];
392
393 if self.scale.len() != 1 {
394 return Err(TrustformersError::quantization_error(
395 "BnB 4-bit per-tensor quantization requires single scale".into(),
396 ));
397 }
398
399 let scale = self.scale[0];
400
401 for &byte in &self.data {
402 let high_nibble = (byte >> 4) & 0x0F;
404 let high_val = NF4_LEVELS[high_nibble as usize] * scale;
405 result.push(high_val);
406
407 let low_nibble = byte & 0x0F;
409 let low_val = NF4_LEVELS[low_nibble as usize] * scale;
410 result.push(low_val);
411 }
412
413 Ok(())
414 }
415
416 fn dequantize_bnb_4bit_per_channel(&self, result: &mut Vec<f32>) -> Result<()> {
418 const NF4_LEVELS: [f32; 16] = [
419 -1.0,
420 -0.6961928009986877,
421 -0.5250730514526367,
422 -0.39491748809814453,
423 -0.28444138169288635,
424 -0.18477343022823334,
425 -0.09105003625154495,
426 0.0,
427 0.07958029955625534,
428 0.16093020141124725,
429 0.24611230194568634,
430 0.33791524171829224,
431 0.44070982933044434,
432 0.5626170039176941,
433 0.7229568362236023,
434 1.0,
435 ];
436
437 let channels = self.scale.len();
438 let bytes_per_channel = self.data.len() / channels;
439
440 for (channel_idx, &scale) in self.scale.iter().enumerate() {
441 let start_idx = channel_idx * bytes_per_channel;
442 let end_idx = start_idx + bytes_per_channel;
443
444 for &byte in &self.data[start_idx..end_idx] {
445 let high_nibble = (byte >> 4) & 0x0F;
447 let high_val = NF4_LEVELS[high_nibble as usize] * scale;
448 result.push(high_val);
449
450 let low_nibble = byte & 0x0F;
452 let low_val = NF4_LEVELS[low_nibble as usize] * scale;
453 result.push(low_val);
454 }
455 }
456
457 Ok(())
458 }
459
460 fn dequantize_bnb_fp4_per_tensor(&self, result: &mut Vec<f32>) -> Result<()> {
462 const FP4_LEVELS: [f32; 16] = [
464 -12.0, -8.0, -6.0, -4.0, -3.0, -2.0, -1.5, -1.0, 0.0, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 8.0,
465 ];
466
467 if self.scale.len() != 1 {
468 return Err(TrustformersError::quantization_error(
469 "BnB FP4 per-tensor quantization requires single scale".into(),
470 ));
471 }
472
473 let scale = self.scale[0];
474
475 for &byte in &self.data {
476 let high_nibble = (byte >> 4) & 0x0F;
478 let high_val = FP4_LEVELS[high_nibble as usize] * scale;
479 result.push(high_val);
480
481 let low_nibble = byte & 0x0F;
483 let low_val = FP4_LEVELS[low_nibble as usize] * scale;
484 result.push(low_val);
485 }
486
487 Ok(())
488 }
489
490 fn dequantize_bnb_fp4_per_channel(&self, result: &mut Vec<f32>) -> Result<()> {
492 const FP4_LEVELS: [f32; 16] = [
493 -12.0, -8.0, -6.0, -4.0, -3.0, -2.0, -1.5, -1.0, 0.0, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 8.0,
494 ];
495
496 let channels = self.scale.len();
497 let bytes_per_channel = self.data.len() / channels;
498
499 for (channel_idx, &scale) in self.scale.iter().enumerate() {
500 let start_idx = channel_idx * bytes_per_channel;
501 let end_idx = start_idx + bytes_per_channel;
502
503 for &byte in &self.data[start_idx..end_idx] {
504 let high_nibble = (byte >> 4) & 0x0F;
506 let high_val = FP4_LEVELS[high_nibble as usize] * scale;
507 result.push(high_val);
508
509 let low_nibble = byte & 0x0F;
511 let low_val = FP4_LEVELS[low_nibble as usize] * scale;
512 result.push(low_val);
513 }
514 }
515
516 Ok(())
517 }
518}
519
520pub struct Quantizer;
522
523impl Quantizer {
524 pub fn quantize(tensor: &Tensor, config: &QuantizationConfig) -> Result<QuantizedTensor> {
526 match config.scheme {
527 QuantizationScheme::Int8 => {
528 if config.per_channel {
529 Self::quantize_per_channel_int8(tensor, config.symmetric)
530 } else {
531 Self::quantize_per_tensor_int8(tensor, config.symmetric)
532 }
533 },
534 QuantizationScheme::Int4 => {
535 if config.per_channel {
536 Self::quantize_per_channel_int4(tensor, config.symmetric, config.group_size)
537 } else {
538 Self::quantize_per_tensor_int4(tensor, config.symmetric)
539 }
540 },
541 QuantizationScheme::Dynamic => Self::dynamic_quantize(tensor),
542 QuantizationScheme::DynamicINT8 => {
543 Self::dynamic_quantize(tensor)
545 },
546 QuantizationScheme::GPTQ => {
547 if config.per_channel {
551 Self::quantize_per_channel_int4(tensor, true, config.group_size)
552 } else {
553 Self::quantize_per_tensor_int4(tensor, true)
554 }
555 },
556 QuantizationScheme::AWQ => {
557 if config.per_channel {
561 Self::quantize_per_channel_int4(tensor, true, config.group_size)
562 } else {
563 Self::quantize_per_tensor_int4(tensor, true)
564 }
565 },
566 QuantizationScheme::BnB8bit => {
567 let bnb_config = config.bnb_config.clone().unwrap_or(BnBConfig {
569 use_double_quant: false,
570 quant_type: BnBQuantType::Int8,
571 compute_dtype: BnBComputeType::Float16,
572 bnb_4bit_quant_storage: None,
573 });
574 let quantizer = BnBQuantizer::new(bnb_config);
575 quantizer.quantize_bnb_int8(tensor)
576 },
577 QuantizationScheme::BnB4bit => {
578 let bnb_config = config.bnb_config.clone().unwrap_or(BnBConfig {
580 use_double_quant: false,
581 quant_type: BnBQuantType::NF4,
582 compute_dtype: BnBComputeType::Float16,
583 bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
584 });
585 let quantizer = BnBQuantizer::new(bnb_config);
586 quantizer.quantize_nf4(tensor)
587 },
588 QuantizationScheme::BnB4bitFP4 => {
589 let bnb_config = config.bnb_config.clone().unwrap_or(BnBConfig {
591 use_double_quant: false,
592 quant_type: BnBQuantType::FP4,
593 compute_dtype: BnBComputeType::Float16,
594 bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
595 });
596 let quantizer = BnBQuantizer::new(bnb_config);
597 quantizer.quantize_fp4(tensor)
598 },
599 }
600 }
601
602 fn quantize_per_tensor_int8(tensor: &Tensor, symmetric: bool) -> Result<QuantizedTensor> {
604 match tensor {
605 Tensor::F32(arr) => {
606 let data = arr.iter().cloned().collect::<Vec<f32>>();
607 let (scale, zero_point) = Self::compute_quantization_params(&data, symmetric, 8)?;
608
609 let quantized_data: Vec<u8> = data
610 .iter()
611 .map(|&val| Self::quantize_value_int8(val, scale, zero_point))
612 .collect();
613
614 Ok(QuantizedTensor::new(
615 quantized_data,
616 vec![scale],
617 vec![zero_point],
618 arr.shape().to_vec(),
619 QuantizationScheme::Int8,
620 false,
621 ))
622 },
623 _ => Err(TrustformersError::quantization_error(
624 "Unsupported tensor type for quantization".into(),
625 )),
626 }
627 }
628
629 fn quantize_per_channel_int8(tensor: &Tensor, symmetric: bool) -> Result<QuantizedTensor> {
631 match tensor {
632 Tensor::F32(arr) => {
633 let shape = arr.shape();
634 let channels = shape[0]; let elements_per_channel = arr.len() / channels;
636
637 let mut scales = Vec::with_capacity(channels);
638 let mut zero_points = Vec::with_capacity(channels);
639 let mut quantized_data = Vec::with_capacity(arr.len());
640
641 for channel in 0..channels {
642 let start_idx = channel * elements_per_channel;
643 let end_idx = start_idx + elements_per_channel;
644 let channel_data = arr
645 .iter()
646 .skip(start_idx)
647 .take(elements_per_channel)
648 .cloned()
649 .collect::<Vec<f32>>();
650
651 let (scale, zero_point) =
652 Self::compute_quantization_params(&channel_data, symmetric, 8)?;
653 scales.push(scale);
654 zero_points.push(zero_point);
655
656 let channel_quantized: Vec<u8> = channel_data
657 .iter()
658 .map(|&val| Self::quantize_value_int8(val, scale, zero_point))
659 .collect();
660
661 quantized_data.extend(channel_quantized);
662 }
663
664 Ok(QuantizedTensor::new(
665 quantized_data,
666 scales,
667 zero_points,
668 shape.to_vec(),
669 QuantizationScheme::Int8,
670 true,
671 ))
672 },
673 _ => Err(TrustformersError::quantization_error(
674 "Unsupported tensor type for quantization".into(),
675 )),
676 }
677 }
678
679 fn quantize_per_tensor_int4(tensor: &Tensor, symmetric: bool) -> Result<QuantizedTensor> {
681 match tensor {
682 Tensor::F32(arr) => {
683 let data = arr.iter().cloned().collect::<Vec<f32>>();
684 let (scale, zero_point) = Self::compute_quantization_params(&data, symmetric, 4)?;
685
686 let quantized_data = Self::pack_int4_values(&data, scale, zero_point)?;
687
688 Ok(QuantizedTensor::new(
689 quantized_data,
690 vec![scale],
691 vec![zero_point],
692 arr.shape().to_vec(),
693 QuantizationScheme::Int4,
694 false,
695 ))
696 },
697 _ => Err(TrustformersError::quantization_error(
698 "Unsupported tensor type for quantization".into(),
699 )),
700 }
701 }
702
703 fn quantize_per_channel_int4(
705 tensor: &Tensor,
706 symmetric: bool,
707 group_size: Option<usize>,
708 ) -> Result<QuantizedTensor> {
709 match tensor {
710 Tensor::F32(arr) => {
711 let shape = arr.shape();
712 let total_elements = arr.len();
713 let group_size = group_size.unwrap_or(128);
714 let num_groups = total_elements.div_ceil(group_size);
715
716 let mut scales = Vec::with_capacity(num_groups);
717 let mut zero_points = Vec::with_capacity(num_groups);
718 let mut quantized_data = Vec::with_capacity(total_elements / 2); for group_idx in 0..num_groups {
721 let start_idx = group_idx * group_size;
722 let end_idx = (start_idx + group_size).min(total_elements);
723
724 let group_data = arr
725 .iter()
726 .skip(start_idx)
727 .take(end_idx - start_idx)
728 .cloned()
729 .collect::<Vec<f32>>();
730
731 let (scale, zero_point) =
732 Self::compute_quantization_params(&group_data, symmetric, 4)?;
733 scales.push(scale);
734 zero_points.push(zero_point);
735
736 let group_quantized = Self::pack_int4_values(&group_data, scale, zero_point)?;
737 quantized_data.extend(group_quantized);
738 }
739
740 Ok(QuantizedTensor::new(
741 quantized_data,
742 scales,
743 zero_points,
744 shape.to_vec(),
745 QuantizationScheme::Int4,
746 true,
747 ))
748 },
749 _ => Err(TrustformersError::quantization_error(
750 "Unsupported tensor type for quantization".into(),
751 )),
752 }
753 }
754
755 fn dynamic_quantize(tensor: &Tensor) -> Result<QuantizedTensor> {
757 Self::quantize_per_tensor_int8(tensor, false)
759 }
760
761 fn compute_quantization_params(data: &[f32], symmetric: bool, bits: u8) -> Result<(f32, i32)> {
763 if data.is_empty() {
764 return Err(TrustformersError::quantization_error(
765 "Cannot quantize empty data".into(),
766 ));
767 }
768
769 let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
770 let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
771
772 let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 };
773 let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 };
774
775 let (scale, zero_point) = if symmetric {
776 let abs_max = max_val.abs().max(min_val.abs());
777 let scale = abs_max / (q_max - q_min) as f32;
778 (scale, 0)
779 } else {
780 let scale = (max_val - min_val) / (q_max - q_min) as f32;
781 let zero_point = q_min - (min_val / scale).round() as i32;
782 let zero_point = zero_point.clamp(q_min, q_max);
783 (scale, zero_point)
784 };
785
786 Ok((scale, zero_point))
787 }
788
789 fn quantize_value_int8(value: f32, scale: f32, zero_point: i32) -> u8 {
791 let quantized = (value / scale).round() as i32 + zero_point;
792 quantized.clamp(0, 255) as u8
793 }
794
795 fn pack_int4_values(data: &[f32], scale: f32, zero_point: i32) -> Result<Vec<u8>> {
797 let mut packed = Vec::with_capacity(data.len().div_ceil(2));
798
799 for chunk in data.chunks(2) {
800 let val1 = Self::quantize_value_int4(chunk[0], scale, zero_point);
801 let val2 = if chunk.len() > 1 {
802 Self::quantize_value_int4(chunk[1], scale, zero_point)
803 } else {
804 0 };
806
807 let packed_byte = (val1 << 4) | val2;
809 packed.push(packed_byte);
810 }
811
812 Ok(packed)
813 }
814
815 fn quantize_value_int4(value: f32, scale: f32, zero_point: i32) -> u8 {
817 let quantized = (value / scale).round() as i32 + zero_point;
818 quantized.clamp(0, 15) as u8
819 }
820
821 pub fn calibrate(
823 samples: &[Tensor],
824 config: &QuantizationConfig,
825 ) -> Result<QuantizationConfig> {
826 let mut calibrated_config = config.clone();
829
830 if let Some(sample_count) = config.calibration_samples {
831 let num_samples = samples.len().min(sample_count);
832
833 let mut all_values = Vec::new();
835 for sample in samples.iter().take(num_samples) {
836 match sample {
837 Tensor::F32(arr) => {
838 all_values.extend(arr.iter().cloned());
839 },
840 _ => continue,
841 }
842 }
843
844 if !all_values.is_empty() {
845 let abs_max = all_values.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
847
848 let min_val = all_values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
850 let max_val = all_values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
851
852 calibrated_config.symmetric =
853 (min_val.abs() - max_val.abs()).abs() / max_val.abs() < 0.1;
854 }
855 }
856
857 Ok(calibrated_config)
858 }
859}
860
861pub struct GPTQQuantizer {
863 config: QuantizationConfig,
864}
865
866impl GPTQQuantizer {
867 pub fn new(config: QuantizationConfig) -> Self {
868 Self { config }
869 }
870
871 pub fn quantize(&self, tensor: &Tensor, hessian: Option<&Tensor>) -> Result<QuantizedTensor> {
874 Quantizer::quantize(tensor, &self.config)
877 }
878}
879
880pub struct AWQQuantizer {
882 config: QuantizationConfig,
883 activation_scales: Option<Vec<f32>>,
884}
885
886impl AWQQuantizer {
887 pub fn new(config: QuantizationConfig) -> Self {
888 Self {
889 config,
890 activation_scales: None,
891 }
892 }
893
894 pub fn set_activation_scales(&mut self, scales: Vec<f32>) {
896 self.activation_scales = Some(scales);
897 }
898
899 pub fn quantize(&self, tensor: &Tensor) -> Result<QuantizedTensor> {
901 Quantizer::quantize(tensor, &self.config)
904 }
905}
906
907pub struct BnBQuantizer {
909 config: BnBConfig,
910}
911
912impl BnBQuantizer {
913 pub fn new(config: BnBConfig) -> Self {
914 Self { config }
915 }
916
917 pub fn quantize(&self, tensor: &Tensor) -> Result<QuantizedTensor> {
919 match self.config.quant_type {
920 BnBQuantType::NF4 => self.quantize_nf4(tensor),
921 BnBQuantType::FP4 => self.quantize_fp4(tensor),
922 BnBQuantType::Int8 => self.quantize_bnb_int8(tensor),
923 }
924 }
925
926 fn quantize_nf4(&self, tensor: &Tensor) -> Result<QuantizedTensor> {
928 match tensor {
929 Tensor::F32(arr) => {
930 let data = arr.iter().cloned().collect::<Vec<f32>>();
931 let block_size = 64; let mut quantized_data = Vec::new();
934 let mut scales = Vec::new();
935 let mut zero_points = Vec::new();
936
937 for chunk in data.chunks(block_size) {
938 let (block_scale, block_quantized) = self.nf4_quantize_block(chunk)?;
939 scales.push(block_scale);
940 zero_points.push(0); quantized_data.extend(block_quantized);
942 }
943
944 Ok(QuantizedTensor::new(
945 quantized_data,
946 scales,
947 zero_points,
948 arr.shape().to_vec(),
949 QuantizationScheme::BnB4bit,
950 false,
951 ))
952 },
953 _ => Err(TrustformersError::quantization_error(
954 "Unsupported tensor type for BnB NF4".into(),
955 )),
956 }
957 }
958
959 fn quantize_fp4(&self, tensor: &Tensor) -> Result<QuantizedTensor> {
961 match tensor {
962 Tensor::F32(arr) => {
963 let data = arr.iter().cloned().collect::<Vec<f32>>();
964 let block_size = 64;
965
966 let mut quantized_data = Vec::new();
967 let mut scales = Vec::new();
968 let mut zero_points = Vec::new();
969
970 for chunk in data.chunks(block_size) {
971 let (block_scale, block_quantized) = self.fp4_quantize_block(chunk)?;
972 scales.push(block_scale);
973 zero_points.push(0); quantized_data.extend(block_quantized);
975 }
976
977 Ok(QuantizedTensor::new(
978 quantized_data,
979 scales,
980 zero_points,
981 arr.shape().to_vec(),
982 QuantizationScheme::BnB4bitFP4,
983 false,
984 ))
985 },
986 _ => Err(TrustformersError::quantization_error(
987 "Unsupported tensor type for BnB FP4".into(),
988 )),
989 }
990 }
991
992 fn quantize_bnb_int8(&self, tensor: &Tensor) -> Result<QuantizedTensor> {
994 match tensor {
995 Tensor::F32(arr) => {
996 let data = arr.iter().cloned().collect::<Vec<f32>>();
997 let (scale, zero_point) = Quantizer::compute_quantization_params(&data, false, 8)?;
998
999 let quantized_data: Vec<u8> = data
1000 .iter()
1001 .map(|&val| Quantizer::quantize_value_int8(val, scale, zero_point))
1002 .collect();
1003
1004 Ok(QuantizedTensor::new(
1005 quantized_data,
1006 vec![scale],
1007 vec![zero_point],
1008 arr.shape().to_vec(),
1009 QuantizationScheme::BnB8bit,
1010 false,
1011 ))
1012 },
1013 _ => Err(TrustformersError::quantization_error(
1014 "Unsupported tensor type for BnB Int8".into(),
1015 )),
1016 }
1017 }
1018
1019 fn nf4_quantize_block(&self, data: &[f32]) -> Result<(f32, Vec<u8>)> {
1021 const NF4_LEVELS: [f32; 16] = [
1023 -1.0,
1024 -0.6961928009986877,
1025 -0.5250730514526367,
1026 -0.39491748809814453,
1027 -0.28444138169288635,
1028 -0.18477343022823334,
1029 -0.09105003625154495,
1030 0.0,
1031 0.07958029955625534,
1032 0.16093020141124725,
1033 0.24611230194568634,
1034 0.33791524171829224,
1035 0.44070982933044434,
1036 0.5626170039176941,
1037 0.7229568362236023,
1038 1.0,
1039 ];
1040
1041 if data.is_empty() {
1042 return Err(TrustformersError::quantization_error(
1043 "Cannot quantize empty block".into(),
1044 ));
1045 }
1046
1047 let abs_max = data.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
1049 let scale = abs_max;
1050
1051 if scale == 0.0 {
1052 return Ok((scale, vec![0; data.len()]));
1053 }
1054
1055 let mut quantized = Vec::with_capacity(data.len());
1057 for &val in data {
1058 let normalized = val / scale;
1059 let mut best_idx = 0;
1060 let mut best_dist = (normalized - NF4_LEVELS[0]).abs();
1061
1062 for (idx, &level) in NF4_LEVELS.iter().enumerate().skip(1) {
1063 let dist = (normalized - level).abs();
1064 if dist < best_dist {
1065 best_dist = dist;
1066 best_idx = idx;
1067 }
1068 }
1069
1070 quantized.push(best_idx as u8);
1071 }
1072
1073 Ok((scale, quantized))
1074 }
1075
1076 fn fp4_quantize_block(&self, data: &[f32]) -> Result<(f32, Vec<u8>)> {
1078 const FP4_LEVELS: [f32; 16] = [
1080 0.0, 0.0625, 0.125, 0.1875, 0.25, 0.3125, 0.375, 0.4375, 0.5, 0.625, 0.75, 0.875, 1.0,
1081 1.25, 1.5, 2.0,
1082 ];
1083
1084 if data.is_empty() {
1085 return Err(TrustformersError::quantization_error(
1086 "Cannot quantize empty block".into(),
1087 ));
1088 }
1089
1090 let abs_max = data.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
1092 let scale = abs_max / 2.0; if scale == 0.0 {
1095 return Ok((scale, vec![0; data.len()]));
1096 }
1097
1098 let mut quantized = Vec::with_capacity(data.len());
1100 for &val in data {
1101 let abs_val = val.abs() / scale;
1102 let sign = if val >= 0.0 { 0 } else { 8 }; let mut best_idx = 0;
1105 let mut best_dist = (abs_val - FP4_LEVELS[0]).abs();
1106
1107 for (idx, &level) in FP4_LEVELS[..8].iter().enumerate().skip(1) {
1108 let dist = (abs_val - level).abs();
1109 if dist < best_dist {
1110 best_dist = dist;
1111 best_idx = idx;
1112 }
1113 }
1114
1115 quantized.push((sign | best_idx) as u8);
1116 }
1117
1118 Ok((scale, quantized))
1119 }
1120
1121 pub fn dequantize(&self, tensor: &QuantizedTensor) -> Result<Tensor> {
1123 match tensor.scheme {
1124 QuantizationScheme::BnB4bit => self.dequantize_nf4(tensor),
1125 QuantizationScheme::BnB4bitFP4 => self.dequantize_fp4(tensor),
1126 QuantizationScheme::BnB8bit => tensor.dequantize(), _ => Err(TrustformersError::quantization_error(format!(
1128 "BnB dequantization not supported for scheme {:?}",
1129 tensor.scheme
1130 ))),
1131 }
1132 }
1133
1134 fn dequantize_nf4(&self, tensor: &QuantizedTensor) -> Result<Tensor> {
1136 const NF4_LEVELS: [f32; 16] = [
1137 -1.0,
1138 -0.6961928009986877,
1139 -0.5250730514526367,
1140 -0.39491748809814453,
1141 -0.28444138169288635,
1142 -0.18477343022823334,
1143 -0.09105003625154495,
1144 0.0,
1145 0.07958029955625534,
1146 0.16093020141124725,
1147 0.24611230194568634,
1148 0.33791524171829224,
1149 0.44070982933044434,
1150 0.5626170039176941,
1151 0.7229568362236023,
1152 1.0,
1153 ];
1154
1155 let block_size = 64;
1156 let mut result = Vec::with_capacity(tensor.data.len());
1157 let num_blocks = tensor.scale.len();
1158
1159 for block_idx in 0..num_blocks {
1160 let scale = tensor.scale[block_idx];
1161 let start_idx = block_idx * block_size;
1162 let end_idx = (start_idx + block_size).min(tensor.data.len());
1163
1164 for &quantized_val in &tensor.data[start_idx..end_idx] {
1165 let idx = (quantized_val as usize).min(15);
1166 let dequantized = NF4_LEVELS[idx] * scale;
1167 result.push(dequantized);
1168 }
1169 }
1170
1171 Tensor::from_vec(result, &tensor.shape)
1172 }
1173
1174 fn dequantize_fp4(&self, tensor: &QuantizedTensor) -> Result<Tensor> {
1176 const FP4_LEVELS: [f32; 8] = [0.0, 0.0625, 0.125, 0.1875, 0.25, 0.3125, 0.375, 0.4375];
1177
1178 let block_size = 64;
1179 let mut result = Vec::with_capacity(tensor.data.len());
1180 let num_blocks = tensor.scale.len();
1181
1182 for block_idx in 0..num_blocks {
1183 let scale = tensor.scale[block_idx];
1184 let start_idx = block_idx * block_size;
1185 let end_idx = (start_idx + block_size).min(tensor.data.len());
1186
1187 for &quantized_val in &tensor.data[start_idx..end_idx] {
1188 let sign = if (quantized_val & 8) != 0 { -1.0 } else { 1.0 };
1189 let idx = (quantized_val & 7) as usize;
1190 let abs_val = FP4_LEVELS[idx];
1191 let dequantized = sign * abs_val * scale;
1192 result.push(dequantized);
1193 }
1194 }
1195
1196 Tensor::from_vec(result, &tensor.shape)
1197 }
1198}
1199
1200pub struct QATConfig {
1202 pub fake_quantize: bool,
1203 pub observe: bool,
1204 pub reduce_range: bool,
1205 pub qscheme: QuantizationScheme,
1206}
1207
1208impl Default for QATConfig {
1209 fn default() -> Self {
1210 Self {
1211 fake_quantize: true,
1212 observe: true,
1213 reduce_range: false,
1214 qscheme: QuantizationScheme::Int8,
1215 }
1216 }
1217}
1218
1219pub struct FakeQuantize {
1221 config: QATConfig,
1222 observers: Vec<Observer>,
1223}
1224
1225pub struct Observer {
1227 min_val: f32,
1228 max_val: f32,
1229 count: usize,
1230}
1231
1232impl Default for Observer {
1233 fn default() -> Self {
1234 Self::new()
1235 }
1236}
1237
1238impl Observer {
1239 pub fn new() -> Self {
1240 Self {
1241 min_val: f32::INFINITY,
1242 max_val: f32::NEG_INFINITY,
1243 count: 0,
1244 }
1245 }
1246
1247 pub fn update(&mut self, tensor: &Tensor) {
1248 if let Tensor::F32(arr) = tensor {
1249 for &val in arr.iter() {
1250 self.min_val = self.min_val.min(val);
1251 self.max_val = self.max_val.max(val);
1252 self.count += 1;
1253 }
1254 }
1255 }
1256
1257 pub fn get_quantization_params(&self, symmetric: bool, bits: u8) -> Result<(f32, i32)> {
1258 if self.count == 0 {
1259 return Err(TrustformersError::quantization_error(
1260 "No observations for quantization".into(),
1261 ));
1262 }
1263
1264 let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 };
1265 let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 };
1266
1267 let (scale, zero_point) = if symmetric {
1268 let abs_max = self.max_val.abs().max(self.min_val.abs());
1269 let scale = abs_max / (q_max - q_min) as f32;
1270 (scale, 0)
1271 } else {
1272 let scale = (self.max_val - self.min_val) / (q_max - q_min) as f32;
1273 let zero_point = q_min - (self.min_val / scale).round() as i32;
1274 let zero_point = zero_point.clamp(q_min, q_max);
1275 (scale, zero_point)
1276 };
1277
1278 Ok((scale, zero_point))
1279 }
1280}
1281
1282impl FakeQuantize {
1283 pub fn new(config: QATConfig) -> Self {
1284 Self {
1285 config,
1286 observers: Vec::new(),
1287 }
1288 }
1289
1290 pub fn forward(&mut self, tensor: &Tensor) -> Result<Tensor> {
1292 if self.config.observe {
1293 if self.observers.is_empty() {
1295 self.observers.push(Observer::new());
1296 }
1297 self.observers[0].update(tensor);
1298 }
1299
1300 if self.config.fake_quantize && !self.observers.is_empty() {
1301 let observer = &self.observers[0];
1303 let (scale, zero_point) = observer.get_quantization_params(true, 8)?;
1304
1305 match tensor {
1307 Tensor::F32(arr) => {
1308 let quantized_data: Vec<f32> = arr
1309 .iter()
1310 .map(|&val| {
1311 let q_val = Quantizer::quantize_value_int8(val, scale, zero_point);
1312 let int_val = q_val as i32 - zero_point;
1313 int_val as f32 * scale
1314 })
1315 .collect();
1316
1317 Tensor::from_vec(quantized_data, arr.shape())
1318 },
1319 _ => Ok(tensor.clone()),
1320 }
1321 } else {
1322 Ok(tensor.clone())
1323 }
1324 }
1325}
1326
1327#[cfg(test)]
1328mod tests {
1329 use super::*;
1330
1331 #[test]
1332 fn test_int8_per_tensor_quantization() -> Result<()> {
1333 let tensor = Tensor::randn(&[10, 20])?;
1334 let config = QuantizationConfig {
1335 scheme: QuantizationScheme::Int8,
1336 symmetric: true,
1337 per_channel: false,
1338 calibration_samples: None,
1339 group_size: None,
1340 bnb_config: None,
1341 };
1342
1343 let quantized = Quantizer::quantize(&tensor, &config)?;
1344 assert_eq!(quantized.scheme, QuantizationScheme::Int8);
1345 assert!(!quantized.per_channel);
1346 assert_eq!(quantized.scale.len(), 1);
1347 assert_eq!(quantized.zero_point.len(), 1);
1348
1349 let dequantized = quantized.dequantize()?;
1350 assert_eq!(dequantized.shape(), tensor.shape());
1351 Ok(())
1352 }
1353
1354 #[test]
1355 fn test_int4_per_tensor_quantization() -> Result<()> {
1356 let tensor = Tensor::randn(&[8, 16])?;
1357 let config = QuantizationConfig {
1358 scheme: QuantizationScheme::Int4,
1359 symmetric: false,
1360 per_channel: false,
1361 calibration_samples: None,
1362 group_size: None,
1363 bnb_config: None,
1364 };
1365
1366 let quantized = Quantizer::quantize(&tensor, &config)?;
1367 assert_eq!(quantized.scheme, QuantizationScheme::Int4);
1368 assert!(!quantized.per_channel);
1369
1370 let dequantized = quantized.dequantize()?;
1371 assert_eq!(dequantized.shape(), tensor.shape());
1372 Ok(())
1373 }
1374
1375 #[test]
1376 fn test_per_channel_quantization() -> Result<()> {
1377 let tensor = Tensor::randn(&[4, 32])?;
1378 let config = QuantizationConfig {
1379 scheme: QuantizationScheme::Int8,
1380 symmetric: true,
1381 per_channel: true,
1382 calibration_samples: None,
1383 group_size: None,
1384 bnb_config: None,
1385 };
1386
1387 let quantized = Quantizer::quantize(&tensor, &config)?;
1388 assert!(quantized.per_channel);
1389 assert_eq!(quantized.scale.len(), 4); assert_eq!(quantized.zero_point.len(), 4);
1391
1392 let dequantized = quantized.dequantize()?;
1393 assert_eq!(dequantized.shape(), tensor.shape());
1394 Ok(())
1395 }
1396
1397 #[test]
1398 fn test_dynamic_quantization() -> Result<()> {
1399 let tensor = Tensor::randn(&[16, 32])?;
1400 let config = QuantizationConfig {
1401 scheme: QuantizationScheme::Dynamic,
1402 symmetric: false,
1403 per_channel: false,
1404 calibration_samples: None,
1405 group_size: None,
1406 bnb_config: None,
1407 };
1408
1409 let quantized = Quantizer::quantize(&tensor, &config)?;
1410 let dequantized = quantized.dequantize()?;
1411 assert_eq!(dequantized.shape(), tensor.shape());
1412 Ok(())
1413 }
1414
1415 #[test]
1416 fn test_quantization_params_computation() -> Result<()> {
1417 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1418
1419 let (scale, zero_point) = Quantizer::compute_quantization_params(&data, true, 8)?;
1421 assert_eq!(zero_point, 0);
1422 assert!(scale > 0.0);
1423
1424 let (scale, zero_point) = Quantizer::compute_quantization_params(&data, false, 8)?;
1426 assert!(scale > 0.0);
1427 Ok(())
1428 }
1429
1430 #[test]
1431 fn test_gptq_quantizer() -> Result<()> {
1432 let tensor = Tensor::randn(&[16, 32])?;
1433 let config = QuantizationConfig::default();
1434 let gptq = GPTQQuantizer::new(config);
1435
1436 let quantized = gptq.quantize(&tensor, None)?;
1437 let dequantized = quantized.dequantize()?;
1438 assert_eq!(dequantized.shape(), tensor.shape());
1439 Ok(())
1440 }
1441
1442 #[test]
1443 fn test_awq_quantizer() -> Result<()> {
1444 let tensor = Tensor::randn(&[16, 32])?;
1445 let config = QuantizationConfig::default();
1446 let mut awq = AWQQuantizer::new(config);
1447
1448 let scales = vec![1.0; 16];
1449 awq.set_activation_scales(scales);
1450
1451 let quantized = awq.quantize(&tensor)?;
1452 let dequantized = quantized.dequantize()?;
1453 assert_eq!(dequantized.shape(), tensor.shape());
1454 Ok(())
1455 }
1456
1457 #[test]
1458 fn test_calibration() -> Result<()> {
1459 let samples = vec![
1460 Tensor::randn(&[16, 32])?,
1461 Tensor::randn(&[16, 32])?,
1462 Tensor::randn(&[16, 32])?,
1463 ];
1464
1465 let config = QuantizationConfig {
1466 calibration_samples: Some(2),
1467 ..Default::default()
1468 };
1469
1470 let calibrated = Quantizer::calibrate(&samples, &config)?;
1471 assert_eq!(calibrated.scheme, config.scheme);
1472 Ok(())
1473 }
1474
1475 #[test]
1476 fn test_bnb_nf4_quantization() -> Result<()> {
1477 let tensor = Tensor::randn(&[128])?;
1478 let config = BnBConfig {
1479 quant_type: BnBQuantType::NF4,
1480 compute_dtype: BnBComputeType::Float16,
1481 use_double_quant: false,
1482 bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
1483 };
1484
1485 let bnb = BnBQuantizer::new(config);
1486 let quantized = bnb.quantize(&tensor)?;
1487 assert_eq!(quantized.scheme, QuantizationScheme::BnB4bit);
1488
1489 let dequantized = bnb.dequantize(&quantized)?;
1490 assert_eq!(dequantized.shape(), tensor.shape());
1491 Ok(())
1492 }
1493
1494 #[test]
1495 fn test_bnb_fp4_quantization() -> Result<()> {
1496 let tensor = Tensor::randn(&[128])?;
1497 let config = BnBConfig {
1498 quant_type: BnBQuantType::FP4,
1499 compute_dtype: BnBComputeType::Float16,
1500 use_double_quant: false,
1501 bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
1502 };
1503
1504 let bnb = BnBQuantizer::new(config);
1505 let quantized = bnb.quantize(&tensor)?;
1506 assert_eq!(quantized.scheme, QuantizationScheme::BnB4bitFP4);
1507
1508 let dequantized = bnb.dequantize(&quantized)?;
1509 assert_eq!(dequantized.shape(), tensor.shape());
1510 Ok(())
1511 }
1512
1513 #[test]
1514 fn test_bnb_int8_quantization() -> Result<()> {
1515 let tensor = Tensor::randn(&[64, 64])?;
1516 let config = BnBConfig {
1517 quant_type: BnBQuantType::Int8,
1518 compute_dtype: BnBComputeType::Float32,
1519 use_double_quant: false,
1520 bnb_4bit_quant_storage: None,
1521 };
1522
1523 let bnb = BnBQuantizer::new(config);
1524 let quantized = bnb.quantize(&tensor)?;
1525 assert_eq!(quantized.scheme, QuantizationScheme::BnB8bit);
1526
1527 let dequantized = quantized.dequantize()?;
1528 assert_eq!(dequantized.shape(), tensor.shape());
1529 Ok(())
1530 }
1531
1532 #[test]
1533 fn test_qat_fake_quantize() -> Result<()> {
1534 let tensor = Tensor::randn(&[32, 32])?;
1535 let config = QATConfig::default();
1536 let mut fake_quant = FakeQuantize::new(config);
1537
1538 let result1 = fake_quant.forward(&tensor)?;
1540 assert_eq!(result1.shape(), tensor.shape());
1541
1542 let result2 = fake_quant.forward(&tensor)?;
1544 assert_eq!(result2.shape(), tensor.shape());
1545 Ok(())
1546 }
1547
1548 #[test]
1549 fn test_observer_statistics() -> Result<()> {
1550 let mut observer = Observer::new();
1551 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5])?;
1552
1553 observer.update(&tensor);
1554 assert_eq!(observer.count, 5);
1555
1556 let (scale, zero_point) = observer.get_quantization_params(true, 8)?;
1557 assert!(scale > 0.0);
1558 assert_eq!(zero_point, 0); Ok(())
1560 }
1561
1562 #[test]
1563 fn test_bnb_config_serialization() -> Result<()> {
1564 let config = BnBConfig {
1565 quant_type: BnBQuantType::NF4,
1566 compute_dtype: BnBComputeType::Float16,
1567 use_double_quant: true,
1568 bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
1569 };
1570
1571 let serialized = serde_json::to_string(&config)
1572 .map_err(|e| TrustformersError::serialization_error(e.to_string()))?;
1573 let deserialized: BnBConfig = serde_json::from_str(&serialized)
1574 .map_err(|e| TrustformersError::serialization_error(e.to_string()))?;
1575
1576 assert_eq!(config.quant_type, deserialized.quant_type);
1577 assert_eq!(config.compute_dtype, deserialized.compute_dtype);
1578 assert_eq!(config.use_double_quant, deserialized.use_double_quant);
1579 Ok(())
1580 }
1581}