Skip to main content

tensorlogic_infer/
quantization.rs

1//! Advanced quantization support for model compression and acceleration.
2//!
3//! This module provides comprehensive quantization capabilities including:
4//! - Multiple quantization schemes (INT8, INT4, FP8, binary)
5//! - Quantization-Aware Training (QAT) support
6//! - Post-Training Quantization (PTQ) with calibration
7//! - Per-channel and per-tensor quantization
8//! - Symmetric and asymmetric quantization
9//! - Dynamic and static quantization modes
10//! - Quantization simulation for accuracy validation
11
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use tensorlogic_ir::OpType;
15use thiserror::Error;
16
17/// Node identifier (0-based index into graph.nodes).
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
19pub struct NodeId(pub usize);
20
21/// Quantization-related errors.
22#[derive(Error, Debug, Clone, PartialEq)]
23pub enum QuantizationError {
24    #[error("Unsupported data type for quantization: {0}")]
25    UnsupportedDataType(String),
26
27    #[error("Invalid quantization range: min={min}, max={max}")]
28    InvalidRange { min: f64, max: f64 },
29
30    #[error("Calibration failed: {0}")]
31    CalibrationFailed(String),
32
33    #[error("Quantization not supported for operation: {0:?}")]
34    UnsupportedOperation(OpType),
35
36    #[error("Invalid quantization parameters: {0}")]
37    InvalidParameters(String),
38
39    #[error("Insufficient calibration data")]
40    InsufficientData,
41}
42
43/// Quantization data types.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
45pub enum QuantizationType {
46    /// 8-bit integer quantization
47    Int8,
48    /// 4-bit integer quantization
49    Int4,
50    /// 2-bit integer quantization (extreme compression)
51    Int2,
52    /// 8-bit floating point (E4M3 or E5M2)
53    FP8E4M3,
54    /// FP8 E5M2 format
55    FP8E5M2,
56    /// 16-bit floating point
57    FP16,
58    /// 16-bit brain float
59    BF16,
60    /// Binary quantization (1-bit)
61    Binary,
62    /// Ternary quantization (-1, 0, 1)
63    Ternary,
64}
65
66impl QuantizationType {
67    /// Returns the number of bits used by this quantization type.
68    pub fn bits(&self) -> u32 {
69        match self {
70            Self::Binary => 1,
71            Self::Int2 => 2,
72            Self::Int4 => 4,
73            Self::Int8 | Self::FP8E4M3 | Self::FP8E5M2 => 8,
74            Self::FP16 | Self::BF16 => 16,
75            Self::Ternary => 2, // effectively 1.58 bits, but stored as 2
76        }
77    }
78
79    /// Returns the theoretical compression ratio vs FP32.
80    pub fn compression_ratio(&self) -> f64 {
81        32.0 / self.bits() as f64
82    }
83
84    /// Returns whether this type supports floating point values.
85    pub fn is_floating_point(&self) -> bool {
86        matches!(
87            self,
88            Self::FP8E4M3 | Self::FP8E5M2 | Self::FP16 | Self::BF16
89        )
90    }
91}
92
93/// Quantization granularity.
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
95pub enum QuantizationGranularity {
96    /// Quantize entire tensor with single scale/zero-point
97    PerTensor,
98    /// Quantize each channel independently
99    PerChannel { axis: usize },
100    /// Quantize groups of channels
101    PerGroup { axis: usize, group_size: usize },
102}
103
104/// Quantization symmetry mode.
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
106pub enum QuantizationSymmetry {
107    /// Symmetric quantization (zero_point = 0)
108    Symmetric,
109    /// Asymmetric quantization (arbitrary zero_point)
110    Asymmetric,
111}
112
113/// Quantization mode.
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
115pub enum QuantizationMode {
116    /// Static quantization (pre-computed scales)
117    Static,
118    /// Dynamic quantization (compute scales at runtime)
119    Dynamic,
120    /// Quantization-aware training simulation
121    QAT,
122}
123
124/// Calibration strategy for post-training quantization.
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
126pub enum CalibrationStrategy {
127    /// Use min/max values observed during calibration
128    MinMax,
129    /// Use percentiles to handle outliers (e.g., 0.1% and 99.9%)
130    Percentile { lower: u32, upper: u32 },
131    /// Minimize mean squared error
132    MSE,
133    /// Minimize KL divergence between distributions
134    KLDivergence,
135    /// Entropy-based calibration
136    Entropy,
137}
138
139/// Quantization parameters for a tensor.
140#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
141pub struct QuantizationParams {
142    /// Quantization type
143    pub qtype: QuantizationType,
144    /// Scale factor(s)
145    pub scale: Vec<f64>,
146    /// Zero point(s)
147    pub zero_point: Vec<i32>,
148    /// Granularity
149    pub granularity: QuantizationGranularity,
150    /// Symmetry mode
151    pub symmetry: QuantizationSymmetry,
152    /// Observed min/max during calibration (for validation)
153    pub observed_min: Option<f64>,
154    pub observed_max: Option<f64>,
155}
156
157impl QuantizationParams {
158    /// Create symmetric per-tensor quantization parameters.
159    pub fn symmetric_per_tensor(
160        qtype: QuantizationType,
161        abs_max: f64,
162    ) -> Result<Self, QuantizationError> {
163        if abs_max <= 0.0 {
164            return Err(QuantizationError::InvalidRange {
165                min: -abs_max,
166                max: abs_max,
167            });
168        }
169
170        let qmax = match qtype {
171            QuantizationType::Int8 => 127.0,
172            QuantizationType::Int4 => 7.0,
173            QuantizationType::Int2 => 1.0,
174            QuantizationType::Binary => 1.0,
175            QuantizationType::Ternary => 1.0,
176            _ => {
177                return Err(QuantizationError::UnsupportedDataType(format!(
178                    "{:?}",
179                    qtype
180                )))
181            }
182        };
183
184        let scale = abs_max / qmax;
185
186        Ok(Self {
187            qtype,
188            scale: vec![scale],
189            zero_point: vec![0],
190            granularity: QuantizationGranularity::PerTensor,
191            symmetry: QuantizationSymmetry::Symmetric,
192            observed_min: Some(-abs_max),
193            observed_max: Some(abs_max),
194        })
195    }
196
197    /// Create asymmetric per-tensor quantization parameters.
198    pub fn asymmetric_per_tensor(
199        qtype: QuantizationType,
200        min: f64,
201        max: f64,
202    ) -> Result<Self, QuantizationError> {
203        if min >= max {
204            return Err(QuantizationError::InvalidRange { min, max });
205        }
206
207        let (qmin, qmax) = match qtype {
208            QuantizationType::Int8 => (-128.0, 127.0),
209            QuantizationType::Int4 => (-8.0, 7.0),
210            QuantizationType::Int2 => (-2.0, 1.0),
211            _ => {
212                return Err(QuantizationError::UnsupportedDataType(format!(
213                    "{:?}",
214                    qtype
215                )))
216            }
217        };
218
219        let scale = (max - min) / (qmax - qmin);
220        let zero_point = (qmin - min / scale).round() as i32;
221
222        Ok(Self {
223            qtype,
224            scale: vec![scale],
225            zero_point: vec![zero_point],
226            granularity: QuantizationGranularity::PerTensor,
227            symmetry: QuantizationSymmetry::Asymmetric,
228            observed_min: Some(min),
229            observed_max: Some(max),
230        })
231    }
232
233    /// Quantize a floating-point value to integer.
234    pub fn quantize(&self, value: f64) -> i32 {
235        let scale = self.scale[0];
236        let zero_point = self.zero_point[0];
237        ((value / scale).round() as i32 + zero_point).clamp(self.qmin(), self.qmax())
238    }
239
240    /// Dequantize an integer value to floating-point.
241    pub fn dequantize(&self, qvalue: i32) -> f64 {
242        let scale = self.scale[0];
243        let zero_point = self.zero_point[0];
244        (qvalue - zero_point) as f64 * scale
245    }
246
247    /// Get quantization minimum value.
248    fn qmin(&self) -> i32 {
249        match self.qtype {
250            QuantizationType::Int8 => -128,
251            QuantizationType::Int4 => -8,
252            QuantizationType::Int2 => -2,
253            QuantizationType::Binary => 0,
254            QuantizationType::Ternary => -1,
255            _ => 0,
256        }
257    }
258
259    /// Get quantization maximum value.
260    fn qmax(&self) -> i32 {
261        match self.qtype {
262            QuantizationType::Int8 => 127,
263            QuantizationType::Int4 => 7,
264            QuantizationType::Int2 => 1,
265            QuantizationType::Binary => 1,
266            QuantizationType::Ternary => 1,
267            _ => 255,
268        }
269    }
270}
271
272/// Quantization configuration for a graph or model.
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct QuantizationConfig {
275    /// Default quantization type
276    pub default_qtype: QuantizationType,
277    /// Quantization mode
278    pub mode: QuantizationMode,
279    /// Granularity
280    pub granularity: QuantizationGranularity,
281    /// Symmetry mode
282    pub symmetry: QuantizationSymmetry,
283    /// Calibration strategy (for PTQ)
284    pub calibration: CalibrationStrategy,
285    /// Number of calibration samples
286    pub calibration_samples: usize,
287    /// Operations to skip quantization
288    pub skip_ops: Vec<OpType>,
289    /// Per-node quantization overrides
290    pub node_overrides: HashMap<NodeId, QuantizationType>,
291}
292
293impl Default for QuantizationConfig {
294    fn default() -> Self {
295        Self {
296            default_qtype: QuantizationType::Int8,
297            mode: QuantizationMode::Static,
298            granularity: QuantizationGranularity::PerTensor,
299            symmetry: QuantizationSymmetry::Symmetric,
300            calibration: CalibrationStrategy::MinMax,
301            calibration_samples: 100,
302            skip_ops: vec![],
303            node_overrides: HashMap::new(),
304        }
305    }
306}
307
308impl QuantizationConfig {
309    /// Create a configuration for int8 quantization.
310    pub fn int8() -> Self {
311        Self {
312            default_qtype: QuantizationType::Int8,
313            ..Default::default()
314        }
315    }
316
317    /// Create a configuration for int4 quantization.
318    pub fn int4() -> Self {
319        Self {
320            default_qtype: QuantizationType::Int4,
321            ..Default::default()
322        }
323    }
324
325    /// Create a configuration for FP8 quantization.
326    pub fn fp8() -> Self {
327        Self {
328            default_qtype: QuantizationType::FP8E4M3,
329            symmetry: QuantizationSymmetry::Symmetric,
330            ..Default::default()
331        }
332    }
333
334    /// Create a configuration for quantization-aware training.
335    pub fn qat(qtype: QuantizationType) -> Self {
336        Self {
337            default_qtype: qtype,
338            mode: QuantizationMode::QAT,
339            ..Default::default()
340        }
341    }
342
343    /// Enable per-channel quantization.
344    pub fn per_channel(mut self, axis: usize) -> Self {
345        self.granularity = QuantizationGranularity::PerChannel { axis };
346        self
347    }
348
349    /// Enable asymmetric quantization.
350    pub fn asymmetric(mut self) -> Self {
351        self.symmetry = QuantizationSymmetry::Asymmetric;
352        self
353    }
354
355    /// Set calibration strategy.
356    pub fn with_calibration(mut self, strategy: CalibrationStrategy) -> Self {
357        self.calibration = strategy;
358        self
359    }
360}
361
362/// Statistics collected during calibration.
363#[derive(Debug, Clone, Default, Serialize, Deserialize)]
364pub struct CalibrationStats {
365    /// Minimum values observed per node
366    pub min_values: HashMap<NodeId, f64>,
367    /// Maximum values observed per node
368    pub max_values: HashMap<NodeId, f64>,
369    /// Histogram bins for distribution analysis
370    pub histograms: HashMap<NodeId, Vec<u32>>,
371    /// Number of samples observed
372    pub num_samples: usize,
373}
374
375impl CalibrationStats {
376    /// Create new calibration statistics.
377    pub fn new() -> Self {
378        Self::default()
379    }
380
381    /// Update statistics with a new observation.
382    pub fn update(&mut self, node_id: NodeId, min: f64, max: f64) {
383        self.min_values
384            .entry(node_id)
385            .and_modify(|v| *v = v.min(min))
386            .or_insert(min);
387        self.max_values
388            .entry(node_id)
389            .and_modify(|v| *v = v.max(max))
390            .or_insert(max);
391        self.num_samples += 1;
392    }
393
394    /// Get computed quantization parameters for a node.
395    pub fn compute_params(
396        &self,
397        node_id: NodeId,
398        config: &QuantizationConfig,
399    ) -> Result<QuantizationParams, QuantizationError> {
400        let min = self
401            .min_values
402            .get(&node_id)
403            .ok_or(QuantizationError::InsufficientData)?;
404        let max = self
405            .max_values
406            .get(&node_id)
407            .ok_or(QuantizationError::InsufficientData)?;
408
409        let qtype = config
410            .node_overrides
411            .get(&node_id)
412            .copied()
413            .unwrap_or(config.default_qtype);
414
415        match config.symmetry {
416            QuantizationSymmetry::Symmetric => {
417                let abs_max = min.abs().max(max.abs());
418                QuantizationParams::symmetric_per_tensor(qtype, abs_max)
419            }
420            QuantizationSymmetry::Asymmetric => {
421                QuantizationParams::asymmetric_per_tensor(qtype, *min, *max)
422            }
423        }
424    }
425}
426
427/// Quantizer for converting graphs to quantized representations.
428pub struct Quantizer {
429    config: QuantizationConfig,
430    stats: CalibrationStats,
431    params: HashMap<NodeId, QuantizationParams>,
432}
433
434impl Quantizer {
435    /// Create a new quantizer with the given configuration.
436    pub fn new(config: QuantizationConfig) -> Self {
437        Self {
438            config,
439            stats: CalibrationStats::new(),
440            params: HashMap::new(),
441        }
442    }
443
444    /// Create a quantizer for int8 quantization.
445    pub fn int8() -> Self {
446        Self::new(QuantizationConfig::int8())
447    }
448
449    /// Create a quantizer for int4 quantization.
450    pub fn int4() -> Self {
451        Self::new(QuantizationConfig::int4())
452    }
453
454    /// Get the configuration.
455    pub fn config(&self) -> &QuantizationConfig {
456        &self.config
457    }
458
459    /// Get calibration statistics.
460    pub fn stats(&self) -> &CalibrationStats {
461        &self.stats
462    }
463
464    /// Get quantization parameters for a node.
465    pub fn get_params(&self, node_id: NodeId) -> Option<&QuantizationParams> {
466        self.params.get(&node_id)
467    }
468
469    /// Add calibration data for a node.
470    pub fn calibrate(&mut self, node_id: NodeId, min: f64, max: f64) {
471        self.stats.update(node_id, min, max);
472    }
473
474    /// Finalize calibration and compute quantization parameters.
475    pub fn finalize_calibration(&mut self) -> Result<(), QuantizationError> {
476        if self.stats.num_samples < self.config.calibration_samples {
477            return Err(QuantizationError::InsufficientData);
478        }
479
480        // Compute params for all calibrated nodes
481        for &node_id in self.stats.min_values.keys() {
482            let params = self.stats.compute_params(node_id, &self.config)?;
483            self.params.insert(node_id, params);
484        }
485
486        Ok(())
487    }
488
489    /// Get quantization summary statistics.
490    pub fn summary(&self) -> QuantizationSummary {
491        let mut type_counts = HashMap::new();
492        for params in self.params.values() {
493            *type_counts.entry(params.qtype).or_insert(0) += 1;
494        }
495
496        let total_params = self.params.len();
497        let avg_compression = self
498            .params
499            .values()
500            .map(|p| p.qtype.compression_ratio())
501            .sum::<f64>()
502            / total_params.max(1) as f64;
503
504        QuantizationSummary {
505            num_quantized_nodes: total_params,
506            type_distribution: type_counts,
507            avg_compression_ratio: avg_compression,
508            calibration_samples: self.stats.num_samples,
509        }
510    }
511}
512
513/// Summary of quantization results.
514#[derive(Debug, Clone, Serialize, Deserialize)]
515pub struct QuantizationSummary {
516    /// Number of quantized nodes
517    pub num_quantized_nodes: usize,
518    /// Distribution of quantization types
519    pub type_distribution: HashMap<QuantizationType, usize>,
520    /// Average compression ratio
521    pub avg_compression_ratio: f64,
522    /// Number of calibration samples used
523    pub calibration_samples: usize,
524}
525
526impl QuantizationSummary {
527    /// Get estimated memory savings.
528    pub fn memory_savings(&self) -> f64 {
529        if self.avg_compression_ratio > 1.0 {
530            (1.0 - 1.0 / self.avg_compression_ratio) * 100.0
531        } else {
532            0.0
533        }
534    }
535}
536
537/// Fake quantization for QAT (simulates quantization during training).
538pub struct FakeQuantize {
539    params: QuantizationParams,
540    enabled: bool,
541}
542
543impl FakeQuantize {
544    /// Create a new fake quantization module.
545    pub fn new(params: QuantizationParams) -> Self {
546        Self {
547            params,
548            enabled: true,
549        }
550    }
551
552    /// Enable or disable fake quantization.
553    pub fn set_enabled(&mut self, enabled: bool) {
554        self.enabled = enabled;
555    }
556
557    /// Apply fake quantization to a value.
558    pub fn forward(&self, value: f64) -> f64 {
559        if !self.enabled {
560            return value;
561        }
562
563        // Quantize then dequantize (simulating quantization noise)
564        let qvalue = self.params.quantize(value);
565        self.params.dequantize(qvalue)
566    }
567
568    /// Simulate quantization on a batch of values.
569    pub fn forward_batch(&self, values: &[f64]) -> Vec<f64> {
570        values.iter().map(|&v| self.forward(v)).collect()
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577
578    #[test]
579    fn test_quantization_type_properties() {
580        assert_eq!(QuantizationType::Int8.bits(), 8);
581        assert_eq!(QuantizationType::Int4.bits(), 4);
582        assert_eq!(QuantizationType::Binary.bits(), 1);
583        assert_eq!(QuantizationType::Int8.compression_ratio(), 4.0);
584        assert!(QuantizationType::FP16.is_floating_point());
585        assert!(!QuantizationType::Int8.is_floating_point());
586    }
587
588    #[test]
589    fn test_symmetric_quantization() {
590        let params =
591            QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, 127.0).unwrap();
592        assert_eq!(params.scale[0], 1.0);
593        assert_eq!(params.zero_point[0], 0);
594
595        // Test quantize/dequantize
596        assert_eq!(params.quantize(0.0), 0);
597        assert_eq!(params.quantize(127.0), 127);
598        assert_eq!(params.quantize(-127.0), -127);
599        assert!((params.dequantize(127) - 127.0).abs() < 1e-10);
600    }
601
602    #[test]
603    fn test_asymmetric_quantization() {
604        let params =
605            QuantizationParams::asymmetric_per_tensor(QuantizationType::Int8, -10.0, 20.0).unwrap();
606
607        assert!(params.scale[0] > 0.0);
608        assert_ne!(params.zero_point[0], 0);
609
610        // Test round-trip
611        let original = 5.0;
612        let quantized = params.quantize(original);
613        let dequantized = params.dequantize(quantized);
614        assert!((dequantized - original).abs() < 1.0); // Allow quantization error
615    }
616
617    #[test]
618    fn test_quantization_config() {
619        let config = QuantizationConfig::int8();
620        assert_eq!(config.default_qtype, QuantizationType::Int8);
621
622        let config = QuantizationConfig::int4().per_channel(0).asymmetric();
623        assert_eq!(config.default_qtype, QuantizationType::Int4);
624        assert!(matches!(
625            config.granularity,
626            QuantizationGranularity::PerChannel { axis: 0 }
627        ));
628        assert_eq!(config.symmetry, QuantizationSymmetry::Asymmetric);
629    }
630
631    #[test]
632    fn test_calibration_stats() {
633        let mut stats = CalibrationStats::new();
634        stats.update(NodeId(0), -5.0, 10.0);
635        stats.update(NodeId(0), -8.0, 12.0);
636
637        assert_eq!(stats.min_values[&NodeId(0)], -8.0);
638        assert_eq!(stats.max_values[&NodeId(0)], 12.0);
639    }
640
641    #[test]
642    fn test_quantizer() {
643        let mut quantizer = Quantizer::int8();
644
645        // Calibrate
646        quantizer.calibrate(NodeId(0), -10.0, 10.0);
647        quantizer.calibrate(NodeId(0), -8.0, 12.0);
648
649        // Since calibration_samples default is 100, we need to adjust or add more
650        // For testing, let's manually set sufficient samples
651        for _ in 0..100 {
652            quantizer.calibrate(NodeId(0), -10.0, 10.0);
653        }
654
655        assert!(quantizer.finalize_calibration().is_ok());
656        assert!(quantizer.get_params(NodeId(0)).is_some());
657
658        let summary = quantizer.summary();
659        assert_eq!(summary.num_quantized_nodes, 1);
660        assert!(summary.avg_compression_ratio > 1.0);
661    }
662
663    #[test]
664    fn test_fake_quantize() {
665        let params =
666            QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, 10.0).unwrap();
667        let fake_quant = FakeQuantize::new(params);
668
669        let original = 3.5;
670        let faked = fake_quant.forward(original);
671
672        // Should be close but not exact due to quantization
673        assert!((faked - original).abs() < 1.0);
674    }
675
676    #[test]
677    fn test_quantization_summary() {
678        let mut quantizer = Quantizer::int8();
679        for _ in 0..100 {
680            quantizer.calibrate(NodeId(0), -10.0, 10.0);
681        }
682        quantizer.finalize_calibration().unwrap();
683
684        let summary = quantizer.summary();
685        assert!(summary.memory_savings() > 0.0);
686        assert!(summary.memory_savings() < 100.0);
687    }
688
689    #[test]
690    fn test_int4_quantization() {
691        let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int4, 7.0).unwrap();
692
693        let value = 5.0;
694        let qvalue = params.quantize(value);
695        assert!((-8..=7).contains(&qvalue));
696    }
697
698    #[test]
699    fn test_invalid_range() {
700        let result = QuantizationParams::asymmetric_per_tensor(QuantizationType::Int8, 10.0, 5.0);
701        assert!(matches!(
702            result,
703            Err(QuantizationError::InvalidRange { .. })
704        ));
705    }
706}