Skip to main content

tensorlogic_train/
quantization.rs

1//! Model quantization utilities for compression and acceleration.
2//!
3//! This module provides quantization techniques to reduce model size and improve inference speed:
4//! - Post-training quantization (PTQ) for immediate deployment
5//! - Quantization-aware training (QAT) for better accuracy
6//! - Multiple bit-width support (int8, int4, int2)
7//! - Per-tensor and per-channel quantization
8//!
9//! # Examples
10//!
11//! ```
12//! use tensorlogic_train::{QuantizationConfig, Quantizer, QuantizationMode};
13//! use scirs2_core::ndarray::Array2;
14//!
15//! // Post-training quantization (PTQ)
16//! let weights = Array2::<f32>::zeros((10, 10));
17//! let config = QuantizationConfig::int8_symmetric();
18//! let quantized = Quantizer::quantize_tensor(&weights.view(), &config);
19//!
20//! // Dequantize for inference
21//! let dequantized = Quantizer::dequantize_tensor(&quantized);
22//! ```
23
24use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27
28/// Quantization mode determines the quantization strategy.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30pub enum QuantizationMode {
31    /// Symmetric quantization: range is [-max, max]
32    Symmetric,
33    /// Asymmetric quantization: range is [min, max]
34    Asymmetric,
35}
36
37/// Bit-width for quantization.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum BitWidth {
40    /// 8-bit integer quantization (most common)
41    Int8,
42    /// 4-bit integer quantization (higher compression)
43    Int4,
44    /// 2-bit integer quantization (extreme compression)
45    Int2,
46}
47
48impl BitWidth {
49    /// Returns the number of quantization levels for this bit-width.
50    pub fn levels(&self) -> i32 {
51        match self {
52            BitWidth::Int8 => 256, // 2^8
53            BitWidth::Int4 => 16,  // 2^4
54            BitWidth::Int2 => 4,   // 2^2
55        }
56    }
57
58    /// Returns the minimum quantized value.
59    pub fn qmin(&self) -> i32 {
60        match self {
61            BitWidth::Int8 => -128,
62            BitWidth::Int4 => -8,
63            BitWidth::Int2 => -2,
64        }
65    }
66
67    /// Returns the maximum quantized value.
68    pub fn qmax(&self) -> i32 {
69        match self {
70            BitWidth::Int8 => 127,
71            BitWidth::Int4 => 7,
72            BitWidth::Int2 => 1,
73        }
74    }
75}
76
77/// Quantization granularity (per-tensor or per-channel).
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum Granularity {
80    /// Single scale/zero-point for entire tensor
81    PerTensor,
82    /// Separate scale/zero-point per output channel (axis 0)
83    PerChannel,
84}
85
86/// Configuration for quantization.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct QuantizationConfig {
89    /// Quantization mode (symmetric or asymmetric)
90    pub mode: QuantizationMode,
91    /// Bit-width for quantization
92    pub bit_width: BitWidth,
93    /// Granularity (per-tensor or per-channel)
94    pub granularity: Granularity,
95    /// Small epsilon to avoid division by zero
96    pub eps: f32,
97}
98
99impl QuantizationConfig {
100    /// Creates a default int8 symmetric per-tensor configuration.
101    pub fn int8_symmetric() -> Self {
102        Self {
103            mode: QuantizationMode::Symmetric,
104            bit_width: BitWidth::Int8,
105            granularity: Granularity::PerTensor,
106            eps: 1e-8,
107        }
108    }
109
110    /// Creates a default int8 asymmetric per-tensor configuration.
111    pub fn int8_asymmetric() -> Self {
112        Self {
113            mode: QuantizationMode::Asymmetric,
114            bit_width: BitWidth::Int8,
115            granularity: Granularity::PerTensor,
116            eps: 1e-8,
117        }
118    }
119
120    /// Creates a default int4 symmetric per-channel configuration.
121    pub fn int4_per_channel() -> Self {
122        Self {
123            mode: QuantizationMode::Symmetric,
124            bit_width: BitWidth::Int4,
125            granularity: Granularity::PerChannel,
126            eps: 1e-8,
127        }
128    }
129
130    /// Creates a custom configuration.
131    pub fn new(mode: QuantizationMode, bit_width: BitWidth, granularity: Granularity) -> Self {
132        Self {
133            mode,
134            bit_width,
135            granularity,
136            eps: 1e-8,
137        }
138    }
139}
140
141/// Quantization parameters (scale and zero-point).
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct QuantizationParams {
144    /// Scale factor(s) for quantization
145    pub scale: Array1<f32>,
146    /// Zero-point(s) for asymmetric quantization
147    pub zero_point: Array1<i32>,
148    /// Original configuration used
149    pub config: QuantizationConfig,
150}
151
152/// Quantized tensor representation.
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct QuantizedTensor {
155    /// Quantized integer values
156    pub data: Array2<i8>,
157    /// Quantization parameters
158    pub params: QuantizationParams,
159}
160
161/// Main quantizer for model compression.
162pub struct Quantizer;
163
164impl Quantizer {
165    /// Quantizes a 2D tensor using the specified configuration.
166    ///
167    /// # Arguments
168    /// * `tensor` - Input floating-point tensor
169    /// * `config` - Quantization configuration
170    ///
171    /// # Returns
172    /// Quantized tensor with parameters
173    pub fn quantize_tensor(
174        tensor: &ArrayView2<f32>,
175        config: &QuantizationConfig,
176    ) -> QuantizedTensor {
177        match config.granularity {
178            Granularity::PerTensor => Self::quantize_per_tensor(tensor, config),
179            Granularity::PerChannel => Self::quantize_per_channel(tensor, config),
180        }
181    }
182
183    /// Per-tensor quantization (single scale/zero-point).
184    fn quantize_per_tensor(
185        tensor: &ArrayView2<f32>,
186        config: &QuantizationConfig,
187    ) -> QuantizedTensor {
188        let (scale, zero_point) = Self::compute_params_tensor(tensor, config);
189
190        let quantized = tensor.mapv(|x| {
191            let q = (x / scale).round() + zero_point as f32;
192            Self::clamp_to_qrange(q as i32, config.bit_width) as i8
193        });
194
195        QuantizedTensor {
196            data: quantized,
197            params: QuantizationParams {
198                scale: Array1::from_vec(vec![scale]),
199                zero_point: Array1::from_vec(vec![zero_point]),
200                config: config.clone(),
201            },
202        }
203    }
204
205    /// Per-channel quantization (separate scale/zero-point per channel).
206    fn quantize_per_channel(
207        tensor: &ArrayView2<f32>,
208        config: &QuantizationConfig,
209    ) -> QuantizedTensor {
210        let num_channels = tensor.shape()[0];
211        let mut scales = Vec::with_capacity(num_channels);
212        let mut zero_points = Vec::with_capacity(num_channels);
213
214        // Compute parameters per channel
215        for i in 0..num_channels {
216            let channel = tensor.index_axis(Axis(0), i);
217            let (scale, zero_point) = Self::compute_params_channel(&channel, config);
218            scales.push(scale);
219            zero_points.push(zero_point);
220        }
221
222        // Quantize each channel
223        let mut quantized = Array2::<i8>::zeros(tensor.dim());
224        for (i, mut row) in quantized.axis_iter_mut(Axis(0)).enumerate() {
225            let channel = tensor.index_axis(Axis(0), i);
226            let scale = scales[i];
227            let zero_point = zero_points[i];
228
229            for (j, &val) in channel.iter().enumerate() {
230                let q = (val / scale).round() + zero_point as f32;
231                row[j] = Self::clamp_to_qrange(q as i32, config.bit_width) as i8;
232            }
233        }
234
235        QuantizedTensor {
236            data: quantized,
237            params: QuantizationParams {
238                scale: Array1::from_vec(scales),
239                zero_point: Array1::from_vec(zero_points),
240                config: config.clone(),
241            },
242        }
243    }
244
245    /// Computes quantization parameters for entire tensor.
246    fn compute_params_tensor(tensor: &ArrayView2<f32>, config: &QuantizationConfig) -> (f32, i32) {
247        let min = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
248        let max = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
249        Self::compute_scale_zero_point(min, max, config)
250    }
251
252    /// Computes quantization parameters for a single channel.
253    fn compute_params_channel(
254        channel: &ArrayView1<f32>,
255        config: &QuantizationConfig,
256    ) -> (f32, i32) {
257        let min = channel.iter().cloned().fold(f32::INFINITY, f32::min);
258        let max = channel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
259        Self::compute_scale_zero_point(min, max, config)
260    }
261
262    /// Computes scale and zero-point from min/max values.
263    fn compute_scale_zero_point(min: f32, max: f32, config: &QuantizationConfig) -> (f32, i32) {
264        let qmin = config.bit_width.qmin() as f32;
265        let qmax = config.bit_width.qmax() as f32;
266
267        match config.mode {
268            QuantizationMode::Symmetric => {
269                let abs_max = min.abs().max(max.abs());
270                let scale = (2.0 * abs_max / (qmax - qmin)).max(config.eps);
271                (scale, 0)
272            }
273            QuantizationMode::Asymmetric => {
274                let scale = ((max - min) / (qmax - qmin)).max(config.eps);
275                let zero_point = (qmin - min / scale).round() as i32;
276                let zero_point = Self::clamp_to_qrange(zero_point, config.bit_width);
277                (scale, zero_point)
278            }
279        }
280    }
281
282    /// Clamps a value to the quantization range.
283    fn clamp_to_qrange(value: i32, bit_width: BitWidth) -> i32 {
284        value.max(bit_width.qmin()).min(bit_width.qmax())
285    }
286
287    /// Dequantizes a quantized tensor back to float32.
288    ///
289    /// # Arguments
290    /// * `quantized` - Quantized tensor with parameters
291    ///
292    /// # Returns
293    /// Dequantized floating-point tensor
294    pub fn dequantize_tensor(quantized: &QuantizedTensor) -> Array2<f32> {
295        match quantized.params.config.granularity {
296            Granularity::PerTensor => {
297                let scale = quantized.params.scale[0];
298                let zero_point = quantized.params.zero_point[0];
299                quantized
300                    .data
301                    .mapv(|q| scale * (q as f32 - zero_point as f32))
302            }
303            Granularity::PerChannel => {
304                let mut result = Array2::<f32>::zeros(quantized.data.dim());
305                for (i, mut row) in result.axis_iter_mut(Axis(0)).enumerate() {
306                    let scale = quantized.params.scale[i];
307                    let zero_point = quantized.params.zero_point[i];
308                    let q_row = quantized.data.index_axis(Axis(0), i);
309
310                    for (j, &q) in q_row.iter().enumerate() {
311                        row[j] = scale * (q as f32 - zero_point as f32);
312                    }
313                }
314                result
315            }
316        }
317    }
318
319    /// Computes the compression ratio achieved by quantization.
320    pub fn compression_ratio(config: &QuantizationConfig) -> f32 {
321        let original_bits = 32.0; // f32
322        let quantized_bits = match config.bit_width {
323            BitWidth::Int8 => 8.0,
324            BitWidth::Int4 => 4.0,
325            BitWidth::Int2 => 2.0,
326        };
327        original_bits / quantized_bits
328    }
329
330    /// Estimates the quantization error (MSE) for a tensor.
331    pub fn quantization_error(original: &ArrayView2<f32>, quantized: &QuantizedTensor) -> f32 {
332        let dequantized = Self::dequantize_tensor(quantized);
333        let diff = original - &dequantized.view();
334        diff.mapv(|x| x * x).mean().unwrap_or(0.0)
335    }
336}
337
338/// Quantization-aware training (QAT) utilities.
339pub struct QuantizationAwareTraining {
340    /// Layer name to quantization config mapping
341    layer_configs: HashMap<String, QuantizationConfig>,
342    /// Whether to simulate quantization during training
343    simulate_quantization: bool,
344}
345
346impl QuantizationAwareTraining {
347    /// Creates a new QAT instance.
348    pub fn new(simulate_quantization: bool) -> Self {
349        Self {
350            layer_configs: HashMap::new(),
351            simulate_quantization,
352        }
353    }
354
355    /// Registers a layer for quantization-aware training.
356    pub fn register_layer(&mut self, layer_name: String, config: QuantizationConfig) {
357        self.layer_configs.insert(layer_name, config);
358    }
359
360    /// Simulates quantization during forward pass (straight-through estimator).
361    ///
362    /// This applies fake quantization: quantize then immediately dequantize,
363    /// allowing gradients to flow through.
364    pub fn fake_quantize(&self, tensor: &Array2<f32>, layer_name: &str) -> Array2<f32> {
365        if !self.simulate_quantization {
366            return tensor.clone();
367        }
368
369        if let Some(config) = self.layer_configs.get(layer_name) {
370            let quantized = Quantizer::quantize_tensor(&tensor.view(), config);
371            Quantizer::dequantize_tensor(&quantized)
372        } else {
373            tensor.clone()
374        }
375    }
376
377    /// Gets the quantization config for a layer.
378    pub fn get_config(&self, layer_name: &str) -> Option<&QuantizationConfig> {
379        self.layer_configs.get(layer_name)
380    }
381
382    /// Returns all registered layer names.
383    pub fn registered_layers(&self) -> Vec<&String> {
384        self.layer_configs.keys().collect()
385    }
386}
387
388/// Dynamic range calibration for post-training quantization.
389pub struct DynamicRangeCalibrator {
390    /// Collected min/max statistics per layer
391    statistics: HashMap<String, (f32, f32)>,
392    /// Number of samples collected
393    num_samples: usize,
394}
395
396impl DynamicRangeCalibrator {
397    /// Creates a new calibrator.
398    pub fn new() -> Self {
399        Self {
400            statistics: HashMap::new(),
401            num_samples: 0,
402        }
403    }
404
405    /// Collects statistics from a batch of activations.
406    pub fn collect(&mut self, layer_name: String, tensor: &ArrayView2<f32>) {
407        let min = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
408        let max = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
409
410        self.statistics
411            .entry(layer_name)
412            .and_modify(|(prev_min, prev_max)| {
413                *prev_min = prev_min.min(min);
414                *prev_max = prev_max.max(max);
415            })
416            .or_insert((min, max));
417
418        self.num_samples += 1;
419    }
420
421    /// Finalizes calibration and returns quantization configs.
422    pub fn finalize(
423        &self,
424        default_config: &QuantizationConfig,
425    ) -> HashMap<String, QuantizationConfig> {
426        self.statistics
427            .keys()
428            .map(|name| (name.clone(), default_config.clone()))
429            .collect()
430    }
431
432    /// Gets the collected range for a layer.
433    pub fn get_range(&self, layer_name: &str) -> Option<(f32, f32)> {
434        self.statistics.get(layer_name).copied()
435    }
436
437    /// Resets all collected statistics.
438    pub fn reset(&mut self) {
439        self.statistics.clear();
440        self.num_samples = 0;
441    }
442}
443
444impl Default for DynamicRangeCalibrator {
445    fn default() -> Self {
446        Self::new()
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use approx::assert_relative_eq;
454
455    #[test]
456    fn test_int8_symmetric_quantization() {
457        let tensor = Array2::from_shape_vec((2, 3), vec![-1.0, 0.0, 1.0, -2.0, 2.0, 0.5]).unwrap();
458        let config = QuantizationConfig::int8_symmetric();
459
460        let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
461        let dequantized = Quantizer::dequantize_tensor(&quantized);
462
463        // Check shape preserved
464        assert_eq!(dequantized.dim(), tensor.dim());
465
466        // Check approximate reconstruction
467        for (orig, deq) in tensor.iter().zip(dequantized.iter()) {
468            assert_relative_eq!(orig, deq, epsilon = 0.1);
469        }
470    }
471
472    #[test]
473    fn test_int8_asymmetric_quantization() {
474        let tensor = Array2::from_shape_vec((2, 2), vec![0.0, 1.0, 2.0, 3.0]).unwrap();
475        let config = QuantizationConfig::int8_asymmetric();
476
477        let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
478        assert_eq!(quantized.params.config.mode, QuantizationMode::Asymmetric);
479
480        let dequantized = Quantizer::dequantize_tensor(&quantized);
481        assert_relative_eq!(dequantized[[0, 0]], 0.0, epsilon = 0.05);
482        assert_relative_eq!(dequantized[[1, 1]], 3.0, epsilon = 0.05);
483    }
484
485    #[test]
486    fn test_int4_per_channel_quantization() {
487        let tensor =
488            Array2::from_shape_vec((2, 4), vec![-1.0, 0.0, 1.0, 2.0, -10.0, -5.0, 5.0, 10.0])
489                .unwrap();
490        let config = QuantizationConfig::int4_per_channel();
491
492        let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
493
494        // Should have 2 scales (one per channel)
495        assert_eq!(quantized.params.scale.len(), 2);
496        assert_eq!(quantized.params.zero_point.len(), 2);
497
498        let dequantized = Quantizer::dequantize_tensor(&quantized);
499        assert_eq!(dequantized.dim(), tensor.dim());
500    }
501
502    #[test]
503    fn test_bit_width_levels() {
504        assert_eq!(BitWidth::Int8.levels(), 256);
505        assert_eq!(BitWidth::Int4.levels(), 16);
506        assert_eq!(BitWidth::Int2.levels(), 4);
507    }
508
509    #[test]
510    fn test_bit_width_ranges() {
511        assert_eq!(BitWidth::Int8.qmin(), -128);
512        assert_eq!(BitWidth::Int8.qmax(), 127);
513        assert_eq!(BitWidth::Int4.qmin(), -8);
514        assert_eq!(BitWidth::Int4.qmax(), 7);
515    }
516
517    #[test]
518    fn test_compression_ratio() {
519        let config_int8 = QuantizationConfig::int8_symmetric();
520        assert_eq!(Quantizer::compression_ratio(&config_int8), 4.0);
521
522        let config_int4 = QuantizationConfig::new(
523            QuantizationMode::Symmetric,
524            BitWidth::Int4,
525            Granularity::PerTensor,
526        );
527        assert_eq!(Quantizer::compression_ratio(&config_int4), 8.0);
528    }
529
530    #[test]
531    fn test_quantization_error() {
532        let tensor = Array2::from_shape_vec((3, 3), vec![1.0; 9]).unwrap();
533        let config = QuantizationConfig::int8_symmetric();
534
535        let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
536        let error = Quantizer::quantization_error(&tensor.view(), &quantized);
537
538        // Error should be small for uniform values
539        assert!(error < 0.01);
540    }
541
542    #[test]
543    fn test_qat_registration() {
544        let mut qat = QuantizationAwareTraining::new(true);
545        qat.register_layer("layer1".to_string(), QuantizationConfig::int8_symmetric());
546        qat.register_layer("layer2".to_string(), QuantizationConfig::int4_per_channel());
547
548        assert_eq!(qat.registered_layers().len(), 2);
549        assert!(qat.get_config("layer1").is_some());
550        assert!(qat.get_config("layer3").is_none());
551    }
552
553    #[test]
554    fn test_fake_quantization() {
555        let mut qat = QuantizationAwareTraining::new(true);
556        qat.register_layer("fc1".to_string(), QuantizationConfig::int8_symmetric());
557
558        let tensor = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
559        let fake_quantized = qat.fake_quantize(&tensor, "fc1");
560
561        // Should be similar but not identical due to quantization
562        assert_eq!(fake_quantized.dim(), tensor.dim());
563    }
564
565    #[test]
566    fn test_dynamic_range_calibrator() {
567        let mut calibrator = DynamicRangeCalibrator::new();
568
569        let tensor1 = Array2::from_shape_vec((2, 2), vec![0.0, 1.0, 2.0, 3.0]).unwrap();
570        let tensor2 = Array2::from_shape_vec((2, 2), vec![-1.0, 0.0, 1.0, 4.0]).unwrap();
571
572        calibrator.collect("layer1".to_string(), &tensor1.view());
573        calibrator.collect("layer1".to_string(), &tensor2.view());
574
575        let (min, max) = calibrator.get_range("layer1").unwrap();
576        assert_eq!(min, -1.0);
577        assert_eq!(max, 4.0);
578    }
579
580    #[test]
581    fn test_calibrator_finalize() {
582        let mut calibrator = DynamicRangeCalibrator::new();
583        let tensor = Array2::from_shape_vec((2, 2), vec![1.0; 4]).unwrap();
584
585        calibrator.collect("layer1".to_string(), &tensor.view());
586        calibrator.collect("layer2".to_string(), &tensor.view());
587
588        let config = QuantizationConfig::int8_symmetric();
589        let configs = calibrator.finalize(&config);
590
591        assert_eq!(configs.len(), 2);
592        assert!(configs.contains_key("layer1"));
593        assert!(configs.contains_key("layer2"));
594    }
595
596    #[test]
597    fn test_calibrator_reset() {
598        let mut calibrator = DynamicRangeCalibrator::new();
599        let tensor = Array2::from_shape_vec((2, 2), vec![1.0; 4]).unwrap();
600
601        calibrator.collect("layer1".to_string(), &tensor.view());
602        assert_eq!(calibrator.num_samples, 1);
603
604        calibrator.reset();
605        assert_eq!(calibrator.num_samples, 0);
606        assert!(calibrator.get_range("layer1").is_none());
607    }
608
609    #[test]
610    fn test_zero_tensor_quantization() {
611        let tensor = Array2::zeros((3, 3));
612        let config = QuantizationConfig::int8_symmetric();
613
614        let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
615        let dequantized = Quantizer::dequantize_tensor(&quantized);
616
617        assert_eq!(dequantized, tensor);
618    }
619
620    #[test]
621    fn test_extreme_values_quantization() {
622        let tensor = Array2::from_shape_vec(
623            (2, 2),
624            vec![f32::MIN / 1e6, f32::MAX / 1e6, -1000.0, 1000.0],
625        )
626        .unwrap();
627        let config = QuantizationConfig::int8_symmetric();
628
629        let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
630        let dequantized = Quantizer::dequantize_tensor(&quantized);
631
632        // Should handle extreme values without panicking
633        assert_eq!(dequantized.dim(), tensor.dim());
634    }
635}