torsh_backend/quantization/params.rs
1//! Quantization parameters and configuration
2//!
3//! This module provides the QuantizationParams struct and related functionality
4//! for managing quantization configuration. It handles parameter calculation
5//! from statistics, preset configurations for common quantization schemes,
6//! and parameter validation.
7
8use super::types::{QuantizationScheme, QuantizedDType};
9use crate::BackendResult;
10
11#[cfg(not(feature = "std"))]
12use alloc::vec::Vec;
13
14/// Quantization parameters
15///
16/// Contains all the parameters needed to quantize and dequantize tensors,
17/// including scale factors, zero points, and metadata about the quantization
18/// scheme being used.
19#[derive(Debug, Clone)]
20pub struct QuantizationParams {
21 /// Quantization data type
22 ///
23 /// Specifies the target quantized data type (e.g., Int8, UInt8, Int4)
24 pub dtype: QuantizedDType,
25
26 /// Quantization scheme
27 ///
28 /// Defines how the quantization mapping is performed (linear, symmetric, etc.)
29 pub scheme: QuantizationScheme,
30
31 /// Scale factor(s)
32 ///
33 /// Maps quantized values back to floating-point range.
34 /// For per-channel quantization, contains one scale per channel.
35 /// Formula: float_val = scale * (quantized_val - zero_point)
36 pub scale: Vec<f32>,
37
38 /// Zero point(s)
39 ///
40 /// The quantized value that corresponds to floating-point zero.
41 /// For per-channel quantization, contains one zero point per channel.
42 /// For symmetric quantization, this is always 0.
43 pub zero_point: Vec<i32>,
44
45 /// Block size for block-wise quantization
46 ///
47 /// When using block-wise quantization, specifies the size of each block
48 /// that gets its own quantization parameters. None for other schemes.
49 pub block_size: Option<usize>,
50
51 /// Minimum value observed during calibration
52 ///
53 /// Used for parameter calculation and validation. Set during calibration
54 /// or when computing parameters from statistics.
55 pub min_val: Option<f32>,
56
57 /// Maximum value observed during calibration
58 ///
59 /// Used for parameter calculation and validation. Set during calibration
60 /// or when computing parameters from statistics.
61 pub max_val: Option<f32>,
62}
63
64impl Default for QuantizationParams {
65 /// Default quantization parameters
66 ///
67 /// Creates parameters for UInt8 linear quantization with scale=1.0
68 /// and zero_point=0, suitable for testing and initialization.
69 fn default() -> Self {
70 Self {
71 dtype: QuantizedDType::UInt8,
72 scheme: QuantizationScheme::Linear,
73 scale: vec![1.0],
74 zero_point: vec![0],
75 block_size: None,
76 min_val: None,
77 max_val: None,
78 }
79 }
80}
81
82impl QuantizationParams {
83 /// Create parameters for INT8 symmetric quantization
84 ///
85 /// INT8 symmetric quantization is commonly used for weights in neural networks
86 /// due to its simplicity and good hardware support. The zero point is always 0,
87 /// and the range is symmetric around zero.
88 ///
89 /// # Examples
90 ///
91 /// ```
92 /// use torsh_backend::quantization::QuantizationParams;
93 ///
94 /// let params = QuantizationParams::int8_symmetric();
95 /// assert_eq!(params.zero_point[0], 0);
96 /// ```
97 pub fn int8_symmetric() -> Self {
98 Self {
99 dtype: QuantizedDType::Int8,
100 scheme: QuantizationScheme::Symmetric,
101 scale: vec![1.0],
102 zero_point: vec![0],
103 block_size: None,
104 min_val: None,
105 max_val: None,
106 }
107 }
108
109 /// Create basic quantization parameters with custom scale and zero point
110 ///
111 /// This is a general-purpose constructor for creating quantization parameters
112 /// with custom scale and zero point values. Useful for benchmarking and
113 /// testing with specific parameter configurations.
114 ///
115 /// # Arguments
116 ///
117 /// * `scale` - Scale factor for the quantization
118 /// * `zero_point` - Zero point for the quantization
119 ///
120 /// # Examples
121 ///
122 /// ```
123 /// use torsh_backend::quantization::QuantizationParams;
124 ///
125 /// let params = QuantizationParams::new(255.0, 128);
126 /// assert_eq!(params.scale[0], 255.0);
127 /// assert_eq!(params.zero_point[0], 128);
128 /// ```
129 pub fn new(scale: f32, zero_point: i32) -> Self {
130 Self {
131 dtype: QuantizedDType::UInt8, // Default to UInt8 for general usage
132 scheme: QuantizationScheme::Asymmetric,
133 scale: vec![scale],
134 zero_point: vec![zero_point],
135 block_size: None,
136 min_val: None,
137 max_val: None,
138 }
139 }
140
141 /// Create parameters for UINT8 asymmetric quantization
142 ///
143 /// UInt8 asymmetric quantization is commonly used for activations,
144 /// especially after ReLU layers where values are non-negative.
145 /// The zero point is typically set to 128 for balanced range utilization.
146 ///
147 /// # Examples
148 ///
149 /// ```
150 /// use torsh_backend::quantization::QuantizationParams;
151 ///
152 /// let params = QuantizationParams::uint8_asymmetric();
153 /// assert_eq!(params.zero_point[0], 128);
154 /// ```
155 pub fn uint8_asymmetric() -> Self {
156 Self {
157 dtype: QuantizedDType::UInt8,
158 scheme: QuantizationScheme::Asymmetric,
159 scale: vec![1.0],
160 zero_point: vec![128],
161 block_size: None,
162 min_val: None,
163 max_val: None,
164 }
165 }
166
167 /// Create parameters for INT4 symmetric quantization
168 ///
169 /// INT4 quantization provides extreme compression at the cost of accuracy.
170 /// Symmetric INT4 is often used for weights in models where 4-bit precision
171 /// is sufficient.
172 ///
173 /// # Examples
174 ///
175 /// ```
176 /// use torsh_backend::quantization::QuantizationParams;
177 ///
178 /// let params = QuantizationParams::int4_symmetric();
179 /// assert_eq!(params.dtype.bits(), 4);
180 /// ```
181 pub fn int4_symmetric() -> Self {
182 Self {
183 dtype: QuantizedDType::Int4,
184 scheme: QuantizationScheme::Symmetric,
185 scale: vec![1.0],
186 zero_point: vec![0],
187 block_size: None,
188 min_val: None,
189 max_val: None,
190 }
191 }
192
193 /// Create parameters for channel-wise quantization
194 ///
195 /// Channel-wise quantization applies different quantization parameters
196 /// to each channel, providing better accuracy for models with varying
197 /// channel sensitivities at the cost of increased parameter storage.
198 ///
199 /// # Arguments
200 ///
201 /// * `num_channels` - Number of channels in the tensor
202 /// * `dtype` - Quantization data type to use
203 ///
204 /// # Examples
205 ///
206 /// ```
207 /// use torsh_backend::quantization::{QuantizationParams, QuantizedDType};
208 ///
209 /// let params = QuantizationParams::channel_wise(64, QuantizedDType::Int8);
210 /// assert_eq!(params.scale.len(), 64);
211 /// assert_eq!(params.zero_point.len(), 64);
212 /// ```
213 pub fn channel_wise(num_channels: usize, dtype: QuantizedDType) -> Self {
214 Self {
215 dtype,
216 scheme: QuantizationScheme::ChannelWise,
217 scale: vec![1.0; num_channels],
218 zero_point: vec![0; num_channels],
219 block_size: None,
220 min_val: None,
221 max_val: None,
222 }
223 }
224
225 /// Create parameters for block-wise quantization
226 ///
227 /// Block-wise quantization divides the tensor into blocks and applies
228 /// different quantization parameters to each block. This can provide
229 /// better accuracy than tensor-wise quantization while being more
230 /// memory-efficient than channel-wise quantization.
231 ///
232 /// # Arguments
233 ///
234 /// * `block_size` - Size of each quantization block
235 /// * `dtype` - Quantization data type to use
236 ///
237 /// # Examples
238 ///
239 /// ```
240 /// use torsh_backend::quantization::{QuantizationParams, QuantizedDType};
241 ///
242 /// let params = QuantizationParams::block_wise(128, QuantizedDType::Int8);
243 /// assert_eq!(params.block_size, Some(128));
244 /// ```
245 pub fn block_wise(block_size: usize, dtype: QuantizedDType) -> Self {
246 Self {
247 dtype,
248 scheme: QuantizationScheme::BlockWise,
249 scale: vec![1.0], // Will be expanded based on tensor size
250 zero_point: vec![0],
251 block_size: Some(block_size),
252 min_val: None,
253 max_val: None,
254 }
255 }
256
257 /// Calculate quantization parameters from input statistics
258 ///
259 /// Computes the optimal scale and zero point parameters based on the
260 /// observed minimum and maximum values in the data. The calculation
261 /// depends on the quantization scheme being used.
262 ///
263 /// # Arguments
264 ///
265 /// * `min_val` - Minimum value observed in the data
266 /// * `max_val` - Maximum value observed in the data
267 ///
268 /// # Returns
269 ///
270 /// Returns `Ok(())` if parameters were calculated successfully,
271 /// or an error if the statistics are invalid.
272 ///
273 /// # Examples
274 ///
275 /// ```
276 /// use torsh_backend::quantization::QuantizationParams;
277 ///
278 /// let mut params = QuantizationParams::int8_symmetric();
279 /// params.from_statistics(-2.0, 2.0).unwrap();
280 /// // Scale will be calculated to map [-2.0, 2.0] to [-128, 127]
281 /// ```
282 pub fn from_statistics(&mut self, min_val: f32, max_val: f32) -> BackendResult<()> {
283 // Validate input statistics
284 if min_val > max_val {
285 return Err(torsh_core::error::TorshError::InvalidArgument(
286 "min_val must be <= max_val".to_string(),
287 ));
288 }
289
290 if !min_val.is_finite() || !max_val.is_finite() {
291 return Err(torsh_core::error::TorshError::InvalidArgument(
292 "min_val and max_val must be finite".to_string(),
293 ));
294 }
295
296 self.min_val = Some(min_val);
297 self.max_val = Some(max_val);
298
299 let (qmin, qmax) = self.dtype.value_range();
300 let qmin = qmin as f32;
301 let qmax = qmax as f32;
302
303 match self.scheme {
304 QuantizationScheme::Symmetric => {
305 self.calculate_symmetric_params(min_val, max_val, qmin, qmax)?;
306 }
307 QuantizationScheme::Asymmetric | QuantizationScheme::Linear => {
308 self.calculate_asymmetric_params(min_val, max_val, qmin, qmax)?;
309 }
310 QuantizationScheme::Logarithmic => {
311 self.calculate_logarithmic_params(min_val, max_val, qmin, qmax)?;
312 }
313 QuantizationScheme::BlockWise | QuantizationScheme::ChannelWise => {
314 // For block-wise and channel-wise, use asymmetric as base
315 // Individual blocks/channels will be calculated separately
316 self.calculate_asymmetric_params(min_val, max_val, qmin, qmax)?;
317 }
318 }
319
320 Ok(())
321 }
322
323 /// Calculate symmetric quantization parameters
324 fn calculate_symmetric_params(
325 &mut self,
326 min_val: f32,
327 max_val: f32,
328 qmin: f32,
329 qmax: f32,
330 ) -> BackendResult<()> {
331 let max_range = max_val.abs().max(min_val.abs());
332 if max_range == 0.0 {
333 self.scale[0] = 1.0;
334 } else {
335 // For symmetric quantization, we map [-max_range, max_range] to [qmin, qmax]
336 self.scale[0] = (2.0 * max_range) / (qmax - qmin);
337 }
338 self.zero_point[0] = 0;
339 Ok(())
340 }
341
342 /// Calculate asymmetric quantization parameters
343 fn calculate_asymmetric_params(
344 &mut self,
345 min_val: f32,
346 max_val: f32,
347 qmin: f32,
348 qmax: f32,
349 ) -> BackendResult<()> {
350 if max_val == min_val {
351 // Degenerate case: all values are the same
352 self.scale[0] = 1.0;
353 self.zero_point[0] = qmin as i32;
354 } else {
355 // Calculate scale to map [min_val, max_val] to [qmin, qmax]
356 self.scale[0] = (max_val - min_val) / (qmax - qmin);
357
358 // Calculate zero point such that min_val maps to qmin
359 let zero_point_from_min = qmin - min_val / self.scale[0];
360 self.zero_point[0] = zero_point_from_min.round().clamp(qmin, qmax) as i32;
361 }
362 Ok(())
363 }
364
365 /// Calculate logarithmic quantization parameters
366 fn calculate_logarithmic_params(
367 &mut self,
368 min_val: f32,
369 max_val: f32,
370 qmin: f32,
371 qmax: f32,
372 ) -> BackendResult<()> {
373 // For logarithmic quantization, we need positive values
374 if min_val <= 0.0 {
375 return Err(torsh_core::error::TorshError::InvalidArgument(
376 "Logarithmic quantization requires positive values".to_string(),
377 ));
378 }
379
380 // Use logarithmic scale mapping
381 let log_min = min_val.ln();
382 let log_max = max_val.ln();
383
384 if log_max == log_min {
385 self.scale[0] = 1.0;
386 self.zero_point[0] = qmin as i32;
387 } else {
388 self.scale[0] = (log_max - log_min) / (qmax - qmin);
389 self.zero_point[0] = (qmin - log_min / self.scale[0]).round() as i32;
390 }
391 Ok(())
392 }
393
394 /// Validate that the parameters are consistent and usable
395 ///
396 /// Checks that all parameter vectors have consistent lengths,
397 /// scale factors are positive, and zero points are within valid ranges.
398 pub fn validate(&self) -> BackendResult<()> {
399 // Check that scale and zero_point vectors have consistent lengths
400 if self.scale.len() != self.zero_point.len() {
401 return Err(torsh_core::error::TorshError::InvalidArgument(
402 "Scale and zero_point vectors must have the same length".to_string(),
403 ));
404 }
405
406 // Check that all scale factors are positive and finite
407 for (i, &scale) in self.scale.iter().enumerate() {
408 if scale <= 0.0 || !scale.is_finite() {
409 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
410 "Scale factor at index {} must be positive and finite, got {}",
411 i, scale
412 )));
413 }
414 }
415
416 // Check that zero points are within the valid range for the data type
417 let (qmin, qmax) = self.dtype.value_range();
418 for (i, &zero_point) in self.zero_point.iter().enumerate() {
419 if (zero_point as i64) < qmin || (zero_point as i64) > qmax {
420 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
421 "Zero point at index {} ({}) is outside valid range [{}, {}]",
422 i, zero_point, qmin, qmax
423 )));
424 }
425 }
426
427 // Scheme-specific validation
428 match self.scheme {
429 QuantizationScheme::Symmetric => {
430 // Symmetric quantization should have zero_point = 0
431 for (i, &zero_point) in self.zero_point.iter().enumerate() {
432 if zero_point != 0 {
433 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
434 "Symmetric quantization requires zero_point[{}] = 0, got {}",
435 i, zero_point
436 )));
437 }
438 }
439 }
440 QuantizationScheme::BlockWise => {
441 // Block-wise quantization should have a block size specified
442 if self.block_size.is_none() {
443 return Err(torsh_core::error::TorshError::InvalidArgument(
444 "Block-wise quantization requires block_size to be specified".to_string(),
445 ));
446 }
447 }
448 QuantizationScheme::ChannelWise => {
449 // Channel-wise should have multiple parameters
450 if self.scale.len() == 1 {
451 return Err(torsh_core::error::TorshError::InvalidArgument(
452 "Channel-wise quantization requires multiple scale/zero_point values"
453 .to_string(),
454 ));
455 }
456 }
457 _ => {} // Other schemes have no specific requirements
458 }
459
460 Ok(())
461 }
462
463 /// Get the effective number of quantization parameter sets
464 ///
465 /// Returns the number of independent parameter sets (scale/zero_point pairs)
466 /// that this configuration represents. For tensor-wise quantization this is 1,
467 /// for channel-wise it's the number of channels.
468 pub fn num_parameter_sets(&self) -> usize {
469 self.scale.len()
470 }
471
472 /// Check if this configuration uses per-channel parameters
473 pub fn is_per_channel(&self) -> bool {
474 self.scheme.is_per_channel() && self.scale.len() > 1
475 }
476
477 /// Get the quantization error bound for this configuration
478 ///
479 /// Returns the maximum possible quantization error (in the original
480 /// floating-point scale) for this quantization configuration.
481 pub fn quantization_error_bound(&self) -> f32 {
482 // The maximum error is half the quantization step size
483 self.scale
484 .iter()
485 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
486 .copied()
487 .unwrap_or(0.0)
488 * 0.5
489 }
490
491 /// Calculate the compression ratio achieved by this quantization
492 ///
493 /// Returns the ratio of original size to quantized size.
494 /// Assumes the original data was 32-bit floating point.
495 pub fn compression_ratio(&self) -> f32 {
496 let bits_per_value = self.dtype.bits() as f32;
497 32.0 / bits_per_value
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn test_default_params() {
507 let params = QuantizationParams::default();
508 assert_eq!(params.dtype, QuantizedDType::UInt8);
509 assert_eq!(params.scheme, QuantizationScheme::Linear);
510 assert_eq!(params.scale, vec![1.0]);
511 assert_eq!(params.zero_point, vec![0]);
512 }
513
514 #[test]
515 fn test_preset_configs() {
516 let int8_sym = QuantizationParams::int8_symmetric();
517 assert_eq!(int8_sym.dtype, QuantizedDType::Int8);
518 assert_eq!(int8_sym.scheme, QuantizationScheme::Symmetric);
519 assert_eq!(int8_sym.zero_point[0], 0);
520
521 let uint8_asym = QuantizationParams::uint8_asymmetric();
522 assert_eq!(uint8_asym.dtype, QuantizedDType::UInt8);
523 assert_eq!(uint8_asym.scheme, QuantizationScheme::Asymmetric);
524 assert_eq!(uint8_asym.zero_point[0], 128);
525
526 let int4_sym = QuantizationParams::int4_symmetric();
527 assert_eq!(int4_sym.dtype, QuantizedDType::Int4);
528 assert_eq!(int4_sym.zero_point[0], 0);
529 }
530
531 #[test]
532 fn test_channel_wise_params() {
533 let params = QuantizationParams::channel_wise(64, QuantizedDType::Int8);
534 assert_eq!(params.scheme, QuantizationScheme::ChannelWise);
535 assert_eq!(params.scale.len(), 64);
536 assert_eq!(params.zero_point.len(), 64);
537 assert!(params.is_per_channel());
538 }
539
540 #[test]
541 fn test_block_wise_params() {
542 let params = QuantizationParams::block_wise(128, QuantizedDType::Int8);
543 assert_eq!(params.scheme, QuantizationScheme::BlockWise);
544 assert_eq!(params.block_size, Some(128));
545 }
546
547 #[test]
548 fn test_from_statistics_symmetric() {
549 let mut params = QuantizationParams::int8_symmetric();
550 params.from_statistics(-2.0, 2.0).unwrap();
551
552 assert_eq!(params.zero_point[0], 0);
553 assert!(params.scale[0] > 0.0);
554 assert_eq!(params.min_val, Some(-2.0));
555 assert_eq!(params.max_val, Some(2.0));
556 }
557
558 #[test]
559 fn test_from_statistics_asymmetric() {
560 let mut params = QuantizationParams::uint8_asymmetric();
561 params.from_statistics(0.0, 255.0).unwrap();
562
563 assert!(params.scale[0] > 0.0);
564 assert!(params.zero_point[0] >= 0 && params.zero_point[0] <= 255);
565 }
566
567 #[test]
568 fn test_from_statistics_invalid() {
569 let mut params = QuantizationParams::default();
570
571 // min > max should fail
572 assert!(params.from_statistics(2.0, 1.0).is_err());
573
574 // Non-finite values should fail
575 assert!(params.from_statistics(f32::NAN, 1.0).is_err());
576 assert!(params.from_statistics(1.0, f32::INFINITY).is_err());
577 }
578
579 #[test]
580 fn test_validation() {
581 let mut params = QuantizationParams::default();
582 assert!(params.validate().is_ok());
583
584 // Mismatched vector lengths should fail
585 params.scale.push(2.0);
586 assert!(params.validate().is_err());
587
588 // Reset and test negative scale
589 params.scale = vec![-1.0];
590 params.zero_point = vec![0];
591 assert!(params.validate().is_err());
592 }
593
594 #[test]
595 fn test_validation_symmetric() {
596 let mut params = QuantizationParams::int8_symmetric();
597 params.zero_point[0] = 10; // Should fail for symmetric
598 assert!(params.validate().is_err());
599 }
600
601 #[test]
602 fn test_compression_ratio() {
603 let int8_params = QuantizationParams::int8_symmetric();
604 assert_eq!(int8_params.compression_ratio(), 4.0); // 32 bits -> 8 bits
605
606 let int4_params = QuantizationParams::int4_symmetric();
607 assert_eq!(int4_params.compression_ratio(), 8.0); // 32 bits -> 4 bits
608 }
609
610 #[test]
611 fn test_error_bound() {
612 let mut params = QuantizationParams::int8_symmetric();
613 params.scale = vec![0.1];
614 assert_eq!(params.quantization_error_bound(), 0.05); // Half the scale
615 }
616
617 #[test]
618 fn test_num_parameter_sets() {
619 let tensor_wise = QuantizationParams::default();
620 assert_eq!(tensor_wise.num_parameter_sets(), 1);
621
622 let channel_wise = QuantizationParams::channel_wise(64, QuantizedDType::Int8);
623 assert_eq!(channel_wise.num_parameter_sets(), 64);
624 }
625}