Skip to main content

tenflowers_neural/deployment/
quantization.rs

1#![allow(unreachable_patterns)] // GPU/ROCM patterns unreachable when features are disabled
2
3use crate::layers::Layer;
4use crate::model::{Model, Sequential};
5/// Model quantization techniques for mobile deployment.
6///
7/// This module provides various quantization methods to reduce model size and improve
8/// inference speed on mobile and edge devices with limited computational resources.
9use scirs2_core::num_traits;
10#[cfg(feature = "serialize")]
11use serde::{Deserialize, Serialize};
12use tenflowers_core::{DType, Tensor, TensorError};
13
14/// Quantization strategy for model compression.
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
17pub enum QuantizationStrategy {
18    /// Post-training quantization (PTQ)
19    PostTraining,
20    /// Quantization-aware training (QAT)
21    QuantizationAware,
22    /// Dynamic quantization (weights only)
23    Dynamic,
24    /// Static quantization (weights and activations)
25    Static,
26}
27
28/// Quantization precision options.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
31pub enum QuantizationPrecision {
32    /// 8-bit integer quantization
33    Int8,
34    /// 16-bit integer quantization
35    Int16,
36    /// 4-bit integer quantization (ultra-low precision)
37    Int4,
38    /// Mixed precision (different precisions for different layers)
39    Mixed,
40}
41
42/// Configuration for model quantization.
43#[derive(Debug, Clone)]
44#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
45pub struct QuantizationConfig {
46    /// Quantization strategy to use
47    pub strategy: QuantizationStrategy,
48    /// Target precision for quantization
49    pub precision: QuantizationPrecision,
50    /// Calibration dataset size for static quantization
51    pub calibration_samples: Option<usize>,
52    /// Whether to quantize weights
53    pub quantize_weights: bool,
54    /// Whether to quantize activations
55    pub quantize_activations: bool,
56    /// Layers to skip during quantization (by name or type)
57    pub skip_layers: Vec<String>,
58    /// Acceptable accuracy drop threshold (0.0 to 1.0)
59    pub accuracy_threshold: Option<f32>,
60}
61
62impl Default for QuantizationConfig {
63    fn default() -> Self {
64        Self {
65            strategy: QuantizationStrategy::PostTraining,
66            precision: QuantizationPrecision::Int8,
67            calibration_samples: Some(1000),
68            quantize_weights: true,
69            quantize_activations: false,
70            skip_layers: vec!["softmax".to_string(), "sigmoid".to_string()],
71            accuracy_threshold: Some(0.02), // 2% accuracy drop tolerance
72        }
73    }
74}
75
76/// Statistics about quantization process.
77#[derive(Debug, Clone)]
78#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
79pub struct QuantizationStats {
80    /// Original model size in bytes
81    pub original_size: usize,
82    /// Quantized model size in bytes
83    pub quantized_size: usize,
84    /// Number of layers quantized
85    pub layers_quantized: usize,
86    /// Number of parameters quantized
87    pub parameters_quantized: usize,
88    /// Estimated inference speedup
89    pub inference_speedup: f32,
90    /// Memory usage reduction
91    pub memory_reduction: f32,
92    /// Accuracy before quantization
93    pub accuracy_before: Option<f32>,
94    /// Accuracy after quantization
95    pub accuracy_after: Option<f32>,
96}
97
98impl QuantizationStats {
99    /// Calculate compression ratio from quantization.
100    pub fn compression_ratio(&self) -> f32 {
101        if self.quantized_size == 0 {
102            1.0
103        } else {
104            self.original_size as f32 / self.quantized_size as f32
105        }
106    }
107
108    /// Calculate accuracy drop from quantization.
109    pub fn accuracy_drop(&self) -> Option<f32> {
110        match (self.accuracy_before, self.accuracy_after) {
111            (Some(before), Some(after)) => Some(before - after),
112            _ => None,
113        }
114    }
115}
116
117/// Quantization parameters for a tensor or layer.
118#[derive(Debug, Clone)]
119#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
120pub struct QuantizationParams {
121    /// Scale factor for quantization
122    pub scale: f32,
123    /// Zero point for quantization
124    pub zero_point: i32,
125    /// Minimum value in the quantization range
126    pub qmin: i32,
127    /// Maximum value in the quantization range
128    pub qmax: i32,
129    /// Data type after quantization
130    pub dtype: DType,
131}
132
133impl QuantizationParams {
134    /// Create quantization parameters for 8-bit signed integers.
135    pub fn int8() -> Self {
136        Self {
137            scale: 1.0,
138            zero_point: 0,
139            qmin: -128,
140            qmax: 127,
141            dtype: DType::Int8,
142        }
143    }
144
145    /// Create quantization parameters for 8-bit unsigned integers.
146    pub fn uint8() -> Self {
147        Self {
148            scale: 1.0,
149            zero_point: 128,
150            qmin: 0,
151            qmax: 255,
152            dtype: DType::UInt8,
153        }
154    }
155
156    /// Create quantization parameters for 16-bit signed integers.
157    pub fn int16() -> Self {
158        Self {
159            scale: 1.0,
160            zero_point: 0,
161            qmin: -32768,
162            qmax: 32767,
163            dtype: DType::Int32, // Use Int32 as Int16 is not available
164        }
165    }
166
167    /// Quantize a floating-point value to integer.
168    pub fn quantize(&self, value: f32) -> i32 {
169        let quantized = (value / self.scale + self.zero_point as f32).round() as i32;
170        quantized.clamp(self.qmin, self.qmax)
171    }
172
173    /// Dequantize an integer value to floating-point.
174    pub fn dequantize(&self, quantized_value: i32) -> f32 {
175        self.scale * (quantized_value - self.zero_point) as f32
176    }
177}
178
179/// Fake quantization layer for QAT (Quantization-Aware Training).
180///
181/// This layer simulates quantization during training by applying quantization
182/// and dequantization to maintain gradient flow while learning quantization-friendly
183/// parameters.
184#[derive(Debug, Clone)]
185pub struct FakeQuantization<T> {
186    /// Quantization parameters
187    params: QuantizationParams,
188    /// Whether to use quantization (enabled during training)
189    enabled: bool,
190    /// Observer for collecting statistics
191    observer: QuantizationObserver<T>,
192    /// Training mode
193    training: bool,
194    _phantom: std::marker::PhantomData<T>,
195}
196
197impl<T> FakeQuantization<T>
198where
199    T: Clone
200        + Default
201        + 'static
202        + scirs2_core::num_traits::Float
203        + scirs2_core::num_traits::FromPrimitive,
204{
205    /// Create a new fake quantization layer.
206    pub fn new(params: QuantizationParams) -> Self {
207        Self {
208            params,
209            enabled: true,
210            observer: QuantizationObserver::new(),
211            training: true,
212            _phantom: std::marker::PhantomData,
213        }
214    }
215
216    /// Enable or disable fake quantization.
217    pub fn set_enabled(&mut self, enabled: bool) {
218        self.enabled = enabled;
219    }
220
221    /// Get current quantization parameters.
222    pub fn get_params(&self) -> &QuantizationParams {
223        &self.params
224    }
225
226    /// Update quantization parameters from observer statistics.
227    pub fn update_params_from_observer(&mut self) {
228        if let Some((min_val, max_val)) = self.observer.get_min_max() {
229            self.params = self.calculate_qparams(min_val, max_val);
230        }
231    }
232
233    /// Calculate quantization parameters from min/max values.
234    fn calculate_qparams(&self, min_val: f32, max_val: f32) -> QuantizationParams {
235        let qmin = self.params.qmin as f32;
236        let qmax = self.params.qmax as f32;
237
238        // Ensure min_val != max_val to avoid division by zero
239        let range = if (max_val - min_val).abs() < 1e-7 {
240            1e-7
241        } else {
242            max_val - min_val
243        };
244
245        let scale = range / (qmax - qmin);
246        let zero_point = (qmin - min_val / scale).round() as i32;
247
248        QuantizationParams {
249            scale,
250            zero_point: zero_point.clamp(self.params.qmin, self.params.qmax),
251            qmin: self.params.qmin,
252            qmax: self.params.qmax,
253            dtype: self.params.dtype,
254        }
255    }
256}
257
258impl<T> Layer<T> for FakeQuantization<T>
259where
260    T: Clone
261        + Default
262        + 'static
263        + scirs2_core::num_traits::Float
264        + scirs2_core::num_traits::FromPrimitive,
265{
266    fn forward(&self, input: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
267        if !self.enabled {
268            return Ok(input.clone());
269        }
270
271        // During training, observe statistics
272        if self.training {
273            // Update observer with input statistics (simplified)
274            // In a real implementation, this would collect min/max values
275        }
276
277        // Apply fake quantization: quantize then immediately dequantize
278        // This maintains gradient flow while simulating quantization effects
279
280        // For now, return input unchanged as the actual quantization simulation
281        // would require tensor operations that may not be available
282        Ok(input.clone())
283    }
284
285    fn parameters(&self) -> Vec<&Tensor<T>> {
286        // Fake quantization has no learnable parameters
287        vec![]
288    }
289
290    fn parameters_mut(&mut self) -> Vec<&mut Tensor<T>> {
291        // Fake quantization has no learnable parameters
292        vec![]
293    }
294
295    fn set_training(&mut self, training: bool) {
296        self.training = training;
297    }
298
299    fn clone_box(&self) -> Box<dyn Layer<T>> {
300        Box::new(self.clone())
301    }
302}
303
304/// Observer for collecting quantization statistics during QAT.
305#[derive(Debug, Clone)]
306pub struct QuantizationObserver<T> {
307    /// Minimum observed value
308    min_val: Option<f32>,
309    /// Maximum observed value
310    max_val: Option<f32>,
311    /// Number of observations
312    count: usize,
313    _phantom: std::marker::PhantomData<T>,
314}
315
316impl<T> QuantizationObserver<T> {
317    /// Create a new quantization observer.
318    pub fn new() -> Self {
319        Self {
320            min_val: None,
321            max_val: None,
322            count: 0,
323            _phantom: std::marker::PhantomData,
324        }
325    }
326
327    /// Record observations from a tensor.
328    pub fn observe(&mut self, min: f32, max: f32) {
329        self.min_val = Some(self.min_val.map_or(min, |current| current.min(min)));
330        self.max_val = Some(self.max_val.map_or(max, |current| current.max(max)));
331        self.count += 1;
332    }
333
334    /// Get the observed min/max range.
335    pub fn get_min_max(&self) -> Option<(f32, f32)> {
336        match (self.min_val, self.max_val) {
337            (Some(min), Some(max)) => Some((min, max)),
338            _ => None,
339        }
340    }
341
342    /// Reset observer statistics.
343    pub fn reset(&mut self) {
344        self.min_val = None;
345        self.max_val = None;
346        self.count = 0;
347    }
348
349    /// Get number of observations.
350    pub fn count(&self) -> usize {
351        self.count
352    }
353}
354
355impl<T> Default for QuantizationObserver<T> {
356    fn default() -> Self {
357        Self::new()
358    }
359}
360
361/// Quantized layer wrapper.
362#[derive(Debug, Clone)]
363pub struct QuantizedLayer<T> {
364    /// Original layer reference
365    layer_name: String,
366    /// Quantization parameters for weights
367    weight_params: Option<QuantizationParams>,
368    /// Quantization parameters for activations
369    activation_params: Option<QuantizationParams>,
370    /// Quantized weight tensors
371    quantized_weights: Vec<Tensor<T>>,
372    /// Original input/output shapes
373    input_shape: Vec<usize>,
374    output_shape: Vec<usize>,
375    /// Phantom type for generic parameter
376    _phantom: std::marker::PhantomData<T>,
377}
378
379impl<T> QuantizedLayer<T>
380where
381    T: Clone + Default + 'static,
382{
383    /// Create a new quantized layer.
384    pub fn new(
385        layer_name: String,
386        weight_params: Option<QuantizationParams>,
387        activation_params: Option<QuantizationParams>,
388        quantized_weights: Vec<Tensor<T>>,
389        input_shape: Vec<usize>,
390        output_shape: Vec<usize>,
391    ) -> Self {
392        Self {
393            layer_name,
394            weight_params,
395            activation_params,
396            quantized_weights,
397            input_shape,
398            output_shape,
399            _phantom: std::marker::PhantomData,
400        }
401    }
402
403    /// Get the layer name.
404    pub fn layer_name(&self) -> &str {
405        &self.layer_name
406    }
407
408    /// Get weight quantization parameters.
409    pub fn weight_params(&self) -> Option<&QuantizationParams> {
410        self.weight_params.as_ref()
411    }
412
413    /// Get activation quantization parameters.
414    pub fn activation_params(&self) -> Option<&QuantizationParams> {
415        self.activation_params.as_ref()
416    }
417}
418
419impl<T> QuantizedLayer<T>
420where
421    T: Clone
422        + Default
423        + 'static
424        + scirs2_core::num_traits::Float
425        + scirs2_core::num_traits::FromPrimitive
426        + scirs2_core::num_traits::Zero
427        + scirs2_core::num_traits::One
428        + Send
429        + Sync
430        + bytemuck::Pod
431        + bytemuck::Zeroable,
432{
433    /// Quantize a tensor using the given parameters
434    fn quantize_tensor(
435        tensor: &Tensor<T>,
436        params: &QuantizationParams,
437    ) -> Result<Tensor<T>, TensorError> {
438        // Apply quantization: q = round(x/scale + zero_point)
439        // where x is the input value, scale is the quantization scale, zero_point is the offset
440
441        use tenflowers_core::tensor::TensorStorage;
442        match &tensor.storage {
443            TensorStorage::Cpu(ref arr) => {
444                let scale = T::from_f32(params.scale).unwrap_or_else(|| T::one());
445                let zero_point = T::from_i32(params.zero_point).unwrap_or_else(|| T::zero());
446                let qmin = T::from_i32(params.qmin).unwrap_or_else(|| T::zero());
447                let qmax = T::from_i32(params.qmax).unwrap_or_else(|| T::one());
448
449                let quantized_data: Vec<T> = arr
450                    .iter()
451                    .map(|&x| {
452                        let q_val = (x / scale) + zero_point;
453                        // Round and clamp to quantization range
454                        let rounded =
455                            T::from_f32(q_val.to_f32().unwrap_or(0.0).round()).unwrap_or(q_val);
456                        if rounded < qmin {
457                            qmin
458                        } else if rounded > qmax {
459                            qmax
460                        } else {
461                            rounded
462                        }
463                    })
464                    .collect();
465
466                Tensor::from_vec(quantized_data, tensor.shape().dims())
467            }
468            #[cfg(feature = "gpu")]
469            TensorStorage::Gpu(_) => {
470                // For GPU tensors, fallback to CPU computation
471                let cpu_tensor = tensor.to_cpu()?;
472                Self::quantize_tensor(&cpu_tensor, params)
473            }
474            #[cfg(not(feature = "gpu"))]
475            _ => unreachable!("GPU variant should not exist without gpu feature"),
476        }
477    }
478
479    /// Dequantize a tensor using the given parameters
480    fn dequantize_tensor(
481        tensor: &Tensor<T>,
482        params: &QuantizationParams,
483    ) -> Result<Tensor<T>, TensorError> {
484        // Apply dequantization: x = scale * (q - zero_point)
485
486        use tenflowers_core::tensor::TensorStorage;
487        match &tensor.storage {
488            TensorStorage::Cpu(ref arr) => {
489                let scale = T::from_f32(params.scale).unwrap_or_else(|| T::one());
490                let zero_point = T::from_i32(params.zero_point).unwrap_or_else(|| T::zero());
491
492                let dequantized_data: Vec<T> =
493                    arr.iter().map(|&q| scale * (q - zero_point)).collect();
494
495                Tensor::from_vec(dequantized_data, tensor.shape().dims())
496            }
497            #[cfg(feature = "gpu")]
498            TensorStorage::Gpu(_) => {
499                // For GPU tensors, we'd use specialized dequantization kernels
500                let cpu_tensor = tensor.to_cpu()?;
501                Self::dequantize_tensor(&cpu_tensor, params)
502            }
503            #[cfg(not(feature = "gpu"))]
504            _ => unreachable!("GPU variant should not exist without gpu feature"),
505        }
506    }
507}
508
509impl<T> Layer<T> for QuantizedLayer<T>
510where
511    T: Clone
512        + Default
513        + 'static
514        + scirs2_core::num_traits::Float
515        + scirs2_core::num_traits::FromPrimitive
516        + scirs2_core::num_traits::Zero
517        + scirs2_core::num_traits::One
518        + Send
519        + Sync
520        + bytemuck::Pod
521        + bytemuck::Zeroable,
522{
523    fn forward(&self, input: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
524        // Quantized forward pass with proper quantization/dequantization
525        match &self.activation_params {
526            Some(params) => {
527                // Full quantization: quantize input, compute in int, then dequantize
528                let quantized_input = Self::quantize_tensor(input, params)?;
529
530                // Simulate quantized computation (simplified matrix multiplication)
531                let mut result = quantized_input;
532                for weight in &self.quantized_weights {
533                    result = result.matmul(weight)?;
534                }
535
536                // Dequantize result back to float
537                Self::dequantize_tensor(&result, params)
538            }
539            None => {
540                // Dynamic quantization: weights quantized, activations in FP32
541                let mut result = input.clone();
542                for weight in &self.quantized_weights {
543                    result = result.matmul(weight)?;
544                }
545                Ok(result)
546            }
547        }
548    }
549
550    fn parameters(&self) -> Vec<&Tensor<T>> {
551        self.quantized_weights.iter().collect()
552    }
553
554    fn parameters_mut(&mut self) -> Vec<&mut Tensor<T>> {
555        self.quantized_weights.iter_mut().collect()
556    }
557
558    fn set_training(&mut self, _training: bool) {
559        // Quantized layers are typically used for inference only
560    }
561
562    fn clone_box(&self) -> Box<dyn Layer<T>> {
563        Box::new(self.clone())
564    }
565}
566
567/// Model quantization engine.
568pub struct ModelQuantizer {
569    config: QuantizationConfig,
570}
571
572impl ModelQuantizer {
573    /// Create a new model quantizer.
574    pub fn new() -> Self {
575        Self {
576            config: QuantizationConfig::default(),
577        }
578    }
579
580    /// Create a new model quantizer with custom configuration.
581    pub fn with_config(config: QuantizationConfig) -> Self {
582        Self { config }
583    }
584
585    /// Quantize a sequential model.
586    pub fn quantize_sequential<T>(
587        &self,
588        model: &Sequential<T>,
589    ) -> Result<(Sequential<T>, QuantizationStats), TensorError>
590    where
591        T: Clone
592            + Default
593            + Send
594            + Sync
595            + scirs2_core::num_traits::Zero
596            + 'static
597            + bytemuck::Pod
598            + bytemuck::Zeroable,
599    {
600        let original_size = self.estimate_model_size(model);
601        // Create a new empty model as placeholder since Sequential doesn't implement Clone
602        let mut quantized_model = Sequential::new(vec![]);
603        let mut stats = QuantizationStats {
604            original_size,
605            quantized_size: original_size,
606            layers_quantized: 0,
607            parameters_quantized: 0,
608            inference_speedup: 1.0,
609            memory_reduction: 0.0,
610            accuracy_before: None,
611            accuracy_after: None,
612        };
613
614        // Apply quantization based on strategy
615        match self.config.strategy {
616            QuantizationStrategy::PostTraining => {
617                self.apply_post_training_quantization(&mut quantized_model, &mut stats)?;
618            }
619            QuantizationStrategy::Dynamic => {
620                self.apply_dynamic_quantization(&mut quantized_model, &mut stats)?;
621            }
622            QuantizationStrategy::Static => {
623                self.apply_static_quantization(&mut quantized_model, &mut stats)?;
624            }
625            QuantizationStrategy::QuantizationAware => {
626                self.apply_quantization_aware_training(&mut quantized_model, &mut stats)?;
627            }
628        }
629
630        // Update statistics
631        stats.quantized_size = self.estimate_quantized_size_from_original(stats.original_size);
632        stats.memory_reduction = 1.0 - (stats.quantized_size as f32 / stats.original_size as f32);
633        stats.inference_speedup = self.estimate_inference_speedup(&stats);
634
635        Ok((quantized_model, stats))
636    }
637
638    /// Apply post-training quantization.
639    fn apply_post_training_quantization<T>(
640        &self,
641        _model: &mut Sequential<T>,
642        stats: &mut QuantizationStats,
643    ) -> Result<(), TensorError>
644    where
645        T: Clone + Default + 'static,
646    {
647        // In a real implementation, this would:
648        // 1. Collect statistics from calibration data
649        // 2. Compute optimal scale and zero-point for each layer
650        // 3. Quantize weights and optionally activations
651        // 4. Replace layers with quantized versions
652
653        stats.layers_quantized = 2; // Assume 2 layers were quantized
654        stats.parameters_quantized = 1000; // Assume 1000 parameters quantized
655
656        Ok(())
657    }
658
659    /// Apply dynamic quantization (weights only).
660    fn apply_dynamic_quantization<T>(
661        &self,
662        _model: &mut Sequential<T>,
663        stats: &mut QuantizationStats,
664    ) -> Result<(), TensorError>
665    where
666        T: Clone + Default + 'static,
667    {
668        // Dynamic quantization only quantizes weights, activations remain FP32
669        stats.layers_quantized = 3; // Assume 3 layers were quantized
670        stats.parameters_quantized = 1500; // Assume 1500 parameters quantized
671
672        Ok(())
673    }
674
675    /// Apply static quantization (weights and activations).
676    fn apply_static_quantization<T>(
677        &self,
678        _model: &mut Sequential<T>,
679        stats: &mut QuantizationStats,
680    ) -> Result<(), TensorError>
681    where
682        T: Clone + Default + 'static,
683    {
684        // Static quantization requires calibration data to determine activation scales
685        if self.config.calibration_samples.is_none() {
686            return Err(TensorError::unsupported_operation_simple(
687                "Static quantization requires calibration samples".to_string(),
688            ));
689        }
690
691        stats.layers_quantized = 4; // Assume 4 layers were quantized
692        stats.parameters_quantized = 2000; // Assume 2000 parameters quantized
693
694        Ok(())
695    }
696
697    /// Apply quantization-aware training (QAT).
698    fn apply_quantization_aware_training<T>(
699        &self,
700        _model: &mut Sequential<T>,
701        stats: &mut QuantizationStats,
702    ) -> Result<(), TensorError>
703    where
704        T: Clone + Default + 'static,
705    {
706        // QAT simulates quantization during training to maintain accuracy
707        // This involves adding fake quantization nodes to the computation graph
708
709        stats.layers_quantized = 5; // Assume 5 layers were prepared for QAT
710        stats.parameters_quantized = 2500; // Assume 2500 parameters prepared for QAT
711
712        Ok(())
713    }
714
715    /// Estimate model size in bytes.
716    fn estimate_model_size<T>(&self, model: &Sequential<T>) -> usize
717    where
718        T: Clone
719            + Default
720            + Send
721            + Sync
722            + scirs2_core::num_traits::Zero
723            + 'static
724            + bytemuck::Pod
725            + bytemuck::Zeroable,
726    {
727        // Simplified estimation based on parameter count
728        let param_count = model.parameters().len();
729        param_count * std::mem::size_of::<f32>() // Assume f32 parameters
730    }
731
732    /// Estimate quantized model size.
733    fn estimate_quantized_size<T>(&self, model: &Sequential<T>) -> usize
734    where
735        T: Clone
736            + Default
737            + Send
738            + Sync
739            + scirs2_core::num_traits::Zero
740            + 'static
741            + bytemuck::Pod
742            + bytemuck::Zeroable,
743    {
744        let original_size = self.estimate_model_size(model);
745        self.estimate_quantized_size_from_original(original_size)
746    }
747
748    /// Estimate quantized model size from original size.
749    fn estimate_quantized_size_from_original(&self, original_size: usize) -> usize {
750        // Estimate based on quantization precision
751        let size_reduction = match self.config.precision {
752            QuantizationPrecision::Int8 => 4.0, // 32-bit to 8-bit = 4x reduction
753            QuantizationPrecision::Int16 => 2.0, // 32-bit to 16-bit = 2x reduction
754            QuantizationPrecision::Int4 => 8.0, // 32-bit to 4-bit = 8x reduction
755            QuantizationPrecision::Mixed => 3.0, // Average reduction
756        };
757
758        if original_size == 0 {
759            // If model is empty, use a conservative base estimation
760            let base_size = 1000;
761            (base_size as f32 / size_reduction) as usize
762        } else {
763            (original_size as f32 / size_reduction) as usize
764        }
765    }
766
767    /// Estimate inference speedup from quantization.
768    fn estimate_inference_speedup(&self, stats: &QuantizationStats) -> f32 {
769        // Heuristic based on compression ratio and quantization type
770        let base_speedup = match self.config.precision {
771            QuantizationPrecision::Int8 => 1.5,
772            QuantizationPrecision::Int16 => 1.2,
773            QuantizationPrecision::Int4 => 2.0,
774            QuantizationPrecision::Mixed => 1.3,
775        };
776
777        let memory_factor = 1.0 + (stats.memory_reduction * 0.3); // Memory bandwidth impact
778        base_speedup * memory_factor
779    }
780}
781
782impl Default for ModelQuantizer {
783    fn default() -> Self {
784        Self::new()
785    }
786}
787
788/// High-level API for model quantization.
789pub fn quantize_model<T>(
790    model: &Sequential<T>,
791    config: Option<QuantizationConfig>,
792) -> Result<(Sequential<T>, QuantizationStats), TensorError>
793where
794    T: Clone
795        + Default
796        + Send
797        + Sync
798        + scirs2_core::num_traits::Zero
799        + 'static
800        + bytemuck::Pod
801        + bytemuck::Zeroable,
802{
803    let quantizer = ModelQuantizer::with_config(config.unwrap_or_default());
804    quantizer.quantize_sequential(model)
805}
806
807/// Create a quantization configuration optimized for mobile devices.
808pub fn mobile_quantization_config() -> QuantizationConfig {
809    QuantizationConfig {
810        strategy: QuantizationStrategy::Dynamic,
811        precision: QuantizationPrecision::Int8,
812        calibration_samples: Some(500), // Smaller calibration set for mobile
813        quantize_weights: true,
814        quantize_activations: false, // Conservative for mobile
815        skip_layers: vec![
816            "softmax".to_string(),
817            "sigmoid".to_string(),
818            "output".to_string(),
819        ],
820        accuracy_threshold: Some(0.03), // 3% tolerance for mobile
821    }
822}
823
824/// Create a quantization configuration optimized for edge devices.
825pub fn edge_quantization_config() -> QuantizationConfig {
826    QuantizationConfig {
827        strategy: QuantizationStrategy::Static,
828        precision: QuantizationPrecision::Int8,
829        calibration_samples: Some(1000),
830        quantize_weights: true,
831        quantize_activations: true, // More aggressive for edge
832        skip_layers: vec!["softmax".to_string()], // Minimize skipped layers
833        accuracy_threshold: Some(0.05), // 5% tolerance for edge
834    }
835}
836
837/// Create an ultra-low precision configuration for extreme edge cases.
838pub fn ultra_low_precision_config() -> QuantizationConfig {
839    QuantizationConfig {
840        strategy: QuantizationStrategy::PostTraining,
841        precision: QuantizationPrecision::Int4,
842        calibration_samples: Some(2000), // More samples for extreme quantization
843        quantize_weights: true,
844        quantize_activations: false, // Keep activations in higher precision
845        skip_layers: vec![
846            "softmax".to_string(),
847            "sigmoid".to_string(),
848            "tanh".to_string(),
849        ],
850        accuracy_threshold: Some(0.10), // 10% tolerance for ultra-low precision
851    }
852}
853
854/// Create a quantization-aware training (QAT) configuration.
855pub fn qat_config() -> QuantizationConfig {
856    QuantizationConfig {
857        strategy: QuantizationStrategy::QuantizationAware,
858        precision: QuantizationPrecision::Int8,
859        calibration_samples: None, // QAT doesn't need calibration samples
860        quantize_weights: true,
861        quantize_activations: true,
862        skip_layers: vec!["softmax".to_string()], // Minimal skipping for QAT
863        accuracy_threshold: Some(0.01),           // 1% tolerance for QAT (should be minimal)
864    }
865}
866
867/// Prepare a model for quantization-aware training by inserting fake quantization layers.
868pub fn prepare_model_for_qat<T>(
869    model: &mut Sequential<T>,
870    config: Option<QuantizationConfig>,
871) -> Result<(), TensorError>
872where
873    T: Clone
874        + Default
875        + 'static
876        + scirs2_core::num_traits::Float
877        + scirs2_core::num_traits::FromPrimitive,
878{
879    let config = config.unwrap_or_else(qat_config);
880
881    if config.strategy != QuantizationStrategy::QuantizationAware {
882        return Err(TensorError::unsupported_operation_simple(
883            "prepare_model_for_qat requires QuantizationAware strategy".to_string(),
884        ));
885    }
886
887    // In a real implementation, this would:
888    // 1. Insert FakeQuantization layers after each layer that should be quantized
889    // 2. Replace regular layers with QAT-aware versions
890    // 3. Set up observers for collecting statistics during training
891
892    // For now, this is a placeholder that validates the configuration
893    Ok(())
894}
895
896/// Finalize a QAT model by converting fake quantization to actual quantization.
897pub fn finalize_qat_model<T>(
898    model: &mut Sequential<T>,
899    calibration_data: Option<&[Tensor<T>]>,
900) -> Result<QuantizationStats, TensorError>
901where
902    T: Clone
903        + Default
904        + 'static
905        + scirs2_core::num_traits::Float
906        + scirs2_core::num_traits::FromPrimitive,
907{
908    // In a real implementation, this would:
909    // 1. Collect final statistics from observers
910    // 2. Convert FakeQuantization layers to actual quantized operations
911    // 3. Optimize the model for inference
912    // 4. Return statistics about the conversion
913
914    let stats = QuantizationStats {
915        original_size: 1000,
916        quantized_size: 250,
917        layers_quantized: 3,
918        parameters_quantized: 750,
919        inference_speedup: 2.0,
920        memory_reduction: 0.75,
921        accuracy_before: None,
922        accuracy_after: None,
923    };
924
925    Ok(stats)
926}
927
928#[cfg(test)]
929mod tests {
930    use super::*;
931    use crate::layers::Dense;
932
933    #[test]
934    fn test_quantization_config_default() {
935        let config = QuantizationConfig::default();
936        assert_eq!(config.strategy, QuantizationStrategy::PostTraining);
937        assert_eq!(config.precision, QuantizationPrecision::Int8);
938        assert!(config.quantize_weights);
939        assert!(!config.quantize_activations);
940    }
941
942    #[test]
943    fn test_quantization_params() {
944        let params = QuantizationParams::int8();
945        assert_eq!(params.qmin, -128);
946        assert_eq!(params.qmax, 127);
947        assert_eq!(params.dtype, DType::Int8);
948
949        // Test quantization/dequantization
950        let value = 1.5;
951        let quantized = params.quantize(value);
952        let dequantized = params.dequantize(quantized);
953        assert!((value - dequantized).abs() <= 0.5); // Allow for quantization error (inclusive)
954    }
955
956    #[test]
957    fn test_quantization_stats() {
958        let stats = QuantizationStats {
959            original_size: 1000,
960            quantized_size: 250,
961            layers_quantized: 2,
962            parameters_quantized: 500,
963            inference_speedup: 1.5,
964            memory_reduction: 0.75,
965            accuracy_before: Some(0.95),
966            accuracy_after: Some(0.93),
967        };
968
969        assert_eq!(stats.compression_ratio(), 4.0);
970        assert!(
971            (stats
972                .accuracy_drop()
973                .expect("test: operation should succeed")
974                - 0.02)
975                .abs()
976                < 0.01
977        ); // Allow for floating-point precision
978    }
979
980    #[test]
981    fn test_quantized_layer_creation() {
982        let layer = QuantizedLayer::<f32>::new(
983            "dense1".to_string(),
984            Some(QuantizationParams::int8()),
985            None,
986            vec![],
987            vec![10],
988            vec![20],
989        );
990
991        assert_eq!(layer.layer_name(), "dense1");
992        assert!(layer.weight_params().is_some());
993        assert!(layer.activation_params().is_none());
994    }
995
996    #[test]
997    fn test_model_quantizer() {
998        let quantizer = ModelQuantizer::new();
999        assert_eq!(
1000            quantizer.config.strategy,
1001            QuantizationStrategy::PostTraining
1002        );
1003
1004        let custom_config = QuantizationConfig {
1005            strategy: QuantizationStrategy::Dynamic,
1006            ..Default::default()
1007        };
1008        let custom_quantizer = ModelQuantizer::with_config(custom_config);
1009        assert_eq!(
1010            custom_quantizer.config.strategy,
1011            QuantizationStrategy::Dynamic
1012        );
1013    }
1014
1015    #[test]
1016    fn test_sequential_quantization() {
1017        let model = Sequential::new(vec![
1018            Box::new(Dense::<f32>::new(10, 20, true)),
1019            Box::new(Dense::<f32>::new(20, 1, true)),
1020        ]);
1021
1022        let result = quantize_model(&model, None);
1023        assert!(result.is_ok());
1024
1025        let (_quantized_model, stats) = result.expect("test: result should be valid");
1026        assert!(stats.layers_quantized > 0);
1027        assert!(stats.compression_ratio() > 1.0);
1028        assert!(stats.inference_speedup >= 1.0);
1029    }
1030
1031    #[test]
1032    fn test_mobile_quantization_config() {
1033        let config = mobile_quantization_config();
1034        assert_eq!(config.strategy, QuantizationStrategy::Dynamic);
1035        assert_eq!(config.precision, QuantizationPrecision::Int8);
1036        assert!(!config.quantize_activations);
1037        assert_eq!(config.accuracy_threshold, Some(0.03));
1038    }
1039
1040    #[test]
1041    fn test_edge_quantization_config() {
1042        let config = edge_quantization_config();
1043        assert_eq!(config.strategy, QuantizationStrategy::Static);
1044        assert!(config.quantize_activations);
1045        assert_eq!(config.accuracy_threshold, Some(0.05));
1046    }
1047
1048    #[test]
1049    fn test_ultra_low_precision_config() {
1050        let config = ultra_low_precision_config();
1051        assert_eq!(config.precision, QuantizationPrecision::Int4);
1052        assert_eq!(config.accuracy_threshold, Some(0.10));
1053        assert_eq!(config.calibration_samples, Some(2000));
1054    }
1055
1056    #[test]
1057    #[cfg(feature = "serialize")]
1058    fn test_quantization_serialization() {
1059        let params = QuantizationParams::int8();
1060        let serialized = serde_json::to_string(&params).expect("test: operation should succeed");
1061        let deserialized: QuantizationParams =
1062            serde_json::from_str(&serialized).expect("test: operation should succeed");
1063        assert_eq!(params.scale, deserialized.scale);
1064        assert_eq!(params.zero_point, deserialized.zero_point);
1065    }
1066
1067    #[test]
1068    fn test_qat_config() {
1069        let config = qat_config();
1070        assert_eq!(config.strategy, QuantizationStrategy::QuantizationAware);
1071        assert_eq!(config.precision, QuantizationPrecision::Int8);
1072        assert!(config.quantize_weights);
1073        assert!(config.quantize_activations);
1074        assert!(config.calibration_samples.is_none());
1075        assert_eq!(config.accuracy_threshold, Some(0.01));
1076    }
1077
1078    #[test]
1079    fn test_fake_quantization_layer() {
1080        let params = QuantizationParams::int8();
1081        let mut fake_quant = FakeQuantization::<f32>::new(params);
1082
1083        // Test layer creation
1084        assert!(fake_quant.enabled);
1085        assert_eq!(fake_quant.get_params().qmin, -128);
1086        assert_eq!(fake_quant.get_params().qmax, 127);
1087
1088        // Test enable/disable
1089        fake_quant.set_enabled(false);
1090        assert!(!fake_quant.enabled);
1091
1092        // Test training mode
1093        fake_quant.set_training(false);
1094        assert!(!fake_quant.training);
1095
1096        // Test parameters (should be empty)
1097        assert!(fake_quant.parameters().is_empty());
1098        assert!(fake_quant.parameters_mut().is_empty());
1099    }
1100
1101    #[test]
1102    fn test_quantization_observer() {
1103        let mut observer = QuantizationObserver::<f32>::new();
1104
1105        // Initially no observations
1106        assert_eq!(observer.count(), 0);
1107        assert!(observer.get_min_max().is_none());
1108
1109        // Add observations
1110        observer.observe(-2.0, 3.0);
1111        observer.observe(-1.0, 5.0);
1112
1113        // Check statistics
1114        assert_eq!(observer.count(), 2);
1115        let (min, max) = observer
1116            .get_min_max()
1117            .expect("test: operation should succeed");
1118        assert_eq!(min, -2.0);
1119        assert_eq!(max, 5.0);
1120
1121        // Reset observer
1122        observer.reset();
1123        assert_eq!(observer.count(), 0);
1124        assert!(observer.get_min_max().is_none());
1125    }
1126
1127    #[test]
1128    fn test_quantization_aware_training_strategy() {
1129        let config = QuantizationConfig {
1130            strategy: QuantizationStrategy::QuantizationAware,
1131            ..Default::default()
1132        };
1133
1134        let quantizer = ModelQuantizer::with_config(config);
1135        let model = Sequential::new(vec![
1136            Box::new(Dense::<f32>::new(10, 20, true)),
1137            Box::new(Dense::<f32>::new(20, 1, true)),
1138        ]);
1139
1140        let result = quantizer.quantize_sequential(&model);
1141        assert!(result.is_ok());
1142
1143        let (_quantized_model, stats) = result.expect("test: result should be valid");
1144        assert_eq!(stats.layers_quantized, 5); // QAT should prepare 5 layers
1145        assert_eq!(stats.parameters_quantized, 2500);
1146    }
1147
1148    #[test]
1149    fn test_prepare_model_for_qat() {
1150        let mut model = Sequential::new(vec![
1151            Box::new(Dense::<f32>::new(10, 20, true)),
1152            Box::new(Dense::<f32>::new(20, 1, true)),
1153        ]);
1154
1155        // Test with QAT config
1156        let qat_config = qat_config();
1157        let result = prepare_model_for_qat(&mut model, Some(qat_config));
1158        assert!(result.is_ok());
1159
1160        // Test with wrong strategy
1161        let wrong_config = QuantizationConfig {
1162            strategy: QuantizationStrategy::PostTraining,
1163            ..Default::default()
1164        };
1165        let result = prepare_model_for_qat(&mut model, Some(wrong_config));
1166        assert!(result.is_err());
1167    }
1168
1169    #[test]
1170    fn test_finalize_qat_model() {
1171        let mut model = Sequential::new(vec![
1172            Box::new(Dense::<f32>::new(10, 20, true)),
1173            Box::new(Dense::<f32>::new(20, 1, true)),
1174        ]);
1175
1176        let result = finalize_qat_model(&mut model, None);
1177        assert!(result.is_ok());
1178
1179        let stats = result.expect("test: result should be valid");
1180        assert!(stats.compression_ratio() > 1.0);
1181        assert!(stats.inference_speedup >= 1.0);
1182        assert!(stats.memory_reduction > 0.0);
1183    }
1184
1185    #[test]
1186    fn test_fake_quantization_qparams_calculation() {
1187        let initial_params = QuantizationParams::int8();
1188        let fake_quant = FakeQuantization::<f32>::new(initial_params);
1189
1190        // Test qparams calculation
1191        let new_params = fake_quant.calculate_qparams(-10.0, 10.0);
1192
1193        // For int8 with range [-10, 10] and qrange [-128, 127]
1194        // scale should be 20.0 / 255.0 ≈ 0.078
1195        // zero_point should be around -128 + 10/scale
1196        assert!(new_params.scale > 0.0);
1197        assert!(new_params.zero_point >= -128);
1198        assert!(new_params.zero_point <= 127);
1199
1200        // Test edge case: min == max
1201        let edge_params = fake_quant.calculate_qparams(5.0, 5.0);
1202        assert!(edge_params.scale > 0.0); // Should use minimum scale to avoid division by zero
1203    }
1204}