1use crate::{Scirs2Tensor, TlBackendError, TlBackendResult};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum QuantizationType {
14 Int8,
16 Fp16,
18 BFloat16,
20 Int4,
22 None,
24}
25
26impl QuantizationType {
27 pub fn bits(&self) -> usize {
29 match self {
30 QuantizationType::Int4 => 4,
31 QuantizationType::Int8 => 8,
32 QuantizationType::Fp16 | QuantizationType::BFloat16 => 16,
33 QuantizationType::None => 64, }
35 }
36
37 pub fn compression_ratio(&self) -> f64 {
39 64.0 / self.bits() as f64
40 }
41
42 pub fn is_float(&self) -> bool {
44 matches!(
45 self,
46 QuantizationType::Fp16 | QuantizationType::BFloat16 | QuantizationType::None
47 )
48 }
49
50 pub fn is_integer(&self) -> bool {
52 matches!(self, QuantizationType::Int8 | QuantizationType::Int4)
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum QuantizationScheme {
59 Symmetric,
61 Asymmetric,
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67pub enum QuantizationGranularity {
68 PerTensor,
70 PerChannel,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct QuantizationParams {
77 pub qtype: QuantizationType,
79
80 pub scheme: QuantizationScheme,
82
83 pub granularity: QuantizationGranularity,
85
86 pub scale: Vec<f64>,
88
89 pub zero_point: Vec<i32>,
91
92 pub min_val: Vec<f64>,
94
95 pub max_val: Vec<f64>,
97}
98
99impl QuantizationParams {
100 pub fn symmetric_per_tensor(qtype: QuantizationType, tensor: &Scirs2Tensor) -> Self {
102 let abs_max = tensor.iter().map(|&x| x.abs()).fold(0.0f64, f64::max);
103
104 let scale = match qtype {
105 QuantizationType::Int8 => abs_max / 127.0,
106 QuantizationType::Int4 => abs_max / 7.0,
107 QuantizationType::Fp16 | QuantizationType::BFloat16 => 1.0,
108 QuantizationType::None => 1.0,
109 };
110
111 Self {
112 qtype,
113 scheme: QuantizationScheme::Symmetric,
114 granularity: QuantizationGranularity::PerTensor,
115 scale: vec![scale],
116 zero_point: vec![0],
117 min_val: vec![-abs_max],
118 max_val: vec![abs_max],
119 }
120 }
121
122 pub fn asymmetric_per_tensor(qtype: QuantizationType, tensor: &Scirs2Tensor) -> Self {
124 let min_val = tensor.iter().fold(f64::INFINITY, |a, &b| a.min(b));
125 let max_val = tensor.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
126
127 let (scale, zero_point) = match qtype {
128 QuantizationType::Int8 => {
129 let scale = (max_val - min_val) / 255.0;
130 let zero_point = (-min_val / scale).round() as i32;
131 (scale, zero_point)
132 }
133 QuantizationType::Int4 => {
134 let scale = (max_val - min_val) / 15.0;
135 let zero_point = (-min_val / scale).round() as i32;
136 (scale, zero_point)
137 }
138 QuantizationType::Fp16 | QuantizationType::BFloat16 | QuantizationType::None => {
139 (1.0, 0)
140 }
141 };
142
143 Self {
144 qtype,
145 scheme: QuantizationScheme::Asymmetric,
146 granularity: QuantizationGranularity::PerTensor,
147 scale: vec![scale],
148 zero_point: vec![zero_point],
149 min_val: vec![min_val],
150 max_val: vec![max_val],
151 }
152 }
153
154 pub fn dynamic_range(&self) -> f64 {
156 self.max_val[0] - self.min_val[0]
157 }
158
159 pub fn quantization_error_bound(&self) -> f64 {
161 self.scale[0] / 2.0
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct QuantizedTensor {
168 pub data: Scirs2Tensor,
170
171 pub params: QuantizationParams,
173}
174
175impl QuantizedTensor {
176 pub fn quantize(tensor: &Scirs2Tensor, params: QuantizationParams) -> Self {
178 let quantized_data = match params.qtype {
179 QuantizationType::Int8 => quantize_int8(tensor, ¶ms),
180 QuantizationType::Int4 => quantize_int4(tensor, ¶ms),
181 QuantizationType::Fp16 => quantize_fp16(tensor),
182 QuantizationType::BFloat16 => quantize_bf16(tensor),
183 QuantizationType::None => tensor.clone(),
184 };
185
186 Self {
187 data: quantized_data,
188 params,
189 }
190 }
191
192 pub fn dequantize(&self) -> Scirs2Tensor {
194 match self.params.qtype {
195 QuantizationType::Int8 | QuantizationType::Int4 => {
196 dequantize_integer(&self.data, &self.params)
197 }
198 QuantizationType::Fp16 | QuantizationType::BFloat16 => {
199 self.data.clone()
201 }
202 QuantizationType::None => self.data.clone(),
203 }
204 }
205
206 pub fn memory_reduction(&self) -> f64 {
208 self.params.qtype.compression_ratio()
209 }
210
211 pub fn quantization_error(&self, original: &Scirs2Tensor) -> f64 {
213 let dequantized = self.dequantize();
214 let diff = &dequantized - original;
215 let squared_error: f64 = diff.iter().map(|&x| x * x).sum();
216 squared_error / original.len() as f64
217 }
218}
219
220fn quantize_int8(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
222 let scale = params.scale[0];
223 let zero_point = params.zero_point[0];
224
225 tensor.mapv(|x| {
226 let quantized = (x / scale).round() + zero_point as f64;
227 quantized.clamp(-128.0, 127.0)
228 })
229}
230
231fn quantize_int4(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
233 let scale = params.scale[0];
234 let zero_point = params.zero_point[0];
235
236 tensor.mapv(|x| {
237 let quantized = (x / scale).round() + zero_point as f64;
238 quantized.clamp(-8.0, 7.0)
239 })
240}
241
242fn quantize_fp16(tensor: &Scirs2Tensor) -> Scirs2Tensor {
244 tensor.mapv(|x| {
245 let scaled = x * (1024.0f64).powi(2);
248 (scaled.round() / (1024.0f64).powi(2)).clamp(-65504.0, 65504.0)
249 })
250}
251
252fn quantize_bf16(tensor: &Scirs2Tensor) -> Scirs2Tensor {
254 tensor.mapv(|x| {
255 let scaled = x * (128.0f64).powi(2);
257 scaled.round() / (128.0f64).powi(2)
258 })
259}
260
261fn dequantize_integer(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
263 let scale = params.scale[0];
264 let zero_point = params.zero_point[0];
265
266 tensor.mapv(|q| (q - zero_point as f64) * scale)
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct QatConfig {
272 pub target_qtype: QuantizationType,
274
275 pub scheme: QuantizationScheme,
277
278 pub warmup_epochs: usize,
280
281 pub use_ste: bool,
283
284 pub learnable_params: bool,
286}
287
288impl Default for QatConfig {
289 fn default() -> Self {
290 Self {
291 target_qtype: QuantizationType::Int8,
292 scheme: QuantizationScheme::Symmetric,
293 warmup_epochs: 2,
294 use_ste: true,
295 learnable_params: false,
296 }
297 }
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct QuantizationStats {
303 pub num_tensors: usize,
305
306 pub memory_saved: u64,
308
309 pub avg_error: f64,
311
312 pub max_error: f64,
314
315 pub type_distribution: Vec<(QuantizationType, usize)>,
317}
318
319impl QuantizationStats {
320 pub fn new() -> Self {
322 Self {
323 num_tensors: 0,
324 memory_saved: 0,
325 avg_error: 0.0,
326 max_error: 0.0,
327 type_distribution: Vec::new(),
328 }
329 }
330
331 pub fn update(&mut self, original_size: u64, compression_ratio: f64, error: f64) {
333 self.num_tensors += 1;
334 self.memory_saved += (original_size as f64 * (1.0 - 1.0 / compression_ratio)) as u64;
335
336 let n = self.num_tensors as f64;
338 self.avg_error = (self.avg_error * (n - 1.0) + error) / n;
339 self.max_error = self.max_error.max(error);
340 }
341
342 pub fn memory_reduction_pct(&self, total_memory: u64) -> f64 {
344 if total_memory == 0 {
345 0.0
346 } else {
347 (self.memory_saved as f64 / total_memory as f64) * 100.0
348 }
349 }
350}
351
352impl Default for QuantizationStats {
353 fn default() -> Self {
354 Self::new()
355 }
356}
357
358pub fn calibrate_quantization(
360 samples: &[Scirs2Tensor],
361 qtype: QuantizationType,
362 scheme: QuantizationScheme,
363) -> TlBackendResult<QuantizationParams> {
364 if samples.is_empty() {
365 return Err(TlBackendError::GraphError(
366 "Cannot calibrate with empty samples".to_string(),
367 ));
368 }
369
370 let mut global_min = f64::INFINITY;
372 let mut global_max = f64::NEG_INFINITY;
373 let mut global_abs_max = 0.0f64;
374
375 for sample in samples {
376 let sample_min = sample.iter().fold(f64::INFINITY, |a, &b| a.min(b));
377 let sample_max = sample.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
378 let sample_abs_max = sample.iter().map(|&x| x.abs()).fold(0.0f64, f64::max);
379
380 global_min = global_min.min(sample_min);
381 global_max = global_max.max(sample_max);
382 global_abs_max = global_abs_max.max(sample_abs_max);
383 }
384
385 let params = match scheme {
386 QuantizationScheme::Symmetric => {
387 let scale = match qtype {
388 QuantizationType::Int8 => global_abs_max / 127.0,
389 QuantizationType::Int4 => global_abs_max / 7.0,
390 _ => 1.0,
391 };
392
393 QuantizationParams {
394 qtype,
395 scheme,
396 granularity: QuantizationGranularity::PerTensor,
397 scale: vec![scale],
398 zero_point: vec![0],
399 min_val: vec![-global_abs_max],
400 max_val: vec![global_abs_max],
401 }
402 }
403 QuantizationScheme::Asymmetric => {
404 let (scale, zero_point) = match qtype {
405 QuantizationType::Int8 => {
406 let scale = (global_max - global_min) / 255.0;
407 let zero_point = (-global_min / scale).round() as i32;
408 (scale, zero_point)
409 }
410 QuantizationType::Int4 => {
411 let scale = (global_max - global_min) / 15.0;
412 let zero_point = (-global_min / scale).round() as i32;
413 (scale, zero_point)
414 }
415 _ => (1.0, 0),
416 };
417
418 QuantizationParams {
419 qtype,
420 scheme,
421 granularity: QuantizationGranularity::PerTensor,
422 scale: vec![scale],
423 zero_point: vec![zero_point],
424 min_val: vec![global_min],
425 max_val: vec![global_max],
426 }
427 }
428 };
429
430 Ok(params)
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use scirs2_core::ndarray::ArrayD;
437
438 #[test]
439 fn test_quantization_type_properties() {
440 assert_eq!(QuantizationType::Int8.bits(), 8);
441 assert_eq!(QuantizationType::Int4.bits(), 4);
442 assert_eq!(QuantizationType::Fp16.bits(), 16);
443 assert_eq!(QuantizationType::BFloat16.bits(), 16);
444
445 assert_eq!(QuantizationType::Int8.compression_ratio(), 8.0);
446 assert_eq!(QuantizationType::Int4.compression_ratio(), 16.0);
447
448 assert!(QuantizationType::Int8.is_integer());
449 assert!(QuantizationType::Fp16.is_float());
450 }
451
452 #[test]
453 fn test_symmetric_quantization_int8() {
454 let data = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
455 let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).unwrap();
456
457 let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
458
459 assert_eq!(params.scheme, QuantizationScheme::Symmetric);
460 assert_eq!(params.zero_point[0], 0);
461 assert!(params.scale[0] > 0.0);
462 }
463
464 #[test]
465 fn test_asymmetric_quantization_int8() {
466 let data = vec![0.0, 2.0, 4.0, 6.0, 8.0];
467 let tensor = ArrayD::from_shape_vec(vec![5], data).unwrap();
468
469 let params = QuantizationParams::asymmetric_per_tensor(QuantizationType::Int8, &tensor);
470
471 assert_eq!(params.scheme, QuantizationScheme::Asymmetric);
472 assert!(params.zero_point[0] >= 0);
473 assert!(params.scale[0] > 0.0);
474 }
475
476 #[test]
477 fn test_quantize_dequantize_int8() {
478 let data = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
479 let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).unwrap();
480
481 let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
482 let quantized = QuantizedTensor::quantize(&tensor, params);
483 let dequantized = quantized.dequantize();
484
485 for (orig, deq) in tensor.iter().zip(dequantized.iter()) {
487 assert!(
488 (orig - deq).abs() < 0.1,
489 "Original: {}, Dequantized: {}",
490 orig,
491 deq
492 );
493 }
494 }
495
496 #[test]
497 fn test_quantization_error() {
498 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
499 let tensor = ArrayD::from_shape_vec(vec![5], data).unwrap();
500
501 let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
502 let quantized = QuantizedTensor::quantize(&tensor, params);
503
504 let error = quantized.quantization_error(&tensor);
505 assert!(error >= 0.0);
506 assert!(error < 1.0); }
508
509 #[test]
510 fn test_memory_reduction() {
511 let tensor = ArrayD::from_shape_vec(vec![100], vec![1.0; 100]).unwrap();
512 let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
513 let quantized = QuantizedTensor::quantize(&tensor, params);
514
515 assert_eq!(quantized.memory_reduction(), 8.0); }
517
518 #[test]
519 fn test_calibrate_quantization() {
520 let sample1 = ArrayD::from_shape_vec(vec![3], vec![-10.0, 0.0, 10.0]).unwrap();
521 let sample2 = ArrayD::from_shape_vec(vec![3], vec![-8.0, 2.0, 12.0]).unwrap();
522 let samples = vec![sample1, sample2];
523
524 let params = calibrate_quantization(
525 &samples,
526 QuantizationType::Int8,
527 QuantizationScheme::Symmetric,
528 )
529 .unwrap();
530
531 assert!(params.scale[0] > 0.0);
532 assert_eq!(params.zero_point[0], 0); }
534
535 #[test]
536 fn test_quantization_stats() {
537 let mut stats = QuantizationStats::new();
538
539 stats.update(1000, 8.0, 0.01);
540 stats.update(2000, 8.0, 0.02);
541
542 assert_eq!(stats.num_tensors, 2);
543 assert!(stats.memory_saved > 0);
544 assert!(stats.avg_error > 0.0);
545 assert_eq!(stats.max_error, 0.02);
546 }
547
548 #[test]
549 fn test_fp16_quantization() {
550 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
551 let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).unwrap();
552
553 let quantized = quantize_fp16(&tensor);
554
555 for (orig, quant) in tensor.iter().zip(quantized.iter()) {
557 assert!((orig - quant).abs() < 0.001);
558 }
559 }
560
561 #[test]
562 fn test_qat_config_default() {
563 let config = QatConfig::default();
564
565 assert_eq!(config.target_qtype, QuantizationType::Int8);
566 assert_eq!(config.scheme, QuantizationScheme::Symmetric);
567 assert!(config.use_ste);
568 }
569}