Skip to main content

torsh_backend/quantization/
params.rs

1//! Quantization parameters and configuration
2//!
3//! This module provides the QuantizationParams struct and related functionality
4//! for managing quantization configuration. It handles parameter calculation
5//! from statistics, preset configurations for common quantization schemes,
6//! and parameter validation.
7
8use super::types::{QuantizationScheme, QuantizedDType};
9use crate::BackendResult;
10
11#[cfg(not(feature = "std"))]
12use alloc::vec::Vec;
13
14/// Quantization parameters
15///
16/// Contains all the parameters needed to quantize and dequantize tensors,
17/// including scale factors, zero points, and metadata about the quantization
18/// scheme being used.
19#[derive(Debug, Clone)]
20pub struct QuantizationParams {
21    /// Quantization data type
22    ///
23    /// Specifies the target quantized data type (e.g., Int8, UInt8, Int4)
24    pub dtype: QuantizedDType,
25
26    /// Quantization scheme
27    ///
28    /// Defines how the quantization mapping is performed (linear, symmetric, etc.)
29    pub scheme: QuantizationScheme,
30
31    /// Scale factor(s)
32    ///
33    /// Maps quantized values back to floating-point range.
34    /// For per-channel quantization, contains one scale per channel.
35    /// Formula: float_val = scale * (quantized_val - zero_point)
36    pub scale: Vec<f32>,
37
38    /// Zero point(s)
39    ///
40    /// The quantized value that corresponds to floating-point zero.
41    /// For per-channel quantization, contains one zero point per channel.
42    /// For symmetric quantization, this is always 0.
43    pub zero_point: Vec<i32>,
44
45    /// Block size for block-wise quantization
46    ///
47    /// When using block-wise quantization, specifies the size of each block
48    /// that gets its own quantization parameters. None for other schemes.
49    pub block_size: Option<usize>,
50
51    /// Minimum value observed during calibration
52    ///
53    /// Used for parameter calculation and validation. Set during calibration
54    /// or when computing parameters from statistics.
55    pub min_val: Option<f32>,
56
57    /// Maximum value observed during calibration
58    ///
59    /// Used for parameter calculation and validation. Set during calibration
60    /// or when computing parameters from statistics.
61    pub max_val: Option<f32>,
62}
63
64impl Default for QuantizationParams {
65    /// Default quantization parameters
66    ///
67    /// Creates parameters for UInt8 linear quantization with scale=1.0
68    /// and zero_point=0, suitable for testing and initialization.
69    fn default() -> Self {
70        Self {
71            dtype: QuantizedDType::UInt8,
72            scheme: QuantizationScheme::Linear,
73            scale: vec![1.0],
74            zero_point: vec![0],
75            block_size: None,
76            min_val: None,
77            max_val: None,
78        }
79    }
80}
81
82impl QuantizationParams {
83    /// Create parameters for INT8 symmetric quantization
84    ///
85    /// INT8 symmetric quantization is commonly used for weights in neural networks
86    /// due to its simplicity and good hardware support. The zero point is always 0,
87    /// and the range is symmetric around zero.
88    ///
89    /// # Examples
90    ///
91    /// ```
92    /// use torsh_backend::quantization::QuantizationParams;
93    ///
94    /// let params = QuantizationParams::int8_symmetric();
95    /// assert_eq!(params.zero_point[0], 0);
96    /// ```
97    pub fn int8_symmetric() -> Self {
98        Self {
99            dtype: QuantizedDType::Int8,
100            scheme: QuantizationScheme::Symmetric,
101            scale: vec![1.0],
102            zero_point: vec![0],
103            block_size: None,
104            min_val: None,
105            max_val: None,
106        }
107    }
108
109    /// Create basic quantization parameters with custom scale and zero point
110    ///
111    /// This is a general-purpose constructor for creating quantization parameters
112    /// with custom scale and zero point values. Useful for benchmarking and
113    /// testing with specific parameter configurations.
114    ///
115    /// # Arguments
116    ///
117    /// * `scale` - Scale factor for the quantization
118    /// * `zero_point` - Zero point for the quantization
119    ///
120    /// # Examples
121    ///
122    /// ```
123    /// use torsh_backend::quantization::QuantizationParams;
124    ///
125    /// let params = QuantizationParams::new(255.0, 128);
126    /// assert_eq!(params.scale[0], 255.0);
127    /// assert_eq!(params.zero_point[0], 128);
128    /// ```
129    pub fn new(scale: f32, zero_point: i32) -> Self {
130        Self {
131            dtype: QuantizedDType::UInt8, // Default to UInt8 for general usage
132            scheme: QuantizationScheme::Asymmetric,
133            scale: vec![scale],
134            zero_point: vec![zero_point],
135            block_size: None,
136            min_val: None,
137            max_val: None,
138        }
139    }
140
141    /// Create parameters for UINT8 asymmetric quantization
142    ///
143    /// UInt8 asymmetric quantization is commonly used for activations,
144    /// especially after ReLU layers where values are non-negative.
145    /// The zero point is typically set to 128 for balanced range utilization.
146    ///
147    /// # Examples
148    ///
149    /// ```
150    /// use torsh_backend::quantization::QuantizationParams;
151    ///
152    /// let params = QuantizationParams::uint8_asymmetric();
153    /// assert_eq!(params.zero_point[0], 128);
154    /// ```
155    pub fn uint8_asymmetric() -> Self {
156        Self {
157            dtype: QuantizedDType::UInt8,
158            scheme: QuantizationScheme::Asymmetric,
159            scale: vec![1.0],
160            zero_point: vec![128],
161            block_size: None,
162            min_val: None,
163            max_val: None,
164        }
165    }
166
167    /// Create parameters for INT4 symmetric quantization
168    ///
169    /// INT4 quantization provides extreme compression at the cost of accuracy.
170    /// Symmetric INT4 is often used for weights in models where 4-bit precision
171    /// is sufficient.
172    ///
173    /// # Examples
174    ///
175    /// ```
176    /// use torsh_backend::quantization::QuantizationParams;
177    ///
178    /// let params = QuantizationParams::int4_symmetric();
179    /// assert_eq!(params.dtype.bits(), 4);
180    /// ```
181    pub fn int4_symmetric() -> Self {
182        Self {
183            dtype: QuantizedDType::Int4,
184            scheme: QuantizationScheme::Symmetric,
185            scale: vec![1.0],
186            zero_point: vec![0],
187            block_size: None,
188            min_val: None,
189            max_val: None,
190        }
191    }
192
193    /// Create parameters for channel-wise quantization
194    ///
195    /// Channel-wise quantization applies different quantization parameters
196    /// to each channel, providing better accuracy for models with varying
197    /// channel sensitivities at the cost of increased parameter storage.
198    ///
199    /// # Arguments
200    ///
201    /// * `num_channels` - Number of channels in the tensor
202    /// * `dtype` - Quantization data type to use
203    ///
204    /// # Examples
205    ///
206    /// ```
207    /// use torsh_backend::quantization::{QuantizationParams, QuantizedDType};
208    ///
209    /// let params = QuantizationParams::channel_wise(64, QuantizedDType::Int8);
210    /// assert_eq!(params.scale.len(), 64);
211    /// assert_eq!(params.zero_point.len(), 64);
212    /// ```
213    pub fn channel_wise(num_channels: usize, dtype: QuantizedDType) -> Self {
214        Self {
215            dtype,
216            scheme: QuantizationScheme::ChannelWise,
217            scale: vec![1.0; num_channels],
218            zero_point: vec![0; num_channels],
219            block_size: None,
220            min_val: None,
221            max_val: None,
222        }
223    }
224
225    /// Create parameters for block-wise quantization
226    ///
227    /// Block-wise quantization divides the tensor into blocks and applies
228    /// different quantization parameters to each block. This can provide
229    /// better accuracy than tensor-wise quantization while being more
230    /// memory-efficient than channel-wise quantization.
231    ///
232    /// # Arguments
233    ///
234    /// * `block_size` - Size of each quantization block
235    /// * `dtype` - Quantization data type to use
236    ///
237    /// # Examples
238    ///
239    /// ```
240    /// use torsh_backend::quantization::{QuantizationParams, QuantizedDType};
241    ///
242    /// let params = QuantizationParams::block_wise(128, QuantizedDType::Int8);
243    /// assert_eq!(params.block_size, Some(128));
244    /// ```
245    pub fn block_wise(block_size: usize, dtype: QuantizedDType) -> Self {
246        Self {
247            dtype,
248            scheme: QuantizationScheme::BlockWise,
249            scale: vec![1.0], // Will be expanded based on tensor size
250            zero_point: vec![0],
251            block_size: Some(block_size),
252            min_val: None,
253            max_val: None,
254        }
255    }
256
257    /// Calculate quantization parameters from input statistics
258    ///
259    /// Computes the optimal scale and zero point parameters based on the
260    /// observed minimum and maximum values in the data. The calculation
261    /// depends on the quantization scheme being used.
262    ///
263    /// # Arguments
264    ///
265    /// * `min_val` - Minimum value observed in the data
266    /// * `max_val` - Maximum value observed in the data
267    ///
268    /// # Returns
269    ///
270    /// Returns `Ok(())` if parameters were calculated successfully,
271    /// or an error if the statistics are invalid.
272    ///
273    /// # Examples
274    ///
275    /// ```
276    /// use torsh_backend::quantization::QuantizationParams;
277    ///
278    /// let mut params = QuantizationParams::int8_symmetric();
279    /// params.from_statistics(-2.0, 2.0).unwrap();
280    /// // Scale will be calculated to map [-2.0, 2.0] to [-128, 127]
281    /// ```
282    pub fn from_statistics(&mut self, min_val: f32, max_val: f32) -> BackendResult<()> {
283        // Validate input statistics
284        if min_val > max_val {
285            return Err(torsh_core::error::TorshError::InvalidArgument(
286                "min_val must be <= max_val".to_string(),
287            ));
288        }
289
290        if !min_val.is_finite() || !max_val.is_finite() {
291            return Err(torsh_core::error::TorshError::InvalidArgument(
292                "min_val and max_val must be finite".to_string(),
293            ));
294        }
295
296        self.min_val = Some(min_val);
297        self.max_val = Some(max_val);
298
299        let (qmin, qmax) = self.dtype.value_range();
300        let qmin = qmin as f32;
301        let qmax = qmax as f32;
302
303        match self.scheme {
304            QuantizationScheme::Symmetric => {
305                self.calculate_symmetric_params(min_val, max_val, qmin, qmax)?;
306            }
307            QuantizationScheme::Asymmetric | QuantizationScheme::Linear => {
308                self.calculate_asymmetric_params(min_val, max_val, qmin, qmax)?;
309            }
310            QuantizationScheme::Logarithmic => {
311                self.calculate_logarithmic_params(min_val, max_val, qmin, qmax)?;
312            }
313            QuantizationScheme::BlockWise | QuantizationScheme::ChannelWise => {
314                // For block-wise and channel-wise, use asymmetric as base
315                // Individual blocks/channels will be calculated separately
316                self.calculate_asymmetric_params(min_val, max_val, qmin, qmax)?;
317            }
318        }
319
320        Ok(())
321    }
322
323    /// Calculate symmetric quantization parameters
324    fn calculate_symmetric_params(
325        &mut self,
326        min_val: f32,
327        max_val: f32,
328        qmin: f32,
329        qmax: f32,
330    ) -> BackendResult<()> {
331        let max_range = max_val.abs().max(min_val.abs());
332        if max_range == 0.0 {
333            self.scale[0] = 1.0;
334        } else {
335            // For symmetric quantization, we map [-max_range, max_range] to [qmin, qmax]
336            self.scale[0] = (2.0 * max_range) / (qmax - qmin);
337        }
338        self.zero_point[0] = 0;
339        Ok(())
340    }
341
342    /// Calculate asymmetric quantization parameters
343    fn calculate_asymmetric_params(
344        &mut self,
345        min_val: f32,
346        max_val: f32,
347        qmin: f32,
348        qmax: f32,
349    ) -> BackendResult<()> {
350        if max_val == min_val {
351            // Degenerate case: all values are the same
352            self.scale[0] = 1.0;
353            self.zero_point[0] = qmin as i32;
354        } else {
355            // Calculate scale to map [min_val, max_val] to [qmin, qmax]
356            self.scale[0] = (max_val - min_val) / (qmax - qmin);
357
358            // Calculate zero point such that min_val maps to qmin
359            let zero_point_from_min = qmin - min_val / self.scale[0];
360            self.zero_point[0] = zero_point_from_min.round().clamp(qmin, qmax) as i32;
361        }
362        Ok(())
363    }
364
365    /// Calculate logarithmic quantization parameters
366    fn calculate_logarithmic_params(
367        &mut self,
368        min_val: f32,
369        max_val: f32,
370        qmin: f32,
371        qmax: f32,
372    ) -> BackendResult<()> {
373        // For logarithmic quantization, we need positive values
374        if min_val <= 0.0 {
375            return Err(torsh_core::error::TorshError::InvalidArgument(
376                "Logarithmic quantization requires positive values".to_string(),
377            ));
378        }
379
380        // Use logarithmic scale mapping
381        let log_min = min_val.ln();
382        let log_max = max_val.ln();
383
384        if log_max == log_min {
385            self.scale[0] = 1.0;
386            self.zero_point[0] = qmin as i32;
387        } else {
388            self.scale[0] = (log_max - log_min) / (qmax - qmin);
389            self.zero_point[0] = (qmin - log_min / self.scale[0]).round() as i32;
390        }
391        Ok(())
392    }
393
394    /// Validate that the parameters are consistent and usable
395    ///
396    /// Checks that all parameter vectors have consistent lengths,
397    /// scale factors are positive, and zero points are within valid ranges.
398    pub fn validate(&self) -> BackendResult<()> {
399        // Check that scale and zero_point vectors have consistent lengths
400        if self.scale.len() != self.zero_point.len() {
401            return Err(torsh_core::error::TorshError::InvalidArgument(
402                "Scale and zero_point vectors must have the same length".to_string(),
403            ));
404        }
405
406        // Check that all scale factors are positive and finite
407        for (i, &scale) in self.scale.iter().enumerate() {
408            if scale <= 0.0 || !scale.is_finite() {
409                return Err(torsh_core::error::TorshError::InvalidArgument(format!(
410                    "Scale factor at index {} must be positive and finite, got {}",
411                    i, scale
412                )));
413            }
414        }
415
416        // Check that zero points are within the valid range for the data type
417        let (qmin, qmax) = self.dtype.value_range();
418        for (i, &zero_point) in self.zero_point.iter().enumerate() {
419            if (zero_point as i64) < qmin || (zero_point as i64) > qmax {
420                return Err(torsh_core::error::TorshError::InvalidArgument(format!(
421                    "Zero point at index {} ({}) is outside valid range [{}, {}]",
422                    i, zero_point, qmin, qmax
423                )));
424            }
425        }
426
427        // Scheme-specific validation
428        match self.scheme {
429            QuantizationScheme::Symmetric => {
430                // Symmetric quantization should have zero_point = 0
431                for (i, &zero_point) in self.zero_point.iter().enumerate() {
432                    if zero_point != 0 {
433                        return Err(torsh_core::error::TorshError::InvalidArgument(format!(
434                            "Symmetric quantization requires zero_point[{}] = 0, got {}",
435                            i, zero_point
436                        )));
437                    }
438                }
439            }
440            QuantizationScheme::BlockWise => {
441                // Block-wise quantization should have a block size specified
442                if self.block_size.is_none() {
443                    return Err(torsh_core::error::TorshError::InvalidArgument(
444                        "Block-wise quantization requires block_size to be specified".to_string(),
445                    ));
446                }
447            }
448            QuantizationScheme::ChannelWise => {
449                // Channel-wise should have multiple parameters
450                if self.scale.len() == 1 {
451                    return Err(torsh_core::error::TorshError::InvalidArgument(
452                        "Channel-wise quantization requires multiple scale/zero_point values"
453                            .to_string(),
454                    ));
455                }
456            }
457            _ => {} // Other schemes have no specific requirements
458        }
459
460        Ok(())
461    }
462
463    /// Get the effective number of quantization parameter sets
464    ///
465    /// Returns the number of independent parameter sets (scale/zero_point pairs)
466    /// that this configuration represents. For tensor-wise quantization this is 1,
467    /// for channel-wise it's the number of channels.
468    pub fn num_parameter_sets(&self) -> usize {
469        self.scale.len()
470    }
471
472    /// Check if this configuration uses per-channel parameters
473    pub fn is_per_channel(&self) -> bool {
474        self.scheme.is_per_channel() && self.scale.len() > 1
475    }
476
477    /// Get the quantization error bound for this configuration
478    ///
479    /// Returns the maximum possible quantization error (in the original
480    /// floating-point scale) for this quantization configuration.
481    pub fn quantization_error_bound(&self) -> f32 {
482        // The maximum error is half the quantization step size
483        self.scale
484            .iter()
485            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
486            .copied()
487            .unwrap_or(0.0)
488            * 0.5
489    }
490
491    /// Calculate the compression ratio achieved by this quantization
492    ///
493    /// Returns the ratio of original size to quantized size.
494    /// Assumes the original data was 32-bit floating point.
495    pub fn compression_ratio(&self) -> f32 {
496        let bits_per_value = self.dtype.bits() as f32;
497        32.0 / bits_per_value
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn test_default_params() {
507        let params = QuantizationParams::default();
508        assert_eq!(params.dtype, QuantizedDType::UInt8);
509        assert_eq!(params.scheme, QuantizationScheme::Linear);
510        assert_eq!(params.scale, vec![1.0]);
511        assert_eq!(params.zero_point, vec![0]);
512    }
513
514    #[test]
515    fn test_preset_configs() {
516        let int8_sym = QuantizationParams::int8_symmetric();
517        assert_eq!(int8_sym.dtype, QuantizedDType::Int8);
518        assert_eq!(int8_sym.scheme, QuantizationScheme::Symmetric);
519        assert_eq!(int8_sym.zero_point[0], 0);
520
521        let uint8_asym = QuantizationParams::uint8_asymmetric();
522        assert_eq!(uint8_asym.dtype, QuantizedDType::UInt8);
523        assert_eq!(uint8_asym.scheme, QuantizationScheme::Asymmetric);
524        assert_eq!(uint8_asym.zero_point[0], 128);
525
526        let int4_sym = QuantizationParams::int4_symmetric();
527        assert_eq!(int4_sym.dtype, QuantizedDType::Int4);
528        assert_eq!(int4_sym.zero_point[0], 0);
529    }
530
531    #[test]
532    fn test_channel_wise_params() {
533        let params = QuantizationParams::channel_wise(64, QuantizedDType::Int8);
534        assert_eq!(params.scheme, QuantizationScheme::ChannelWise);
535        assert_eq!(params.scale.len(), 64);
536        assert_eq!(params.zero_point.len(), 64);
537        assert!(params.is_per_channel());
538    }
539
540    #[test]
541    fn test_block_wise_params() {
542        let params = QuantizationParams::block_wise(128, QuantizedDType::Int8);
543        assert_eq!(params.scheme, QuantizationScheme::BlockWise);
544        assert_eq!(params.block_size, Some(128));
545    }
546
547    #[test]
548    fn test_from_statistics_symmetric() {
549        let mut params = QuantizationParams::int8_symmetric();
550        params.from_statistics(-2.0, 2.0).unwrap();
551
552        assert_eq!(params.zero_point[0], 0);
553        assert!(params.scale[0] > 0.0);
554        assert_eq!(params.min_val, Some(-2.0));
555        assert_eq!(params.max_val, Some(2.0));
556    }
557
558    #[test]
559    fn test_from_statistics_asymmetric() {
560        let mut params = QuantizationParams::uint8_asymmetric();
561        params.from_statistics(0.0, 255.0).unwrap();
562
563        assert!(params.scale[0] > 0.0);
564        assert!(params.zero_point[0] >= 0 && params.zero_point[0] <= 255);
565    }
566
567    #[test]
568    fn test_from_statistics_invalid() {
569        let mut params = QuantizationParams::default();
570
571        // min > max should fail
572        assert!(params.from_statistics(2.0, 1.0).is_err());
573
574        // Non-finite values should fail
575        assert!(params.from_statistics(f32::NAN, 1.0).is_err());
576        assert!(params.from_statistics(1.0, f32::INFINITY).is_err());
577    }
578
579    #[test]
580    fn test_validation() {
581        let mut params = QuantizationParams::default();
582        assert!(params.validate().is_ok());
583
584        // Mismatched vector lengths should fail
585        params.scale.push(2.0);
586        assert!(params.validate().is_err());
587
588        // Reset and test negative scale
589        params.scale = vec![-1.0];
590        params.zero_point = vec![0];
591        assert!(params.validate().is_err());
592    }
593
594    #[test]
595    fn test_validation_symmetric() {
596        let mut params = QuantizationParams::int8_symmetric();
597        params.zero_point[0] = 10; // Should fail for symmetric
598        assert!(params.validate().is_err());
599    }
600
601    #[test]
602    fn test_compression_ratio() {
603        let int8_params = QuantizationParams::int8_symmetric();
604        assert_eq!(int8_params.compression_ratio(), 4.0); // 32 bits -> 8 bits
605
606        let int4_params = QuantizationParams::int4_symmetric();
607        assert_eq!(int4_params.compression_ratio(), 8.0); // 32 bits -> 4 bits
608    }
609
610    #[test]
611    fn test_error_bound() {
612        let mut params = QuantizationParams::int8_symmetric();
613        params.scale = vec![0.1];
614        assert_eq!(params.quantization_error_bound(), 0.05); // Half the scale
615    }
616
617    #[test]
618    fn test_num_parameter_sets() {
619        let tensor_wise = QuantizationParams::default();
620        assert_eq!(tensor_wise.num_parameter_sets(), 1);
621
622        let channel_wise = QuantizationParams::channel_wise(64, QuantizedDType::Int8);
623        assert_eq!(channel_wise.num_parameter_sets(), 64);
624    }
625}