Skip to main content

trustformers_core/quantization/
base.rs

1#![allow(unused_variables)] // Quantization base implementation
2
3use crate::errors::{Result, TrustformersError};
4use crate::tensor::Tensor;
5use serde::{Deserialize, Serialize};
6
7/// Quantization schemes supported by TrustformeRS
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum QuantizationScheme {
10    /// 8-bit integer quantization
11    Int8,
12    /// 4-bit integer quantization (weight-only)
13    Int4,
14    /// Dynamic quantization (runtime quantization)
15    Dynamic,
16    /// Dynamic 8-bit integer quantization (runtime quantization)
17    DynamicINT8,
18    /// GPTQ (Gradient-based Post-Training Quantization)
19    GPTQ,
20    /// AWQ (Activation-aware Weight Quantization)
21    AWQ,
22    /// BitsAndBytes 8-bit quantization
23    BnB8bit,
24    /// BitsAndBytes 4-bit NormalFloat quantization
25    BnB4bit,
26    /// BitsAndBytes 4-bit Float16 quantization
27    BnB4bitFP4,
28}
29
30/// Quantization configuration
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct QuantizationConfig {
33    pub scheme: QuantizationScheme,
34    pub symmetric: bool,
35    pub per_channel: bool,
36    pub calibration_samples: Option<usize>,
37    pub group_size: Option<usize>,     // For grouped quantization
38    pub bnb_config: Option<BnBConfig>, // BitsAndBytes specific configuration
39}
40
41/// BitsAndBytes quantization configuration
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct BnBConfig {
44    pub use_double_quant: bool,
45    pub quant_type: BnBQuantType,
46    pub compute_dtype: BnBComputeType,
47    pub bnb_4bit_quant_storage: Option<BnBStorageType>,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
51pub enum BnBQuantType {
52    NF4,  // NormalFloat 4-bit
53    FP4,  // Float 4-bit
54    Int8, // Integer 8-bit
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum BnBComputeType {
59    Float16,
60    BFloat16,
61    Float32,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
65pub enum BnBStorageType {
66    UInt8,
67    Int8,
68    Float16,
69}
70
71impl Default for QuantizationConfig {
72    fn default() -> Self {
73        Self {
74            scheme: QuantizationScheme::Int8,
75            symmetric: true,
76            per_channel: false,
77            calibration_samples: Some(128),
78            group_size: Some(128),
79            bnb_config: None,
80        }
81    }
82}
83
84impl Default for BnBConfig {
85    fn default() -> Self {
86        Self {
87            use_double_quant: false,
88            quant_type: BnBQuantType::NF4,
89            compute_dtype: BnBComputeType::Float16,
90            bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
91        }
92    }
93}
94
95/// Quantized tensor representation
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct QuantizedTensor {
98    pub data: Vec<u8>,
99    pub scale: Vec<f32>,
100    pub zero_point: Vec<i32>,
101    pub shape: Vec<usize>,
102    pub scheme: QuantizationScheme,
103    pub per_channel: bool,
104}
105
106impl QuantizedTensor {
107    /// Create a new quantized tensor
108    pub fn new(
109        data: Vec<u8>,
110        scale: Vec<f32>,
111        zero_point: Vec<i32>,
112        shape: Vec<usize>,
113        scheme: QuantizationScheme,
114        per_channel: bool,
115    ) -> Self {
116        Self {
117            data,
118            scale,
119            zero_point,
120            shape,
121            scheme,
122            per_channel,
123        }
124    }
125
126    /// Dequantize back to f32 tensor
127    pub fn dequantize(&self) -> Result<Tensor> {
128        let total_elements: usize = self.shape.iter().product();
129        let mut result = Vec::with_capacity(total_elements);
130
131        match self.scheme {
132            QuantizationScheme::Int8 | QuantizationScheme::BnB8bit => {
133                if self.per_channel {
134                    self.dequantize_per_channel_int8(&mut result)?;
135                } else {
136                    self.dequantize_per_tensor_int8(&mut result)?;
137                }
138            },
139            QuantizationScheme::Int4 => {
140                if self.per_channel {
141                    self.dequantize_per_channel_int4(&mut result)?;
142                } else {
143                    self.dequantize_per_tensor_int4(&mut result)?;
144                }
145            },
146            QuantizationScheme::Dynamic | QuantizationScheme::DynamicINT8 => {
147                // Dynamic quantization uses runtime statistics
148                if self.per_channel {
149                    self.dequantize_per_channel_int8(&mut result)?;
150                } else {
151                    self.dequantize_per_tensor_int8(&mut result)?;
152                }
153            },
154            QuantizationScheme::GPTQ => {
155                // GPTQ (Gradient-based Post-Training Quantization)
156                // Uses optimized quantization with gradient information
157                if self.per_channel {
158                    self.dequantize_gptq_per_channel(&mut result)?;
159                } else {
160                    self.dequantize_gptq_per_tensor(&mut result)?;
161                }
162            },
163            QuantizationScheme::AWQ => {
164                // AWQ (Activation-aware Weight Quantization)
165                // Uses activation statistics for better quantization
166                if self.per_channel {
167                    self.dequantize_awq_per_channel(&mut result)?;
168                } else {
169                    self.dequantize_awq_per_tensor(&mut result)?;
170                }
171            },
172            QuantizationScheme::BnB4bit => {
173                // BitsAndBytes 4-bit NormalFloat quantization
174                if self.per_channel {
175                    self.dequantize_bnb_4bit_per_channel(&mut result)?;
176                } else {
177                    self.dequantize_bnb_4bit_per_tensor(&mut result)?;
178                }
179            },
180            QuantizationScheme::BnB4bitFP4 => {
181                // BitsAndBytes 4-bit Float16 quantization
182                if self.per_channel {
183                    self.dequantize_bnb_fp4_per_channel(&mut result)?;
184                } else {
185                    self.dequantize_bnb_fp4_per_tensor(&mut result)?;
186                }
187            },
188        }
189
190        Tensor::from_vec(result, &self.shape)
191    }
192
193    fn dequantize_per_tensor_int8(&self, result: &mut Vec<f32>) -> Result<()> {
194        if self.scale.len() != 1 || self.zero_point.len() != 1 {
195            return Err(TrustformersError::quantization_error(
196                "Per-tensor quantization requires single scale and zero point".into(),
197            ));
198        }
199
200        let scale = self.scale[0];
201        let zero_point = self.zero_point[0];
202
203        for &quantized_val in &self.data {
204            let int_val = quantized_val as i32 - zero_point;
205            let float_val = int_val as f32 * scale;
206            result.push(float_val);
207        }
208
209        Ok(())
210    }
211
212    fn dequantize_per_channel_int8(&self, result: &mut Vec<f32>) -> Result<()> {
213        let channels = self.scale.len();
214        let elements_per_channel = self.data.len() / channels;
215
216        for (channel_idx, (&scale, &zero_point)) in
217            self.scale.iter().zip(&self.zero_point).enumerate()
218        {
219            let start_idx = channel_idx * elements_per_channel;
220            let end_idx = start_idx + elements_per_channel;
221
222            for &quantized_val in &self.data[start_idx..end_idx] {
223                let int_val = quantized_val as i32 - zero_point;
224                let float_val = int_val as f32 * scale;
225                result.push(float_val);
226            }
227        }
228
229        Ok(())
230    }
231
232    fn dequantize_per_tensor_int4(&self, result: &mut Vec<f32>) -> Result<()> {
233        if self.scale.len() != 1 || self.zero_point.len() != 1 {
234            return Err(TrustformersError::quantization_error(
235                "Per-tensor quantization requires single scale and zero point".into(),
236            ));
237        }
238
239        let scale = self.scale[0];
240        let zero_point = self.zero_point[0];
241
242        // Each byte contains 2 4-bit values
243        for &byte in &self.data {
244            // Extract high 4 bits
245            let high_nibble = (byte >> 4) as i32 - zero_point;
246            let high_val = high_nibble as f32 * scale;
247            result.push(high_val);
248
249            // Extract low 4 bits
250            let low_nibble = (byte & 0x0F) as i32 - zero_point;
251            let low_val = low_nibble as f32 * scale;
252            result.push(low_val);
253        }
254
255        Ok(())
256    }
257
258    fn dequantize_per_channel_int4(&self, result: &mut Vec<f32>) -> Result<()> {
259        let channels = self.scale.len();
260        let bytes_per_channel = self.data.len() / channels;
261
262        for (channel_idx, (&scale, &zero_point)) in
263            self.scale.iter().zip(&self.zero_point).enumerate()
264        {
265            let start_idx = channel_idx * bytes_per_channel;
266            let end_idx = start_idx + bytes_per_channel;
267
268            for &byte in &self.data[start_idx..end_idx] {
269                // Extract high 4 bits
270                let high_nibble = (byte >> 4) as i32 - zero_point;
271                let high_val = high_nibble as f32 * scale;
272                result.push(high_val);
273
274                // Extract low 4 bits
275                let low_nibble = (byte & 0x0F) as i32 - zero_point;
276                let low_val = low_nibble as f32 * scale;
277                result.push(low_val);
278            }
279        }
280
281        Ok(())
282    }
283
284    /// GPTQ dequantization per tensor
285    fn dequantize_gptq_per_tensor(&self, result: &mut Vec<f32>) -> Result<()> {
286        if self.scale.len() != 1 || self.zero_point.len() != 1 {
287            return Err(TrustformersError::quantization_error(
288                "GPTQ per-tensor quantization requires single scale and zero point".into(),
289            ));
290        }
291
292        let scale = self.scale[0];
293        let zero_point = self.zero_point[0];
294
295        // GPTQ uses optimized quantization with gradient information
296        // Similar to int8 but with better error compensation
297        for &quantized_val in &self.data {
298            let int_val = quantized_val as i32 - zero_point;
299            let float_val = int_val as f32 * scale;
300            result.push(float_val);
301        }
302
303        Ok(())
304    }
305
306    /// GPTQ dequantization per channel
307    fn dequantize_gptq_per_channel(&self, result: &mut Vec<f32>) -> Result<()> {
308        let channels = self.scale.len();
309        let elements_per_channel = self.data.len() / channels;
310
311        for (channel_idx, (&scale, &zero_point)) in
312            self.scale.iter().zip(&self.zero_point).enumerate()
313        {
314            let start_idx = channel_idx * elements_per_channel;
315            let end_idx = start_idx + elements_per_channel;
316
317            for &quantized_val in &self.data[start_idx..end_idx] {
318                let int_val = quantized_val as i32 - zero_point;
319                let float_val = int_val as f32 * scale;
320                result.push(float_val);
321            }
322        }
323
324        Ok(())
325    }
326
327    /// AWQ dequantization per tensor
328    fn dequantize_awq_per_tensor(&self, result: &mut Vec<f32>) -> Result<()> {
329        if self.scale.len() != 1 || self.zero_point.len() != 1 {
330            return Err(TrustformersError::quantization_error(
331                "AWQ per-tensor quantization requires single scale and zero point".into(),
332            ));
333        }
334
335        let scale = self.scale[0];
336        let zero_point = self.zero_point[0];
337
338        // AWQ uses activation-aware weight quantization
339        // Similar to int8 but optimized for specific activation patterns
340        for &quantized_val in &self.data {
341            let int_val = quantized_val as i32 - zero_point;
342            let float_val = int_val as f32 * scale;
343            result.push(float_val);
344        }
345
346        Ok(())
347    }
348
349    /// AWQ dequantization per channel
350    fn dequantize_awq_per_channel(&self, result: &mut Vec<f32>) -> Result<()> {
351        let channels = self.scale.len();
352        let elements_per_channel = self.data.len() / channels;
353
354        for (channel_idx, (&scale, &zero_point)) in
355            self.scale.iter().zip(&self.zero_point).enumerate()
356        {
357            let start_idx = channel_idx * elements_per_channel;
358            let end_idx = start_idx + elements_per_channel;
359
360            for &quantized_val in &self.data[start_idx..end_idx] {
361                let int_val = quantized_val as i32 - zero_point;
362                let float_val = int_val as f32 * scale;
363                result.push(float_val);
364            }
365        }
366
367        Ok(())
368    }
369
370    /// BitsAndBytes 4-bit NormalFloat dequantization per tensor
371    fn dequantize_bnb_4bit_per_tensor(&self, result: &mut Vec<f32>) -> Result<()> {
372        // NF4 (NormalFloat 4-bit) uses a non-uniform quantization table
373        // optimized for normal distributions of weights
374        const NF4_LEVELS: [f32; 16] = [
375            -1.0,
376            -0.6961928009986877,
377            -0.5250730514526367,
378            -0.39491748809814453,
379            -0.28444138169288635,
380            -0.18477343022823334,
381            -0.09105003625154495,
382            0.0,
383            0.07958029955625534,
384            0.16093020141124725,
385            0.24611230194568634,
386            0.33791524171829224,
387            0.44070982933044434,
388            0.5626170039176941,
389            0.7229568362236023,
390            1.0,
391        ];
392
393        if self.scale.len() != 1 {
394            return Err(TrustformersError::quantization_error(
395                "BnB 4-bit per-tensor quantization requires single scale".into(),
396            ));
397        }
398
399        let scale = self.scale[0];
400
401        for &byte in &self.data {
402            // Extract high 4 bits
403            let high_nibble = (byte >> 4) & 0x0F;
404            let high_val = NF4_LEVELS[high_nibble as usize] * scale;
405            result.push(high_val);
406
407            // Extract low 4 bits
408            let low_nibble = byte & 0x0F;
409            let low_val = NF4_LEVELS[low_nibble as usize] * scale;
410            result.push(low_val);
411        }
412
413        Ok(())
414    }
415
416    /// BitsAndBytes 4-bit NormalFloat dequantization per channel
417    fn dequantize_bnb_4bit_per_channel(&self, result: &mut Vec<f32>) -> Result<()> {
418        const NF4_LEVELS: [f32; 16] = [
419            -1.0,
420            -0.6961928009986877,
421            -0.5250730514526367,
422            -0.39491748809814453,
423            -0.28444138169288635,
424            -0.18477343022823334,
425            -0.09105003625154495,
426            0.0,
427            0.07958029955625534,
428            0.16093020141124725,
429            0.24611230194568634,
430            0.33791524171829224,
431            0.44070982933044434,
432            0.5626170039176941,
433            0.7229568362236023,
434            1.0,
435        ];
436
437        let channels = self.scale.len();
438        let bytes_per_channel = self.data.len() / channels;
439
440        for (channel_idx, &scale) in self.scale.iter().enumerate() {
441            let start_idx = channel_idx * bytes_per_channel;
442            let end_idx = start_idx + bytes_per_channel;
443
444            for &byte in &self.data[start_idx..end_idx] {
445                // Extract high 4 bits
446                let high_nibble = (byte >> 4) & 0x0F;
447                let high_val = NF4_LEVELS[high_nibble as usize] * scale;
448                result.push(high_val);
449
450                // Extract low 4 bits
451                let low_nibble = byte & 0x0F;
452                let low_val = NF4_LEVELS[low_nibble as usize] * scale;
453                result.push(low_val);
454            }
455        }
456
457        Ok(())
458    }
459
460    /// BitsAndBytes 4-bit Float16 dequantization per tensor
461    fn dequantize_bnb_fp4_per_tensor(&self, result: &mut Vec<f32>) -> Result<()> {
462        // FP4 uses a uniform quantization table for Float16 values
463        const FP4_LEVELS: [f32; 16] = [
464            -12.0, -8.0, -6.0, -4.0, -3.0, -2.0, -1.5, -1.0, 0.0, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 8.0,
465        ];
466
467        if self.scale.len() != 1 {
468            return Err(TrustformersError::quantization_error(
469                "BnB FP4 per-tensor quantization requires single scale".into(),
470            ));
471        }
472
473        let scale = self.scale[0];
474
475        for &byte in &self.data {
476            // Extract high 4 bits
477            let high_nibble = (byte >> 4) & 0x0F;
478            let high_val = FP4_LEVELS[high_nibble as usize] * scale;
479            result.push(high_val);
480
481            // Extract low 4 bits
482            let low_nibble = byte & 0x0F;
483            let low_val = FP4_LEVELS[low_nibble as usize] * scale;
484            result.push(low_val);
485        }
486
487        Ok(())
488    }
489
490    /// BitsAndBytes 4-bit Float16 dequantization per channel
491    fn dequantize_bnb_fp4_per_channel(&self, result: &mut Vec<f32>) -> Result<()> {
492        const FP4_LEVELS: [f32; 16] = [
493            -12.0, -8.0, -6.0, -4.0, -3.0, -2.0, -1.5, -1.0, 0.0, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 8.0,
494        ];
495
496        let channels = self.scale.len();
497        let bytes_per_channel = self.data.len() / channels;
498
499        for (channel_idx, &scale) in self.scale.iter().enumerate() {
500            let start_idx = channel_idx * bytes_per_channel;
501            let end_idx = start_idx + bytes_per_channel;
502
503            for &byte in &self.data[start_idx..end_idx] {
504                // Extract high 4 bits
505                let high_nibble = (byte >> 4) & 0x0F;
506                let high_val = FP4_LEVELS[high_nibble as usize] * scale;
507                result.push(high_val);
508
509                // Extract low 4 bits
510                let low_nibble = byte & 0x0F;
511                let low_val = FP4_LEVELS[low_nibble as usize] * scale;
512                result.push(low_val);
513            }
514        }
515
516        Ok(())
517    }
518}
519
520/// Quantization utilities
521pub struct Quantizer;
522
523impl Quantizer {
524    /// Quantize a tensor according to the given configuration
525    pub fn quantize(tensor: &Tensor, config: &QuantizationConfig) -> Result<QuantizedTensor> {
526        match config.scheme {
527            QuantizationScheme::Int8 => {
528                if config.per_channel {
529                    Self::quantize_per_channel_int8(tensor, config.symmetric)
530                } else {
531                    Self::quantize_per_tensor_int8(tensor, config.symmetric)
532                }
533            },
534            QuantizationScheme::Int4 => {
535                if config.per_channel {
536                    Self::quantize_per_channel_int4(tensor, config.symmetric, config.group_size)
537                } else {
538                    Self::quantize_per_tensor_int4(tensor, config.symmetric)
539                }
540            },
541            QuantizationScheme::Dynamic => Self::dynamic_quantize(tensor),
542            QuantizationScheme::DynamicINT8 => {
543                // Dynamic INT8 quantization - same as Dynamic but explicitly 8-bit
544                Self::dynamic_quantize(tensor)
545            },
546            QuantizationScheme::GPTQ => {
547                // GPTQ (Gradient-based Post-Training Quantization)
548                // For now, use standard INT4 quantization with optimized settings
549                // Full GPTQ implementation would require Hessian computation
550                if config.per_channel {
551                    Self::quantize_per_channel_int4(tensor, true, config.group_size)
552                } else {
553                    Self::quantize_per_tensor_int4(tensor, true)
554                }
555            },
556            QuantizationScheme::AWQ => {
557                // AWQ (Activation-aware Weight Quantization)
558                // For now, use standard INT4 quantization with symmetric mode
559                // Full AWQ implementation would use activation statistics
560                if config.per_channel {
561                    Self::quantize_per_channel_int4(tensor, true, config.group_size)
562                } else {
563                    Self::quantize_per_tensor_int4(tensor, true)
564                }
565            },
566            QuantizationScheme::BnB8bit => {
567                // BitsAndBytes 8-bit quantization
568                let bnb_config = config.bnb_config.clone().unwrap_or(BnBConfig {
569                    use_double_quant: false,
570                    quant_type: BnBQuantType::Int8,
571                    compute_dtype: BnBComputeType::Float16,
572                    bnb_4bit_quant_storage: None,
573                });
574                let quantizer = BnBQuantizer::new(bnb_config);
575                quantizer.quantize_bnb_int8(tensor)
576            },
577            QuantizationScheme::BnB4bit => {
578                // BitsAndBytes 4-bit NormalFloat quantization
579                let bnb_config = config.bnb_config.clone().unwrap_or(BnBConfig {
580                    use_double_quant: false,
581                    quant_type: BnBQuantType::NF4,
582                    compute_dtype: BnBComputeType::Float16,
583                    bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
584                });
585                let quantizer = BnBQuantizer::new(bnb_config);
586                quantizer.quantize_nf4(tensor)
587            },
588            QuantizationScheme::BnB4bitFP4 => {
589                // BitsAndBytes 4-bit Float16 quantization
590                let bnb_config = config.bnb_config.clone().unwrap_or(BnBConfig {
591                    use_double_quant: false,
592                    quant_type: BnBQuantType::FP4,
593                    compute_dtype: BnBComputeType::Float16,
594                    bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
595                });
596                let quantizer = BnBQuantizer::new(bnb_config);
597                quantizer.quantize_fp4(tensor)
598            },
599        }
600    }
601
602    /// Per-tensor 8-bit quantization
603    fn quantize_per_tensor_int8(tensor: &Tensor, symmetric: bool) -> Result<QuantizedTensor> {
604        match tensor {
605            Tensor::F32(arr) => {
606                let data = arr.iter().cloned().collect::<Vec<f32>>();
607                let (scale, zero_point) = Self::compute_quantization_params(&data, symmetric, 8)?;
608
609                let quantized_data: Vec<u8> = data
610                    .iter()
611                    .map(|&val| Self::quantize_value_int8(val, scale, zero_point))
612                    .collect();
613
614                Ok(QuantizedTensor::new(
615                    quantized_data,
616                    vec![scale],
617                    vec![zero_point],
618                    arr.shape().to_vec(),
619                    QuantizationScheme::Int8,
620                    false,
621                ))
622            },
623            _ => Err(TrustformersError::quantization_error(
624                "Unsupported tensor type for quantization".into(),
625            )),
626        }
627    }
628
629    /// Per-channel 8-bit quantization
630    fn quantize_per_channel_int8(tensor: &Tensor, symmetric: bool) -> Result<QuantizedTensor> {
631        match tensor {
632            Tensor::F32(arr) => {
633                let shape = arr.shape();
634                let channels = shape[0]; // Assume first dimension is channels
635                let elements_per_channel = arr.len() / channels;
636
637                let mut scales = Vec::with_capacity(channels);
638                let mut zero_points = Vec::with_capacity(channels);
639                let mut quantized_data = Vec::with_capacity(arr.len());
640
641                for channel in 0..channels {
642                    let start_idx = channel * elements_per_channel;
643                    let end_idx = start_idx + elements_per_channel;
644                    let channel_data = arr
645                        .iter()
646                        .skip(start_idx)
647                        .take(elements_per_channel)
648                        .cloned()
649                        .collect::<Vec<f32>>();
650
651                    let (scale, zero_point) =
652                        Self::compute_quantization_params(&channel_data, symmetric, 8)?;
653                    scales.push(scale);
654                    zero_points.push(zero_point);
655
656                    let channel_quantized: Vec<u8> = channel_data
657                        .iter()
658                        .map(|&val| Self::quantize_value_int8(val, scale, zero_point))
659                        .collect();
660
661                    quantized_data.extend(channel_quantized);
662                }
663
664                Ok(QuantizedTensor::new(
665                    quantized_data,
666                    scales,
667                    zero_points,
668                    shape.to_vec(),
669                    QuantizationScheme::Int8,
670                    true,
671                ))
672            },
673            _ => Err(TrustformersError::quantization_error(
674                "Unsupported tensor type for quantization".into(),
675            )),
676        }
677    }
678
679    /// Per-tensor 4-bit quantization (weight-only)
680    fn quantize_per_tensor_int4(tensor: &Tensor, symmetric: bool) -> Result<QuantizedTensor> {
681        match tensor {
682            Tensor::F32(arr) => {
683                let data = arr.iter().cloned().collect::<Vec<f32>>();
684                let (scale, zero_point) = Self::compute_quantization_params(&data, symmetric, 4)?;
685
686                let quantized_data = Self::pack_int4_values(&data, scale, zero_point)?;
687
688                Ok(QuantizedTensor::new(
689                    quantized_data,
690                    vec![scale],
691                    vec![zero_point],
692                    arr.shape().to_vec(),
693                    QuantizationScheme::Int4,
694                    false,
695                ))
696            },
697            _ => Err(TrustformersError::quantization_error(
698                "Unsupported tensor type for quantization".into(),
699            )),
700        }
701    }
702
703    /// Per-channel 4-bit quantization with optional grouping
704    fn quantize_per_channel_int4(
705        tensor: &Tensor,
706        symmetric: bool,
707        group_size: Option<usize>,
708    ) -> Result<QuantizedTensor> {
709        match tensor {
710            Tensor::F32(arr) => {
711                let shape = arr.shape();
712                let total_elements = arr.len();
713                let group_size = group_size.unwrap_or(128);
714                let num_groups = total_elements.div_ceil(group_size);
715
716                let mut scales = Vec::with_capacity(num_groups);
717                let mut zero_points = Vec::with_capacity(num_groups);
718                let mut quantized_data = Vec::with_capacity(total_elements / 2); // 4-bit packing
719
720                for group_idx in 0..num_groups {
721                    let start_idx = group_idx * group_size;
722                    let end_idx = (start_idx + group_size).min(total_elements);
723
724                    let group_data = arr
725                        .iter()
726                        .skip(start_idx)
727                        .take(end_idx - start_idx)
728                        .cloned()
729                        .collect::<Vec<f32>>();
730
731                    let (scale, zero_point) =
732                        Self::compute_quantization_params(&group_data, symmetric, 4)?;
733                    scales.push(scale);
734                    zero_points.push(zero_point);
735
736                    let group_quantized = Self::pack_int4_values(&group_data, scale, zero_point)?;
737                    quantized_data.extend(group_quantized);
738                }
739
740                Ok(QuantizedTensor::new(
741                    quantized_data,
742                    scales,
743                    zero_points,
744                    shape.to_vec(),
745                    QuantizationScheme::Int4,
746                    true,
747                ))
748            },
749            _ => Err(TrustformersError::quantization_error(
750                "Unsupported tensor type for quantization".into(),
751            )),
752        }
753    }
754
755    /// Dynamic quantization (quantize at runtime)
756    fn dynamic_quantize(tensor: &Tensor) -> Result<QuantizedTensor> {
757        // For dynamic quantization, we quantize to int8 per-tensor
758        Self::quantize_per_tensor_int8(tensor, false)
759    }
760
761    /// Compute quantization parameters (scale and zero point)
762    fn compute_quantization_params(data: &[f32], symmetric: bool, bits: u8) -> Result<(f32, i32)> {
763        if data.is_empty() {
764            return Err(TrustformersError::quantization_error(
765                "Cannot quantize empty data".into(),
766            ));
767        }
768
769        let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
770        let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
771
772        let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 };
773        let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 };
774
775        let (scale, zero_point) = if symmetric {
776            let abs_max = max_val.abs().max(min_val.abs());
777            let scale = abs_max / (q_max - q_min) as f32;
778            (scale, 0)
779        } else {
780            let scale = (max_val - min_val) / (q_max - q_min) as f32;
781            let zero_point = q_min - (min_val / scale).round() as i32;
782            let zero_point = zero_point.clamp(q_min, q_max);
783            (scale, zero_point)
784        };
785
786        Ok((scale, zero_point))
787    }
788
789    /// Quantize a single float value to int8
790    fn quantize_value_int8(value: f32, scale: f32, zero_point: i32) -> u8 {
791        let quantized = (value / scale).round() as i32 + zero_point;
792        quantized.clamp(0, 255) as u8
793    }
794
795    /// Pack float values into 4-bit representation
796    fn pack_int4_values(data: &[f32], scale: f32, zero_point: i32) -> Result<Vec<u8>> {
797        let mut packed = Vec::with_capacity(data.len().div_ceil(2));
798
799        for chunk in data.chunks(2) {
800            let val1 = Self::quantize_value_int4(chunk[0], scale, zero_point);
801            let val2 = if chunk.len() > 1 {
802                Self::quantize_value_int4(chunk[1], scale, zero_point)
803            } else {
804                0 // Pad with zero
805            };
806
807            // Pack two 4-bit values into one byte
808            let packed_byte = (val1 << 4) | val2;
809            packed.push(packed_byte);
810        }
811
812        Ok(packed)
813    }
814
815    /// Quantize a single float value to int4
816    fn quantize_value_int4(value: f32, scale: f32, zero_point: i32) -> u8 {
817        let quantized = (value / scale).round() as i32 + zero_point;
818        quantized.clamp(0, 15) as u8
819    }
820
821    /// Calibrate quantization parameters using sample data
822    pub fn calibrate(
823        samples: &[Tensor],
824        config: &QuantizationConfig,
825    ) -> Result<QuantizationConfig> {
826        // This is a simplified calibration - in practice, you'd run the model
827        // on representative data and collect activation statistics
828        let mut calibrated_config = config.clone();
829
830        if let Some(sample_count) = config.calibration_samples {
831            let num_samples = samples.len().min(sample_count);
832
833            // Collect statistics from samples
834            let mut all_values = Vec::new();
835            for sample in samples.iter().take(num_samples) {
836                match sample {
837                    Tensor::F32(arr) => {
838                        all_values.extend(arr.iter().cloned());
839                    },
840                    _ => continue,
841                }
842            }
843
844            if !all_values.is_empty() {
845                // Update configuration based on calibration data
846                let abs_max = all_values.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
847
848                // Adjust symmetric flag based on data distribution
849                let min_val = all_values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
850                let max_val = all_values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
851
852                calibrated_config.symmetric =
853                    (min_val.abs() - max_val.abs()).abs() / max_val.abs() < 0.1;
854            }
855        }
856
857        Ok(calibrated_config)
858    }
859}
860
861/// GPTQ (Gradient-based Post-Training Quantization) implementation
862pub struct GPTQQuantizer {
863    config: QuantizationConfig,
864}
865
866impl GPTQQuantizer {
867    pub fn new(config: QuantizationConfig) -> Self {
868        Self { config }
869    }
870
871    /// Apply GPTQ quantization to a tensor
872    /// This is a simplified version - full GPTQ requires Hessian computation
873    pub fn quantize(&self, tensor: &Tensor, hessian: Option<&Tensor>) -> Result<QuantizedTensor> {
874        // For now, fall back to standard quantization
875        // In a full implementation, this would use the Hessian to minimize quantization error
876        Quantizer::quantize(tensor, &self.config)
877    }
878}
879
880/// AWQ (Activation-aware Weight Quantization) implementation
881pub struct AWQQuantizer {
882    config: QuantizationConfig,
883    activation_scales: Option<Vec<f32>>,
884}
885
886impl AWQQuantizer {
887    pub fn new(config: QuantizationConfig) -> Self {
888        Self {
889            config,
890            activation_scales: None,
891        }
892    }
893
894    /// Set activation scales for weight quantization
895    pub fn set_activation_scales(&mut self, scales: Vec<f32>) {
896        self.activation_scales = Some(scales);
897    }
898
899    /// Apply AWQ quantization to a tensor
900    pub fn quantize(&self, tensor: &Tensor) -> Result<QuantizedTensor> {
901        // For now, fall back to standard quantization
902        // In a full implementation, this would use activation scales to improve quantization
903        Quantizer::quantize(tensor, &self.config)
904    }
905}
906
907/// BitsAndBytes quantizer implementation
908pub struct BnBQuantizer {
909    config: BnBConfig,
910}
911
912impl BnBQuantizer {
913    pub fn new(config: BnBConfig) -> Self {
914        Self { config }
915    }
916
917    /// Quantize tensor using BitsAndBytes method
918    pub fn quantize(&self, tensor: &Tensor) -> Result<QuantizedTensor> {
919        match self.config.quant_type {
920            BnBQuantType::NF4 => self.quantize_nf4(tensor),
921            BnBQuantType::FP4 => self.quantize_fp4(tensor),
922            BnBQuantType::Int8 => self.quantize_bnb_int8(tensor),
923        }
924    }
925
926    /// NormalFloat 4-bit quantization (BitsAndBytes NF4)
927    fn quantize_nf4(&self, tensor: &Tensor) -> Result<QuantizedTensor> {
928        match tensor {
929            Tensor::F32(arr) => {
930                let data = arr.iter().cloned().collect::<Vec<f32>>();
931                let block_size = 64; // Standard block size for NF4
932
933                let mut quantized_data = Vec::new();
934                let mut scales = Vec::new();
935                let mut zero_points = Vec::new();
936
937                for chunk in data.chunks(block_size) {
938                    let (block_scale, block_quantized) = self.nf4_quantize_block(chunk)?;
939                    scales.push(block_scale);
940                    zero_points.push(0); // NF4 is symmetric
941                    quantized_data.extend(block_quantized);
942                }
943
944                Ok(QuantizedTensor::new(
945                    quantized_data,
946                    scales,
947                    zero_points,
948                    arr.shape().to_vec(),
949                    QuantizationScheme::BnB4bit,
950                    false,
951                ))
952            },
953            _ => Err(TrustformersError::quantization_error(
954                "Unsupported tensor type for BnB NF4".into(),
955            )),
956        }
957    }
958
959    /// Float 4-bit quantization (BitsAndBytes FP4)
960    fn quantize_fp4(&self, tensor: &Tensor) -> Result<QuantizedTensor> {
961        match tensor {
962            Tensor::F32(arr) => {
963                let data = arr.iter().cloned().collect::<Vec<f32>>();
964                let block_size = 64;
965
966                let mut quantized_data = Vec::new();
967                let mut scales = Vec::new();
968                let mut zero_points = Vec::new();
969
970                for chunk in data.chunks(block_size) {
971                    let (block_scale, block_quantized) = self.fp4_quantize_block(chunk)?;
972                    scales.push(block_scale);
973                    zero_points.push(0); // FP4 is symmetric
974                    quantized_data.extend(block_quantized);
975                }
976
977                Ok(QuantizedTensor::new(
978                    quantized_data,
979                    scales,
980                    zero_points,
981                    arr.shape().to_vec(),
982                    QuantizationScheme::BnB4bitFP4,
983                    false,
984                ))
985            },
986            _ => Err(TrustformersError::quantization_error(
987                "Unsupported tensor type for BnB FP4".into(),
988            )),
989        }
990    }
991
992    /// BitsAndBytes 8-bit integer quantization
993    fn quantize_bnb_int8(&self, tensor: &Tensor) -> Result<QuantizedTensor> {
994        match tensor {
995            Tensor::F32(arr) => {
996                let data = arr.iter().cloned().collect::<Vec<f32>>();
997                let (scale, zero_point) = Quantizer::compute_quantization_params(&data, false, 8)?;
998
999                let quantized_data: Vec<u8> = data
1000                    .iter()
1001                    .map(|&val| Quantizer::quantize_value_int8(val, scale, zero_point))
1002                    .collect();
1003
1004                Ok(QuantizedTensor::new(
1005                    quantized_data,
1006                    vec![scale],
1007                    vec![zero_point],
1008                    arr.shape().to_vec(),
1009                    QuantizationScheme::BnB8bit,
1010                    false,
1011                ))
1012            },
1013            _ => Err(TrustformersError::quantization_error(
1014                "Unsupported tensor type for BnB Int8".into(),
1015            )),
1016        }
1017    }
1018
1019    /// NF4 block quantization with predefined quantization levels
1020    fn nf4_quantize_block(&self, data: &[f32]) -> Result<(f32, Vec<u8>)> {
1021        // NF4 quantization levels (based on normal distribution)
1022        const NF4_LEVELS: [f32; 16] = [
1023            -1.0,
1024            -0.6961928009986877,
1025            -0.5250730514526367,
1026            -0.39491748809814453,
1027            -0.28444138169288635,
1028            -0.18477343022823334,
1029            -0.09105003625154495,
1030            0.0,
1031            0.07958029955625534,
1032            0.16093020141124725,
1033            0.24611230194568634,
1034            0.33791524171829224,
1035            0.44070982933044434,
1036            0.5626170039176941,
1037            0.7229568362236023,
1038            1.0,
1039        ];
1040
1041        if data.is_empty() {
1042            return Err(TrustformersError::quantization_error(
1043                "Cannot quantize empty block".into(),
1044            ));
1045        }
1046
1047        // Compute block scale
1048        let abs_max = data.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
1049        let scale = abs_max;
1050
1051        if scale == 0.0 {
1052            return Ok((scale, vec![0; data.len()]));
1053        }
1054
1055        // Quantize each value to nearest NF4 level
1056        let mut quantized = Vec::with_capacity(data.len());
1057        for &val in data {
1058            let normalized = val / scale;
1059            let mut best_idx = 0;
1060            let mut best_dist = (normalized - NF4_LEVELS[0]).abs();
1061
1062            for (idx, &level) in NF4_LEVELS.iter().enumerate().skip(1) {
1063                let dist = (normalized - level).abs();
1064                if dist < best_dist {
1065                    best_dist = dist;
1066                    best_idx = idx;
1067                }
1068            }
1069
1070            quantized.push(best_idx as u8);
1071        }
1072
1073        Ok((scale, quantized))
1074    }
1075
1076    /// FP4 block quantization with floating-point levels
1077    fn fp4_quantize_block(&self, data: &[f32]) -> Result<(f32, Vec<u8>)> {
1078        // FP4 quantization levels (exponential distribution)
1079        const FP4_LEVELS: [f32; 16] = [
1080            0.0, 0.0625, 0.125, 0.1875, 0.25, 0.3125, 0.375, 0.4375, 0.5, 0.625, 0.75, 0.875, 1.0,
1081            1.25, 1.5, 2.0,
1082        ];
1083
1084        if data.is_empty() {
1085            return Err(TrustformersError::quantization_error(
1086                "Cannot quantize empty block".into(),
1087            ));
1088        }
1089
1090        // Compute block scale
1091        let abs_max = data.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
1092        let scale = abs_max / 2.0; // Scale for signed values
1093
1094        if scale == 0.0 {
1095            return Ok((scale, vec![0; data.len()]));
1096        }
1097
1098        // Quantize each value
1099        let mut quantized = Vec::with_capacity(data.len());
1100        for &val in data {
1101            let abs_val = val.abs() / scale;
1102            let sign = if val >= 0.0 { 0 } else { 8 }; // Use bit 3 for sign
1103
1104            let mut best_idx = 0;
1105            let mut best_dist = (abs_val - FP4_LEVELS[0]).abs();
1106
1107            for (idx, &level) in FP4_LEVELS[..8].iter().enumerate().skip(1) {
1108                let dist = (abs_val - level).abs();
1109                if dist < best_dist {
1110                    best_dist = dist;
1111                    best_idx = idx;
1112                }
1113            }
1114
1115            quantized.push((sign | best_idx) as u8);
1116        }
1117
1118        Ok((scale, quantized))
1119    }
1120
1121    /// Dequantize BitsAndBytes tensor
1122    pub fn dequantize(&self, tensor: &QuantizedTensor) -> Result<Tensor> {
1123        match tensor.scheme {
1124            QuantizationScheme::BnB4bit => self.dequantize_nf4(tensor),
1125            QuantizationScheme::BnB4bitFP4 => self.dequantize_fp4(tensor),
1126            QuantizationScheme::BnB8bit => tensor.dequantize(), // Use standard dequantization
1127            _ => Err(TrustformersError::quantization_error(format!(
1128                "BnB dequantization not supported for scheme {:?}",
1129                tensor.scheme
1130            ))),
1131        }
1132    }
1133
1134    /// Dequantize NF4 tensor
1135    fn dequantize_nf4(&self, tensor: &QuantizedTensor) -> Result<Tensor> {
1136        const NF4_LEVELS: [f32; 16] = [
1137            -1.0,
1138            -0.6961928009986877,
1139            -0.5250730514526367,
1140            -0.39491748809814453,
1141            -0.28444138169288635,
1142            -0.18477343022823334,
1143            -0.09105003625154495,
1144            0.0,
1145            0.07958029955625534,
1146            0.16093020141124725,
1147            0.24611230194568634,
1148            0.33791524171829224,
1149            0.44070982933044434,
1150            0.5626170039176941,
1151            0.7229568362236023,
1152            1.0,
1153        ];
1154
1155        let block_size = 64;
1156        let mut result = Vec::with_capacity(tensor.data.len());
1157        let num_blocks = tensor.scale.len();
1158
1159        for block_idx in 0..num_blocks {
1160            let scale = tensor.scale[block_idx];
1161            let start_idx = block_idx * block_size;
1162            let end_idx = (start_idx + block_size).min(tensor.data.len());
1163
1164            for &quantized_val in &tensor.data[start_idx..end_idx] {
1165                let idx = (quantized_val as usize).min(15);
1166                let dequantized = NF4_LEVELS[idx] * scale;
1167                result.push(dequantized);
1168            }
1169        }
1170
1171        Tensor::from_vec(result, &tensor.shape)
1172    }
1173
1174    /// Dequantize FP4 tensor
1175    fn dequantize_fp4(&self, tensor: &QuantizedTensor) -> Result<Tensor> {
1176        const FP4_LEVELS: [f32; 8] = [0.0, 0.0625, 0.125, 0.1875, 0.25, 0.3125, 0.375, 0.4375];
1177
1178        let block_size = 64;
1179        let mut result = Vec::with_capacity(tensor.data.len());
1180        let num_blocks = tensor.scale.len();
1181
1182        for block_idx in 0..num_blocks {
1183            let scale = tensor.scale[block_idx];
1184            let start_idx = block_idx * block_size;
1185            let end_idx = (start_idx + block_size).min(tensor.data.len());
1186
1187            for &quantized_val in &tensor.data[start_idx..end_idx] {
1188                let sign = if (quantized_val & 8) != 0 { -1.0 } else { 1.0 };
1189                let idx = (quantized_val & 7) as usize;
1190                let abs_val = FP4_LEVELS[idx];
1191                let dequantized = sign * abs_val * scale;
1192                result.push(dequantized);
1193            }
1194        }
1195
1196        Tensor::from_vec(result, &tensor.shape)
1197    }
1198}
1199
1200/// Quantization-aware training support
1201pub struct QATConfig {
1202    pub fake_quantize: bool,
1203    pub observe: bool,
1204    pub reduce_range: bool,
1205    pub qscheme: QuantizationScheme,
1206}
1207
1208impl Default for QATConfig {
1209    fn default() -> Self {
1210        Self {
1211            fake_quantize: true,
1212            observe: true,
1213            reduce_range: false,
1214            qscheme: QuantizationScheme::Int8,
1215        }
1216    }
1217}
1218
1219/// Fake quantization for QAT
1220pub struct FakeQuantize {
1221    config: QATConfig,
1222    observers: Vec<Observer>,
1223}
1224
1225/// Observer for collecting statistics during QAT
1226pub struct Observer {
1227    min_val: f32,
1228    max_val: f32,
1229    count: usize,
1230}
1231
1232impl Default for Observer {
1233    fn default() -> Self {
1234        Self::new()
1235    }
1236}
1237
1238impl Observer {
1239    pub fn new() -> Self {
1240        Self {
1241            min_val: f32::INFINITY,
1242            max_val: f32::NEG_INFINITY,
1243            count: 0,
1244        }
1245    }
1246
1247    pub fn update(&mut self, tensor: &Tensor) {
1248        if let Tensor::F32(arr) = tensor {
1249            for &val in arr.iter() {
1250                self.min_val = self.min_val.min(val);
1251                self.max_val = self.max_val.max(val);
1252                self.count += 1;
1253            }
1254        }
1255    }
1256
1257    pub fn get_quantization_params(&self, symmetric: bool, bits: u8) -> Result<(f32, i32)> {
1258        if self.count == 0 {
1259            return Err(TrustformersError::quantization_error(
1260                "No observations for quantization".into(),
1261            ));
1262        }
1263
1264        let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 };
1265        let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 };
1266
1267        let (scale, zero_point) = if symmetric {
1268            let abs_max = self.max_val.abs().max(self.min_val.abs());
1269            let scale = abs_max / (q_max - q_min) as f32;
1270            (scale, 0)
1271        } else {
1272            let scale = (self.max_val - self.min_val) / (q_max - q_min) as f32;
1273            let zero_point = q_min - (self.min_val / scale).round() as i32;
1274            let zero_point = zero_point.clamp(q_min, q_max);
1275            (scale, zero_point)
1276        };
1277
1278        Ok((scale, zero_point))
1279    }
1280}
1281
1282impl FakeQuantize {
1283    pub fn new(config: QATConfig) -> Self {
1284        Self {
1285            config,
1286            observers: Vec::new(),
1287        }
1288    }
1289
1290    /// Apply fake quantization during training
1291    pub fn forward(&mut self, tensor: &Tensor) -> Result<Tensor> {
1292        if self.config.observe {
1293            // Update observer statistics
1294            if self.observers.is_empty() {
1295                self.observers.push(Observer::new());
1296            }
1297            self.observers[0].update(tensor);
1298        }
1299
1300        if self.config.fake_quantize && !self.observers.is_empty() {
1301            // Apply fake quantization
1302            let observer = &self.observers[0];
1303            let (scale, zero_point) = observer.get_quantization_params(true, 8)?;
1304
1305            // Quantize and immediately dequantize
1306            match tensor {
1307                Tensor::F32(arr) => {
1308                    let quantized_data: Vec<f32> = arr
1309                        .iter()
1310                        .map(|&val| {
1311                            let q_val = Quantizer::quantize_value_int8(val, scale, zero_point);
1312                            let int_val = q_val as i32 - zero_point;
1313                            int_val as f32 * scale
1314                        })
1315                        .collect();
1316
1317                    Tensor::from_vec(quantized_data, arr.shape())
1318                },
1319                _ => Ok(tensor.clone()),
1320            }
1321        } else {
1322            Ok(tensor.clone())
1323        }
1324    }
1325}
1326
1327#[cfg(test)]
1328mod tests {
1329    use super::*;
1330
1331    #[test]
1332    fn test_int8_per_tensor_quantization() -> Result<()> {
1333        let tensor = Tensor::randn(&[10, 20])?;
1334        let config = QuantizationConfig {
1335            scheme: QuantizationScheme::Int8,
1336            symmetric: true,
1337            per_channel: false,
1338            calibration_samples: None,
1339            group_size: None,
1340            bnb_config: None,
1341        };
1342
1343        let quantized = Quantizer::quantize(&tensor, &config)?;
1344        assert_eq!(quantized.scheme, QuantizationScheme::Int8);
1345        assert!(!quantized.per_channel);
1346        assert_eq!(quantized.scale.len(), 1);
1347        assert_eq!(quantized.zero_point.len(), 1);
1348
1349        let dequantized = quantized.dequantize()?;
1350        assert_eq!(dequantized.shape(), tensor.shape());
1351        Ok(())
1352    }
1353
1354    #[test]
1355    fn test_int4_per_tensor_quantization() -> Result<()> {
1356        let tensor = Tensor::randn(&[8, 16])?;
1357        let config = QuantizationConfig {
1358            scheme: QuantizationScheme::Int4,
1359            symmetric: false,
1360            per_channel: false,
1361            calibration_samples: None,
1362            group_size: None,
1363            bnb_config: None,
1364        };
1365
1366        let quantized = Quantizer::quantize(&tensor, &config)?;
1367        assert_eq!(quantized.scheme, QuantizationScheme::Int4);
1368        assert!(!quantized.per_channel);
1369
1370        let dequantized = quantized.dequantize()?;
1371        assert_eq!(dequantized.shape(), tensor.shape());
1372        Ok(())
1373    }
1374
1375    #[test]
1376    fn test_per_channel_quantization() -> Result<()> {
1377        let tensor = Tensor::randn(&[4, 32])?;
1378        let config = QuantizationConfig {
1379            scheme: QuantizationScheme::Int8,
1380            symmetric: true,
1381            per_channel: true,
1382            calibration_samples: None,
1383            group_size: None,
1384            bnb_config: None,
1385        };
1386
1387        let quantized = Quantizer::quantize(&tensor, &config)?;
1388        assert!(quantized.per_channel);
1389        assert_eq!(quantized.scale.len(), 4); // Number of channels
1390        assert_eq!(quantized.zero_point.len(), 4);
1391
1392        let dequantized = quantized.dequantize()?;
1393        assert_eq!(dequantized.shape(), tensor.shape());
1394        Ok(())
1395    }
1396
1397    #[test]
1398    fn test_dynamic_quantization() -> Result<()> {
1399        let tensor = Tensor::randn(&[16, 32])?;
1400        let config = QuantizationConfig {
1401            scheme: QuantizationScheme::Dynamic,
1402            symmetric: false,
1403            per_channel: false,
1404            calibration_samples: None,
1405            group_size: None,
1406            bnb_config: None,
1407        };
1408
1409        let quantized = Quantizer::quantize(&tensor, &config)?;
1410        let dequantized = quantized.dequantize()?;
1411        assert_eq!(dequantized.shape(), tensor.shape());
1412        Ok(())
1413    }
1414
1415    #[test]
1416    fn test_quantization_params_computation() -> Result<()> {
1417        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1418
1419        // Symmetric quantization
1420        let (scale, zero_point) = Quantizer::compute_quantization_params(&data, true, 8)?;
1421        assert_eq!(zero_point, 0);
1422        assert!(scale > 0.0);
1423
1424        // Asymmetric quantization
1425        let (scale, zero_point) = Quantizer::compute_quantization_params(&data, false, 8)?;
1426        assert!(scale > 0.0);
1427        Ok(())
1428    }
1429
1430    #[test]
1431    fn test_gptq_quantizer() -> Result<()> {
1432        let tensor = Tensor::randn(&[16, 32])?;
1433        let config = QuantizationConfig::default();
1434        let gptq = GPTQQuantizer::new(config);
1435
1436        let quantized = gptq.quantize(&tensor, None)?;
1437        let dequantized = quantized.dequantize()?;
1438        assert_eq!(dequantized.shape(), tensor.shape());
1439        Ok(())
1440    }
1441
1442    #[test]
1443    fn test_awq_quantizer() -> Result<()> {
1444        let tensor = Tensor::randn(&[16, 32])?;
1445        let config = QuantizationConfig::default();
1446        let mut awq = AWQQuantizer::new(config);
1447
1448        let scales = vec![1.0; 16];
1449        awq.set_activation_scales(scales);
1450
1451        let quantized = awq.quantize(&tensor)?;
1452        let dequantized = quantized.dequantize()?;
1453        assert_eq!(dequantized.shape(), tensor.shape());
1454        Ok(())
1455    }
1456
1457    #[test]
1458    fn test_calibration() -> Result<()> {
1459        let samples = vec![
1460            Tensor::randn(&[16, 32])?,
1461            Tensor::randn(&[16, 32])?,
1462            Tensor::randn(&[16, 32])?,
1463        ];
1464
1465        let config = QuantizationConfig {
1466            calibration_samples: Some(2),
1467            ..Default::default()
1468        };
1469
1470        let calibrated = Quantizer::calibrate(&samples, &config)?;
1471        assert_eq!(calibrated.scheme, config.scheme);
1472        Ok(())
1473    }
1474
1475    #[test]
1476    fn test_bnb_nf4_quantization() -> Result<()> {
1477        let tensor = Tensor::randn(&[128])?;
1478        let config = BnBConfig {
1479            quant_type: BnBQuantType::NF4,
1480            compute_dtype: BnBComputeType::Float16,
1481            use_double_quant: false,
1482            bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
1483        };
1484
1485        let bnb = BnBQuantizer::new(config);
1486        let quantized = bnb.quantize(&tensor)?;
1487        assert_eq!(quantized.scheme, QuantizationScheme::BnB4bit);
1488
1489        let dequantized = bnb.dequantize(&quantized)?;
1490        assert_eq!(dequantized.shape(), tensor.shape());
1491        Ok(())
1492    }
1493
1494    #[test]
1495    fn test_bnb_fp4_quantization() -> Result<()> {
1496        let tensor = Tensor::randn(&[128])?;
1497        let config = BnBConfig {
1498            quant_type: BnBQuantType::FP4,
1499            compute_dtype: BnBComputeType::Float16,
1500            use_double_quant: false,
1501            bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
1502        };
1503
1504        let bnb = BnBQuantizer::new(config);
1505        let quantized = bnb.quantize(&tensor)?;
1506        assert_eq!(quantized.scheme, QuantizationScheme::BnB4bitFP4);
1507
1508        let dequantized = bnb.dequantize(&quantized)?;
1509        assert_eq!(dequantized.shape(), tensor.shape());
1510        Ok(())
1511    }
1512
1513    #[test]
1514    fn test_bnb_int8_quantization() -> Result<()> {
1515        let tensor = Tensor::randn(&[64, 64])?;
1516        let config = BnBConfig {
1517            quant_type: BnBQuantType::Int8,
1518            compute_dtype: BnBComputeType::Float32,
1519            use_double_quant: false,
1520            bnb_4bit_quant_storage: None,
1521        };
1522
1523        let bnb = BnBQuantizer::new(config);
1524        let quantized = bnb.quantize(&tensor)?;
1525        assert_eq!(quantized.scheme, QuantizationScheme::BnB8bit);
1526
1527        let dequantized = quantized.dequantize()?;
1528        assert_eq!(dequantized.shape(), tensor.shape());
1529        Ok(())
1530    }
1531
1532    #[test]
1533    fn test_qat_fake_quantize() -> Result<()> {
1534        let tensor = Tensor::randn(&[32, 32])?;
1535        let config = QATConfig::default();
1536        let mut fake_quant = FakeQuantize::new(config);
1537
1538        // First pass should build observer statistics
1539        let result1 = fake_quant.forward(&tensor)?;
1540        assert_eq!(result1.shape(), tensor.shape());
1541
1542        // Second pass should apply fake quantization
1543        let result2 = fake_quant.forward(&tensor)?;
1544        assert_eq!(result2.shape(), tensor.shape());
1545        Ok(())
1546    }
1547
1548    #[test]
1549    fn test_observer_statistics() -> Result<()> {
1550        let mut observer = Observer::new();
1551        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5])?;
1552
1553        observer.update(&tensor);
1554        assert_eq!(observer.count, 5);
1555
1556        let (scale, zero_point) = observer.get_quantization_params(true, 8)?;
1557        assert!(scale > 0.0);
1558        assert_eq!(zero_point, 0); // Symmetric quantization
1559        Ok(())
1560    }
1561
1562    #[test]
1563    fn test_bnb_config_serialization() -> Result<()> {
1564        let config = BnBConfig {
1565            quant_type: BnBQuantType::NF4,
1566            compute_dtype: BnBComputeType::Float16,
1567            use_double_quant: true,
1568            bnb_4bit_quant_storage: Some(BnBStorageType::UInt8),
1569        };
1570
1571        let serialized = serde_json::to_string(&config)
1572            .map_err(|e| TrustformersError::serialization_error(e.to_string()))?;
1573        let deserialized: BnBConfig = serde_json::from_str(&serialized)
1574            .map_err(|e| TrustformersError::serialization_error(e.to_string()))?;
1575
1576        assert_eq!(config.quant_type, deserialized.quant_type);
1577        assert_eq!(config.compute_dtype, deserialized.compute_dtype);
1578        assert_eq!(config.use_double_quant, deserialized.use_double_quant);
1579        Ok(())
1580    }
1581}