1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30pub enum QuantizationMode {
31 Symmetric,
33 Asymmetric,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum BitWidth {
40 Int8,
42 Int4,
44 Int2,
46}
47
48impl BitWidth {
49 pub fn levels(&self) -> i32 {
51 match self {
52 BitWidth::Int8 => 256, BitWidth::Int4 => 16, BitWidth::Int2 => 4, }
56 }
57
58 pub fn qmin(&self) -> i32 {
60 match self {
61 BitWidth::Int8 => -128,
62 BitWidth::Int4 => -8,
63 BitWidth::Int2 => -2,
64 }
65 }
66
67 pub fn qmax(&self) -> i32 {
69 match self {
70 BitWidth::Int8 => 127,
71 BitWidth::Int4 => 7,
72 BitWidth::Int2 => 1,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum Granularity {
80 PerTensor,
82 PerChannel,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct QuantizationConfig {
89 pub mode: QuantizationMode,
91 pub bit_width: BitWidth,
93 pub granularity: Granularity,
95 pub eps: f32,
97}
98
99impl QuantizationConfig {
100 pub fn int8_symmetric() -> Self {
102 Self {
103 mode: QuantizationMode::Symmetric,
104 bit_width: BitWidth::Int8,
105 granularity: Granularity::PerTensor,
106 eps: 1e-8,
107 }
108 }
109
110 pub fn int8_asymmetric() -> Self {
112 Self {
113 mode: QuantizationMode::Asymmetric,
114 bit_width: BitWidth::Int8,
115 granularity: Granularity::PerTensor,
116 eps: 1e-8,
117 }
118 }
119
120 pub fn int4_per_channel() -> Self {
122 Self {
123 mode: QuantizationMode::Symmetric,
124 bit_width: BitWidth::Int4,
125 granularity: Granularity::PerChannel,
126 eps: 1e-8,
127 }
128 }
129
130 pub fn new(mode: QuantizationMode, bit_width: BitWidth, granularity: Granularity) -> Self {
132 Self {
133 mode,
134 bit_width,
135 granularity,
136 eps: 1e-8,
137 }
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct QuantizationParams {
144 pub scale: Array1<f32>,
146 pub zero_point: Array1<i32>,
148 pub config: QuantizationConfig,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct QuantizedTensor {
155 pub data: Array2<i8>,
157 pub params: QuantizationParams,
159}
160
161pub struct Quantizer;
163
164impl Quantizer {
165 pub fn quantize_tensor(
174 tensor: &ArrayView2<f32>,
175 config: &QuantizationConfig,
176 ) -> QuantizedTensor {
177 match config.granularity {
178 Granularity::PerTensor => Self::quantize_per_tensor(tensor, config),
179 Granularity::PerChannel => Self::quantize_per_channel(tensor, config),
180 }
181 }
182
183 fn quantize_per_tensor(
185 tensor: &ArrayView2<f32>,
186 config: &QuantizationConfig,
187 ) -> QuantizedTensor {
188 let (scale, zero_point) = Self::compute_params_tensor(tensor, config);
189
190 let quantized = tensor.mapv(|x| {
191 let q = (x / scale).round() + zero_point as f32;
192 Self::clamp_to_qrange(q as i32, config.bit_width) as i8
193 });
194
195 QuantizedTensor {
196 data: quantized,
197 params: QuantizationParams {
198 scale: Array1::from_vec(vec![scale]),
199 zero_point: Array1::from_vec(vec![zero_point]),
200 config: config.clone(),
201 },
202 }
203 }
204
205 fn quantize_per_channel(
207 tensor: &ArrayView2<f32>,
208 config: &QuantizationConfig,
209 ) -> QuantizedTensor {
210 let num_channels = tensor.shape()[0];
211 let mut scales = Vec::with_capacity(num_channels);
212 let mut zero_points = Vec::with_capacity(num_channels);
213
214 for i in 0..num_channels {
216 let channel = tensor.index_axis(Axis(0), i);
217 let (scale, zero_point) = Self::compute_params_channel(&channel, config);
218 scales.push(scale);
219 zero_points.push(zero_point);
220 }
221
222 let mut quantized = Array2::<i8>::zeros(tensor.dim());
224 for (i, mut row) in quantized.axis_iter_mut(Axis(0)).enumerate() {
225 let channel = tensor.index_axis(Axis(0), i);
226 let scale = scales[i];
227 let zero_point = zero_points[i];
228
229 for (j, &val) in channel.iter().enumerate() {
230 let q = (val / scale).round() + zero_point as f32;
231 row[j] = Self::clamp_to_qrange(q as i32, config.bit_width) as i8;
232 }
233 }
234
235 QuantizedTensor {
236 data: quantized,
237 params: QuantizationParams {
238 scale: Array1::from_vec(scales),
239 zero_point: Array1::from_vec(zero_points),
240 config: config.clone(),
241 },
242 }
243 }
244
245 fn compute_params_tensor(tensor: &ArrayView2<f32>, config: &QuantizationConfig) -> (f32, i32) {
247 let min = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
248 let max = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
249 Self::compute_scale_zero_point(min, max, config)
250 }
251
252 fn compute_params_channel(
254 channel: &ArrayView1<f32>,
255 config: &QuantizationConfig,
256 ) -> (f32, i32) {
257 let min = channel.iter().cloned().fold(f32::INFINITY, f32::min);
258 let max = channel.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
259 Self::compute_scale_zero_point(min, max, config)
260 }
261
262 fn compute_scale_zero_point(min: f32, max: f32, config: &QuantizationConfig) -> (f32, i32) {
264 let qmin = config.bit_width.qmin() as f32;
265 let qmax = config.bit_width.qmax() as f32;
266
267 match config.mode {
268 QuantizationMode::Symmetric => {
269 let abs_max = min.abs().max(max.abs());
270 let scale = (2.0 * abs_max / (qmax - qmin)).max(config.eps);
271 (scale, 0)
272 }
273 QuantizationMode::Asymmetric => {
274 let scale = ((max - min) / (qmax - qmin)).max(config.eps);
275 let zero_point = (qmin - min / scale).round() as i32;
276 let zero_point = Self::clamp_to_qrange(zero_point, config.bit_width);
277 (scale, zero_point)
278 }
279 }
280 }
281
282 fn clamp_to_qrange(value: i32, bit_width: BitWidth) -> i32 {
284 value.max(bit_width.qmin()).min(bit_width.qmax())
285 }
286
287 pub fn dequantize_tensor(quantized: &QuantizedTensor) -> Array2<f32> {
295 match quantized.params.config.granularity {
296 Granularity::PerTensor => {
297 let scale = quantized.params.scale[0];
298 let zero_point = quantized.params.zero_point[0];
299 quantized
300 .data
301 .mapv(|q| scale * (q as f32 - zero_point as f32))
302 }
303 Granularity::PerChannel => {
304 let mut result = Array2::<f32>::zeros(quantized.data.dim());
305 for (i, mut row) in result.axis_iter_mut(Axis(0)).enumerate() {
306 let scale = quantized.params.scale[i];
307 let zero_point = quantized.params.zero_point[i];
308 let q_row = quantized.data.index_axis(Axis(0), i);
309
310 for (j, &q) in q_row.iter().enumerate() {
311 row[j] = scale * (q as f32 - zero_point as f32);
312 }
313 }
314 result
315 }
316 }
317 }
318
319 pub fn compression_ratio(config: &QuantizationConfig) -> f32 {
321 let original_bits = 32.0; let quantized_bits = match config.bit_width {
323 BitWidth::Int8 => 8.0,
324 BitWidth::Int4 => 4.0,
325 BitWidth::Int2 => 2.0,
326 };
327 original_bits / quantized_bits
328 }
329
330 pub fn quantization_error(original: &ArrayView2<f32>, quantized: &QuantizedTensor) -> f32 {
332 let dequantized = Self::dequantize_tensor(quantized);
333 let diff = original - &dequantized.view();
334 diff.mapv(|x| x * x).mean().unwrap_or(0.0)
335 }
336}
337
338pub struct QuantizationAwareTraining {
340 layer_configs: HashMap<String, QuantizationConfig>,
342 simulate_quantization: bool,
344}
345
346impl QuantizationAwareTraining {
347 pub fn new(simulate_quantization: bool) -> Self {
349 Self {
350 layer_configs: HashMap::new(),
351 simulate_quantization,
352 }
353 }
354
355 pub fn register_layer(&mut self, layer_name: String, config: QuantizationConfig) {
357 self.layer_configs.insert(layer_name, config);
358 }
359
360 pub fn fake_quantize(&self, tensor: &Array2<f32>, layer_name: &str) -> Array2<f32> {
365 if !self.simulate_quantization {
366 return tensor.clone();
367 }
368
369 if let Some(config) = self.layer_configs.get(layer_name) {
370 let quantized = Quantizer::quantize_tensor(&tensor.view(), config);
371 Quantizer::dequantize_tensor(&quantized)
372 } else {
373 tensor.clone()
374 }
375 }
376
377 pub fn get_config(&self, layer_name: &str) -> Option<&QuantizationConfig> {
379 self.layer_configs.get(layer_name)
380 }
381
382 pub fn registered_layers(&self) -> Vec<&String> {
384 self.layer_configs.keys().collect()
385 }
386}
387
388pub struct DynamicRangeCalibrator {
390 statistics: HashMap<String, (f32, f32)>,
392 num_samples: usize,
394}
395
396impl DynamicRangeCalibrator {
397 pub fn new() -> Self {
399 Self {
400 statistics: HashMap::new(),
401 num_samples: 0,
402 }
403 }
404
405 pub fn collect(&mut self, layer_name: String, tensor: &ArrayView2<f32>) {
407 let min = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
408 let max = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
409
410 self.statistics
411 .entry(layer_name)
412 .and_modify(|(prev_min, prev_max)| {
413 *prev_min = prev_min.min(min);
414 *prev_max = prev_max.max(max);
415 })
416 .or_insert((min, max));
417
418 self.num_samples += 1;
419 }
420
421 pub fn finalize(
423 &self,
424 default_config: &QuantizationConfig,
425 ) -> HashMap<String, QuantizationConfig> {
426 self.statistics
427 .keys()
428 .map(|name| (name.clone(), default_config.clone()))
429 .collect()
430 }
431
432 pub fn get_range(&self, layer_name: &str) -> Option<(f32, f32)> {
434 self.statistics.get(layer_name).copied()
435 }
436
437 pub fn reset(&mut self) {
439 self.statistics.clear();
440 self.num_samples = 0;
441 }
442}
443
444impl Default for DynamicRangeCalibrator {
445 fn default() -> Self {
446 Self::new()
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453 use approx::assert_relative_eq;
454
455 #[test]
456 fn test_int8_symmetric_quantization() {
457 let tensor = Array2::from_shape_vec((2, 3), vec![-1.0, 0.0, 1.0, -2.0, 2.0, 0.5]).unwrap();
458 let config = QuantizationConfig::int8_symmetric();
459
460 let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
461 let dequantized = Quantizer::dequantize_tensor(&quantized);
462
463 assert_eq!(dequantized.dim(), tensor.dim());
465
466 for (orig, deq) in tensor.iter().zip(dequantized.iter()) {
468 assert_relative_eq!(orig, deq, epsilon = 0.1);
469 }
470 }
471
472 #[test]
473 fn test_int8_asymmetric_quantization() {
474 let tensor = Array2::from_shape_vec((2, 2), vec![0.0, 1.0, 2.0, 3.0]).unwrap();
475 let config = QuantizationConfig::int8_asymmetric();
476
477 let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
478 assert_eq!(quantized.params.config.mode, QuantizationMode::Asymmetric);
479
480 let dequantized = Quantizer::dequantize_tensor(&quantized);
481 assert_relative_eq!(dequantized[[0, 0]], 0.0, epsilon = 0.05);
482 assert_relative_eq!(dequantized[[1, 1]], 3.0, epsilon = 0.05);
483 }
484
485 #[test]
486 fn test_int4_per_channel_quantization() {
487 let tensor =
488 Array2::from_shape_vec((2, 4), vec![-1.0, 0.0, 1.0, 2.0, -10.0, -5.0, 5.0, 10.0])
489 .unwrap();
490 let config = QuantizationConfig::int4_per_channel();
491
492 let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
493
494 assert_eq!(quantized.params.scale.len(), 2);
496 assert_eq!(quantized.params.zero_point.len(), 2);
497
498 let dequantized = Quantizer::dequantize_tensor(&quantized);
499 assert_eq!(dequantized.dim(), tensor.dim());
500 }
501
502 #[test]
503 fn test_bit_width_levels() {
504 assert_eq!(BitWidth::Int8.levels(), 256);
505 assert_eq!(BitWidth::Int4.levels(), 16);
506 assert_eq!(BitWidth::Int2.levels(), 4);
507 }
508
509 #[test]
510 fn test_bit_width_ranges() {
511 assert_eq!(BitWidth::Int8.qmin(), -128);
512 assert_eq!(BitWidth::Int8.qmax(), 127);
513 assert_eq!(BitWidth::Int4.qmin(), -8);
514 assert_eq!(BitWidth::Int4.qmax(), 7);
515 }
516
517 #[test]
518 fn test_compression_ratio() {
519 let config_int8 = QuantizationConfig::int8_symmetric();
520 assert_eq!(Quantizer::compression_ratio(&config_int8), 4.0);
521
522 let config_int4 = QuantizationConfig::new(
523 QuantizationMode::Symmetric,
524 BitWidth::Int4,
525 Granularity::PerTensor,
526 );
527 assert_eq!(Quantizer::compression_ratio(&config_int4), 8.0);
528 }
529
530 #[test]
531 fn test_quantization_error() {
532 let tensor = Array2::from_shape_vec((3, 3), vec![1.0; 9]).unwrap();
533 let config = QuantizationConfig::int8_symmetric();
534
535 let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
536 let error = Quantizer::quantization_error(&tensor.view(), &quantized);
537
538 assert!(error < 0.01);
540 }
541
542 #[test]
543 fn test_qat_registration() {
544 let mut qat = QuantizationAwareTraining::new(true);
545 qat.register_layer("layer1".to_string(), QuantizationConfig::int8_symmetric());
546 qat.register_layer("layer2".to_string(), QuantizationConfig::int4_per_channel());
547
548 assert_eq!(qat.registered_layers().len(), 2);
549 assert!(qat.get_config("layer1").is_some());
550 assert!(qat.get_config("layer3").is_none());
551 }
552
553 #[test]
554 fn test_fake_quantization() {
555 let mut qat = QuantizationAwareTraining::new(true);
556 qat.register_layer("fc1".to_string(), QuantizationConfig::int8_symmetric());
557
558 let tensor = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
559 let fake_quantized = qat.fake_quantize(&tensor, "fc1");
560
561 assert_eq!(fake_quantized.dim(), tensor.dim());
563 }
564
565 #[test]
566 fn test_dynamic_range_calibrator() {
567 let mut calibrator = DynamicRangeCalibrator::new();
568
569 let tensor1 = Array2::from_shape_vec((2, 2), vec![0.0, 1.0, 2.0, 3.0]).unwrap();
570 let tensor2 = Array2::from_shape_vec((2, 2), vec![-1.0, 0.0, 1.0, 4.0]).unwrap();
571
572 calibrator.collect("layer1".to_string(), &tensor1.view());
573 calibrator.collect("layer1".to_string(), &tensor2.view());
574
575 let (min, max) = calibrator.get_range("layer1").unwrap();
576 assert_eq!(min, -1.0);
577 assert_eq!(max, 4.0);
578 }
579
580 #[test]
581 fn test_calibrator_finalize() {
582 let mut calibrator = DynamicRangeCalibrator::new();
583 let tensor = Array2::from_shape_vec((2, 2), vec![1.0; 4]).unwrap();
584
585 calibrator.collect("layer1".to_string(), &tensor.view());
586 calibrator.collect("layer2".to_string(), &tensor.view());
587
588 let config = QuantizationConfig::int8_symmetric();
589 let configs = calibrator.finalize(&config);
590
591 assert_eq!(configs.len(), 2);
592 assert!(configs.contains_key("layer1"));
593 assert!(configs.contains_key("layer2"));
594 }
595
596 #[test]
597 fn test_calibrator_reset() {
598 let mut calibrator = DynamicRangeCalibrator::new();
599 let tensor = Array2::from_shape_vec((2, 2), vec![1.0; 4]).unwrap();
600
601 calibrator.collect("layer1".to_string(), &tensor.view());
602 assert_eq!(calibrator.num_samples, 1);
603
604 calibrator.reset();
605 assert_eq!(calibrator.num_samples, 0);
606 assert!(calibrator.get_range("layer1").is_none());
607 }
608
609 #[test]
610 fn test_zero_tensor_quantization() {
611 let tensor = Array2::zeros((3, 3));
612 let config = QuantizationConfig::int8_symmetric();
613
614 let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
615 let dequantized = Quantizer::dequantize_tensor(&quantized);
616
617 assert_eq!(dequantized, tensor);
618 }
619
620 #[test]
621 fn test_extreme_values_quantization() {
622 let tensor = Array2::from_shape_vec(
623 (2, 2),
624 vec![f32::MIN / 1e6, f32::MAX / 1e6, -1000.0, 1000.0],
625 )
626 .unwrap();
627 let config = QuantizationConfig::int8_symmetric();
628
629 let quantized = Quantizer::quantize_tensor(&tensor.view(), &config);
630 let dequantized = Quantizer::dequantize_tensor(&quantized);
631
632 assert_eq!(dequantized.dim(), tensor.dim());
634 }
635}