ruvector_cnn/quantize/params.rs
1//! Quantization parameters for INT8 quantization.
2//!
3//! This module defines the core quantization parameters used for both
4//! symmetric and asymmetric quantization schemes.
5
6use crate::error::{CnnError, CnnResult};
7use serde::{Deserialize, Serialize};
8
9/// Quantization parameters for a tensor or tensor slice.
10///
11/// Defines the mapping between floating-point values and quantized integers:
12/// - **Symmetric**: `x_q = round(x / scale)`
13/// - **Asymmetric**: `x_q = round(x / scale) + zero_point`
14///
15/// ## Invariants
16///
17/// - `scale > 0.0` (enforced at construction)
18/// - `qmin <= zero_point <= qmax`
19/// - For symmetric mode: `zero_point == 0`
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct QuantizationParams {
22 /// Scale factor for quantization.
23 /// Maps FP32 range to INT8 range.
24 pub scale: f32,
25
26 /// Zero point for asymmetric quantization.
27 /// Always 0 for symmetric quantization.
28 pub zero_point: i32,
29
30 /// Minimum quantized value (typically -128 for i8).
31 pub qmin: i8,
32
33 /// Maximum quantized value (typically 127 for i8).
34 pub qmax: i8,
35}
36
37impl QuantizationParams {
38 /// Create symmetric quantization parameters from min/max values.
39 ///
40 /// Symmetric quantization uses `zero_point = 0` and maps the range
41 /// `[-max_abs, max_abs]` to `[-127, 127]`.
42 ///
43 /// # Arguments
44 ///
45 /// * `min_val` - Minimum value in the FP32 tensor
46 /// * `max_val` - Maximum value in the FP32 tensor
47 ///
48 /// # Returns
49 ///
50 /// Quantization parameters with `zero_point = 0`.
51 ///
52 /// # Example
53 ///
54 /// ```rust,ignore
55 /// let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric);
56 /// assert_eq!(params.zero_point, 0);
57 /// assert!(params.scale > 0.0);
58 /// ```
59 pub fn from_minmax(min_val: f32, max_val: f32, mode: QuantizationMode) -> CnnResult<Self> {
60 if min_val > max_val {
61 return Err(CnnError::InvalidParameter(format!(
62 "min_val ({}) must be <= max_val ({})",
63 min_val, max_val
64 )));
65 }
66
67 match mode {
68 QuantizationMode::Symmetric => {
69 // Symmetric: zero_point = 0, scale based on max absolute value
70 let max_abs = min_val.abs().max(max_val.abs());
71 let scale = if max_abs > 0.0 {
72 max_abs / 127.0
73 } else {
74 1.0 // Prevent division by zero
75 };
76
77 Ok(Self {
78 scale,
79 zero_point: 0,
80 qmin: -127,
81 qmax: 127,
82 })
83 }
84 QuantizationMode::Asymmetric => {
85 // Asymmetric: Map [min_val, max_val] to [-127, 127] (255 bins)
86 // to maintain compatibility with i8 storage
87 let scale = if max_val > min_val {
88 (max_val - min_val) / 254.0 // Use 254 to avoid clipping at edges
89 } else {
90 1.0
91 };
92
93 let zero_point = if scale > 0.0 {
94 // Map min_val to qmin=-127
95 ((-min_val / scale).round() - 127.0).clamp(-127.0, 127.0) as i32
96 } else {
97 0
98 };
99
100 Ok(Self {
101 scale,
102 zero_point,
103 qmin: -127,
104 qmax: 127,
105 })
106 }
107 }
108 }
109
110 /// Create quantization parameters from percentile values.
111 ///
112 /// Used during calibration to exclude outliers that would skew the
113 /// quantization range. Typically uses 0.001 and 0.999 percentiles.
114 ///
115 /// # Arguments
116 ///
117 /// * `percentile_min` - Lower percentile value (e.g., -10.0)
118 /// * `percentile_max` - Upper percentile value (e.g., 10.0)
119 /// * `mode` - Quantization mode (symmetric or asymmetric)
120 ///
121 /// # Example
122 ///
123 /// ```rust,ignore
124 /// // Use 99.8% of the data range, excluding outliers
125 /// let params = QuantizationParams::from_percentile(
126 /// -9.5, 9.5, QuantizationMode::Symmetric
127 /// );
128 /// ```
129 pub fn from_percentile(
130 percentile_min: f32,
131 percentile_max: f32,
132 mode: QuantizationMode,
133 ) -> CnnResult<Self> {
134 Self::from_minmax(percentile_min, percentile_max, mode)
135 }
136
137 /// Validate that the parameters satisfy invariants.
138 ///
139 /// # Invariants
140 ///
141 /// - `scale > 0.0`
142 /// - `qmin <= qmax`
143 /// - `qmin <= zero_point <= qmax`
144 /// - For symmetric mode: `zero_point == 0`
145 pub fn validate(&self) -> CnnResult<()> {
146 if self.scale <= 0.0 {
147 return Err(CnnError::QuantizationError(format!(
148 "scale must be positive, got {}",
149 self.scale
150 )));
151 }
152
153 if self.qmin > self.qmax {
154 return Err(CnnError::QuantizationError(format!(
155 "qmin ({}) must be <= qmax ({})",
156 self.qmin, self.qmax
157 )));
158 }
159
160 if self.zero_point < self.qmin as i32 || self.zero_point > self.qmax as i32 {
161 return Err(CnnError::QuantizationError(format!(
162 "zero_point ({}) must be in range [{}, {}]",
163 self.zero_point, self.qmin, self.qmax
164 )));
165 }
166
167 Ok(())
168 }
169
170 /// Quantize a single FP32 value to INT8.
171 ///
172 /// Formula:
173 /// - Symmetric: `x_q = round(x / scale)`
174 /// - Asymmetric: `x_q = round(x / scale) + zero_point`
175 ///
176 /// Result is clamped to `[qmin, qmax]`.
177 #[inline]
178 pub fn quantize_value(&self, value: f32) -> i8 {
179 let q = (value / self.scale).round() + self.zero_point as f32;
180 q.clamp(self.qmin as f32, self.qmax as f32) as i8
181 }
182
183 /// Dequantize a single INT8 value to FP32.
184 ///
185 /// Formula:
186 /// - Symmetric: `x = x_q * scale`
187 /// - Asymmetric: `x = (x_q - zero_point) * scale`
188 #[inline]
189 pub fn dequantize_value(&self, value: i8) -> f32 {
190 (value as f32 - self.zero_point as f32) * self.scale
191 }
192}
193
194/// Quantization scheme granularity.
195///
196/// Determines whether a single scale factor applies to the entire tensor
197/// or per output channel (for weights).
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
199pub enum QuantizationScheme {
200 /// Single scale factor for entire tensor.
201 /// Used for activations.
202 PerTensor,
203
204 /// Scale factor per output channel.
205 /// Used for Conv2d weights to preserve accuracy.
206 PerChannel,
207}
208
209/// Quantization mode (symmetric vs asymmetric).
210///
211/// - **Symmetric**: Zero point is always 0. Good for weights centered around 0.
212/// - **Asymmetric**: Zero point computed to maximize range utilization. Good for ReLU activations.
213#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
214pub enum QuantizationMode {
215 /// Symmetric quantization: `x_q = round(x / scale)`.
216 ///
217 /// Zero point is always 0. Simpler computation, but may waste range
218 /// if data is not centered around 0.
219 Symmetric,
220
221 /// Asymmetric quantization: `x_q = round(x / scale) + zero_point`.
222 ///
223 /// Full range utilization for asymmetric distributions (e.g., ReLU outputs).
224 Asymmetric,
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_symmetric_minmax() {
233 let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
234 .unwrap();
235
236 assert_eq!(params.zero_point, 0);
237 assert!(params.scale > 0.0);
238 assert_eq!(params.qmin, -127);
239 assert_eq!(params.qmax, 127);
240
241 // Validate
242 params.validate().unwrap();
243 }
244
245 #[test]
246 fn test_asymmetric_minmax() {
247 let params = QuantizationParams::from_minmax(0.0, 10.0, QuantizationMode::Asymmetric)
248 .unwrap();
249
250 // For [0, 10] range, zero_point should map 0.0 to a quantized value
251 assert!(params.scale > 0.0);
252 assert!(params.zero_point >= -128);
253 assert!(params.zero_point <= 127);
254
255 params.validate().unwrap();
256 }
257
258 #[test]
259 fn test_quantize_dequantize_symmetric() {
260 let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
261 .unwrap();
262
263 let value = 5.0f32;
264 let quantized = params.quantize_value(value);
265 let dequantized = params.dequantize_value(quantized);
266
267 // Should be close (within quantization error)
268 assert!((dequantized - value).abs() < 0.1);
269 }
270
271 #[test]
272 fn test_quantize_dequantize_asymmetric() {
273 let params = QuantizationParams::from_minmax(0.0, 10.0, QuantizationMode::Asymmetric)
274 .unwrap();
275
276 let value = 5.0f32;
277 let quantized = params.quantize_value(value);
278 let dequantized = params.dequantize_value(quantized);
279
280 assert!((dequantized - value).abs() < 0.1);
281 }
282
283 #[test]
284 fn test_zero_value_quantization() {
285 let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
286 .unwrap();
287
288 let quantized = params.quantize_value(0.0);
289 assert_eq!(quantized, 0);
290
291 let dequantized = params.dequantize_value(0);
292 assert_eq!(dequantized, 0.0);
293 }
294
295 #[test]
296 fn test_clipping() {
297 let params = QuantizationParams::from_minmax(-1.0, 1.0, QuantizationMode::Symmetric)
298 .unwrap();
299
300 // Values outside range should be clipped
301 let large = params.quantize_value(1000.0);
302 assert_eq!(large, 127);
303
304 let small = params.quantize_value(-1000.0);
305 assert_eq!(small, -127);
306 }
307
308 #[test]
309 fn test_invalid_range() {
310 let result = QuantizationParams::from_minmax(10.0, -10.0, QuantizationMode::Symmetric);
311 assert!(result.is_err());
312 }
313
314 #[test]
315 fn test_percentile_constructor() {
316 let params = QuantizationParams::from_percentile(-9.5, 9.5, QuantizationMode::Symmetric)
317 .unwrap();
318
319 assert_eq!(params.zero_point, 0);
320 params.validate().unwrap();
321 }
322
323 #[test]
324 fn test_validation_negative_scale() {
325 let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
326 .unwrap();
327
328 params.scale = -1.0;
329 assert!(params.validate().is_err());
330 }
331
332 #[test]
333 fn test_validation_zero_scale() {
334 let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
335 .unwrap();
336
337 params.scale = 0.0;
338 assert!(params.validate().is_err());
339 }
340
341 #[test]
342 fn test_validation_invalid_qmin_qmax() {
343 let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
344 .unwrap();
345
346 params.qmin = 127;
347 params.qmax = -127;
348 assert!(params.validate().is_err());
349 }
350
351 #[test]
352 fn test_validation_zero_point_out_of_range() {
353 let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
354 .unwrap();
355
356 params.zero_point = 200;
357 assert!(params.validate().is_err());
358 }
359}