1use crate::{Scirs2Tensor, TlBackendError, TlBackendResult};
9use scirs2_core::ndarray;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub enum QuantizationType {
15 Int8,
17 Fp16,
19 BFloat16,
21 Int4,
23 None,
25}
26
27impl QuantizationType {
28 pub fn bits(&self) -> usize {
30 match self {
31 QuantizationType::Int4 => 4,
32 QuantizationType::Int8 => 8,
33 QuantizationType::Fp16 | QuantizationType::BFloat16 => 16,
34 QuantizationType::None => 64, }
36 }
37
38 pub fn compression_ratio(&self) -> f64 {
40 64.0 / self.bits() as f64
41 }
42
43 pub fn is_float(&self) -> bool {
45 matches!(
46 self,
47 QuantizationType::Fp16 | QuantizationType::BFloat16 | QuantizationType::None
48 )
49 }
50
51 pub fn is_integer(&self) -> bool {
53 matches!(self, QuantizationType::Int8 | QuantizationType::Int4)
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum QuantizationScheme {
60 Symmetric,
62 Asymmetric,
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
68pub enum QuantizationGranularity {
69 PerTensor,
71 PerChannel,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct QuantizationParams {
78 pub qtype: QuantizationType,
80
81 pub scheme: QuantizationScheme,
83
84 pub granularity: QuantizationGranularity,
86
87 pub scale: Vec<f64>,
89
90 pub zero_point: Vec<i32>,
92
93 pub min_val: Vec<f64>,
95
96 pub max_val: Vec<f64>,
98}
99
100impl QuantizationParams {
101 pub fn symmetric_per_tensor(qtype: QuantizationType, tensor: &Scirs2Tensor) -> Self {
103 let abs_max = tensor.iter().map(|&x| x.abs()).fold(0.0f64, f64::max);
104
105 let scale = match qtype {
106 QuantizationType::Int8 => abs_max / 127.0,
107 QuantizationType::Int4 => abs_max / 7.0,
108 QuantizationType::Fp16 | QuantizationType::BFloat16 => 1.0,
109 QuantizationType::None => 1.0,
110 };
111
112 Self {
113 qtype,
114 scheme: QuantizationScheme::Symmetric,
115 granularity: QuantizationGranularity::PerTensor,
116 scale: vec![scale],
117 zero_point: vec![0],
118 min_val: vec![-abs_max],
119 max_val: vec![abs_max],
120 }
121 }
122
123 pub fn asymmetric_per_tensor(qtype: QuantizationType, tensor: &Scirs2Tensor) -> Self {
125 let min_val = tensor.iter().fold(f64::INFINITY, |a, &b| a.min(b));
126 let max_val = tensor.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
127
128 let (scale, zero_point) = match qtype {
129 QuantizationType::Int8 => {
130 let scale = (max_val - min_val) / 255.0;
131 let zero_point = (-min_val / scale).round() as i32;
132 (scale, zero_point)
133 }
134 QuantizationType::Int4 => {
135 let scale = (max_val - min_val) / 15.0;
136 let zero_point = (-min_val / scale).round() as i32;
137 (scale, zero_point)
138 }
139 QuantizationType::Fp16 | QuantizationType::BFloat16 | QuantizationType::None => {
140 (1.0, 0)
141 }
142 };
143
144 Self {
145 qtype,
146 scheme: QuantizationScheme::Asymmetric,
147 granularity: QuantizationGranularity::PerTensor,
148 scale: vec![scale],
149 zero_point: vec![zero_point],
150 min_val: vec![min_val],
151 max_val: vec![max_val],
152 }
153 }
154
155 pub fn dynamic_range(&self) -> f64 {
157 self.max_val[0] - self.min_val[0]
158 }
159
160 pub fn quantization_error_bound(&self) -> f64 {
162 self.scale[0] / 2.0
163 }
164}
165
166#[derive(Debug, Clone)]
168pub struct QuantizedTensor {
169 pub data: Scirs2Tensor,
171
172 pub params: QuantizationParams,
174}
175
176impl QuantizedTensor {
177 pub fn quantize(tensor: &Scirs2Tensor, params: QuantizationParams) -> Self {
179 let quantized_data = match params.qtype {
180 QuantizationType::Int8 => quantize_int8(tensor, ¶ms),
181 QuantizationType::Int4 => quantize_int4(tensor, ¶ms),
182 QuantizationType::Fp16 => quantize_fp16(tensor),
183 QuantizationType::BFloat16 => quantize_bf16(tensor),
184 QuantizationType::None => tensor.clone(),
185 };
186
187 Self {
188 data: quantized_data,
189 params,
190 }
191 }
192
193 pub fn dequantize(&self) -> Scirs2Tensor {
195 match self.params.qtype {
196 QuantizationType::Int8 | QuantizationType::Int4 => {
197 dequantize_integer(&self.data, &self.params)
198 }
199 QuantizationType::Fp16 | QuantizationType::BFloat16 => {
200 self.data.clone()
202 }
203 QuantizationType::None => self.data.clone(),
204 }
205 }
206
207 pub fn memory_reduction(&self) -> f64 {
209 self.params.qtype.compression_ratio()
210 }
211
212 pub fn quantization_error(&self, original: &Scirs2Tensor) -> f64 {
214 let dequantized = self.dequantize();
215 let diff = &dequantized - original;
216 let squared_error: f64 = diff.iter().map(|&x| x * x).sum();
217 squared_error / original.len() as f64
218 }
219}
220
221fn quantize_int8(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
227 match params.granularity {
228 QuantizationGranularity::PerTensor => {
229 let scale = params.scale[0];
230 let zero_point = params.zero_point[0] as f64;
231 tensor.mapv(|x| ((x / scale).round() + zero_point).clamp(-128.0, 127.0))
232 }
233 QuantizationGranularity::PerChannel => {
234 let n_channels = tensor.shape()[0];
235 let mut out = tensor.clone();
236 for (c, mut slab) in out.axis_iter_mut(ndarray::Axis(0)).enumerate() {
237 if c >= params.scale.len() {
238 break;
240 }
241 let s = params.scale[c];
242 let zp = params.zero_point[c] as f64;
243 slab.mapv_inplace(|x| ((x / s).round() + zp).clamp(-128.0, 127.0));
244 }
245 let _ = n_channels; out
247 }
248 }
249}
250
251fn quantize_int4(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
257 match params.granularity {
258 QuantizationGranularity::PerTensor => {
259 let scale = params.scale[0];
260 let zero_point = params.zero_point[0] as f64;
261 tensor.mapv(|x| ((x / scale).round() + zero_point).clamp(-8.0, 7.0))
262 }
263 QuantizationGranularity::PerChannel => {
264 let n_channels = tensor.shape()[0];
265 let mut out = tensor.clone();
266 for (c, mut slab) in out.axis_iter_mut(ndarray::Axis(0)).enumerate() {
267 if c >= params.scale.len() {
268 break;
269 }
270 let s = params.scale[c];
271 let zp = params.zero_point[c] as f64;
272 slab.mapv_inplace(|x| ((x / s).round() + zp).clamp(-8.0, 7.0));
273 }
274 let _ = n_channels;
275 out
276 }
277 }
278}
279
280fn quantize_fp16(tensor: &Scirs2Tensor) -> Scirs2Tensor {
282 tensor.mapv(|x| {
283 let scaled = x * (1024.0f64).powi(2);
286 (scaled.round() / (1024.0f64).powi(2)).clamp(-65504.0, 65504.0)
287 })
288}
289
290fn quantize_bf16(tensor: &Scirs2Tensor) -> Scirs2Tensor {
292 tensor.mapv(|x| {
293 let scaled = x * (128.0f64).powi(2);
295 scaled.round() / (128.0f64).powi(2)
296 })
297}
298
299fn dequantize_integer(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
305 match params.granularity {
306 QuantizationGranularity::PerTensor => {
307 let scale = params.scale[0];
308 let zero_point = params.zero_point[0] as f64;
309 tensor.mapv(|q| (q - zero_point) * scale)
310 }
311 QuantizationGranularity::PerChannel => {
312 let mut out = tensor.clone();
313 for (c, mut slab) in out.axis_iter_mut(ndarray::Axis(0)).enumerate() {
314 if c >= params.scale.len() {
315 break;
316 }
317 let s = params.scale[c];
318 let zp = params.zero_point[c] as f64;
319 slab.mapv_inplace(|q| (q - zp) * s);
320 }
321 out
322 }
323 }
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct QatConfig {
329 pub target_qtype: QuantizationType,
331
332 pub scheme: QuantizationScheme,
334
335 pub warmup_epochs: usize,
337
338 pub use_ste: bool,
340
341 pub learnable_params: bool,
343}
344
345impl Default for QatConfig {
346 fn default() -> Self {
347 Self {
348 target_qtype: QuantizationType::Int8,
349 scheme: QuantizationScheme::Symmetric,
350 warmup_epochs: 2,
351 use_ste: true,
352 learnable_params: false,
353 }
354 }
355}
356
357#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct QuantizationStats {
360 pub num_tensors: usize,
362
363 pub memory_saved: u64,
365
366 pub avg_error: f64,
368
369 pub max_error: f64,
371
372 pub type_distribution: Vec<(QuantizationType, usize)>,
374}
375
376impl QuantizationStats {
377 pub fn new() -> Self {
379 Self {
380 num_tensors: 0,
381 memory_saved: 0,
382 avg_error: 0.0,
383 max_error: 0.0,
384 type_distribution: Vec::new(),
385 }
386 }
387
388 pub fn update(&mut self, original_size: u64, compression_ratio: f64, error: f64) {
390 self.num_tensors += 1;
391 self.memory_saved += (original_size as f64 * (1.0 - 1.0 / compression_ratio)) as u64;
392
393 let n = self.num_tensors as f64;
395 self.avg_error = (self.avg_error * (n - 1.0) + error) / n;
396 self.max_error = self.max_error.max(error);
397 }
398
399 pub fn memory_reduction_pct(&self, total_memory: u64) -> f64 {
401 if total_memory == 0 {
402 0.0
403 } else {
404 (self.memory_saved as f64 / total_memory as f64) * 100.0
405 }
406 }
407}
408
409impl Default for QuantizationStats {
410 fn default() -> Self {
411 Self::new()
412 }
413}
414
415pub fn calibrate_quantization(
417 samples: &[Scirs2Tensor],
418 qtype: QuantizationType,
419 scheme: QuantizationScheme,
420) -> TlBackendResult<QuantizationParams> {
421 if samples.is_empty() {
422 return Err(TlBackendError::GraphError(
423 "Cannot calibrate with empty samples".to_string(),
424 ));
425 }
426
427 let mut global_min = f64::INFINITY;
429 let mut global_max = f64::NEG_INFINITY;
430 let mut global_abs_max = 0.0f64;
431
432 for sample in samples {
433 let sample_min = sample.iter().fold(f64::INFINITY, |a, &b| a.min(b));
434 let sample_max = sample.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
435 let sample_abs_max = sample.iter().map(|&x| x.abs()).fold(0.0f64, f64::max);
436
437 global_min = global_min.min(sample_min);
438 global_max = global_max.max(sample_max);
439 global_abs_max = global_abs_max.max(sample_abs_max);
440 }
441
442 let params = match scheme {
443 QuantizationScheme::Symmetric => {
444 let scale = match qtype {
445 QuantizationType::Int8 => global_abs_max / 127.0,
446 QuantizationType::Int4 => global_abs_max / 7.0,
447 _ => 1.0,
448 };
449
450 QuantizationParams {
451 qtype,
452 scheme,
453 granularity: QuantizationGranularity::PerTensor,
454 scale: vec![scale],
455 zero_point: vec![0],
456 min_val: vec![-global_abs_max],
457 max_val: vec![global_abs_max],
458 }
459 }
460 QuantizationScheme::Asymmetric => {
461 let (scale, zero_point) = match qtype {
462 QuantizationType::Int8 => {
463 let scale = (global_max - global_min) / 255.0;
464 let zero_point = (-global_min / scale).round() as i32;
465 (scale, zero_point)
466 }
467 QuantizationType::Int4 => {
468 let scale = (global_max - global_min) / 15.0;
469 let zero_point = (-global_min / scale).round() as i32;
470 (scale, zero_point)
471 }
472 _ => (1.0, 0),
473 };
474
475 QuantizationParams {
476 qtype,
477 scheme,
478 granularity: QuantizationGranularity::PerTensor,
479 scale: vec![scale],
480 zero_point: vec![zero_point],
481 min_val: vec![global_min],
482 max_val: vec![global_max],
483 }
484 }
485 };
486
487 Ok(params)
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use scirs2_core::ndarray::ArrayD;
494
495 #[test]
496 fn test_quantization_type_properties() {
497 assert_eq!(QuantizationType::Int8.bits(), 8);
498 assert_eq!(QuantizationType::Int4.bits(), 4);
499 assert_eq!(QuantizationType::Fp16.bits(), 16);
500 assert_eq!(QuantizationType::BFloat16.bits(), 16);
501
502 assert_eq!(QuantizationType::Int8.compression_ratio(), 8.0);
503 assert_eq!(QuantizationType::Int4.compression_ratio(), 16.0);
504
505 assert!(QuantizationType::Int8.is_integer());
506 assert!(QuantizationType::Fp16.is_float());
507 }
508
509 #[test]
510 fn test_symmetric_quantization_int8() {
511 let data = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
512 let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).expect("unwrap");
513
514 let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
515
516 assert_eq!(params.scheme, QuantizationScheme::Symmetric);
517 assert_eq!(params.zero_point[0], 0);
518 assert!(params.scale[0] > 0.0);
519 }
520
521 #[test]
522 fn test_asymmetric_quantization_int8() {
523 let data = vec![0.0, 2.0, 4.0, 6.0, 8.0];
524 let tensor = ArrayD::from_shape_vec(vec![5], data).expect("unwrap");
525
526 let params = QuantizationParams::asymmetric_per_tensor(QuantizationType::Int8, &tensor);
527
528 assert_eq!(params.scheme, QuantizationScheme::Asymmetric);
529 assert!(params.zero_point[0] >= 0);
530 assert!(params.scale[0] > 0.0);
531 }
532
533 #[test]
534 fn test_quantize_dequantize_int8() {
535 let data = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
536 let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).expect("unwrap");
537
538 let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
539 let quantized = QuantizedTensor::quantize(&tensor, params);
540 let dequantized = quantized.dequantize();
541
542 for (orig, deq) in tensor.iter().zip(dequantized.iter()) {
544 assert!(
545 (orig - deq).abs() < 0.1,
546 "Original: {}, Dequantized: {}",
547 orig,
548 deq
549 );
550 }
551 }
552
553 #[test]
554 fn test_quantization_error() {
555 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
556 let tensor = ArrayD::from_shape_vec(vec![5], data).expect("unwrap");
557
558 let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
559 let quantized = QuantizedTensor::quantize(&tensor, params);
560
561 let error = quantized.quantization_error(&tensor);
562 assert!(error >= 0.0);
563 assert!(error < 1.0); }
565
566 #[test]
567 fn test_memory_reduction() {
568 let tensor = ArrayD::from_shape_vec(vec![100], vec![1.0; 100]).expect("unwrap");
569 let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
570 let quantized = QuantizedTensor::quantize(&tensor, params);
571
572 assert_eq!(quantized.memory_reduction(), 8.0); }
574
575 #[test]
576 fn test_calibrate_quantization() {
577 let sample1 = ArrayD::from_shape_vec(vec![3], vec![-10.0, 0.0, 10.0]).expect("unwrap");
578 let sample2 = ArrayD::from_shape_vec(vec![3], vec![-8.0, 2.0, 12.0]).expect("unwrap");
579 let samples = vec![sample1, sample2];
580
581 let params = calibrate_quantization(
582 &samples,
583 QuantizationType::Int8,
584 QuantizationScheme::Symmetric,
585 )
586 .expect("unwrap");
587
588 assert!(params.scale[0] > 0.0);
589 assert_eq!(params.zero_point[0], 0); }
591
592 #[test]
593 fn test_quantization_stats() {
594 let mut stats = QuantizationStats::new();
595
596 stats.update(1000, 8.0, 0.01);
597 stats.update(2000, 8.0, 0.02);
598
599 assert_eq!(stats.num_tensors, 2);
600 assert!(stats.memory_saved > 0);
601 assert!(stats.avg_error > 0.0);
602 assert_eq!(stats.max_error, 0.02);
603 }
604
605 #[test]
606 fn test_fp16_quantization() {
607 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
608 let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).expect("unwrap");
609
610 let quantized = quantize_fp16(&tensor);
611
612 for (orig, quant) in tensor.iter().zip(quantized.iter()) {
614 assert!((orig - quant).abs() < 0.001);
615 }
616 }
617
618 #[test]
619 fn test_qat_config_default() {
620 let config = QatConfig::default();
621
622 assert_eq!(config.target_qtype, QuantizationType::Int8);
623 assert_eq!(config.scheme, QuantizationScheme::Symmetric);
624 assert!(config.use_ste);
625 }
626
627 fn make_per_channel_params_int8() -> QuantizationParams {
634 let scale_0 = 100.0_f64 / 127.0;
637 let scale_1 = 1.0_f64 / 127.0;
638 QuantizationParams {
639 qtype: QuantizationType::Int8,
640 scheme: QuantizationScheme::Symmetric,
641 granularity: QuantizationGranularity::PerChannel,
642 scale: vec![scale_0, scale_1],
643 zero_point: vec![0, 0],
644 min_val: vec![-100.0, -1.0],
645 max_val: vec![100.0, 1.0],
646 }
647 }
648
649 #[test]
650 fn test_per_channel_uses_different_scales() {
651 let params = make_per_channel_params_int8();
652 assert!(
654 (params.scale[0] - params.scale[1]).abs() > 0.1,
655 "scale[0]={} scale[1]={} should differ",
656 params.scale[0],
657 params.scale[1]
658 );
659 }
660
661 #[test]
662 fn test_per_channel_quantize_int8_uses_channel_scale() {
663 let data = vec![100.0, -100.0, 50.0, 1.0, -1.0, 0.5];
666 let tensor = ArrayD::from_shape_vec(vec![2, 3], data).expect("build tensor");
667
668 let params = make_per_channel_params_int8();
669 let quantized_tensor = QuantizedTensor::quantize(&tensor, params.clone());
670
671 let row0_q_first = quantized_tensor
673 .data
674 .slice(ndarray::s![0, ..])
675 .iter()
676 .copied()
677 .next()
678 .unwrap_or(f64::NAN);
679 let row1_q_first = quantized_tensor
681 .data
682 .slice(ndarray::s![1, ..])
683 .iter()
684 .copied()
685 .next()
686 .unwrap_or(f64::NAN);
687
688 assert!(
690 (row0_q_first - 127.0).abs() < 2.0,
691 "row0[0]={row0_q_first} expected ≈127"
692 );
693 assert!(
694 (row1_q_first - 127.0).abs() < 2.0,
695 "row1[0]={row1_q_first} expected ≈127"
696 );
697
698 let dequantized = quantized_tensor.dequantize();
700
701 let orig_r0_c0 = 100.0_f64;
702 let deq_r0_c0 = dequantized
703 .slice(ndarray::s![0, 0])
704 .first()
705 .copied()
706 .unwrap_or(f64::NAN);
707 assert!(
708 (orig_r0_c0 - deq_r0_c0).abs() < 1.0,
709 "round-trip row0[0]: orig={} deq={}",
710 orig_r0_c0,
711 deq_r0_c0
712 );
713
714 let orig_r1_c0 = 1.0_f64;
715 let deq_r1_c0 = dequantized
716 .slice(ndarray::s![1, 0])
717 .first()
718 .copied()
719 .unwrap_or(f64::NAN);
720 assert!(
721 (orig_r1_c0 - deq_r1_c0).abs() < 0.02,
722 "round-trip row1[0]: orig={} deq={}",
723 orig_r1_c0,
724 deq_r1_c0
725 );
726 }
727
728 #[test]
729 fn test_per_channel_roundtrip_preserves_row_fidelity() {
730 let data = vec![100.0, -100.0, 50.0, 1.0, -1.0, 0.5];
734 let tensor = ArrayD::from_shape_vec(vec![2, 3], data).expect("build tensor");
735
736 let params = make_per_channel_params_int8();
737 let quantized = QuantizedTensor::quantize(&tensor, params);
738 let dequantized = quantized.dequantize();
739
740 let orig_vals = [1.0_f64, -1.0, 0.5];
742 for (col, &expected) in orig_vals.iter().enumerate() {
743 let got = *dequantized
744 .slice(ndarray::s![1, col..col + 1])
745 .iter()
746 .next()
747 .expect("element");
748 assert!(
749 (expected - got).abs() < 0.02,
750 "row1 col{}: expected={} got={}",
751 col,
752 expected,
753 got
754 );
755 }
756 }
757}