Skip to main content

torsh_quantization/
config.rs

1//! Quantization configuration types and builders
2//!
3//! This module provides the core configuration types for quantization operations,
4//! including quantization schemes, backend configurations, and observer types.
5//!
6//! # Features
7//!
8//! - **Quantization Schemes**: Support for various quantization types (INT8, INT4, Binary, etc.)
9//! - **Backend Configuration**: Multiple backend support (FBGEMM, QNNPACK, Native)
10//! - **Observer Types**: Different calibration observers for optimal quantization parameters
11//! - **Mixed Precision**: Advanced configuration for layer-specific precision
12//! - **Builder Pattern**: Fluent API for configuration construction
13//! - **Validation**: Comprehensive configuration validation
14
15#[cfg(feature = "std")]
16use std::collections::HashMap;
17
18#[cfg(not(feature = "std"))]
19extern crate alloc;
20
21#[cfg(not(feature = "std"))]
22use alloc::{collections::BTreeMap as HashMap, string::String};
23
24use torsh_core::{
25    dtype::DType,
26    error::{Result as TorshResult, TorshError},
27};
28
29/// Quantization scheme
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
31pub enum QScheme {
32    /// Per-tensor affine quantization
33    PerTensorAffine,
34    /// Per-channel affine quantization
35    PerChannelAffine,
36    /// Symmetric quantization
37    PerTensorSymmetric,
38    /// Per-channel symmetric
39    PerChannelSymmetric,
40    /// INT4 quantization (4-bit)
41    Int4PerTensor,
42    /// INT4 per-channel quantization
43    Int4PerChannel,
44    /// Mixed precision quantization
45    MixedPrecision,
46    /// Binary quantization (1-bit)
47    Binary,
48    /// Ternary quantization (2-bit with -1, 0, 1)
49    Ternary,
50    /// Group-wise quantization (groups channels and quantizes per-group)
51    GroupWise,
52}
53
54/// Quantization backend types
55#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
56pub enum QuantBackend {
57    /// FBGEMM backend (CPU optimized)
58    Fbgemm,
59    /// QNNPACK backend (mobile optimized)
60    Qnnpack,
61    /// Native backend (fallback)
62    Native,
63    /// XNNPACK backend
64    Xnnpack,
65}
66
67/// Quantization reduction type
68#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
69pub enum ReduceRange {
70    /// No range reduction
71    None,
72    /// Reduce range for better accuracy
73    Reduce,
74}
75
76/// Observer types for quantization parameter calculation
77#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
78pub enum ObserverType {
79    /// Min-max observer
80    MinMax,
81    /// Moving average min-max observer
82    MovingAverage,
83    /// Histogram observer
84    Histogram,
85    /// Percentile observer
86    Percentile,
87    /// KL divergence observer (for mixed precision)
88    KLDivergence,
89    /// Entropy-based observer
90    Entropy,
91}
92
93/// Quantization configuration
94#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
95pub struct QuantConfig {
96    pub dtype: DType,
97    pub scheme: QScheme,
98    pub enable_fake_quant: bool,
99    pub observer_type: ObserverType,
100    pub backend: QuantBackend,
101    pub reduce_range: ReduceRange,
102    pub qint_min: Option<i32>,
103    pub qint_max: Option<i32>,
104    pub eps: f32,
105    pub averaging_constant: f32,
106    pub ch_axis: Option<usize>,
107    /// Group size for group-wise quantization
108    pub group_size: Option<usize>,
109}
110
111impl Default for QuantConfig {
112    fn default() -> Self {
113        Self {
114            dtype: DType::I8,
115            scheme: QScheme::PerTensorAffine,
116            enable_fake_quant: false,
117            observer_type: ObserverType::MinMax,
118            backend: QuantBackend::Native,
119            reduce_range: ReduceRange::None,
120            qint_min: None,
121            qint_max: None,
122            eps: 1e-8,
123            averaging_constant: 0.01,
124            ch_axis: None,
125            group_size: None,
126        }
127    }
128}
129
130/// Mixed precision configuration
131#[derive(Debug, Clone)]
132pub struct MixedPrecisionConfig {
133    /// Precision for different layer types
134    pub layer_precision: HashMap<String, DType>,
135    /// Default precision for unspecified layers
136    pub default_precision: DType,
137    /// Sensitivity threshold for precision selection
138    pub sensitivity_threshold: f32,
139}
140
141impl Default for MixedPrecisionConfig {
142    fn default() -> Self {
143        let mut layer_precision = HashMap::new();
144        layer_precision.insert("embedding".to_string(), DType::I8);
145        layer_precision.insert("attention".to_string(), DType::F16);
146        layer_precision.insert("output".to_string(), DType::F32);
147
148        Self {
149            layer_precision,
150            default_precision: DType::I8,
151            sensitivity_threshold: 0.1,
152        }
153    }
154}
155
156impl QuantConfig {
157    /// Create a new quantization config with defaults
158    pub fn new() -> Self {
159        Self::default()
160    }
161
162    /// Create config for INT8 quantization
163    pub fn int8() -> Self {
164        Self {
165            dtype: DType::I8,
166            qint_min: Some(-128),
167            qint_max: Some(127),
168            ..Self::default()
169        }
170    }
171
172    /// Create config for INT4 quantization
173    pub fn int4() -> Self {
174        Self {
175            dtype: DType::I8, // Store as I8 but quantize to 4-bit range
176            scheme: QScheme::Int4PerTensor,
177            qint_min: Some(-8),
178            qint_max: Some(7),
179            observer_type: ObserverType::Histogram,
180            ..Self::default()
181        }
182    }
183
184    /// Create config for binary quantization
185    pub fn binary() -> Self {
186        Self {
187            dtype: DType::I8,
188            scheme: QScheme::Binary,
189            qint_min: Some(-1),
190            qint_max: Some(1),
191            observer_type: ObserverType::MinMax,
192            ..Self::default()
193        }
194    }
195
196    /// Create config for ternary quantization
197    pub fn ternary() -> Self {
198        Self {
199            dtype: DType::I8,
200            scheme: QScheme::Ternary,
201            qint_min: Some(-1),
202            qint_max: Some(1),
203            observer_type: ObserverType::MinMax,
204            ..Self::default()
205        }
206    }
207
208    /// Create config for mixed precision
209    pub fn mixed_precision() -> Self {
210        Self {
211            dtype: DType::I8, // Default precision
212            scheme: QScheme::MixedPrecision,
213            observer_type: ObserverType::KLDivergence,
214            ..Self::default()
215        }
216    }
217
218    /// Create config for UINT8 quantization
219    pub fn uint8() -> Self {
220        Self {
221            dtype: DType::U8,
222            qint_min: Some(0),
223            qint_max: Some(255),
224            ..Self::default()
225        }
226    }
227
228    /// Create config for per-channel quantization
229    pub fn per_channel(ch_axis: usize) -> Self {
230        Self {
231            scheme: QScheme::PerChannelAffine,
232            ch_axis: Some(ch_axis),
233            ..Self::default()
234        }
235    }
236
237    /// Create config for group-wise quantization
238    pub fn group_wise(ch_axis: usize, group_size: usize) -> Self {
239        Self {
240            scheme: QScheme::GroupWise,
241            ch_axis: Some(ch_axis),
242            group_size: Some(group_size),
243            observer_type: ObserverType::Histogram,
244            ..Self::default()
245        }
246    }
247
248    /// Create config for QAT (Quantization Aware Training)
249    pub fn qat() -> Self {
250        Self {
251            enable_fake_quant: true,
252            observer_type: ObserverType::MovingAverage,
253            ..Self::default()
254        }
255    }
256
257    /// Create config with specific backend
258    pub fn with_backend(mut self, backend: QuantBackend) -> Self {
259        self.backend = backend;
260        self
261    }
262
263    /// Set observer type
264    pub fn with_observer(mut self, observer_type: ObserverType) -> Self {
265        self.observer_type = observer_type;
266        self
267    }
268
269    /// Set quantization scheme
270    pub fn with_scheme(mut self, scheme: QScheme) -> Self {
271        self.scheme = scheme;
272        if matches!(
273            scheme,
274            QScheme::PerChannelAffine | QScheme::PerChannelSymmetric | QScheme::GroupWise
275        ) && self.ch_axis.is_none()
276        {
277            self.ch_axis = Some(0); // Default channel axis
278        }
279        if matches!(scheme, QScheme::GroupWise) && self.group_size.is_none() {
280            self.group_size = Some(32); // Default group size
281        }
282        self
283    }
284
285    /// Enable/disable fake quantization
286    pub fn with_fake_quant(mut self, enable: bool) -> Self {
287        self.enable_fake_quant = enable;
288        self
289    }
290
291    /// Set reduce range option
292    pub fn with_reduce_range(mut self, reduce_range: ReduceRange) -> Self {
293        self.reduce_range = reduce_range;
294        self
295    }
296
297    /// Set group size for group-wise quantization
298    pub fn with_group_size(mut self, group_size: usize) -> Self {
299        self.group_size = Some(group_size);
300        self
301    }
302
303    /// Get effective quantization range considering scheme and reduce_range
304    pub fn get_qint_range(&self) -> (i32, i32) {
305        let (base_min, base_max) = match self.scheme {
306            QScheme::Int4PerTensor | QScheme::Int4PerChannel => (-8, 7),
307            QScheme::Binary => (-1, 1),
308            QScheme::Ternary => (-1, 1),
309            _ => match self.dtype {
310                DType::I8 => (-128, 127),
311                DType::U8 => (0, 255),
312                _ => (self.qint_min.unwrap_or(-128), self.qint_max.unwrap_or(127)),
313            },
314        };
315
316        let (qmin, qmax) = match self.reduce_range {
317            ReduceRange::None => (base_min, base_max),
318            ReduceRange::Reduce => {
319                // Reduce range by 1 bit for better accuracy
320                let range = base_max - base_min;
321                let reduced_range = range / 2;
322                let mid = (base_min + base_max) / 2;
323                (mid - reduced_range / 2, mid + reduced_range / 2)
324            }
325        };
326
327        (qmin, qmax)
328    }
329
330    /// Validate configuration
331    pub fn validate(&self) -> TorshResult<()> {
332        // Check if per-channel scheme has channel axis
333        if matches!(
334            self.scheme,
335            QScheme::PerChannelAffine | QScheme::PerChannelSymmetric | QScheme::GroupWise
336        ) && self.ch_axis.is_none()
337        {
338            return Err(TorshError::InvalidArgument(
339                "Per-channel/Group-wise quantization requires channel axis".to_string(),
340            ));
341        }
342
343        // Check if group-wise scheme has group size
344        if matches!(self.scheme, QScheme::GroupWise) {
345            if self.group_size.is_none() {
346                return Err(TorshError::InvalidArgument(
347                    "Group-wise quantization requires group size".to_string(),
348                ));
349            }
350            if let Some(group_size) = self.group_size {
351                if group_size == 0 {
352                    return Err(TorshError::InvalidArgument(
353                        "Group size must be greater than 0".to_string(),
354                    ));
355                }
356            }
357        }
358
359        // Check if symmetric scheme is compatible with zero point
360        if matches!(
361            self.scheme,
362            QScheme::PerTensorSymmetric | QScheme::PerChannelSymmetric
363        ) {
364            // Symmetric quantization should have zero_point = 0
365        }
366
367        // Check if binary/ternary schemes are valid
368        if matches!(self.scheme, QScheme::Binary | QScheme::Ternary)
369            && !matches!(
370                self.observer_type,
371                ObserverType::MinMax | ObserverType::MovingAverage
372            )
373        {
374            return Err(TorshError::InvalidArgument(
375                "Binary/ternary quantization requires MinMax or MovingAverage observer".to_string(),
376            ));
377        }
378
379        // Check if mixed precision has valid configuration
380        if matches!(self.scheme, QScheme::MixedPrecision)
381            && !matches!(
382                self.observer_type,
383                ObserverType::KLDivergence | ObserverType::Entropy
384            )
385        {
386            return Err(TorshError::InvalidArgument(
387                "Mixed precision quantization requires KLDivergence or Entropy observer"
388                    .to_string(),
389            ));
390        }
391
392        // Validate eps
393        if self.eps <= 0.0 {
394            return Err(TorshError::InvalidArgument(
395                "eps must be positive".to_string(),
396            ));
397        }
398
399        // Validate averaging constant
400        if self.averaging_constant <= 0.0 || self.averaging_constant >= 1.0 {
401            return Err(TorshError::InvalidArgument(
402                "averaging_constant must be in (0, 1)".to_string(),
403            ));
404        }
405
406        Ok(())
407    }
408}
409
410/// Configuration builder for specific quantization backends
411pub struct QuantConfigBuilder {
412    config: QuantConfig,
413}
414
415impl QuantConfigBuilder {
416    /// Start building a new configuration
417    pub fn new() -> Self {
418        Self {
419            config: QuantConfig::default(),
420        }
421    }
422
423    /// Set the data type
424    pub fn dtype(mut self, dtype: DType) -> Self {
425        self.config.dtype = dtype;
426        self
427    }
428
429    /// Set the quantization scheme
430    pub fn scheme(mut self, scheme: QScheme) -> Self {
431        self.config = self.config.with_scheme(scheme);
432        self
433    }
434
435    /// Set the observer type
436    pub fn observer(mut self, observer_type: ObserverType) -> Self {
437        self.config.observer_type = observer_type;
438        self
439    }
440
441    /// Set the backend
442    pub fn backend(mut self, backend: QuantBackend) -> Self {
443        self.config.backend = backend;
444        self
445    }
446
447    /// Enable/disable fake quantization
448    pub fn fake_quant(mut self, enable: bool) -> Self {
449        self.config.enable_fake_quant = enable;
450        self
451    }
452
453    /// Set channel axis for per-channel quantization
454    pub fn channel_axis(mut self, axis: usize) -> Self {
455        self.config.ch_axis = Some(axis);
456        self
457    }
458
459    /// Set group size for group-wise quantization
460    pub fn group_size(mut self, size: usize) -> Self {
461        self.config.group_size = Some(size);
462        self
463    }
464
465    /// Build the final configuration
466    pub fn build(self) -> TorshResult<QuantConfig> {
467        self.config.validate()?;
468        Ok(self.config)
469    }
470}
471
472impl Default for QuantConfigBuilder {
473    fn default() -> Self {
474        Self::new()
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[test]
483    fn test_quant_config_defaults() {
484        let config = QuantConfig::default();
485        assert_eq!(config.dtype, DType::I8);
486        assert_eq!(config.scheme, QScheme::PerTensorAffine);
487        assert!(!config.enable_fake_quant);
488        assert_eq!(config.observer_type, ObserverType::MinMax);
489        assert_eq!(config.backend, QuantBackend::Native);
490        assert_eq!(config.reduce_range, ReduceRange::None);
491    }
492
493    #[test]
494    fn test_quant_config_presets() {
495        let int8_config = QuantConfig::int8();
496        assert_eq!(int8_config.dtype, DType::I8);
497        assert_eq!(int8_config.qint_min, Some(-128));
498        assert_eq!(int8_config.qint_max, Some(127));
499
500        let binary_config = QuantConfig::binary();
501        assert_eq!(binary_config.scheme, QScheme::Binary);
502        assert_eq!(binary_config.qint_min, Some(-1));
503        assert_eq!(binary_config.qint_max, Some(1));
504
505        let int4_config = QuantConfig::int4();
506        assert_eq!(int4_config.scheme, QScheme::Int4PerTensor);
507        assert_eq!(int4_config.observer_type, ObserverType::Histogram);
508    }
509
510    #[test]
511    fn test_quant_config_builder() {
512        let config = QuantConfigBuilder::new()
513            .dtype(DType::I8)
514            .scheme(QScheme::PerChannelAffine)
515            .observer(ObserverType::Histogram)
516            .backend(QuantBackend::Fbgemm)
517            .channel_axis(1)
518            .build()
519            .unwrap();
520
521        assert_eq!(config.dtype, DType::I8);
522        assert_eq!(config.scheme, QScheme::PerChannelAffine);
523        assert_eq!(config.observer_type, ObserverType::Histogram);
524        assert_eq!(config.backend, QuantBackend::Fbgemm);
525        assert_eq!(config.ch_axis, Some(1));
526    }
527
528    #[test]
529    fn test_config_validation() {
530        // Valid configuration
531        let valid_config = QuantConfig::per_channel(0);
532        assert!(valid_config.validate().is_ok());
533
534        // Invalid per-channel without axis
535        let mut invalid_config = QuantConfig::default();
536        invalid_config.scheme = QScheme::PerChannelAffine;
537        invalid_config.ch_axis = None;
538        assert!(invalid_config.validate().is_err());
539
540        // Invalid group-wise without size
541        let mut invalid_group = QuantConfig::default();
542        invalid_group.scheme = QScheme::GroupWise;
543        invalid_group.ch_axis = Some(0);
544        invalid_group.group_size = None;
545        assert!(invalid_group.validate().is_err());
546
547        // Invalid eps
548        let mut invalid_eps = QuantConfig::default();
549        invalid_eps.eps = -1.0;
550        assert!(invalid_eps.validate().is_err());
551
552        // Invalid averaging constant
553        let mut invalid_avg = QuantConfig::default();
554        invalid_avg.averaging_constant = 1.5;
555        assert!(invalid_avg.validate().is_err());
556    }
557
558    #[test]
559    fn test_get_qint_range() {
560        let int8_config = QuantConfig::int8();
561        assert_eq!(int8_config.get_qint_range(), (-128, 127));
562
563        let uint8_config = QuantConfig::uint8();
564        assert_eq!(uint8_config.get_qint_range(), (0, 255));
565
566        let int4_config = QuantConfig::int4();
567        assert_eq!(int4_config.get_qint_range(), (-8, 7));
568
569        let binary_config = QuantConfig::binary();
570        assert_eq!(binary_config.get_qint_range(), (-1, 1));
571
572        // Test reduced range
573        let reduced_config = QuantConfig::int8().with_reduce_range(ReduceRange::Reduce);
574        let (min, max) = reduced_config.get_qint_range();
575        assert!(min > -128 && max < 127);
576    }
577
578    #[test]
579    fn test_mixed_precision_config() {
580        let mixed_config = MixedPrecisionConfig::default();
581        assert_eq!(mixed_config.default_precision, DType::I8);
582        assert_eq!(mixed_config.sensitivity_threshold, 0.1);
583        assert!(mixed_config.layer_precision.contains_key("embedding"));
584    }
585
586    #[test]
587    fn test_config_serialization() {
588        let config = QuantConfig::int8().with_observer(ObserverType::Histogram);
589
590        // Test JSON serialization
591        let json = serde_json::to_string(&config).unwrap();
592        let deserialized: QuantConfig = serde_json::from_str(&json).unwrap();
593
594        assert_eq!(config.dtype, deserialized.dtype);
595        assert_eq!(config.scheme, deserialized.scheme);
596        assert_eq!(config.observer_type, deserialized.observer_type);
597    }
598}