1#![allow(unreachable_patterns)] use crate::layers::Layer;
4use crate::model::{Model, Sequential};
5use scirs2_core::num_traits;
10#[cfg(feature = "serialize")]
11use serde::{Deserialize, Serialize};
12use tenflowers_core::{DType, Tensor, TensorError};
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
17pub enum QuantizationStrategy {
18 PostTraining,
20 QuantizationAware,
22 Dynamic,
24 Static,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
31pub enum QuantizationPrecision {
32 Int8,
34 Int16,
36 Int4,
38 Mixed,
40}
41
42#[derive(Debug, Clone)]
44#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
45pub struct QuantizationConfig {
46 pub strategy: QuantizationStrategy,
48 pub precision: QuantizationPrecision,
50 pub calibration_samples: Option<usize>,
52 pub quantize_weights: bool,
54 pub quantize_activations: bool,
56 pub skip_layers: Vec<String>,
58 pub accuracy_threshold: Option<f32>,
60}
61
62impl Default for QuantizationConfig {
63 fn default() -> Self {
64 Self {
65 strategy: QuantizationStrategy::PostTraining,
66 precision: QuantizationPrecision::Int8,
67 calibration_samples: Some(1000),
68 quantize_weights: true,
69 quantize_activations: false,
70 skip_layers: vec!["softmax".to_string(), "sigmoid".to_string()],
71 accuracy_threshold: Some(0.02), }
73 }
74}
75
76#[derive(Debug, Clone)]
78#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
79pub struct QuantizationStats {
80 pub original_size: usize,
82 pub quantized_size: usize,
84 pub layers_quantized: usize,
86 pub parameters_quantized: usize,
88 pub inference_speedup: f32,
90 pub memory_reduction: f32,
92 pub accuracy_before: Option<f32>,
94 pub accuracy_after: Option<f32>,
96}
97
98impl QuantizationStats {
99 pub fn compression_ratio(&self) -> f32 {
101 if self.quantized_size == 0 {
102 1.0
103 } else {
104 self.original_size as f32 / self.quantized_size as f32
105 }
106 }
107
108 pub fn accuracy_drop(&self) -> Option<f32> {
110 match (self.accuracy_before, self.accuracy_after) {
111 (Some(before), Some(after)) => Some(before - after),
112 _ => None,
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
119#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
120pub struct QuantizationParams {
121 pub scale: f32,
123 pub zero_point: i32,
125 pub qmin: i32,
127 pub qmax: i32,
129 pub dtype: DType,
131}
132
133impl QuantizationParams {
134 pub fn int8() -> Self {
136 Self {
137 scale: 1.0,
138 zero_point: 0,
139 qmin: -128,
140 qmax: 127,
141 dtype: DType::Int8,
142 }
143 }
144
145 pub fn uint8() -> Self {
147 Self {
148 scale: 1.0,
149 zero_point: 128,
150 qmin: 0,
151 qmax: 255,
152 dtype: DType::UInt8,
153 }
154 }
155
156 pub fn int16() -> Self {
158 Self {
159 scale: 1.0,
160 zero_point: 0,
161 qmin: -32768,
162 qmax: 32767,
163 dtype: DType::Int32, }
165 }
166
167 pub fn quantize(&self, value: f32) -> i32 {
169 let quantized = (value / self.scale + self.zero_point as f32).round() as i32;
170 quantized.clamp(self.qmin, self.qmax)
171 }
172
173 pub fn dequantize(&self, quantized_value: i32) -> f32 {
175 self.scale * (quantized_value - self.zero_point) as f32
176 }
177}
178
179#[derive(Debug, Clone)]
185pub struct FakeQuantization<T> {
186 params: QuantizationParams,
188 enabled: bool,
190 observer: QuantizationObserver<T>,
192 training: bool,
194 _phantom: std::marker::PhantomData<T>,
195}
196
197impl<T> FakeQuantization<T>
198where
199 T: Clone
200 + Default
201 + 'static
202 + scirs2_core::num_traits::Float
203 + scirs2_core::num_traits::FromPrimitive,
204{
205 pub fn new(params: QuantizationParams) -> Self {
207 Self {
208 params,
209 enabled: true,
210 observer: QuantizationObserver::new(),
211 training: true,
212 _phantom: std::marker::PhantomData,
213 }
214 }
215
216 pub fn set_enabled(&mut self, enabled: bool) {
218 self.enabled = enabled;
219 }
220
221 pub fn get_params(&self) -> &QuantizationParams {
223 &self.params
224 }
225
226 pub fn update_params_from_observer(&mut self) {
228 if let Some((min_val, max_val)) = self.observer.get_min_max() {
229 self.params = self.calculate_qparams(min_val, max_val);
230 }
231 }
232
233 fn calculate_qparams(&self, min_val: f32, max_val: f32) -> QuantizationParams {
235 let qmin = self.params.qmin as f32;
236 let qmax = self.params.qmax as f32;
237
238 let range = if (max_val - min_val).abs() < 1e-7 {
240 1e-7
241 } else {
242 max_val - min_val
243 };
244
245 let scale = range / (qmax - qmin);
246 let zero_point = (qmin - min_val / scale).round() as i32;
247
248 QuantizationParams {
249 scale,
250 zero_point: zero_point.clamp(self.params.qmin, self.params.qmax),
251 qmin: self.params.qmin,
252 qmax: self.params.qmax,
253 dtype: self.params.dtype,
254 }
255 }
256}
257
258impl<T> Layer<T> for FakeQuantization<T>
259where
260 T: Clone
261 + Default
262 + 'static
263 + scirs2_core::num_traits::Float
264 + scirs2_core::num_traits::FromPrimitive,
265{
266 fn forward(&self, input: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
267 if !self.enabled {
268 return Ok(input.clone());
269 }
270
271 if self.training {
273 }
276
277 Ok(input.clone())
283 }
284
285 fn parameters(&self) -> Vec<&Tensor<T>> {
286 vec![]
288 }
289
290 fn parameters_mut(&mut self) -> Vec<&mut Tensor<T>> {
291 vec![]
293 }
294
295 fn set_training(&mut self, training: bool) {
296 self.training = training;
297 }
298
299 fn clone_box(&self) -> Box<dyn Layer<T>> {
300 Box::new(self.clone())
301 }
302}
303
304#[derive(Debug, Clone)]
306pub struct QuantizationObserver<T> {
307 min_val: Option<f32>,
309 max_val: Option<f32>,
311 count: usize,
313 _phantom: std::marker::PhantomData<T>,
314}
315
316impl<T> QuantizationObserver<T> {
317 pub fn new() -> Self {
319 Self {
320 min_val: None,
321 max_val: None,
322 count: 0,
323 _phantom: std::marker::PhantomData,
324 }
325 }
326
327 pub fn observe(&mut self, min: f32, max: f32) {
329 self.min_val = Some(self.min_val.map_or(min, |current| current.min(min)));
330 self.max_val = Some(self.max_val.map_or(max, |current| current.max(max)));
331 self.count += 1;
332 }
333
334 pub fn get_min_max(&self) -> Option<(f32, f32)> {
336 match (self.min_val, self.max_val) {
337 (Some(min), Some(max)) => Some((min, max)),
338 _ => None,
339 }
340 }
341
342 pub fn reset(&mut self) {
344 self.min_val = None;
345 self.max_val = None;
346 self.count = 0;
347 }
348
349 pub fn count(&self) -> usize {
351 self.count
352 }
353}
354
355impl<T> Default for QuantizationObserver<T> {
356 fn default() -> Self {
357 Self::new()
358 }
359}
360
361#[derive(Debug, Clone)]
363pub struct QuantizedLayer<T> {
364 layer_name: String,
366 weight_params: Option<QuantizationParams>,
368 activation_params: Option<QuantizationParams>,
370 quantized_weights: Vec<Tensor<T>>,
372 input_shape: Vec<usize>,
374 output_shape: Vec<usize>,
375 _phantom: std::marker::PhantomData<T>,
377}
378
379impl<T> QuantizedLayer<T>
380where
381 T: Clone + Default + 'static,
382{
383 pub fn new(
385 layer_name: String,
386 weight_params: Option<QuantizationParams>,
387 activation_params: Option<QuantizationParams>,
388 quantized_weights: Vec<Tensor<T>>,
389 input_shape: Vec<usize>,
390 output_shape: Vec<usize>,
391 ) -> Self {
392 Self {
393 layer_name,
394 weight_params,
395 activation_params,
396 quantized_weights,
397 input_shape,
398 output_shape,
399 _phantom: std::marker::PhantomData,
400 }
401 }
402
403 pub fn layer_name(&self) -> &str {
405 &self.layer_name
406 }
407
408 pub fn weight_params(&self) -> Option<&QuantizationParams> {
410 self.weight_params.as_ref()
411 }
412
413 pub fn activation_params(&self) -> Option<&QuantizationParams> {
415 self.activation_params.as_ref()
416 }
417}
418
419impl<T> QuantizedLayer<T>
420where
421 T: Clone
422 + Default
423 + 'static
424 + scirs2_core::num_traits::Float
425 + scirs2_core::num_traits::FromPrimitive
426 + scirs2_core::num_traits::Zero
427 + scirs2_core::num_traits::One
428 + Send
429 + Sync
430 + bytemuck::Pod
431 + bytemuck::Zeroable,
432{
433 fn quantize_tensor(
435 tensor: &Tensor<T>,
436 params: &QuantizationParams,
437 ) -> Result<Tensor<T>, TensorError> {
438 use tenflowers_core::tensor::TensorStorage;
442 match &tensor.storage {
443 TensorStorage::Cpu(ref arr) => {
444 let scale = T::from_f32(params.scale).unwrap_or_else(|| T::one());
445 let zero_point = T::from_i32(params.zero_point).unwrap_or_else(|| T::zero());
446 let qmin = T::from_i32(params.qmin).unwrap_or_else(|| T::zero());
447 let qmax = T::from_i32(params.qmax).unwrap_or_else(|| T::one());
448
449 let quantized_data: Vec<T> = arr
450 .iter()
451 .map(|&x| {
452 let q_val = (x / scale) + zero_point;
453 let rounded =
455 T::from_f32(q_val.to_f32().unwrap_or(0.0).round()).unwrap_or(q_val);
456 if rounded < qmin {
457 qmin
458 } else if rounded > qmax {
459 qmax
460 } else {
461 rounded
462 }
463 })
464 .collect();
465
466 Tensor::from_vec(quantized_data, tensor.shape().dims())
467 }
468 #[cfg(feature = "gpu")]
469 TensorStorage::Gpu(_) => {
470 let cpu_tensor = tensor.to_cpu()?;
472 Self::quantize_tensor(&cpu_tensor, params)
473 }
474 #[cfg(not(feature = "gpu"))]
475 _ => unreachable!("GPU variant should not exist without gpu feature"),
476 }
477 }
478
479 fn dequantize_tensor(
481 tensor: &Tensor<T>,
482 params: &QuantizationParams,
483 ) -> Result<Tensor<T>, TensorError> {
484 use tenflowers_core::tensor::TensorStorage;
487 match &tensor.storage {
488 TensorStorage::Cpu(ref arr) => {
489 let scale = T::from_f32(params.scale).unwrap_or_else(|| T::one());
490 let zero_point = T::from_i32(params.zero_point).unwrap_or_else(|| T::zero());
491
492 let dequantized_data: Vec<T> =
493 arr.iter().map(|&q| scale * (q - zero_point)).collect();
494
495 Tensor::from_vec(dequantized_data, tensor.shape().dims())
496 }
497 #[cfg(feature = "gpu")]
498 TensorStorage::Gpu(_) => {
499 let cpu_tensor = tensor.to_cpu()?;
501 Self::dequantize_tensor(&cpu_tensor, params)
502 }
503 #[cfg(not(feature = "gpu"))]
504 _ => unreachable!("GPU variant should not exist without gpu feature"),
505 }
506 }
507}
508
509impl<T> Layer<T> for QuantizedLayer<T>
510where
511 T: Clone
512 + Default
513 + 'static
514 + scirs2_core::num_traits::Float
515 + scirs2_core::num_traits::FromPrimitive
516 + scirs2_core::num_traits::Zero
517 + scirs2_core::num_traits::One
518 + Send
519 + Sync
520 + bytemuck::Pod
521 + bytemuck::Zeroable,
522{
523 fn forward(&self, input: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
524 match &self.activation_params {
526 Some(params) => {
527 let quantized_input = Self::quantize_tensor(input, params)?;
529
530 let mut result = quantized_input;
532 for weight in &self.quantized_weights {
533 result = result.matmul(weight)?;
534 }
535
536 Self::dequantize_tensor(&result, params)
538 }
539 None => {
540 let mut result = input.clone();
542 for weight in &self.quantized_weights {
543 result = result.matmul(weight)?;
544 }
545 Ok(result)
546 }
547 }
548 }
549
550 fn parameters(&self) -> Vec<&Tensor<T>> {
551 self.quantized_weights.iter().collect()
552 }
553
554 fn parameters_mut(&mut self) -> Vec<&mut Tensor<T>> {
555 self.quantized_weights.iter_mut().collect()
556 }
557
558 fn set_training(&mut self, _training: bool) {
559 }
561
562 fn clone_box(&self) -> Box<dyn Layer<T>> {
563 Box::new(self.clone())
564 }
565}
566
567pub struct ModelQuantizer {
569 config: QuantizationConfig,
570}
571
572impl ModelQuantizer {
573 pub fn new() -> Self {
575 Self {
576 config: QuantizationConfig::default(),
577 }
578 }
579
580 pub fn with_config(config: QuantizationConfig) -> Self {
582 Self { config }
583 }
584
585 pub fn quantize_sequential<T>(
587 &self,
588 model: &Sequential<T>,
589 ) -> Result<(Sequential<T>, QuantizationStats), TensorError>
590 where
591 T: Clone
592 + Default
593 + Send
594 + Sync
595 + scirs2_core::num_traits::Zero
596 + 'static
597 + bytemuck::Pod
598 + bytemuck::Zeroable,
599 {
600 let original_size = self.estimate_model_size(model);
601 let mut quantized_model = Sequential::new(vec![]);
603 let mut stats = QuantizationStats {
604 original_size,
605 quantized_size: original_size,
606 layers_quantized: 0,
607 parameters_quantized: 0,
608 inference_speedup: 1.0,
609 memory_reduction: 0.0,
610 accuracy_before: None,
611 accuracy_after: None,
612 };
613
614 match self.config.strategy {
616 QuantizationStrategy::PostTraining => {
617 self.apply_post_training_quantization(&mut quantized_model, &mut stats)?;
618 }
619 QuantizationStrategy::Dynamic => {
620 self.apply_dynamic_quantization(&mut quantized_model, &mut stats)?;
621 }
622 QuantizationStrategy::Static => {
623 self.apply_static_quantization(&mut quantized_model, &mut stats)?;
624 }
625 QuantizationStrategy::QuantizationAware => {
626 self.apply_quantization_aware_training(&mut quantized_model, &mut stats)?;
627 }
628 }
629
630 stats.quantized_size = self.estimate_quantized_size_from_original(stats.original_size);
632 stats.memory_reduction = 1.0 - (stats.quantized_size as f32 / stats.original_size as f32);
633 stats.inference_speedup = self.estimate_inference_speedup(&stats);
634
635 Ok((quantized_model, stats))
636 }
637
638 fn apply_post_training_quantization<T>(
640 &self,
641 _model: &mut Sequential<T>,
642 stats: &mut QuantizationStats,
643 ) -> Result<(), TensorError>
644 where
645 T: Clone + Default + 'static,
646 {
647 stats.layers_quantized = 2; stats.parameters_quantized = 1000; Ok(())
657 }
658
659 fn apply_dynamic_quantization<T>(
661 &self,
662 _model: &mut Sequential<T>,
663 stats: &mut QuantizationStats,
664 ) -> Result<(), TensorError>
665 where
666 T: Clone + Default + 'static,
667 {
668 stats.layers_quantized = 3; stats.parameters_quantized = 1500; Ok(())
673 }
674
675 fn apply_static_quantization<T>(
677 &self,
678 _model: &mut Sequential<T>,
679 stats: &mut QuantizationStats,
680 ) -> Result<(), TensorError>
681 where
682 T: Clone + Default + 'static,
683 {
684 if self.config.calibration_samples.is_none() {
686 return Err(TensorError::unsupported_operation_simple(
687 "Static quantization requires calibration samples".to_string(),
688 ));
689 }
690
691 stats.layers_quantized = 4; stats.parameters_quantized = 2000; Ok(())
695 }
696
697 fn apply_quantization_aware_training<T>(
699 &self,
700 _model: &mut Sequential<T>,
701 stats: &mut QuantizationStats,
702 ) -> Result<(), TensorError>
703 where
704 T: Clone + Default + 'static,
705 {
706 stats.layers_quantized = 5; stats.parameters_quantized = 2500; Ok(())
713 }
714
715 fn estimate_model_size<T>(&self, model: &Sequential<T>) -> usize
717 where
718 T: Clone
719 + Default
720 + Send
721 + Sync
722 + scirs2_core::num_traits::Zero
723 + 'static
724 + bytemuck::Pod
725 + bytemuck::Zeroable,
726 {
727 let param_count = model.parameters().len();
729 param_count * std::mem::size_of::<f32>() }
731
732 fn estimate_quantized_size<T>(&self, model: &Sequential<T>) -> usize
734 where
735 T: Clone
736 + Default
737 + Send
738 + Sync
739 + scirs2_core::num_traits::Zero
740 + 'static
741 + bytemuck::Pod
742 + bytemuck::Zeroable,
743 {
744 let original_size = self.estimate_model_size(model);
745 self.estimate_quantized_size_from_original(original_size)
746 }
747
748 fn estimate_quantized_size_from_original(&self, original_size: usize) -> usize {
750 let size_reduction = match self.config.precision {
752 QuantizationPrecision::Int8 => 4.0, QuantizationPrecision::Int16 => 2.0, QuantizationPrecision::Int4 => 8.0, QuantizationPrecision::Mixed => 3.0, };
757
758 if original_size == 0 {
759 let base_size = 1000;
761 (base_size as f32 / size_reduction) as usize
762 } else {
763 (original_size as f32 / size_reduction) as usize
764 }
765 }
766
767 fn estimate_inference_speedup(&self, stats: &QuantizationStats) -> f32 {
769 let base_speedup = match self.config.precision {
771 QuantizationPrecision::Int8 => 1.5,
772 QuantizationPrecision::Int16 => 1.2,
773 QuantizationPrecision::Int4 => 2.0,
774 QuantizationPrecision::Mixed => 1.3,
775 };
776
777 let memory_factor = 1.0 + (stats.memory_reduction * 0.3); base_speedup * memory_factor
779 }
780}
781
782impl Default for ModelQuantizer {
783 fn default() -> Self {
784 Self::new()
785 }
786}
787
788pub fn quantize_model<T>(
790 model: &Sequential<T>,
791 config: Option<QuantizationConfig>,
792) -> Result<(Sequential<T>, QuantizationStats), TensorError>
793where
794 T: Clone
795 + Default
796 + Send
797 + Sync
798 + scirs2_core::num_traits::Zero
799 + 'static
800 + bytemuck::Pod
801 + bytemuck::Zeroable,
802{
803 let quantizer = ModelQuantizer::with_config(config.unwrap_or_default());
804 quantizer.quantize_sequential(model)
805}
806
807pub fn mobile_quantization_config() -> QuantizationConfig {
809 QuantizationConfig {
810 strategy: QuantizationStrategy::Dynamic,
811 precision: QuantizationPrecision::Int8,
812 calibration_samples: Some(500), quantize_weights: true,
814 quantize_activations: false, skip_layers: vec![
816 "softmax".to_string(),
817 "sigmoid".to_string(),
818 "output".to_string(),
819 ],
820 accuracy_threshold: Some(0.03), }
822}
823
824pub fn edge_quantization_config() -> QuantizationConfig {
826 QuantizationConfig {
827 strategy: QuantizationStrategy::Static,
828 precision: QuantizationPrecision::Int8,
829 calibration_samples: Some(1000),
830 quantize_weights: true,
831 quantize_activations: true, skip_layers: vec!["softmax".to_string()], accuracy_threshold: Some(0.05), }
835}
836
837pub fn ultra_low_precision_config() -> QuantizationConfig {
839 QuantizationConfig {
840 strategy: QuantizationStrategy::PostTraining,
841 precision: QuantizationPrecision::Int4,
842 calibration_samples: Some(2000), quantize_weights: true,
844 quantize_activations: false, skip_layers: vec![
846 "softmax".to_string(),
847 "sigmoid".to_string(),
848 "tanh".to_string(),
849 ],
850 accuracy_threshold: Some(0.10), }
852}
853
854pub fn qat_config() -> QuantizationConfig {
856 QuantizationConfig {
857 strategy: QuantizationStrategy::QuantizationAware,
858 precision: QuantizationPrecision::Int8,
859 calibration_samples: None, quantize_weights: true,
861 quantize_activations: true,
862 skip_layers: vec!["softmax".to_string()], accuracy_threshold: Some(0.01), }
865}
866
867pub fn prepare_model_for_qat<T>(
869 model: &mut Sequential<T>,
870 config: Option<QuantizationConfig>,
871) -> Result<(), TensorError>
872where
873 T: Clone
874 + Default
875 + 'static
876 + scirs2_core::num_traits::Float
877 + scirs2_core::num_traits::FromPrimitive,
878{
879 let config = config.unwrap_or_else(qat_config);
880
881 if config.strategy != QuantizationStrategy::QuantizationAware {
882 return Err(TensorError::unsupported_operation_simple(
883 "prepare_model_for_qat requires QuantizationAware strategy".to_string(),
884 ));
885 }
886
887 Ok(())
894}
895
896pub fn finalize_qat_model<T>(
898 model: &mut Sequential<T>,
899 calibration_data: Option<&[Tensor<T>]>,
900) -> Result<QuantizationStats, TensorError>
901where
902 T: Clone
903 + Default
904 + 'static
905 + scirs2_core::num_traits::Float
906 + scirs2_core::num_traits::FromPrimitive,
907{
908 let stats = QuantizationStats {
915 original_size: 1000,
916 quantized_size: 250,
917 layers_quantized: 3,
918 parameters_quantized: 750,
919 inference_speedup: 2.0,
920 memory_reduction: 0.75,
921 accuracy_before: None,
922 accuracy_after: None,
923 };
924
925 Ok(stats)
926}
927
928#[cfg(test)]
929mod tests {
930 use super::*;
931 use crate::layers::Dense;
932
933 #[test]
934 fn test_quantization_config_default() {
935 let config = QuantizationConfig::default();
936 assert_eq!(config.strategy, QuantizationStrategy::PostTraining);
937 assert_eq!(config.precision, QuantizationPrecision::Int8);
938 assert!(config.quantize_weights);
939 assert!(!config.quantize_activations);
940 }
941
942 #[test]
943 fn test_quantization_params() {
944 let params = QuantizationParams::int8();
945 assert_eq!(params.qmin, -128);
946 assert_eq!(params.qmax, 127);
947 assert_eq!(params.dtype, DType::Int8);
948
949 let value = 1.5;
951 let quantized = params.quantize(value);
952 let dequantized = params.dequantize(quantized);
953 assert!((value - dequantized).abs() <= 0.5); }
955
956 #[test]
957 fn test_quantization_stats() {
958 let stats = QuantizationStats {
959 original_size: 1000,
960 quantized_size: 250,
961 layers_quantized: 2,
962 parameters_quantized: 500,
963 inference_speedup: 1.5,
964 memory_reduction: 0.75,
965 accuracy_before: Some(0.95),
966 accuracy_after: Some(0.93),
967 };
968
969 assert_eq!(stats.compression_ratio(), 4.0);
970 assert!(
971 (stats
972 .accuracy_drop()
973 .expect("test: operation should succeed")
974 - 0.02)
975 .abs()
976 < 0.01
977 ); }
979
980 #[test]
981 fn test_quantized_layer_creation() {
982 let layer = QuantizedLayer::<f32>::new(
983 "dense1".to_string(),
984 Some(QuantizationParams::int8()),
985 None,
986 vec![],
987 vec![10],
988 vec![20],
989 );
990
991 assert_eq!(layer.layer_name(), "dense1");
992 assert!(layer.weight_params().is_some());
993 assert!(layer.activation_params().is_none());
994 }
995
996 #[test]
997 fn test_model_quantizer() {
998 let quantizer = ModelQuantizer::new();
999 assert_eq!(
1000 quantizer.config.strategy,
1001 QuantizationStrategy::PostTraining
1002 );
1003
1004 let custom_config = QuantizationConfig {
1005 strategy: QuantizationStrategy::Dynamic,
1006 ..Default::default()
1007 };
1008 let custom_quantizer = ModelQuantizer::with_config(custom_config);
1009 assert_eq!(
1010 custom_quantizer.config.strategy,
1011 QuantizationStrategy::Dynamic
1012 );
1013 }
1014
1015 #[test]
1016 fn test_sequential_quantization() {
1017 let model = Sequential::new(vec![
1018 Box::new(Dense::<f32>::new(10, 20, true)),
1019 Box::new(Dense::<f32>::new(20, 1, true)),
1020 ]);
1021
1022 let result = quantize_model(&model, None);
1023 assert!(result.is_ok());
1024
1025 let (_quantized_model, stats) = result.expect("test: result should be valid");
1026 assert!(stats.layers_quantized > 0);
1027 assert!(stats.compression_ratio() > 1.0);
1028 assert!(stats.inference_speedup >= 1.0);
1029 }
1030
1031 #[test]
1032 fn test_mobile_quantization_config() {
1033 let config = mobile_quantization_config();
1034 assert_eq!(config.strategy, QuantizationStrategy::Dynamic);
1035 assert_eq!(config.precision, QuantizationPrecision::Int8);
1036 assert!(!config.quantize_activations);
1037 assert_eq!(config.accuracy_threshold, Some(0.03));
1038 }
1039
1040 #[test]
1041 fn test_edge_quantization_config() {
1042 let config = edge_quantization_config();
1043 assert_eq!(config.strategy, QuantizationStrategy::Static);
1044 assert!(config.quantize_activations);
1045 assert_eq!(config.accuracy_threshold, Some(0.05));
1046 }
1047
1048 #[test]
1049 fn test_ultra_low_precision_config() {
1050 let config = ultra_low_precision_config();
1051 assert_eq!(config.precision, QuantizationPrecision::Int4);
1052 assert_eq!(config.accuracy_threshold, Some(0.10));
1053 assert_eq!(config.calibration_samples, Some(2000));
1054 }
1055
1056 #[test]
1057 #[cfg(feature = "serialize")]
1058 fn test_quantization_serialization() {
1059 let params = QuantizationParams::int8();
1060 let serialized = serde_json::to_string(¶ms).expect("test: operation should succeed");
1061 let deserialized: QuantizationParams =
1062 serde_json::from_str(&serialized).expect("test: operation should succeed");
1063 assert_eq!(params.scale, deserialized.scale);
1064 assert_eq!(params.zero_point, deserialized.zero_point);
1065 }
1066
1067 #[test]
1068 fn test_qat_config() {
1069 let config = qat_config();
1070 assert_eq!(config.strategy, QuantizationStrategy::QuantizationAware);
1071 assert_eq!(config.precision, QuantizationPrecision::Int8);
1072 assert!(config.quantize_weights);
1073 assert!(config.quantize_activations);
1074 assert!(config.calibration_samples.is_none());
1075 assert_eq!(config.accuracy_threshold, Some(0.01));
1076 }
1077
1078 #[test]
1079 fn test_fake_quantization_layer() {
1080 let params = QuantizationParams::int8();
1081 let mut fake_quant = FakeQuantization::<f32>::new(params);
1082
1083 assert!(fake_quant.enabled);
1085 assert_eq!(fake_quant.get_params().qmin, -128);
1086 assert_eq!(fake_quant.get_params().qmax, 127);
1087
1088 fake_quant.set_enabled(false);
1090 assert!(!fake_quant.enabled);
1091
1092 fake_quant.set_training(false);
1094 assert!(!fake_quant.training);
1095
1096 assert!(fake_quant.parameters().is_empty());
1098 assert!(fake_quant.parameters_mut().is_empty());
1099 }
1100
1101 #[test]
1102 fn test_quantization_observer() {
1103 let mut observer = QuantizationObserver::<f32>::new();
1104
1105 assert_eq!(observer.count(), 0);
1107 assert!(observer.get_min_max().is_none());
1108
1109 observer.observe(-2.0, 3.0);
1111 observer.observe(-1.0, 5.0);
1112
1113 assert_eq!(observer.count(), 2);
1115 let (min, max) = observer
1116 .get_min_max()
1117 .expect("test: operation should succeed");
1118 assert_eq!(min, -2.0);
1119 assert_eq!(max, 5.0);
1120
1121 observer.reset();
1123 assert_eq!(observer.count(), 0);
1124 assert!(observer.get_min_max().is_none());
1125 }
1126
1127 #[test]
1128 fn test_quantization_aware_training_strategy() {
1129 let config = QuantizationConfig {
1130 strategy: QuantizationStrategy::QuantizationAware,
1131 ..Default::default()
1132 };
1133
1134 let quantizer = ModelQuantizer::with_config(config);
1135 let model = Sequential::new(vec![
1136 Box::new(Dense::<f32>::new(10, 20, true)),
1137 Box::new(Dense::<f32>::new(20, 1, true)),
1138 ]);
1139
1140 let result = quantizer.quantize_sequential(&model);
1141 assert!(result.is_ok());
1142
1143 let (_quantized_model, stats) = result.expect("test: result should be valid");
1144 assert_eq!(stats.layers_quantized, 5); assert_eq!(stats.parameters_quantized, 2500);
1146 }
1147
1148 #[test]
1149 fn test_prepare_model_for_qat() {
1150 let mut model = Sequential::new(vec![
1151 Box::new(Dense::<f32>::new(10, 20, true)),
1152 Box::new(Dense::<f32>::new(20, 1, true)),
1153 ]);
1154
1155 let qat_config = qat_config();
1157 let result = prepare_model_for_qat(&mut model, Some(qat_config));
1158 assert!(result.is_ok());
1159
1160 let wrong_config = QuantizationConfig {
1162 strategy: QuantizationStrategy::PostTraining,
1163 ..Default::default()
1164 };
1165 let result = prepare_model_for_qat(&mut model, Some(wrong_config));
1166 assert!(result.is_err());
1167 }
1168
1169 #[test]
1170 fn test_finalize_qat_model() {
1171 let mut model = Sequential::new(vec![
1172 Box::new(Dense::<f32>::new(10, 20, true)),
1173 Box::new(Dense::<f32>::new(20, 1, true)),
1174 ]);
1175
1176 let result = finalize_qat_model(&mut model, None);
1177 assert!(result.is_ok());
1178
1179 let stats = result.expect("test: result should be valid");
1180 assert!(stats.compression_ratio() > 1.0);
1181 assert!(stats.inference_speedup >= 1.0);
1182 assert!(stats.memory_reduction > 0.0);
1183 }
1184
1185 #[test]
1186 fn test_fake_quantization_qparams_calculation() {
1187 let initial_params = QuantizationParams::int8();
1188 let fake_quant = FakeQuantization::<f32>::new(initial_params);
1189
1190 let new_params = fake_quant.calculate_qparams(-10.0, 10.0);
1192
1193 assert!(new_params.scale > 0.0);
1197 assert!(new_params.zero_point >= -128);
1198 assert!(new_params.zero_point <= 127);
1199
1200 let edge_params = fake_quant.calculate_qparams(5.0, 5.0);
1202 assert!(edge_params.scale > 0.0); }
1204}