1use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9use crate::{AcousticError, Result};
10
11pub mod calibration;
12pub mod ptq;
13pub mod qat;
14pub mod utils;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum QuantizationPrecision {
19 Int8,
21 Int16,
23 Int4,
25 Mixed,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum QuantizationMethod {
32 PostTraining,
34 AwareTraining,
36 Dynamic,
38}
39
40#[derive(Debug, Clone)]
42pub struct QuantizationConfig {
43 pub precision: QuantizationPrecision,
45 pub method: QuantizationMethod,
47 pub calibration_samples: usize,
49 pub skip_layers: Vec<String>,
51 pub symmetric: bool,
53 pub per_channel: bool,
55 pub target_accuracy: f32,
57}
58
59impl Default for QuantizationConfig {
60 fn default() -> Self {
61 Self {
62 precision: QuantizationPrecision::Int8,
63 method: QuantizationMethod::PostTraining,
64 calibration_samples: 1000,
65 skip_layers: vec!["output".to_string()], symmetric: true,
67 per_channel: true,
68 target_accuracy: 0.95, }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct QuantizationParams {
76 pub scale: f32,
78 pub zero_point: i32,
80 pub qmin: i32,
82 pub qmax: i32,
84 pub symmetric: bool,
86}
87
88impl QuantizationParams {
89 pub fn new(scale: f32, zero_point: i32, qmin: i32, qmax: i32, symmetric: bool) -> Self {
91 Self {
92 scale,
93 zero_point,
94 qmin,
95 qmax,
96 symmetric,
97 }
98 }
99
100 pub fn symmetric(scale: f32, qmin: i32, qmax: i32) -> Self {
102 Self {
103 scale,
104 zero_point: 0,
105 qmin,
106 qmax,
107 symmetric: true,
108 }
109 }
110
111 pub fn asymmetric(scale: f32, zero_point: i32, qmin: i32, qmax: i32) -> Self {
113 Self {
114 scale,
115 zero_point,
116 qmin,
117 qmax,
118 symmetric: false,
119 }
120 }
121
122 pub fn quantize(&self, value: f32) -> i32 {
124 let quantized = (value / self.scale).round() as i32 + self.zero_point;
125 quantized.clamp(self.qmin, self.qmax)
126 }
127
128 pub fn dequantize(&self, quantized: i32) -> f32 {
130 (quantized - self.zero_point) as f32 * self.scale
131 }
132
133 pub fn quantize_tensor(&self, input: &[f32]) -> Vec<i32> {
135 input.iter().map(|&x| self.quantize(x)).collect()
136 }
137
138 pub fn dequantize_tensor(&self, quantized: &[i32]) -> Vec<f32> {
140 quantized.iter().map(|&x| self.dequantize(x)).collect()
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct QuantizedTensor {
147 pub data: Vec<i32>,
149 pub params: QuantizationParams,
151 pub shape: Vec<usize>,
153 pub name: String,
155}
156
157impl QuantizedTensor {
158 pub fn new(
160 data: Vec<i32>,
161 params: QuantizationParams,
162 shape: Vec<usize>,
163 name: String,
164 ) -> Self {
165 Self {
166 data,
167 params,
168 shape,
169 name,
170 }
171 }
172
173 pub fn dequantize(&self) -> Vec<f32> {
175 self.params.dequantize_tensor(&self.data)
176 }
177
178 pub fn size(&self) -> usize {
180 self.shape.iter().product()
181 }
182
183 pub fn memory_usage(&self) -> usize {
185 self.data.len() * std::mem::size_of::<i32>()
186 }
187
188 pub fn compression_ratio(&self) -> f32 {
190 let fp32_size = self.size() * std::mem::size_of::<f32>();
191 let quantized_size = self.memory_usage();
192 fp32_size as f32 / quantized_size as f32
193 }
194}
195
196pub struct ModelQuantizer {
198 config: QuantizationConfig,
200 layer_params: Arc<Mutex<HashMap<String, QuantizationParams>>>,
202 calibration_cache: Arc<Mutex<HashMap<String, Vec<f32>>>>,
204}
205
206impl ModelQuantizer {
207 pub fn new(config: QuantizationConfig) -> Self {
209 Self {
210 config,
211 layer_params: Arc::new(Mutex::new(HashMap::new())),
212 calibration_cache: Arc::new(Mutex::new(HashMap::new())),
213 }
214 }
215
216 pub fn add_calibration_data(&self, layer_name: String, data: Vec<f32>) -> Result<()> {
218 let mut cache = self.calibration_cache.lock().unwrap();
219 cache.insert(layer_name, data);
220 Ok(())
221 }
222
223 pub fn calibrate(&self) -> Result<()> {
225 let cache = self.calibration_cache.lock().unwrap();
226 let mut params = self.layer_params.lock().unwrap();
227
228 for (layer_name, data) in cache.iter() {
229 if self.config.skip_layers.contains(layer_name) {
230 continue;
231 }
232
233 let qparams = self.calculate_quantization_params(data)?;
234 params.insert(layer_name.clone(), qparams);
235 }
236
237 Ok(())
238 }
239
240 fn calculate_quantization_params(&self, data: &[f32]) -> Result<QuantizationParams> {
242 if data.is_empty() {
243 return Err(AcousticError::ProcessingError {
244 message: "Cannot calculate quantization params from empty data".to_string(),
245 });
246 }
247
248 let min_val = data.iter().copied().fold(f32::INFINITY, f32::min);
249 let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
250
251 let (qmin, qmax) = match self.config.precision {
252 QuantizationPrecision::Int8 => (-128, 127),
253 QuantizationPrecision::Int16 => (-32768, 32767),
254 QuantizationPrecision::Int4 => (-8, 7),
255 QuantizationPrecision::Mixed => (-128, 127), };
257
258 if self.config.symmetric {
259 let abs_max = max_val.abs().max(min_val.abs());
260 let scale = abs_max / (qmax as f32);
261 Ok(QuantizationParams::symmetric(scale, qmin, qmax))
262 } else {
263 let scale = (max_val - min_val) / (qmax - qmin) as f32;
264 let zero_point = qmin - (min_val / scale).round() as i32;
265 Ok(QuantizationParams::asymmetric(
266 scale, zero_point, qmin, qmax,
267 ))
268 }
269 }
270
271 pub fn quantize_tensor(
273 &self,
274 layer_name: &str,
275 data: &[f32],
276 shape: Vec<usize>,
277 ) -> Result<QuantizedTensor> {
278 let params = self.layer_params.lock().unwrap();
279 let qparams = params
280 .get(layer_name)
281 .ok_or_else(|| AcousticError::ProcessingError {
282 message: format!("No quantization parameters found for layer: {layer_name}"),
283 })?;
284
285 let quantized_data = qparams.quantize_tensor(data);
286 Ok(QuantizedTensor::new(
287 quantized_data,
288 qparams.clone(),
289 shape,
290 layer_name.to_string(),
291 ))
292 }
293
294 pub fn get_layer_params(&self, layer_name: &str) -> Option<QuantizationParams> {
296 self.layer_params.lock().unwrap().get(layer_name).cloned()
297 }
298
299 pub fn config(&self) -> &QuantizationConfig {
301 &self.config
302 }
303
304 pub fn calibration_progress(&self) -> f32 {
306 let cache = self.calibration_cache.lock().unwrap();
307 if cache.is_empty() {
308 0.0
309 } else {
310 let params = self.layer_params.lock().unwrap();
311 params.len() as f32 / cache.len() as f32
312 }
313 }
314}
315
316#[derive(Debug, Clone)]
318pub struct QuantizationStats {
319 pub original_size: usize,
321 pub quantized_size: usize,
323 pub compression_ratio: f32,
325 pub quantized_layers: usize,
327 pub skipped_layers: usize,
329 pub estimated_accuracy: f32,
331}
332
333impl QuantizationStats {
334 pub fn compression_savings(&self) -> f32 {
336 if self.original_size == 0 {
337 0.0
338 } else {
339 1.0 - (self.quantized_size as f32 / self.original_size as f32)
340 }
341 }
342
343 pub fn memory_savings_mb(&self) -> f32 {
345 (self.original_size - self.quantized_size) as f32 / (1024.0 * 1024.0)
346 }
347}
348
349#[derive(Debug, Clone)]
351pub struct QuantizationBenchmark {
352 pub original_inference_ms: f32,
354 pub quantized_inference_ms: f32,
356 pub speedup: f32,
358 pub original_accuracy: f32,
360 pub quantized_accuracy: f32,
362 pub accuracy_degradation: f32,
364}
365
366impl QuantizationBenchmark {
367 pub fn new(
369 original_inference_ms: f32,
370 quantized_inference_ms: f32,
371 original_accuracy: f32,
372 quantized_accuracy: f32,
373 ) -> Self {
374 let speedup = original_inference_ms / quantized_inference_ms;
375 let accuracy_degradation = original_accuracy - quantized_accuracy;
376
377 Self {
378 original_inference_ms,
379 quantized_inference_ms,
380 speedup,
381 original_accuracy,
382 quantized_accuracy,
383 accuracy_degradation,
384 }
385 }
386
387 pub fn meets_targets(&self, target_speedup: f32, max_accuracy_loss: f32) -> bool {
389 self.speedup >= target_speedup && self.accuracy_degradation <= max_accuracy_loss
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_quantization_params() {
399 let params = QuantizationParams::symmetric(0.1, -128, 127);
400 assert_eq!(params.scale, 0.1);
401 assert_eq!(params.zero_point, 0);
402 assert!(params.symmetric);
403
404 let value = 1.0;
406 let quantized = params.quantize(value);
407 let dequantized = params.dequantize(quantized);
408 assert!((dequantized - value).abs() < 0.1);
409 }
410
411 #[test]
412 fn test_quantized_tensor() {
413 let data = vec![1, 2, 3, 4];
414 let params = QuantizationParams::symmetric(0.1, -128, 127);
415 let shape = vec![2, 2];
416 let tensor = QuantizedTensor::new(data, params, shape, "test".to_string());
417
418 assert_eq!(tensor.size(), 4);
419 assert_eq!(tensor.memory_usage(), 16); assert_eq!(tensor.compression_ratio(), 1.0); }
422
423 #[test]
424 fn test_model_quantizer() {
425 let config = QuantizationConfig::default();
426 let quantizer = ModelQuantizer::new(config);
427
428 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
430 quantizer
431 .add_calibration_data("layer1".to_string(), data.clone())
432 .unwrap();
433
434 quantizer.calibrate().unwrap();
436
437 assert!(quantizer.get_layer_params("layer1").is_some());
439
440 let quantized = quantizer.quantize_tensor("layer1", &data, vec![5]).unwrap();
442 assert_eq!(quantized.size(), 5);
443 }
444
445 #[test]
446 fn test_quantization_config() {
447 let config = QuantizationConfig::default();
448 assert_eq!(config.precision, QuantizationPrecision::Int8);
449 assert_eq!(config.method, QuantizationMethod::PostTraining);
450 assert_eq!(config.calibration_samples, 1000);
451 }
452
453 #[test]
454 fn test_quantization_benchmark() {
455 let benchmark = QuantizationBenchmark::new(100.0, 50.0, 0.95, 0.92);
456 assert_eq!(benchmark.speedup, 2.0);
457 assert!((benchmark.accuracy_degradation - 0.03).abs() < 1e-6);
458 assert!(benchmark.meets_targets(1.5, 0.05));
459 assert!(!benchmark.meets_targets(3.0, 0.02));
460 }
461
462 #[test]
463 fn test_quantization_stats() {
464 let stats = QuantizationStats {
465 original_size: 1000,
466 quantized_size: 250,
467 compression_ratio: 4.0,
468 quantized_layers: 8,
469 skipped_layers: 2,
470 estimated_accuracy: 0.94,
471 };
472
473 assert_eq!(stats.compression_savings(), 0.75);
474 assert!((stats.memory_savings_mb() - 0.000715).abs() < 0.0001);
475 }
476}