Skip to main content

ruvector_cnn/quantize/
params.rs

1//! Quantization parameters for INT8 quantization.
2//!
3//! This module defines the core quantization parameters used for both
4//! symmetric and asymmetric quantization schemes.
5
6use crate::error::{CnnError, CnnResult};
7use serde::{Deserialize, Serialize};
8
9/// Quantization parameters for a tensor or tensor slice.
10///
11/// Defines the mapping between floating-point values and quantized integers:
12/// - **Symmetric**: `x_q = round(x / scale)`
13/// - **Asymmetric**: `x_q = round(x / scale) + zero_point`
14///
15/// ## Invariants
16///
17/// - `scale > 0.0` (enforced at construction)
18/// - `qmin <= zero_point <= qmax`
19/// - For symmetric mode: `zero_point == 0`
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct QuantizationParams {
22    /// Scale factor for quantization.
23    /// Maps FP32 range to INT8 range.
24    pub scale: f32,
25
26    /// Zero point for asymmetric quantization.
27    /// Always 0 for symmetric quantization.
28    pub zero_point: i32,
29
30    /// Minimum quantized value (typically -128 for i8).
31    pub qmin: i8,
32
33    /// Maximum quantized value (typically 127 for i8).
34    pub qmax: i8,
35}
36
37impl QuantizationParams {
38    /// Create symmetric quantization parameters from min/max values.
39    ///
40    /// Symmetric quantization uses `zero_point = 0` and maps the range
41    /// `[-max_abs, max_abs]` to `[-127, 127]`.
42    ///
43    /// # Arguments
44    ///
45    /// * `min_val` - Minimum value in the FP32 tensor
46    /// * `max_val` - Maximum value in the FP32 tensor
47    ///
48    /// # Returns
49    ///
50    /// Quantization parameters with `zero_point = 0`.
51    ///
52    /// # Example
53    ///
54    /// ```rust,ignore
55    /// let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric);
56    /// assert_eq!(params.zero_point, 0);
57    /// assert!(params.scale > 0.0);
58    /// ```
59    pub fn from_minmax(min_val: f32, max_val: f32, mode: QuantizationMode) -> CnnResult<Self> {
60        if min_val > max_val {
61            return Err(CnnError::InvalidParameter(format!(
62                "min_val ({}) must be <= max_val ({})",
63                min_val, max_val
64            )));
65        }
66
67        match mode {
68            QuantizationMode::Symmetric => {
69                // Symmetric: zero_point = 0, scale based on max absolute value
70                let max_abs = min_val.abs().max(max_val.abs());
71                let scale = if max_abs > 0.0 {
72                    max_abs / 127.0
73                } else {
74                    1.0 // Prevent division by zero
75                };
76
77                Ok(Self {
78                    scale,
79                    zero_point: 0,
80                    qmin: -127,
81                    qmax: 127,
82                })
83            }
84            QuantizationMode::Asymmetric => {
85                // Asymmetric: Map [min_val, max_val] to [-127, 127] (255 bins)
86                // to maintain compatibility with i8 storage
87                let scale = if max_val > min_val {
88                    (max_val - min_val) / 254.0  // Use 254 to avoid clipping at edges
89                } else {
90                    1.0
91                };
92
93                let zero_point = if scale > 0.0 {
94                    // Map min_val to qmin=-127
95                    ((-min_val / scale).round() - 127.0).clamp(-127.0, 127.0) as i32
96                } else {
97                    0
98                };
99
100                Ok(Self {
101                    scale,
102                    zero_point,
103                    qmin: -127,
104                    qmax: 127,
105                })
106            }
107        }
108    }
109
110    /// Create quantization parameters from percentile values.
111    ///
112    /// Used during calibration to exclude outliers that would skew the
113    /// quantization range. Typically uses 0.001 and 0.999 percentiles.
114    ///
115    /// # Arguments
116    ///
117    /// * `percentile_min` - Lower percentile value (e.g., -10.0)
118    /// * `percentile_max` - Upper percentile value (e.g., 10.0)
119    /// * `mode` - Quantization mode (symmetric or asymmetric)
120    ///
121    /// # Example
122    ///
123    /// ```rust,ignore
124    /// // Use 99.8% of the data range, excluding outliers
125    /// let params = QuantizationParams::from_percentile(
126    ///     -9.5, 9.5, QuantizationMode::Symmetric
127    /// );
128    /// ```
129    pub fn from_percentile(
130        percentile_min: f32,
131        percentile_max: f32,
132        mode: QuantizationMode,
133    ) -> CnnResult<Self> {
134        Self::from_minmax(percentile_min, percentile_max, mode)
135    }
136
137    /// Validate that the parameters satisfy invariants.
138    ///
139    /// # Invariants
140    ///
141    /// - `scale > 0.0`
142    /// - `qmin <= qmax`
143    /// - `qmin <= zero_point <= qmax`
144    /// - For symmetric mode: `zero_point == 0`
145    pub fn validate(&self) -> CnnResult<()> {
146        if self.scale <= 0.0 {
147            return Err(CnnError::QuantizationError(format!(
148                "scale must be positive, got {}",
149                self.scale
150            )));
151        }
152
153        if self.qmin > self.qmax {
154            return Err(CnnError::QuantizationError(format!(
155                "qmin ({}) must be <= qmax ({})",
156                self.qmin, self.qmax
157            )));
158        }
159
160        if self.zero_point < self.qmin as i32 || self.zero_point > self.qmax as i32 {
161            return Err(CnnError::QuantizationError(format!(
162                "zero_point ({}) must be in range [{}, {}]",
163                self.zero_point, self.qmin, self.qmax
164            )));
165        }
166
167        Ok(())
168    }
169
170    /// Quantize a single FP32 value to INT8.
171    ///
172    /// Formula:
173    /// - Symmetric: `x_q = round(x / scale)`
174    /// - Asymmetric: `x_q = round(x / scale) + zero_point`
175    ///
176    /// Result is clamped to `[qmin, qmax]`.
177    #[inline]
178    pub fn quantize_value(&self, value: f32) -> i8 {
179        let q = (value / self.scale).round() + self.zero_point as f32;
180        q.clamp(self.qmin as f32, self.qmax as f32) as i8
181    }
182
183    /// Dequantize a single INT8 value to FP32.
184    ///
185    /// Formula:
186    /// - Symmetric: `x = x_q * scale`
187    /// - Asymmetric: `x = (x_q - zero_point) * scale`
188    #[inline]
189    pub fn dequantize_value(&self, value: i8) -> f32 {
190        (value as f32 - self.zero_point as f32) * self.scale
191    }
192}
193
194/// Quantization scheme granularity.
195///
196/// Determines whether a single scale factor applies to the entire tensor
197/// or per output channel (for weights).
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
199pub enum QuantizationScheme {
200    /// Single scale factor for entire tensor.
201    /// Used for activations.
202    PerTensor,
203
204    /// Scale factor per output channel.
205    /// Used for Conv2d weights to preserve accuracy.
206    PerChannel,
207}
208
209/// Quantization mode (symmetric vs asymmetric).
210///
211/// - **Symmetric**: Zero point is always 0. Good for weights centered around 0.
212/// - **Asymmetric**: Zero point computed to maximize range utilization. Good for ReLU activations.
213#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
214pub enum QuantizationMode {
215    /// Symmetric quantization: `x_q = round(x / scale)`.
216    ///
217    /// Zero point is always 0. Simpler computation, but may waste range
218    /// if data is not centered around 0.
219    Symmetric,
220
221    /// Asymmetric quantization: `x_q = round(x / scale) + zero_point`.
222    ///
223    /// Full range utilization for asymmetric distributions (e.g., ReLU outputs).
224    Asymmetric,
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_symmetric_minmax() {
233        let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
234            .unwrap();
235
236        assert_eq!(params.zero_point, 0);
237        assert!(params.scale > 0.0);
238        assert_eq!(params.qmin, -127);
239        assert_eq!(params.qmax, 127);
240
241        // Validate
242        params.validate().unwrap();
243    }
244
245    #[test]
246    fn test_asymmetric_minmax() {
247        let params = QuantizationParams::from_minmax(0.0, 10.0, QuantizationMode::Asymmetric)
248            .unwrap();
249
250        // For [0, 10] range, zero_point should map 0.0 to a quantized value
251        assert!(params.scale > 0.0);
252        assert!(params.zero_point >= -128);
253        assert!(params.zero_point <= 127);
254
255        params.validate().unwrap();
256    }
257
258    #[test]
259    fn test_quantize_dequantize_symmetric() {
260        let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
261            .unwrap();
262
263        let value = 5.0f32;
264        let quantized = params.quantize_value(value);
265        let dequantized = params.dequantize_value(quantized);
266
267        // Should be close (within quantization error)
268        assert!((dequantized - value).abs() < 0.1);
269    }
270
271    #[test]
272    fn test_quantize_dequantize_asymmetric() {
273        let params = QuantizationParams::from_minmax(0.0, 10.0, QuantizationMode::Asymmetric)
274            .unwrap();
275
276        let value = 5.0f32;
277        let quantized = params.quantize_value(value);
278        let dequantized = params.dequantize_value(quantized);
279
280        assert!((dequantized - value).abs() < 0.1);
281    }
282
283    #[test]
284    fn test_zero_value_quantization() {
285        let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
286            .unwrap();
287
288        let quantized = params.quantize_value(0.0);
289        assert_eq!(quantized, 0);
290
291        let dequantized = params.dequantize_value(0);
292        assert_eq!(dequantized, 0.0);
293    }
294
295    #[test]
296    fn test_clipping() {
297        let params = QuantizationParams::from_minmax(-1.0, 1.0, QuantizationMode::Symmetric)
298            .unwrap();
299
300        // Values outside range should be clipped
301        let large = params.quantize_value(1000.0);
302        assert_eq!(large, 127);
303
304        let small = params.quantize_value(-1000.0);
305        assert_eq!(small, -127);
306    }
307
308    #[test]
309    fn test_invalid_range() {
310        let result = QuantizationParams::from_minmax(10.0, -10.0, QuantizationMode::Symmetric);
311        assert!(result.is_err());
312    }
313
314    #[test]
315    fn test_percentile_constructor() {
316        let params = QuantizationParams::from_percentile(-9.5, 9.5, QuantizationMode::Symmetric)
317            .unwrap();
318
319        assert_eq!(params.zero_point, 0);
320        params.validate().unwrap();
321    }
322
323    #[test]
324    fn test_validation_negative_scale() {
325        let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
326            .unwrap();
327
328        params.scale = -1.0;
329        assert!(params.validate().is_err());
330    }
331
332    #[test]
333    fn test_validation_zero_scale() {
334        let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
335            .unwrap();
336
337        params.scale = 0.0;
338        assert!(params.validate().is_err());
339    }
340
341    #[test]
342    fn test_validation_invalid_qmin_qmax() {
343        let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
344            .unwrap();
345
346        params.qmin = 127;
347        params.qmax = -127;
348        assert!(params.validate().is_err());
349    }
350
351    #[test]
352    fn test_validation_zero_point_out_of_range() {
353        let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
354            .unwrap();
355
356        params.zero_point = 200;
357        assert!(params.validate().is_err());
358    }
359}