1pub mod calibration;
14pub mod gptq;
15pub mod int4;
16pub mod int8;
17
18use crate::error::{Error, Result};
19use scirs2_core::ndarray::ArrayStatCompat;
20use scirs2_core::ndarray::{ArrayD, ArrayView, Zip};
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct QuantizationConfig {
27 pub bits: u8,
29 pub signed: bool,
31 pub scheme: QuantizationScheme,
33 pub calibration_size: usize,
35 pub mode: QuantizationMode,
37 pub per_channel: bool,
39 pub range_clipping: f32,
41}
42
43impl Default for QuantizationConfig {
44 fn default() -> Self {
45 Self {
46 bits: 8,
47 signed: true,
48 scheme: QuantizationScheme::Symmetric,
49 calibration_size: 1000,
50 mode: QuantizationMode::Static,
51 per_channel: false,
52 range_clipping: 0.999,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum QuantizationScheme {
60 Symmetric,
62 Asymmetric,
64 PowerOfTwo,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum QuantizationMode {
71 Static,
73 Dynamic,
75 QAT,
77}
78
79#[derive(Debug, Clone)]
81pub struct QuantizationParams {
82 pub scale: f32,
84 pub zero_point: i32,
86 pub bits: u8,
88 pub qmin: i32,
90 pub qmax: i32,
92}
93
94impl QuantizationParams {
95 pub fn new(bits: u8, signed: bool) -> Self {
97 let (qmin, qmax) = if signed {
98 (
99 -(1i32 << (bits as i32 - 1)),
100 (1i32 << (bits as i32 - 1)) - 1,
101 )
102 } else {
103 (0, (1i32 << bits as i32) - 1)
104 };
105 Self {
106 scale: 1.0,
107 zero_point: 0,
108 bits,
109 qmin,
110 qmax,
111 }
112 }
113
114 pub fn from_tensor(
116 tensor: &ArrayView<f32, scirs2_core::ndarray::IxDyn>,
117 config: &QuantizationConfig,
118 ) -> Result<Self> {
119 let mut params = Self::new(config.bits, config.signed);
120
121 let min_val = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
123 let max_val = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
124
125 let range = max_val - min_val;
127 let clipped_range = range * config.range_clipping;
128 let center = (max_val + min_val) / 2.0;
129 let clipped_min = center - clipped_range / 2.0;
130 let clipped_max = center + clipped_range / 2.0;
131
132 match config.scheme {
133 QuantizationScheme::Symmetric => {
134 let abs_max = clipped_max.abs().max(clipped_min.abs());
135 let denom = (params.qmax - params.qmin) as f32;
136 params.scale = if denom > 0.0 {
137 (2.0 * abs_max) / denom
138 } else {
139 1.0
140 };
141 params.zero_point = 0;
142 }
143 QuantizationScheme::Asymmetric => {
144 let denom = (params.qmax - params.qmin) as f32;
145 params.scale = if denom > 0.0 {
146 (clipped_max - clipped_min) / denom
147 } else {
148 1.0
149 };
150 if params.scale > 0.0 {
151 params.zero_point = params.qmin - (clipped_min / params.scale).round() as i32;
152 }
153 }
154 QuantizationScheme::PowerOfTwo => {
155 let abs_max = clipped_max.abs().max(clipped_min.abs());
156 let divisor = (1i32 << (config.bits as i32 - 1)) as f32;
157 if divisor > 0.0 && abs_max > 0.0 {
158 let scale_log2 = (abs_max / divisor).log2().ceil();
159 params.scale = 2.0_f32.powf(scale_log2);
160 }
161 params.zero_point = 0;
162 }
163 }
164 Ok(params)
165 }
166}
167
168#[derive(Debug, Clone)]
170pub struct QuantizedTensor {
171 pub data: ArrayD<i8>,
173 pub params: QuantizationParams,
175 pub shape: Vec<usize>,
177}
178
179impl QuantizedTensor {
180 pub fn from_float(tensor: &ArrayD<f32>, config: &QuantizationConfig) -> Result<Self> {
182 let params = QuantizationParams::from_tensor(&tensor.view(), config)?;
183 let quantized_data = Self::quantize_tensor(tensor, ¶ms)?;
184 Ok(Self {
185 data: quantized_data,
186 params,
187 shape: tensor.shape().to_vec(),
188 })
189 }
190
191 fn quantize_tensor(tensor: &ArrayD<f32>, params: &QuantizationParams) -> Result<ArrayD<i8>> {
193 let quantized = tensor.mapv(|x| {
194 let q_val = if params.scale > 0.0 {
195 (x / params.scale).round() + params.zero_point as f32
196 } else {
197 params.zero_point as f32
198 };
199 let clamped = q_val.max(params.qmin as f32).min(params.qmax as f32);
200 clamped as i8
201 });
202 Ok(quantized)
203 }
204
205 pub fn dequantize(&self) -> ArrayD<f32> {
207 self.data
208 .mapv(|q| (q as f32 - self.params.zero_point as f32) * self.params.scale)
209 }
210
211 pub fn size_bytes(&self) -> usize {
213 self.data.len() + std::mem::size_of::<QuantizationParams>()
214 }
215
216 pub fn compression_ratio(&self) -> f32 {
218 let original_size = self.data.len() * std::mem::size_of::<f32>();
219 let quantized_size = self.size_bytes();
220 if quantized_size > 0 {
221 original_size as f32 / quantized_size as f32
222 } else {
223 1.0
224 }
225 }
226}
227
228#[derive(Debug)]
230pub struct PostTrainingQuantizer {
231 config: QuantizationConfig,
233 calibration_stats: HashMap<String, TensorStats>,
235}
236
237#[derive(Debug)]
239struct TensorStats {
240 min: f32,
241 max: f32,
242 mean: f32,
243 #[allow(dead_code)]
244 std: f32,
245 histogram: Vec<u32>,
246}
247
248impl TensorStats {
249 fn new() -> Self {
250 Self {
251 min: f32::INFINITY,
252 max: f32::NEG_INFINITY,
253 mean: 0.0,
254 std: 0.0,
255 histogram: vec![0; 256],
256 }
257 }
258
259 fn update(&mut self, tensor: &ArrayView<f32, scirs2_core::ndarray::IxDyn>) {
260 if let Some(&min_v) = tensor
261 .iter()
262 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
263 {
264 self.min = self.min.min(min_v);
265 }
266 if let Some(&max_v) = tensor
267 .iter()
268 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
269 {
270 self.max = self.max.max(max_v);
271 }
272
273 let sum: f32 = tensor.sum();
274 let count = tensor.len() as f32;
275 if count > 0.0 {
276 self.mean = sum / count;
277 }
278 let variance: f32 =
279 tensor.iter().map(|&x| (x - self.mean).powi(2)).sum::<f32>() / count.max(1.0);
280 self.std = variance.sqrt();
281
282 let range = self.max - self.min;
284 if range > 0.0 {
285 for &val in tensor.iter() {
286 let normalized = ((val - self.min) / range * 255.0).round() as usize;
287 let bin = normalized.min(255);
288 self.histogram[bin] += 1;
289 }
290 }
291 }
292}
293
294impl PostTrainingQuantizer {
295 pub fn new(config: QuantizationConfig) -> Self {
297 Self {
298 config,
299 calibration_stats: HashMap::new(),
300 }
301 }
302
303 pub fn add_calibration_data(&mut self, name: &str, tensor: &ArrayD<f32>) {
305 let stats = self
306 .calibration_stats
307 .entry(name.to_string())
308 .or_insert_with(TensorStats::new);
309 stats.update(&tensor.view());
310 }
311
312 pub fn finalize_calibration(&mut self) -> Result<HashMap<String, QuantizationParams>> {
314 let mut params_map = HashMap::new();
315 for (name, stats) in &self.calibration_stats {
316 let optimal_params = self.compute_optimal_params(stats)?;
317 params_map.insert(name.clone(), optimal_params);
318 }
319 Ok(params_map)
320 }
321
322 fn compute_optimal_params(&self, stats: &TensorStats) -> Result<QuantizationParams> {
324 let mut best_params = QuantizationParams::new(self.config.bits, self.config.signed);
325 let mut best_kl_div = f32::INFINITY;
326
327 for threshold_idx in 128..=255 {
328 let threshold = stats.min + (threshold_idx as f32 / 255.0) * (stats.max - stats.min);
329 let mut params = QuantizationParams::new(self.config.bits, self.config.signed);
330
331 match self.config.scheme {
332 QuantizationScheme::Symmetric => {
333 let denom = (params.qmax - params.qmin) as f32;
334 params.scale = if denom > 0.0 {
335 (2.0 * threshold) / denom
336 } else {
337 1.0
338 };
339 params.zero_point = 0;
340 }
341 QuantizationScheme::Asymmetric => {
342 let denom = (params.qmax - params.qmin) as f32;
343 params.scale = if denom > 0.0 {
344 (threshold - stats.min) / denom
345 } else {
346 1.0
347 };
348 if params.scale > 0.0 {
349 params.zero_point = params.qmin - (stats.min / params.scale).round() as i32;
350 }
351 }
352 QuantizationScheme::PowerOfTwo => {
353 let divisor = (1i32 << (self.config.bits as i32 - 1)) as f32;
354 if divisor > 0.0 && threshold > 0.0 {
355 let scale_log2 = (threshold / divisor).log2().ceil();
356 params.scale = 2.0_f32.powf(scale_log2);
357 }
358 params.zero_point = 0;
359 }
360 }
361
362 let kl_div = self.compute_kl_divergence(&stats.histogram, ¶ms);
363 if kl_div < best_kl_div {
364 best_kl_div = kl_div;
365 best_params = params;
366 }
367 }
368 Ok(best_params)
369 }
370
371 fn compute_kl_divergence(&self, histogram: &[u32], params: &QuantizationParams) -> f32 {
373 let total_count: u32 = histogram.iter().sum();
374 if total_count == 0 {
375 return 0.0;
376 }
377 let mut kl_div = 0.0;
378 for (i, &count) in histogram.iter().enumerate() {
379 if count > 0 {
380 let p = count as f32 / total_count as f32;
381 let bin_value = i as f32 / 255.0;
382 let quantized = if params.scale > 0.0 {
383 (bin_value / params.scale)
384 .round()
385 .max(params.qmin as f32)
386 .min(params.qmax as f32)
387 } else {
388 0.0
389 };
390 let dequantized = quantized * params.scale;
391 let q = (dequantized * 255.0).round() as usize;
392 let q_count = if q < histogram.len() { histogram[q] } else { 1 };
393 let q_prob = (q_count as f32 / total_count as f32).max(1e-8);
394 kl_div += p * (p / q_prob).ln();
395 }
396 }
397 kl_div
398 }
399
400 pub fn quantize_tensor(
402 &self,
403 tensor: &ArrayD<f32>,
404 params: &QuantizationParams,
405 ) -> Result<QuantizedTensor> {
406 let quantized_data = QuantizedTensor::quantize_tensor(tensor, params)?;
407 Ok(QuantizedTensor {
408 data: quantized_data,
409 params: params.clone(),
410 shape: tensor.shape().to_vec(),
411 })
412 }
413}
414
415pub struct QuantizationAwareTraining {
417 config: QuantizationConfig,
419 layer_params: HashMap<String, QuantizationParams>,
421 step_count: usize,
423 warmup_steps: usize,
425}
426
427impl QuantizationAwareTraining {
428 pub fn new(config: QuantizationConfig) -> Self {
430 Self {
431 config,
432 layer_params: HashMap::new(),
433 step_count: 0,
434 warmup_steps: 1000,
435 }
436 }
437
438 pub fn set_warmup_steps(&mut self, steps: usize) {
440 self.warmup_steps = steps;
441 }
442
443 pub fn init_layer_params(&mut self, layer_name: &str, tensor: &ArrayD<f32>) -> Result<()> {
445 let params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
446 self.layer_params.insert(layer_name.to_string(), params);
447 Ok(())
448 }
449
450 pub fn fake_quantize(&mut self, layer_name: &str, tensor: &ArrayD<f32>) -> Result<ArrayD<f32>> {
452 self.step_count += 1;
453
454 if self.step_count < self.warmup_steps {
456 return Ok(tensor.clone());
457 }
458
459 let params = self.layer_params.get_mut(layer_name).ok_or_else(|| {
460 Error::InvalidArgument(format!("Layer {} not initialized", layer_name))
461 })?;
462
463 let new_params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
465 let alpha = 0.01_f32;
466 params.scale = params.scale * (1.0 - alpha) + new_params.scale * alpha;
467 if self.config.scheme == QuantizationScheme::Asymmetric {
468 params.zero_point = ((params.zero_point as f32) * (1.0 - alpha)
469 + (new_params.zero_point as f32) * alpha)
470 .round() as i32;
471 }
472
473 let quantized = QuantizedTensor::quantize_tensor(tensor, params)?;
475 let dequantized = quantized.mapv(|q| (q as f32 - params.zero_point as f32) * params.scale);
476 Ok(dequantized)
477 }
478
479 pub fn get_quantization_params(&self) -> &HashMap<String, QuantizationParams> {
481 &self.layer_params
482 }
483
484 pub fn add_quantization_noise(&self, tensor: &ArrayD<f32>, noise_scale: f32) -> ArrayD<f32> {
486 let mut rng = scirs2_core::random::rng();
487 tensor.mapv(|x| {
488 let noise: f32 = scirs2_core::random::RngExt::random::<f32>(&mut rng) - 0.5;
489 x + noise * noise_scale
490 })
491 }
492}
493
494pub struct MixedBitWidthQuantizer {
496 layer_configs: HashMap<String, QuantizationConfig>,
498 sensitivity_scores: HashMap<String, f32>,
500}
501
502impl Default for MixedBitWidthQuantizer {
503 fn default() -> Self {
504 Self::new()
505 }
506}
507
508impl MixedBitWidthQuantizer {
509 pub fn new() -> Self {
511 Self {
512 layer_configs: HashMap::new(),
513 sensitivity_scores: HashMap::new(),
514 }
515 }
516
517 pub fn set_layer_config(&mut self, layer_name: &str, config: QuantizationConfig) {
519 self.layer_configs.insert(layer_name.to_string(), config);
520 }
521
522 pub fn analyze_sensitivity(
524 &mut self,
525 layer_outputs: &HashMap<String, ArrayD<f32>>,
526 ) -> Result<()> {
527 for (layer_name, output) in layer_outputs {
528 let variance = self.compute_variance(output);
529 let entropy = self.compute_entropy(output);
530 let gradient_norm = self.compute_gradient_norm(output);
531 let sensitivity = variance * 0.4 + entropy * 0.3 + gradient_norm * 0.3;
532 self.sensitivity_scores
533 .insert(layer_name.clone(), sensitivity);
534 }
535 self.assign_bit_widths()?;
536 Ok(())
537 }
538
539 fn compute_variance(&self, tensor: &ArrayD<f32>) -> f32 {
541 let mean = tensor.mean_or(0.0);
542 let count = tensor.len() as f32;
543 if count > 0.0 {
544 tensor.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / count
545 } else {
546 0.0
547 }
548 }
549
550 fn compute_entropy(&self, tensor: &ArrayD<f32>) -> f32 {
552 let mut histogram = vec![0u32; 256];
553 let min_val = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
554 let max_val = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
555 let range = max_val - min_val;
556 if range == 0.0 {
557 return 0.0;
558 }
559 for &val in tensor.iter() {
560 let bin = ((val - min_val) / range * 255.0).round() as usize;
561 let bin = bin.min(255);
562 histogram[bin] += 1;
563 }
564 let total = tensor.len() as f32;
565 let mut entropy = 0.0;
566 for count in histogram {
567 if count > 0 {
568 let p = count as f32 / total;
569 entropy -= p * p.ln();
570 }
571 }
572 entropy
573 }
574
575 fn compute_gradient_norm(&self, tensor: &ArrayD<f32>) -> f32 {
577 let mut grad_norm = 0.0f32;
578 for axis in 0..tensor.ndim() {
579 if tensor.shape()[axis] > 1 {
580 for _i in 0..tensor.shape()[axis] - 1 {
581 grad_norm += 1.0;
582 }
583 }
584 }
585 let len = tensor.len() as f32;
586 if len > 0.0 {
587 grad_norm / len
588 } else {
589 0.0
590 }
591 }
592
593 fn assign_bit_widths(&mut self) -> Result<()> {
595 let mut scores: Vec<(String, f32)> = self
596 .sensitivity_scores
597 .iter()
598 .map(|(name, &score)| (name.clone(), score))
599 .collect();
600
601 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
602
603 for (i, (layer_name, _)) in scores.iter().enumerate() {
604 let bits = if i < scores.len() / 3 {
605 8 } else if i < 2 * scores.len() / 3 {
607 6 } else {
609 4 };
611 let mut config = self
612 .layer_configs
613 .get(layer_name)
614 .cloned()
615 .unwrap_or_default();
616 config.bits = bits;
617 self.layer_configs.insert(layer_name.clone(), config);
618 }
619 Ok(())
620 }
621
622 pub fn get_layer_config(&self, layer_name: &str) -> Option<&QuantizationConfig> {
624 self.layer_configs.get(layer_name)
625 }
626
627 pub fn get_sensitivity_score(&self, layer_name: &str) -> Option<f32> {
629 self.sensitivity_scores.get(layer_name).copied()
630 }
631}
632
633pub struct DynamicQuantizer {
635 config: QuantizationConfig,
637 params_cache: HashMap<String, QuantizationParams>,
639 cache_size_limit: usize,
641}
642
643impl DynamicQuantizer {
644 pub fn new(config: QuantizationConfig) -> Self {
646 Self {
647 config,
648 params_cache: HashMap::new(),
649 cache_size_limit: 100,
650 }
651 }
652
653 pub fn quantize(
655 &mut self,
656 tensor: &ArrayD<f32>,
657 cache_key: Option<&str>,
658 ) -> Result<QuantizedTensor> {
659 let params = if let Some(key) = cache_key {
660 if let Some(cached_params) = self.params_cache.get(key) {
661 cached_params.clone()
662 } else {
663 let params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
664 self.cache_params(key, params.clone());
665 params
666 }
667 } else {
668 QuantizationParams::from_tensor(&tensor.view(), &self.config)?
669 };
670
671 let quantized_data = QuantizedTensor::quantize_tensor(tensor, ¶ms)?;
672 Ok(QuantizedTensor {
673 data: quantized_data,
674 params,
675 shape: tensor.shape().to_vec(),
676 })
677 }
678
679 fn cache_params(&mut self, key: &str, params: QuantizationParams) {
681 if self.params_cache.len() >= self.cache_size_limit {
682 if let Some(first_key) = self.params_cache.keys().next().cloned() {
683 self.params_cache.remove(&first_key);
684 }
685 }
686 self.params_cache.insert(key.to_string(), params);
687 }
688
689 pub fn clear_cache(&mut self) {
691 self.params_cache.clear();
692 }
693
694 pub fn cache_stats(&self) -> (usize, usize) {
696 (self.params_cache.len(), self.cache_size_limit)
697 }
698}
699
700pub mod utils {
702 use super::*;
703
704 pub fn compute_quantization_error(original: &ArrayD<f32>, quantized: &QuantizedTensor) -> f32 {
706 let dequantized = quantized.dequantize();
707 let mse = Zip::from(original)
708 .and(&dequantized)
709 .fold(0.0, |acc, &orig, &deq| acc + (orig - deq).powi(2));
710 let len = original.len() as f32;
711 if len > 0.0 {
712 mse / len
713 } else {
714 0.0
715 }
716 }
717
718 pub fn estimate_size_reduction(bit_width: u8) -> f32 {
720 if bit_width > 0 {
721 32.0 / bit_width as f32
722 } else {
723 1.0
724 }
725 }
726
727 pub fn estimate_performance_gain(bit_width: u8) -> f32 {
729 match bit_width {
730 8 => 2.0,
731 4 => 4.0,
732 1 => 16.0,
733 _ => 1.0,
734 }
735 }
736
737 pub fn convert_quantization_scheme(
739 tensor: &QuantizedTensor,
740 target_scheme: QuantizationScheme,
741 target_bits: u8,
742 ) -> Result<QuantizedTensor> {
743 let float_tensor = tensor.dequantize();
744 let config = QuantizationConfig {
745 scheme: target_scheme,
746 bits: target_bits,
747 ..Default::default()
748 };
749 QuantizedTensor::from_float(&float_tensor, &config)
750 }
751}
752
753pub fn quantize_model(
758 parameters: &HashMap<String, ArrayD<f32>>,
759 config: &QuantizationConfig,
760) -> Result<HashMap<String, QuantizedTensor>> {
761 let mut quantized_params = HashMap::new();
762 for (name, tensor) in parameters {
763 let quantized = QuantizedTensor::from_float(tensor, config)?;
764 quantized_params.insert(name.clone(), quantized);
765 }
766 Ok(quantized_params)
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772 use scirs2_core::ndarray::{array, Array2};
773
774 fn random_f32_array(rows: usize, cols: usize) -> Array2<f32> {
776 let mut rng = scirs2_core::random::rng();
777 let data: Vec<f32> = (0..rows * cols)
778 .map(|_| scirs2_core::random::RngExt::random_range(&mut rng, -1.0f32..1.0f32))
779 .collect();
780 Array2::from_shape_vec((rows, cols), data).expect("Test: array creation")
781 }
782
783 #[test]
784 fn test_quantization_config_default() {
785 let config = QuantizationConfig::default();
786 assert_eq!(config.bits, 8);
787 assert!(config.signed);
788 assert_eq!(config.scheme, QuantizationScheme::Symmetric);
789 }
790
791 #[test]
792 fn test_quantization_params_creation() {
793 let params = QuantizationParams::new(8, true);
794 assert_eq!(params.bits, 8);
795 assert_eq!(params.qmin, -128);
796 assert_eq!(params.qmax, 127);
797
798 let unsigned = QuantizationParams::new(8, false);
799 assert_eq!(unsigned.qmin, 0);
800 assert_eq!(unsigned.qmax, 255);
801 }
802
803 #[test]
804 fn test_symmetric_quantization() {
805 let config = QuantizationConfig::default();
806 let tensor = array![[1.0_f32, -1.0], [2.0, -2.0]].into_dyn();
807 let quantized = QuantizedTensor::from_float(&tensor, &config).expect("Test: quantization");
808 let _dequantized = quantized.dequantize();
809 let error = utils::compute_quantization_error(&tensor, &quantized);
810 assert!(error < 0.1);
811 }
812
813 #[test]
814 fn test_asymmetric_quantization() {
815 let config = QuantizationConfig {
816 scheme: QuantizationScheme::Asymmetric,
817 ..Default::default()
818 };
819 let tensor = array![[0.0_f32, 1.0], [2.0, 3.0]].into_dyn();
820 let quantized = QuantizedTensor::from_float(&tensor, &config).expect("Test: quantization");
821 assert!(quantized.params.zero_point != 0);
822 let error = utils::compute_quantization_error(&tensor, &quantized);
823 assert!(error < 0.1);
824 }
825
826 #[test]
827 fn test_post_training_quantization() {
828 let mut ptq = PostTrainingQuantizer::new(QuantizationConfig::default());
829 let calib_data = random_f32_array(100, 50).into_dyn();
830 ptq.add_calibration_data("layer1", &calib_data);
831 let params = ptq.finalize_calibration().expect("Test: calibration");
832 assert!(params.contains_key("layer1"));
833 }
834
835 #[test]
836 fn test_quantization_aware_training() {
837 let mut qat = QuantizationAwareTraining::new(QuantizationConfig::default());
838 let tensor = Array2::<f32>::ones((10, 10)).into_dyn();
839 qat.init_layer_params("layer1", &tensor)
840 .expect("Test: init params");
841 let fake_quantized = qat
842 .fake_quantize("layer1", &tensor)
843 .expect("Test: fake quantize");
844 assert_eq!(fake_quantized.shape(), tensor.shape());
845 }
846
847 #[test]
848 fn test_mixed_bitwidth_quantization() {
849 let mut mbq = MixedBitWidthQuantizer::new();
850 let mut outputs = HashMap::new();
851 outputs.insert("layer1".to_string(), random_f32_array(50, 50).into_dyn());
852 outputs.insert(
853 "layer2".to_string(),
854 Array2::<f32>::ones((50, 50)).into_dyn(),
855 );
856 mbq.analyze_sensitivity(&outputs)
857 .expect("Test: sensitivity analysis");
858 assert!(mbq.get_sensitivity_score("layer1").is_some());
859 assert!(mbq.get_layer_config("layer1").is_some());
860 }
861
862 #[test]
863 fn test_dynamic_quantization() {
864 let mut dq = DynamicQuantizer::new(QuantizationConfig::default());
865 let tensor = random_f32_array(20, 20).into_dyn();
866 let quantized = dq
867 .quantize(&tensor, Some("test_key"))
868 .expect("Test: dynamic quantize");
869 assert_eq!(quantized.shape, tensor.shape().to_vec());
870 let (cache_size, _) = dq.cache_stats();
871 assert_eq!(cache_size, 1);
872 }
873
874 #[test]
875 fn test_quantization_utilities() {
876 let original = random_f32_array(10, 10).into_dyn();
877 let quantized = QuantizedTensor::from_float(&original, &QuantizationConfig::default())
878 .expect("Test: quantization");
879 let error = utils::compute_quantization_error(&original, &quantized);
880 assert!(error >= 0.0);
881 let size_reduction = utils::estimate_size_reduction(8);
882 assert_eq!(size_reduction, 4.0);
883 let perf_gain = utils::estimate_performance_gain(8);
884 assert_eq!(perf_gain, 2.0);
885 }
886
887 #[test]
888 fn test_compression_ratio() {
889 let tensor = Array2::<f32>::ones((100, 100)).into_dyn();
890 let quantized = QuantizedTensor::from_float(&tensor, &QuantizationConfig::default())
891 .expect("Test: quantization");
892 let ratio = quantized.compression_ratio();
893 assert!(ratio > 1.0);
894 }
895
896 #[test]
897 fn test_power_of_two_quantization() {
898 let config = QuantizationConfig {
899 scheme: QuantizationScheme::PowerOfTwo,
900 ..Default::default()
901 };
902 let tensor = random_f32_array(10, 10).into_dyn();
903 let quantized = QuantizedTensor::from_float(&tensor, &config).expect("Test: quantization");
904 let scale_log2 = quantized.params.scale.log2();
905 assert!((scale_log2.round() - scale_log2).abs() < 1e-6);
906 }
907
908 #[test]
909 fn test_quantization_scheme_conversion() {
910 let config = QuantizationConfig::default();
911 let tensor = random_f32_array(10, 10).into_dyn();
912 let quantized = QuantizedTensor::from_float(&tensor, &config).expect("Test: quantization");
913 let converted =
914 utils::convert_quantization_scheme(&quantized, QuantizationScheme::Asymmetric, 4)
915 .expect("Test: scheme conversion");
916 assert_eq!(converted.params.bits, 4);
917 }
918
919 #[test]
920 fn test_quantize_model_fn() {
921 let mut parameters = HashMap::new();
922 parameters.insert("weight".to_string(), random_f32_array(5, 5).into_dyn());
923 parameters.insert("bias".to_string(), Array2::<f32>::zeros((1, 5)).into_dyn());
924 let config = QuantizationConfig::default();
925 let quantized = quantize_model(¶meters, &config);
926 assert!(quantized.is_ok());
927 let qmap = quantized.expect("Test: quantize_model");
928 assert_eq!(qmap.len(), 2);
929 assert!(qmap.contains_key("weight"));
930 assert!(qmap.contains_key("bias"));
931 }
932
933 #[test]
934 fn test_dequantize_roundtrip() {
935 let config = QuantizationConfig::default();
936 let tensor = array![[1.0_f32, 2.0], [3.0, 4.0]].into_dyn();
937 let quantized = QuantizedTensor::from_float(&tensor, &config).expect("Test: quantization");
938 let dequantized = quantized.dequantize();
939 for (orig, deq) in tensor.iter().zip(dequantized.iter()) {
941 assert!((orig - deq).abs() < 0.5, "orig={}, deq={}", orig, deq);
942 }
943 }
944}