Skip to main content

voirs_acoustic/quantization/
mod.rs

1//! Model quantization utilities for compression and optimization
2//!
3//! This module provides various quantization techniques including post-training quantization (PTQ),
4//! quantization-aware training (QAT), and dynamic range calibration for neural acoustic models.
5
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9use crate::{AcousticError, Result};
10
11pub mod calibration;
12pub mod ptq;
13pub mod qat;
14pub mod utils;
15
16/// Quantization precision types
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum QuantizationPrecision {
19    /// 8-bit integer quantization
20    Int8,
21    /// 16-bit integer quantization
22    Int16,
23    /// 4-bit integer quantization (experimental)
24    Int4,
25    /// Mixed precision (some layers remain FP32)
26    Mixed,
27}
28
29/// Quantization methods
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum QuantizationMethod {
32    /// Post-training quantization
33    PostTraining,
34    /// Quantization-aware training
35    AwareTraining,
36    /// Dynamic quantization
37    Dynamic,
38}
39
40/// Quantization configuration
41#[derive(Debug, Clone)]
42pub struct QuantizationConfig {
43    /// Target precision
44    pub precision: QuantizationPrecision,
45    /// Quantization method
46    pub method: QuantizationMethod,
47    /// Calibration dataset size
48    pub calibration_samples: usize,
49    /// Layers to skip quantization
50    pub skip_layers: Vec<String>,
51    /// Symmetric vs asymmetric quantization
52    pub symmetric: bool,
53    /// Per-channel vs per-tensor quantization
54    pub per_channel: bool,
55    /// Target accuracy retention (0.0 to 1.0)
56    pub target_accuracy: f32,
57}
58
59impl Default for QuantizationConfig {
60    fn default() -> Self {
61        Self {
62            precision: QuantizationPrecision::Int8,
63            method: QuantizationMethod::PostTraining,
64            calibration_samples: 1000,
65            skip_layers: vec!["output".to_string()], // Usually skip output layer
66            symmetric: true,
67            per_channel: true,
68            target_accuracy: 0.95, // 95% accuracy retention
69        }
70    }
71}
72
73/// Quantization parameters for a tensor
74#[derive(Debug, Clone)]
75pub struct QuantizationParams {
76    /// Scale factor
77    pub scale: f32,
78    /// Zero point
79    pub zero_point: i32,
80    /// Quantization range
81    pub qmin: i32,
82    /// Quantization maximum
83    pub qmax: i32,
84    /// Whether quantization is symmetric
85    pub symmetric: bool,
86}
87
88impl QuantizationParams {
89    /// Create new quantization parameters
90    pub fn new(scale: f32, zero_point: i32, qmin: i32, qmax: i32, symmetric: bool) -> Self {
91        Self {
92            scale,
93            zero_point,
94            qmin,
95            qmax,
96            symmetric,
97        }
98    }
99
100    /// Create symmetric quantization parameters
101    pub fn symmetric(scale: f32, qmin: i32, qmax: i32) -> Self {
102        Self {
103            scale,
104            zero_point: 0,
105            qmin,
106            qmax,
107            symmetric: true,
108        }
109    }
110
111    /// Create asymmetric quantization parameters
112    pub fn asymmetric(scale: f32, zero_point: i32, qmin: i32, qmax: i32) -> Self {
113        Self {
114            scale,
115            zero_point,
116            qmin,
117            qmax,
118            symmetric: false,
119        }
120    }
121
122    /// Quantize a value
123    pub fn quantize(&self, value: f32) -> i32 {
124        let quantized = (value / self.scale).round() as i32 + self.zero_point;
125        quantized.clamp(self.qmin, self.qmax)
126    }
127
128    /// Dequantize a value
129    pub fn dequantize(&self, quantized: i32) -> f32 {
130        (quantized - self.zero_point) as f32 * self.scale
131    }
132
133    /// Quantize a tensor
134    pub fn quantize_tensor(&self, input: &[f32]) -> Vec<i32> {
135        input.iter().map(|&x| self.quantize(x)).collect()
136    }
137
138    /// Dequantize a tensor
139    pub fn dequantize_tensor(&self, quantized: &[i32]) -> Vec<f32> {
140        quantized.iter().map(|&x| self.dequantize(x)).collect()
141    }
142}
143
144/// Quantized tensor representation
145#[derive(Debug, Clone)]
146pub struct QuantizedTensor {
147    /// Quantized data
148    pub data: Vec<i32>,
149    /// Quantization parameters
150    pub params: QuantizationParams,
151    /// Original tensor shape
152    pub shape: Vec<usize>,
153    /// Tensor name/identifier
154    pub name: String,
155}
156
157impl QuantizedTensor {
158    /// Create new quantized tensor
159    pub fn new(
160        data: Vec<i32>,
161        params: QuantizationParams,
162        shape: Vec<usize>,
163        name: String,
164    ) -> Self {
165        Self {
166            data,
167            params,
168            shape,
169            name,
170        }
171    }
172
173    /// Dequantize tensor to floating point
174    pub fn dequantize(&self) -> Vec<f32> {
175        self.params.dequantize_tensor(&self.data)
176    }
177
178    /// Get tensor size in elements
179    pub fn size(&self) -> usize {
180        self.shape.iter().product()
181    }
182
183    /// Get memory usage in bytes (for quantized representation)
184    pub fn memory_usage(&self) -> usize {
185        self.data.len() * std::mem::size_of::<i32>()
186    }
187
188    /// Get compression ratio compared to FP32
189    pub fn compression_ratio(&self) -> f32 {
190        let fp32_size = self.size() * std::mem::size_of::<f32>();
191        let quantized_size = self.memory_usage();
192        fp32_size as f32 / quantized_size as f32
193    }
194}
195
196/// Model quantizer for applying quantization to acoustic models
197pub struct ModelQuantizer {
198    /// Quantization configuration
199    config: QuantizationConfig,
200    /// Quantization parameters per layer
201    layer_params: Arc<Mutex<HashMap<String, QuantizationParams>>>,
202    /// Calibration data cache
203    calibration_cache: Arc<Mutex<HashMap<String, Vec<f32>>>>,
204}
205
206impl ModelQuantizer {
207    /// Create new model quantizer
208    pub fn new(config: QuantizationConfig) -> Self {
209        Self {
210            config,
211            layer_params: Arc::new(Mutex::new(HashMap::new())),
212            calibration_cache: Arc::new(Mutex::new(HashMap::new())),
213        }
214    }
215
216    /// Add calibration data for a layer
217    pub fn add_calibration_data(&self, layer_name: String, data: Vec<f32>) -> Result<()> {
218        let mut cache = self.calibration_cache.lock().unwrap();
219        cache.insert(layer_name, data);
220        Ok(())
221    }
222
223    /// Calibrate quantization parameters for all layers
224    pub fn calibrate(&self) -> Result<()> {
225        let cache = self.calibration_cache.lock().unwrap();
226        let mut params = self.layer_params.lock().unwrap();
227
228        for (layer_name, data) in cache.iter() {
229            if self.config.skip_layers.contains(layer_name) {
230                continue;
231            }
232
233            let qparams = self.calculate_quantization_params(data)?;
234            params.insert(layer_name.clone(), qparams);
235        }
236
237        Ok(())
238    }
239
240    /// Calculate quantization parameters from calibration data
241    fn calculate_quantization_params(&self, data: &[f32]) -> Result<QuantizationParams> {
242        if data.is_empty() {
243            return Err(AcousticError::ProcessingError {
244                message: "Cannot calculate quantization params from empty data".to_string(),
245            });
246        }
247
248        let min_val = data.iter().copied().fold(f32::INFINITY, f32::min);
249        let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
250
251        let (qmin, qmax) = match self.config.precision {
252            QuantizationPrecision::Int8 => (-128, 127),
253            QuantizationPrecision::Int16 => (-32768, 32767),
254            QuantizationPrecision::Int4 => (-8, 7),
255            QuantizationPrecision::Mixed => (-128, 127), // Default to Int8 for mixed
256        };
257
258        if self.config.symmetric {
259            let abs_max = max_val.abs().max(min_val.abs());
260            let scale = abs_max / (qmax as f32);
261            Ok(QuantizationParams::symmetric(scale, qmin, qmax))
262        } else {
263            let scale = (max_val - min_val) / (qmax - qmin) as f32;
264            let zero_point = qmin - (min_val / scale).round() as i32;
265            Ok(QuantizationParams::asymmetric(
266                scale, zero_point, qmin, qmax,
267            ))
268        }
269    }
270
271    /// Quantize a tensor using calibrated parameters
272    pub fn quantize_tensor(
273        &self,
274        layer_name: &str,
275        data: &[f32],
276        shape: Vec<usize>,
277    ) -> Result<QuantizedTensor> {
278        let params = self.layer_params.lock().unwrap();
279        let qparams = params
280            .get(layer_name)
281            .ok_or_else(|| AcousticError::ProcessingError {
282                message: format!("No quantization parameters found for layer: {layer_name}"),
283            })?;
284
285        let quantized_data = qparams.quantize_tensor(data);
286        Ok(QuantizedTensor::new(
287            quantized_data,
288            qparams.clone(),
289            shape,
290            layer_name.to_string(),
291        ))
292    }
293
294    /// Get quantization parameters for a layer
295    pub fn get_layer_params(&self, layer_name: &str) -> Option<QuantizationParams> {
296        self.layer_params.lock().unwrap().get(layer_name).cloned()
297    }
298
299    /// Get configuration
300    pub fn config(&self) -> &QuantizationConfig {
301        &self.config
302    }
303
304    /// Get calibration progress
305    pub fn calibration_progress(&self) -> f32 {
306        let cache = self.calibration_cache.lock().unwrap();
307        if cache.is_empty() {
308            0.0
309        } else {
310            let params = self.layer_params.lock().unwrap();
311            params.len() as f32 / cache.len() as f32
312        }
313    }
314}
315
316/// Quantization statistics for model analysis
317#[derive(Debug, Clone)]
318pub struct QuantizationStats {
319    /// Total model size before quantization (bytes)
320    pub original_size: usize,
321    /// Total model size after quantization (bytes)
322    pub quantized_size: usize,
323    /// Compression ratio
324    pub compression_ratio: f32,
325    /// Number of quantized layers
326    pub quantized_layers: usize,
327    /// Number of skipped layers
328    pub skipped_layers: usize,
329    /// Estimated accuracy retention
330    pub estimated_accuracy: f32,
331}
332
333impl QuantizationStats {
334    /// Calculate compression savings
335    pub fn compression_savings(&self) -> f32 {
336        if self.original_size == 0 {
337            0.0
338        } else {
339            1.0 - (self.quantized_size as f32 / self.original_size as f32)
340        }
341    }
342
343    /// Calculate memory savings in MB
344    pub fn memory_savings_mb(&self) -> f32 {
345        (self.original_size - self.quantized_size) as f32 / (1024.0 * 1024.0)
346    }
347}
348
349/// Quantization benchmark for performance testing
350#[derive(Debug, Clone)]
351pub struct QuantizationBenchmark {
352    /// Original model inference time (ms)
353    pub original_inference_ms: f32,
354    /// Quantized model inference time (ms)
355    pub quantized_inference_ms: f32,
356    /// Speedup factor
357    pub speedup: f32,
358    /// Accuracy on test set (original)
359    pub original_accuracy: f32,
360    /// Accuracy on test set (quantized)
361    pub quantized_accuracy: f32,
362    /// Accuracy degradation
363    pub accuracy_degradation: f32,
364}
365
366impl QuantizationBenchmark {
367    /// Create new benchmark
368    pub fn new(
369        original_inference_ms: f32,
370        quantized_inference_ms: f32,
371        original_accuracy: f32,
372        quantized_accuracy: f32,
373    ) -> Self {
374        let speedup = original_inference_ms / quantized_inference_ms;
375        let accuracy_degradation = original_accuracy - quantized_accuracy;
376
377        Self {
378            original_inference_ms,
379            quantized_inference_ms,
380            speedup,
381            original_accuracy,
382            quantized_accuracy,
383            accuracy_degradation,
384        }
385    }
386
387    /// Check if quantization meets quality targets
388    pub fn meets_targets(&self, target_speedup: f32, max_accuracy_loss: f32) -> bool {
389        self.speedup >= target_speedup && self.accuracy_degradation <= max_accuracy_loss
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_quantization_params() {
399        let params = QuantizationParams::symmetric(0.1, -128, 127);
400        assert_eq!(params.scale, 0.1);
401        assert_eq!(params.zero_point, 0);
402        assert!(params.symmetric);
403
404        // Test quantization
405        let value = 1.0;
406        let quantized = params.quantize(value);
407        let dequantized = params.dequantize(quantized);
408        assert!((dequantized - value).abs() < 0.1);
409    }
410
411    #[test]
412    fn test_quantized_tensor() {
413        let data = vec![1, 2, 3, 4];
414        let params = QuantizationParams::symmetric(0.1, -128, 127);
415        let shape = vec![2, 2];
416        let tensor = QuantizedTensor::new(data, params, shape, "test".to_string());
417
418        assert_eq!(tensor.size(), 4);
419        assert_eq!(tensor.memory_usage(), 16); // 4 * 4 bytes
420        assert_eq!(tensor.compression_ratio(), 1.0); // Same as i32 vs f32
421    }
422
423    #[test]
424    fn test_model_quantizer() {
425        let config = QuantizationConfig::default();
426        let quantizer = ModelQuantizer::new(config);
427
428        // Add calibration data
429        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
430        quantizer
431            .add_calibration_data("layer1".to_string(), data.clone())
432            .unwrap();
433
434        // Calibrate
435        quantizer.calibrate().unwrap();
436
437        // Check parameters were calculated
438        assert!(quantizer.get_layer_params("layer1").is_some());
439
440        // Quantize tensor
441        let quantized = quantizer.quantize_tensor("layer1", &data, vec![5]).unwrap();
442        assert_eq!(quantized.size(), 5);
443    }
444
445    #[test]
446    fn test_quantization_config() {
447        let config = QuantizationConfig::default();
448        assert_eq!(config.precision, QuantizationPrecision::Int8);
449        assert_eq!(config.method, QuantizationMethod::PostTraining);
450        assert_eq!(config.calibration_samples, 1000);
451    }
452
453    #[test]
454    fn test_quantization_benchmark() {
455        let benchmark = QuantizationBenchmark::new(100.0, 50.0, 0.95, 0.92);
456        assert_eq!(benchmark.speedup, 2.0);
457        assert!((benchmark.accuracy_degradation - 0.03).abs() < 1e-6);
458        assert!(benchmark.meets_targets(1.5, 0.05));
459        assert!(!benchmark.meets_targets(3.0, 0.02));
460    }
461
462    #[test]
463    fn test_quantization_stats() {
464        let stats = QuantizationStats {
465            original_size: 1000,
466            quantized_size: 250,
467            compression_ratio: 4.0,
468            quantized_layers: 8,
469            skipped_layers: 2,
470            estimated_accuracy: 0.94,
471        };
472
473        assert_eq!(stats.compression_savings(), 0.75);
474        assert!((stats.memory_savings_mb() - 0.000715).abs() < 0.0001);
475    }
476}