Skip to main content

tensorlogic_scirs_backend/
quantization.rs

1//! Quantization Infrastructure for TensorLogic
2//!
3//! This module provides utilities for quantizing tensors to lower precision
4//! formats (INT8, FP16, BF16) for improved memory efficiency and performance.
5//! While full quantized execution requires backend support, this infrastructure
6//! prepares the framework for future quantization-aware training and inference.
7
8use crate::{Scirs2Tensor, TlBackendError, TlBackendResult};
9use serde::{Deserialize, Serialize};
10
11/// Quantization data type.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum QuantizationType {
14    /// 8-bit signed integer quantization
15    Int8,
16    /// 16-bit floating point (IEEE 754 half precision)
17    Fp16,
18    /// 16-bit brain floating point (truncated FP32)
19    BFloat16,
20    /// 4-bit integer quantization (experimental)
21    Int4,
22    /// No quantization (full precision)
23    None,
24}
25
26impl QuantizationType {
27    /// Get the number of bits used by this quantization type.
28    pub fn bits(&self) -> usize {
29        match self {
30            QuantizationType::Int4 => 4,
31            QuantizationType::Int8 => 8,
32            QuantizationType::Fp16 | QuantizationType::BFloat16 => 16,
33            QuantizationType::None => 64, // Assuming f64 for full precision
34        }
35    }
36
37    /// Get the memory compression ratio compared to FP64.
38    pub fn compression_ratio(&self) -> f64 {
39        64.0 / self.bits() as f64
40    }
41
42    /// Check if this is a floating-point quantization.
43    pub fn is_float(&self) -> bool {
44        matches!(
45            self,
46            QuantizationType::Fp16 | QuantizationType::BFloat16 | QuantizationType::None
47        )
48    }
49
50    /// Check if this is an integer quantization.
51    pub fn is_integer(&self) -> bool {
52        matches!(self, QuantizationType::Int8 | QuantizationType::Int4)
53    }
54}
55
56/// Quantization scheme (symmetric vs asymmetric).
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum QuantizationScheme {
59    /// Symmetric quantization: range is [-max, max]
60    Symmetric,
61    /// Asymmetric quantization: range is [min, max]
62    Asymmetric,
63}
64
65/// Quantization granularity (per-tensor vs per-channel).
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67pub enum QuantizationGranularity {
68    /// Single scale and zero-point for entire tensor
69    PerTensor,
70    /// Separate scale and zero-point per output channel
71    PerChannel,
72}
73
74/// Quantization parameters for a tensor.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct QuantizationParams {
77    /// Quantization data type
78    pub qtype: QuantizationType,
79
80    /// Quantization scheme
81    pub scheme: QuantizationScheme,
82
83    /// Quantization granularity
84    pub granularity: QuantizationGranularity,
85
86    /// Scale factor(s) for dequantization
87    pub scale: Vec<f64>,
88
89    /// Zero point(s) for asymmetric quantization
90    pub zero_point: Vec<i32>,
91
92    /// Minimum value(s) in original tensor
93    pub min_val: Vec<f64>,
94
95    /// Maximum value(s) in original tensor
96    pub max_val: Vec<f64>,
97}
98
99impl QuantizationParams {
100    /// Create symmetric per-tensor quantization parameters.
101    pub fn symmetric_per_tensor(qtype: QuantizationType, tensor: &Scirs2Tensor) -> Self {
102        let abs_max = tensor.iter().map(|&x| x.abs()).fold(0.0f64, f64::max);
103
104        let scale = match qtype {
105            QuantizationType::Int8 => abs_max / 127.0,
106            QuantizationType::Int4 => abs_max / 7.0,
107            QuantizationType::Fp16 | QuantizationType::BFloat16 => 1.0,
108            QuantizationType::None => 1.0,
109        };
110
111        Self {
112            qtype,
113            scheme: QuantizationScheme::Symmetric,
114            granularity: QuantizationGranularity::PerTensor,
115            scale: vec![scale],
116            zero_point: vec![0],
117            min_val: vec![-abs_max],
118            max_val: vec![abs_max],
119        }
120    }
121
122    /// Create asymmetric per-tensor quantization parameters.
123    pub fn asymmetric_per_tensor(qtype: QuantizationType, tensor: &Scirs2Tensor) -> Self {
124        let min_val = tensor.iter().fold(f64::INFINITY, |a, &b| a.min(b));
125        let max_val = tensor.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
126
127        let (scale, zero_point) = match qtype {
128            QuantizationType::Int8 => {
129                let scale = (max_val - min_val) / 255.0;
130                let zero_point = (-min_val / scale).round() as i32;
131                (scale, zero_point)
132            }
133            QuantizationType::Int4 => {
134                let scale = (max_val - min_val) / 15.0;
135                let zero_point = (-min_val / scale).round() as i32;
136                (scale, zero_point)
137            }
138            QuantizationType::Fp16 | QuantizationType::BFloat16 | QuantizationType::None => {
139                (1.0, 0)
140            }
141        };
142
143        Self {
144            qtype,
145            scheme: QuantizationScheme::Asymmetric,
146            granularity: QuantizationGranularity::PerTensor,
147            scale: vec![scale],
148            zero_point: vec![zero_point],
149            min_val: vec![min_val],
150            max_val: vec![max_val],
151        }
152    }
153
154    /// Get the dynamic range of this quantization.
155    pub fn dynamic_range(&self) -> f64 {
156        self.max_val[0] - self.min_val[0]
157    }
158
159    /// Get the quantization error bound.
160    pub fn quantization_error_bound(&self) -> f64 {
161        self.scale[0] / 2.0
162    }
163}
164
165/// Simulated quantized tensor (stored as f64 but representing quantized values).
166#[derive(Debug, Clone)]
167pub struct QuantizedTensor {
168    /// The quantized data (stored as f64 for compatibility)
169    pub data: Scirs2Tensor,
170
171    /// Quantization parameters
172    pub params: QuantizationParams,
173}
174
175impl QuantizedTensor {
176    /// Quantize a tensor using the given parameters.
177    pub fn quantize(tensor: &Scirs2Tensor, params: QuantizationParams) -> Self {
178        let quantized_data = match params.qtype {
179            QuantizationType::Int8 => quantize_int8(tensor, &params),
180            QuantizationType::Int4 => quantize_int4(tensor, &params),
181            QuantizationType::Fp16 => quantize_fp16(tensor),
182            QuantizationType::BFloat16 => quantize_bf16(tensor),
183            QuantizationType::None => tensor.clone(),
184        };
185
186        Self {
187            data: quantized_data,
188            params,
189        }
190    }
191
192    /// Dequantize the tensor back to full precision.
193    pub fn dequantize(&self) -> Scirs2Tensor {
194        match self.params.qtype {
195            QuantizationType::Int8 | QuantizationType::Int4 => {
196                dequantize_integer(&self.data, &self.params)
197            }
198            QuantizationType::Fp16 | QuantizationType::BFloat16 => {
199                // Already in f64, just return
200                self.data.clone()
201            }
202            QuantizationType::None => self.data.clone(),
203        }
204    }
205
206    /// Get the memory size reduction ratio.
207    pub fn memory_reduction(&self) -> f64 {
208        self.params.qtype.compression_ratio()
209    }
210
211    /// Calculate the quantization error (MSE).
212    pub fn quantization_error(&self, original: &Scirs2Tensor) -> f64 {
213        let dequantized = self.dequantize();
214        let diff = &dequantized - original;
215        let squared_error: f64 = diff.iter().map(|&x| x * x).sum();
216        squared_error / original.len() as f64
217    }
218}
219
220/// Quantize tensor to INT8.
221fn quantize_int8(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
222    let scale = params.scale[0];
223    let zero_point = params.zero_point[0];
224
225    tensor.mapv(|x| {
226        let quantized = (x / scale).round() + zero_point as f64;
227        quantized.clamp(-128.0, 127.0)
228    })
229}
230
231/// Quantize tensor to INT4.
232fn quantize_int4(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
233    let scale = params.scale[0];
234    let zero_point = params.zero_point[0];
235
236    tensor.mapv(|x| {
237        let quantized = (x / scale).round() + zero_point as f64;
238        quantized.clamp(-8.0, 7.0)
239    })
240}
241
242/// Simulate FP16 quantization (with rounding to FP16 precision).
243fn quantize_fp16(tensor: &Scirs2Tensor) -> Scirs2Tensor {
244    tensor.mapv(|x| {
245        // Simulate FP16 by limiting mantissa precision
246        // FP16 has 10 mantissa bits vs FP64's 52 bits
247        let scaled = x * (1024.0f64).powi(2);
248        (scaled.round() / (1024.0f64).powi(2)).clamp(-65504.0, 65504.0)
249    })
250}
251
252/// Simulate BFloat16 quantization.
253fn quantize_bf16(tensor: &Scirs2Tensor) -> Scirs2Tensor {
254    tensor.mapv(|x| {
255        // BF16 has 7 mantissa bits vs FP64's 52 bits
256        let scaled = x * (128.0f64).powi(2);
257        scaled.round() / (128.0f64).powi(2)
258    })
259}
260
261/// Dequantize integer-quantized tensor.
262fn dequantize_integer(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
263    let scale = params.scale[0];
264    let zero_point = params.zero_point[0];
265
266    tensor.mapv(|q| (q - zero_point as f64) * scale)
267}
268
269/// Quantization-aware training configuration.
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct QatConfig {
272    /// Target quantization type
273    pub target_qtype: QuantizationType,
274
275    /// Quantization scheme
276    pub scheme: QuantizationScheme,
277
278    /// Number of warmup epochs before enabling quantization
279    pub warmup_epochs: usize,
280
281    /// Whether to use straight-through estimator for gradients
282    pub use_ste: bool,
283
284    /// Whether to learn scale and zero-point parameters
285    pub learnable_params: bool,
286}
287
288impl Default for QatConfig {
289    fn default() -> Self {
290        Self {
291            target_qtype: QuantizationType::Int8,
292            scheme: QuantizationScheme::Symmetric,
293            warmup_epochs: 2,
294            use_ste: true,
295            learnable_params: false,
296        }
297    }
298}
299
300/// Quantization statistics for analysis.
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct QuantizationStats {
303    /// Number of quantized tensors
304    pub num_tensors: usize,
305
306    /// Total memory saved (in bytes)
307    pub memory_saved: u64,
308
309    /// Average quantization error (MSE)
310    pub avg_error: f64,
311
312    /// Maximum quantization error
313    pub max_error: f64,
314
315    /// Distribution of quantization types used
316    pub type_distribution: Vec<(QuantizationType, usize)>,
317}
318
319impl QuantizationStats {
320    /// Create empty statistics.
321    pub fn new() -> Self {
322        Self {
323            num_tensors: 0,
324            memory_saved: 0,
325            avg_error: 0.0,
326            max_error: 0.0,
327            type_distribution: Vec::new(),
328        }
329    }
330
331    /// Update statistics with a new quantized tensor.
332    pub fn update(&mut self, original_size: u64, compression_ratio: f64, error: f64) {
333        self.num_tensors += 1;
334        self.memory_saved += (original_size as f64 * (1.0 - 1.0 / compression_ratio)) as u64;
335
336        // Update running average error
337        let n = self.num_tensors as f64;
338        self.avg_error = (self.avg_error * (n - 1.0) + error) / n;
339        self.max_error = self.max_error.max(error);
340    }
341
342    /// Get memory reduction percentage.
343    pub fn memory_reduction_pct(&self, total_memory: u64) -> f64 {
344        if total_memory == 0 {
345            0.0
346        } else {
347            (self.memory_saved as f64 / total_memory as f64) * 100.0
348        }
349    }
350}
351
352impl Default for QuantizationStats {
353    fn default() -> Self {
354        Self::new()
355    }
356}
357
358/// Calibrate quantization parameters using sample data.
359pub fn calibrate_quantization(
360    samples: &[Scirs2Tensor],
361    qtype: QuantizationType,
362    scheme: QuantizationScheme,
363) -> TlBackendResult<QuantizationParams> {
364    if samples.is_empty() {
365        return Err(TlBackendError::GraphError(
366            "Cannot calibrate with empty samples".to_string(),
367        ));
368    }
369
370    // Collect statistics across all samples
371    let mut global_min = f64::INFINITY;
372    let mut global_max = f64::NEG_INFINITY;
373    let mut global_abs_max = 0.0f64;
374
375    for sample in samples {
376        let sample_min = sample.iter().fold(f64::INFINITY, |a, &b| a.min(b));
377        let sample_max = sample.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
378        let sample_abs_max = sample.iter().map(|&x| x.abs()).fold(0.0f64, f64::max);
379
380        global_min = global_min.min(sample_min);
381        global_max = global_max.max(sample_max);
382        global_abs_max = global_abs_max.max(sample_abs_max);
383    }
384
385    let params = match scheme {
386        QuantizationScheme::Symmetric => {
387            let scale = match qtype {
388                QuantizationType::Int8 => global_abs_max / 127.0,
389                QuantizationType::Int4 => global_abs_max / 7.0,
390                _ => 1.0,
391            };
392
393            QuantizationParams {
394                qtype,
395                scheme,
396                granularity: QuantizationGranularity::PerTensor,
397                scale: vec![scale],
398                zero_point: vec![0],
399                min_val: vec![-global_abs_max],
400                max_val: vec![global_abs_max],
401            }
402        }
403        QuantizationScheme::Asymmetric => {
404            let (scale, zero_point) = match qtype {
405                QuantizationType::Int8 => {
406                    let scale = (global_max - global_min) / 255.0;
407                    let zero_point = (-global_min / scale).round() as i32;
408                    (scale, zero_point)
409                }
410                QuantizationType::Int4 => {
411                    let scale = (global_max - global_min) / 15.0;
412                    let zero_point = (-global_min / scale).round() as i32;
413                    (scale, zero_point)
414                }
415                _ => (1.0, 0),
416            };
417
418            QuantizationParams {
419                qtype,
420                scheme,
421                granularity: QuantizationGranularity::PerTensor,
422                scale: vec![scale],
423                zero_point: vec![zero_point],
424                min_val: vec![global_min],
425                max_val: vec![global_max],
426            }
427        }
428    };
429
430    Ok(params)
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use scirs2_core::ndarray::ArrayD;
437
438    #[test]
439    fn test_quantization_type_properties() {
440        assert_eq!(QuantizationType::Int8.bits(), 8);
441        assert_eq!(QuantizationType::Int4.bits(), 4);
442        assert_eq!(QuantizationType::Fp16.bits(), 16);
443        assert_eq!(QuantizationType::BFloat16.bits(), 16);
444
445        assert_eq!(QuantizationType::Int8.compression_ratio(), 8.0);
446        assert_eq!(QuantizationType::Int4.compression_ratio(), 16.0);
447
448        assert!(QuantizationType::Int8.is_integer());
449        assert!(QuantizationType::Fp16.is_float());
450    }
451
452    #[test]
453    fn test_symmetric_quantization_int8() {
454        let data = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
455        let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).unwrap();
456
457        let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
458
459        assert_eq!(params.scheme, QuantizationScheme::Symmetric);
460        assert_eq!(params.zero_point[0], 0);
461        assert!(params.scale[0] > 0.0);
462    }
463
464    #[test]
465    fn test_asymmetric_quantization_int8() {
466        let data = vec![0.0, 2.0, 4.0, 6.0, 8.0];
467        let tensor = ArrayD::from_shape_vec(vec![5], data).unwrap();
468
469        let params = QuantizationParams::asymmetric_per_tensor(QuantizationType::Int8, &tensor);
470
471        assert_eq!(params.scheme, QuantizationScheme::Asymmetric);
472        assert!(params.zero_point[0] >= 0);
473        assert!(params.scale[0] > 0.0);
474    }
475
476    #[test]
477    fn test_quantize_dequantize_int8() {
478        let data = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
479        let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).unwrap();
480
481        let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
482        let quantized = QuantizedTensor::quantize(&tensor, params);
483        let dequantized = quantized.dequantize();
484
485        // Check that dequantized values are close to original
486        for (orig, deq) in tensor.iter().zip(dequantized.iter()) {
487            assert!(
488                (orig - deq).abs() < 0.1,
489                "Original: {}, Dequantized: {}",
490                orig,
491                deq
492            );
493        }
494    }
495
496    #[test]
497    fn test_quantization_error() {
498        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
499        let tensor = ArrayD::from_shape_vec(vec![5], data).unwrap();
500
501        let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
502        let quantized = QuantizedTensor::quantize(&tensor, params);
503
504        let error = quantized.quantization_error(&tensor);
505        assert!(error >= 0.0);
506        assert!(error < 1.0); // Error should be small for this simple case
507    }
508
509    #[test]
510    fn test_memory_reduction() {
511        let tensor = ArrayD::from_shape_vec(vec![100], vec![1.0; 100]).unwrap();
512        let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
513        let quantized = QuantizedTensor::quantize(&tensor, params);
514
515        assert_eq!(quantized.memory_reduction(), 8.0); // 64-bit to 8-bit = 8x compression
516    }
517
518    #[test]
519    fn test_calibrate_quantization() {
520        let sample1 = ArrayD::from_shape_vec(vec![3], vec![-10.0, 0.0, 10.0]).unwrap();
521        let sample2 = ArrayD::from_shape_vec(vec![3], vec![-8.0, 2.0, 12.0]).unwrap();
522        let samples = vec![sample1, sample2];
523
524        let params = calibrate_quantization(
525            &samples,
526            QuantizationType::Int8,
527            QuantizationScheme::Symmetric,
528        )
529        .unwrap();
530
531        assert!(params.scale[0] > 0.0);
532        assert_eq!(params.zero_point[0], 0); // Symmetric
533    }
534
535    #[test]
536    fn test_quantization_stats() {
537        let mut stats = QuantizationStats::new();
538
539        stats.update(1000, 8.0, 0.01);
540        stats.update(2000, 8.0, 0.02);
541
542        assert_eq!(stats.num_tensors, 2);
543        assert!(stats.memory_saved > 0);
544        assert!(stats.avg_error > 0.0);
545        assert_eq!(stats.max_error, 0.02);
546    }
547
548    #[test]
549    fn test_fp16_quantization() {
550        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
551        let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).unwrap();
552
553        let quantized = quantize_fp16(&tensor);
554
555        // FP16 should preserve values reasonably well for small numbers
556        for (orig, quant) in tensor.iter().zip(quantized.iter()) {
557            assert!((orig - quant).abs() < 0.001);
558        }
559    }
560
561    #[test]
562    fn test_qat_config_default() {
563        let config = QatConfig::default();
564
565        assert_eq!(config.target_qtype, QuantizationType::Int8);
566        assert_eq!(config.scheme, QuantizationScheme::Symmetric);
567        assert!(config.use_ste);
568    }
569}