Skip to main content

scirs2_neural/quantization/
mod.rs

1//! Quantization support for neural networks
2//!
3//! This module provides comprehensive quantization capabilities including:
4//! - Post-training quantization (PTQ)
5//! - Quantization-aware training (QAT)
6//! - Mixed bit-width operations
7//! - Dynamic and static quantization schemes
8//! - INT8 symmetric/asymmetric quantization (per-tensor and per-channel)
9//! - INT4 group quantization with nibble packing
10//! - GPTQ Hessian-based optimal rounding
11//! - Calibration observers (MinMax, Percentile, MovingAverage)
12
13pub mod calibration;
14pub mod gptq;
15pub mod int4;
16pub mod int8;
17
18use crate::error::{Error, Result};
19use scirs2_core::ndarray::ArrayStatCompat;
20use scirs2_core::ndarray::{ArrayD, ArrayView, Zip};
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23
24/// Quantization configuration
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct QuantizationConfig {
27    /// Number of bits for quantization
28    pub bits: u8,
29    /// Whether to use signed quantization
30    pub signed: bool,
31    /// Quantization scheme
32    pub scheme: QuantizationScheme,
33    /// Calibration dataset size for PTQ
34    pub calibration_size: usize,
35    /// Quantization mode
36    pub mode: QuantizationMode,
37    /// Per-channel quantization for weights
38    pub per_channel: bool,
39    /// Quantization range clipping
40    pub range_clipping: f32,
41}
42
43impl Default for QuantizationConfig {
44    fn default() -> Self {
45        Self {
46            bits: 8,
47            signed: true,
48            scheme: QuantizationScheme::Symmetric,
49            calibration_size: 1000,
50            mode: QuantizationMode::Static,
51            per_channel: false,
52            range_clipping: 0.999,
53        }
54    }
55}
56
57/// Quantization scheme
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum QuantizationScheme {
60    /// Symmetric quantization around zero
61    Symmetric,
62    /// Asymmetric quantization with zero-point offset
63    Asymmetric,
64    /// Power-of-two quantization for hardware efficiency
65    PowerOfTwo,
66}
67
68/// Quantization mode
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum QuantizationMode {
71    /// Static quantization with fixed parameters
72    Static,
73    /// Dynamic quantization computed at runtime
74    Dynamic,
75    /// QAT (Quantization-Aware Training)
76    QAT,
77}
78
79/// Quantization parameters for a tensor
80#[derive(Debug, Clone)]
81pub struct QuantizationParams {
82    /// Scale factor for quantization
83    pub scale: f32,
84    /// Zero point for asymmetric quantization
85    pub zero_point: i32,
86    /// Number of quantization bits
87    pub bits: u8,
88    /// Minimum quantization value
89    pub qmin: i32,
90    /// Maximum quantization value
91    pub qmax: i32,
92}
93
94impl QuantizationParams {
95    /// Create new quantization parameters
96    pub fn new(bits: u8, signed: bool) -> Self {
97        let (qmin, qmax) = if signed {
98            (
99                -(1i32 << (bits as i32 - 1)),
100                (1i32 << (bits as i32 - 1)) - 1,
101            )
102        } else {
103            (0, (1i32 << bits as i32) - 1)
104        };
105        Self {
106            scale: 1.0,
107            zero_point: 0,
108            bits,
109            qmin,
110            qmax,
111        }
112    }
113
114    /// Calculate quantization parameters from tensor statistics
115    pub fn from_tensor(
116        tensor: &ArrayView<f32, scirs2_core::ndarray::IxDyn>,
117        config: &QuantizationConfig,
118    ) -> Result<Self> {
119        let mut params = Self::new(config.bits, config.signed);
120
121        // Calculate tensor statistics
122        let min_val = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
123        let max_val = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
124
125        // Apply range clipping
126        let range = max_val - min_val;
127        let clipped_range = range * config.range_clipping;
128        let center = (max_val + min_val) / 2.0;
129        let clipped_min = center - clipped_range / 2.0;
130        let clipped_max = center + clipped_range / 2.0;
131
132        match config.scheme {
133            QuantizationScheme::Symmetric => {
134                let abs_max = clipped_max.abs().max(clipped_min.abs());
135                let denom = (params.qmax - params.qmin) as f32;
136                params.scale = if denom > 0.0 {
137                    (2.0 * abs_max) / denom
138                } else {
139                    1.0
140                };
141                params.zero_point = 0;
142            }
143            QuantizationScheme::Asymmetric => {
144                let denom = (params.qmax - params.qmin) as f32;
145                params.scale = if denom > 0.0 {
146                    (clipped_max - clipped_min) / denom
147                } else {
148                    1.0
149                };
150                if params.scale > 0.0 {
151                    params.zero_point = params.qmin - (clipped_min / params.scale).round() as i32;
152                }
153            }
154            QuantizationScheme::PowerOfTwo => {
155                let abs_max = clipped_max.abs().max(clipped_min.abs());
156                let divisor = (1i32 << (config.bits as i32 - 1)) as f32;
157                if divisor > 0.0 && abs_max > 0.0 {
158                    let scale_log2 = (abs_max / divisor).log2().ceil();
159                    params.scale = 2.0_f32.powf(scale_log2);
160                }
161                params.zero_point = 0;
162            }
163        }
164        Ok(params)
165    }
166}
167
168/// Quantized tensor representation
169#[derive(Debug, Clone)]
170pub struct QuantizedTensor {
171    /// Quantized integer data
172    pub data: ArrayD<i8>,
173    /// Quantization parameters
174    pub params: QuantizationParams,
175    /// Original tensor shape
176    pub shape: Vec<usize>,
177}
178
179impl QuantizedTensor {
180    /// Create new quantized tensor from float tensor
181    pub fn from_float(tensor: &ArrayD<f32>, config: &QuantizationConfig) -> Result<Self> {
182        let params = QuantizationParams::from_tensor(&tensor.view(), config)?;
183        let quantized_data = Self::quantize_tensor(tensor, &params)?;
184        Ok(Self {
185            data: quantized_data,
186            params,
187            shape: tensor.shape().to_vec(),
188        })
189    }
190
191    /// Quantize a float tensor to integers
192    fn quantize_tensor(tensor: &ArrayD<f32>, params: &QuantizationParams) -> Result<ArrayD<i8>> {
193        let quantized = tensor.mapv(|x| {
194            let q_val = if params.scale > 0.0 {
195                (x / params.scale).round() + params.zero_point as f32
196            } else {
197                params.zero_point as f32
198            };
199            let clamped = q_val.max(params.qmin as f32).min(params.qmax as f32);
200            clamped as i8
201        });
202        Ok(quantized)
203    }
204
205    /// Dequantize back to float tensor
206    pub fn dequantize(&self) -> ArrayD<f32> {
207        self.data
208            .mapv(|q| (q as f32 - self.params.zero_point as f32) * self.params.scale)
209    }
210
211    /// Get quantized tensor size in bytes
212    pub fn size_bytes(&self) -> usize {
213        self.data.len() + std::mem::size_of::<QuantizationParams>()
214    }
215
216    /// Get compression ratio compared to f32
217    pub fn compression_ratio(&self) -> f32 {
218        let original_size = self.data.len() * std::mem::size_of::<f32>();
219        let quantized_size = self.size_bytes();
220        if quantized_size > 0 {
221            original_size as f32 / quantized_size as f32
222        } else {
223            1.0
224        }
225    }
226}
227
228/// Post-training quantization (PTQ) implementation
229#[derive(Debug)]
230pub struct PostTrainingQuantizer {
231    /// Quantization configuration
232    config: QuantizationConfig,
233    /// Calibration statistics
234    calibration_stats: HashMap<String, TensorStats>,
235}
236
237/// Tensor statistics for calibration
238#[derive(Debug)]
239struct TensorStats {
240    min: f32,
241    max: f32,
242    mean: f32,
243    #[allow(dead_code)]
244    std: f32,
245    histogram: Vec<u32>,
246}
247
248impl TensorStats {
249    fn new() -> Self {
250        Self {
251            min: f32::INFINITY,
252            max: f32::NEG_INFINITY,
253            mean: 0.0,
254            std: 0.0,
255            histogram: vec![0; 256],
256        }
257    }
258
259    fn update(&mut self, tensor: &ArrayView<f32, scirs2_core::ndarray::IxDyn>) {
260        if let Some(&min_v) = tensor
261            .iter()
262            .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
263        {
264            self.min = self.min.min(min_v);
265        }
266        if let Some(&max_v) = tensor
267            .iter()
268            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
269        {
270            self.max = self.max.max(max_v);
271        }
272
273        let sum: f32 = tensor.sum();
274        let count = tensor.len() as f32;
275        if count > 0.0 {
276            self.mean = sum / count;
277        }
278        let variance: f32 =
279            tensor.iter().map(|&x| (x - self.mean).powi(2)).sum::<f32>() / count.max(1.0);
280        self.std = variance.sqrt();
281
282        // Update histogram
283        let range = self.max - self.min;
284        if range > 0.0 {
285            for &val in tensor.iter() {
286                let normalized = ((val - self.min) / range * 255.0).round() as usize;
287                let bin = normalized.min(255);
288                self.histogram[bin] += 1;
289            }
290        }
291    }
292}
293
294impl PostTrainingQuantizer {
295    /// Create new post-training quantizer
296    pub fn new(config: QuantizationConfig) -> Self {
297        Self {
298            config,
299            calibration_stats: HashMap::new(),
300        }
301    }
302
303    /// Add calibration data for a named tensor
304    pub fn add_calibration_data(&mut self, name: &str, tensor: &ArrayD<f32>) {
305        let stats = self
306            .calibration_stats
307            .entry(name.to_string())
308            .or_insert_with(TensorStats::new);
309        stats.update(&tensor.view());
310    }
311
312    /// Finalize calibration and compute optimal quantization parameters
313    pub fn finalize_calibration(&mut self) -> Result<HashMap<String, QuantizationParams>> {
314        let mut params_map = HashMap::new();
315        for (name, stats) in &self.calibration_stats {
316            let optimal_params = self.compute_optimal_params(stats)?;
317            params_map.insert(name.clone(), optimal_params);
318        }
319        Ok(params_map)
320    }
321
322    /// Compute optimal quantization parameters using KL divergence
323    fn compute_optimal_params(&self, stats: &TensorStats) -> Result<QuantizationParams> {
324        let mut best_params = QuantizationParams::new(self.config.bits, self.config.signed);
325        let mut best_kl_div = f32::INFINITY;
326
327        for threshold_idx in 128..=255 {
328            let threshold = stats.min + (threshold_idx as f32 / 255.0) * (stats.max - stats.min);
329            let mut params = QuantizationParams::new(self.config.bits, self.config.signed);
330
331            match self.config.scheme {
332                QuantizationScheme::Symmetric => {
333                    let denom = (params.qmax - params.qmin) as f32;
334                    params.scale = if denom > 0.0 {
335                        (2.0 * threshold) / denom
336                    } else {
337                        1.0
338                    };
339                    params.zero_point = 0;
340                }
341                QuantizationScheme::Asymmetric => {
342                    let denom = (params.qmax - params.qmin) as f32;
343                    params.scale = if denom > 0.0 {
344                        (threshold - stats.min) / denom
345                    } else {
346                        1.0
347                    };
348                    if params.scale > 0.0 {
349                        params.zero_point = params.qmin - (stats.min / params.scale).round() as i32;
350                    }
351                }
352                QuantizationScheme::PowerOfTwo => {
353                    let divisor = (1i32 << (self.config.bits as i32 - 1)) as f32;
354                    if divisor > 0.0 && threshold > 0.0 {
355                        let scale_log2 = (threshold / divisor).log2().ceil();
356                        params.scale = 2.0_f32.powf(scale_log2);
357                    }
358                    params.zero_point = 0;
359                }
360            }
361
362            let kl_div = self.compute_kl_divergence(&stats.histogram, &params);
363            if kl_div < best_kl_div {
364                best_kl_div = kl_div;
365                best_params = params;
366            }
367        }
368        Ok(best_params)
369    }
370
371    /// Compute KL divergence between original and quantized distributions
372    fn compute_kl_divergence(&self, histogram: &[u32], params: &QuantizationParams) -> f32 {
373        let total_count: u32 = histogram.iter().sum();
374        if total_count == 0 {
375            return 0.0;
376        }
377        let mut kl_div = 0.0;
378        for (i, &count) in histogram.iter().enumerate() {
379            if count > 0 {
380                let p = count as f32 / total_count as f32;
381                let bin_value = i as f32 / 255.0;
382                let quantized = if params.scale > 0.0 {
383                    (bin_value / params.scale)
384                        .round()
385                        .max(params.qmin as f32)
386                        .min(params.qmax as f32)
387                } else {
388                    0.0
389                };
390                let dequantized = quantized * params.scale;
391                let q = (dequantized * 255.0).round() as usize;
392                let q_count = if q < histogram.len() { histogram[q] } else { 1 };
393                let q_prob = (q_count as f32 / total_count as f32).max(1e-8);
394                kl_div += p * (p / q_prob).ln();
395            }
396        }
397        kl_div
398    }
399
400    /// Quantize a tensor using computed parameters
401    pub fn quantize_tensor(
402        &self,
403        tensor: &ArrayD<f32>,
404        params: &QuantizationParams,
405    ) -> Result<QuantizedTensor> {
406        let quantized_data = QuantizedTensor::quantize_tensor(tensor, params)?;
407        Ok(QuantizedTensor {
408            data: quantized_data,
409            params: params.clone(),
410            shape: tensor.shape().to_vec(),
411        })
412    }
413}
414
415/// Quantization-aware training (QAT) support
416pub struct QuantizationAwareTraining {
417    /// QAT configuration
418    config: QuantizationConfig,
419    /// Fake quantization parameters for layers
420    layer_params: HashMap<String, QuantizationParams>,
421    /// Training step counter
422    step_count: usize,
423    /// Warmup steps before quantization
424    warmup_steps: usize,
425}
426
427impl QuantizationAwareTraining {
428    /// Create new QAT instance
429    pub fn new(config: QuantizationConfig) -> Self {
430        Self {
431            config,
432            layer_params: HashMap::new(),
433            step_count: 0,
434            warmup_steps: 1000,
435        }
436    }
437
438    /// Set warmup steps
439    pub fn set_warmup_steps(&mut self, steps: usize) {
440        self.warmup_steps = steps;
441    }
442
443    /// Initialize quantization parameters for a layer
444    pub fn init_layer_params(&mut self, layer_name: &str, tensor: &ArrayD<f32>) -> Result<()> {
445        let params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
446        self.layer_params.insert(layer_name.to_string(), params);
447        Ok(())
448    }
449
450    /// Apply fake quantization during training
451    pub fn fake_quantize(&mut self, layer_name: &str, tensor: &ArrayD<f32>) -> Result<ArrayD<f32>> {
452        self.step_count += 1;
453
454        // Skip quantization during warmup
455        if self.step_count < self.warmup_steps {
456            return Ok(tensor.clone());
457        }
458
459        let params = self.layer_params.get_mut(layer_name).ok_or_else(|| {
460            Error::InvalidArgument(format!("Layer {} not initialized", layer_name))
461        })?;
462
463        // Update parameters with exponential moving average
464        let new_params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
465        let alpha = 0.01_f32;
466        params.scale = params.scale * (1.0 - alpha) + new_params.scale * alpha;
467        if self.config.scheme == QuantizationScheme::Asymmetric {
468            params.zero_point = ((params.zero_point as f32) * (1.0 - alpha)
469                + (new_params.zero_point as f32) * alpha)
470                .round() as i32;
471        }
472
473        // Apply fake quantization (quantize then dequantize)
474        let quantized = QuantizedTensor::quantize_tensor(tensor, params)?;
475        let dequantized = quantized.mapv(|q| (q as f32 - params.zero_point as f32) * params.scale);
476        Ok(dequantized)
477    }
478
479    /// Get final quantization parameters for deployment
480    pub fn get_quantization_params(&self) -> &HashMap<String, QuantizationParams> {
481        &self.layer_params
482    }
483
484    /// Simulate quantization noise for better training
485    pub fn add_quantization_noise(&self, tensor: &ArrayD<f32>, noise_scale: f32) -> ArrayD<f32> {
486        let mut rng = scirs2_core::random::rng();
487        tensor.mapv(|x| {
488            let noise: f32 = scirs2_core::random::RngExt::random::<f32>(&mut rng) - 0.5;
489            x + noise * noise_scale
490        })
491    }
492}
493
494/// Mixed bit-width quantization support
495pub struct MixedBitWidthQuantizer {
496    /// Per-layer bit configurations
497    layer_configs: HashMap<String, QuantizationConfig>,
498    /// Sensitivity analysis results
499    sensitivity_scores: HashMap<String, f32>,
500}
501
502impl Default for MixedBitWidthQuantizer {
503    fn default() -> Self {
504        Self::new()
505    }
506}
507
508impl MixedBitWidthQuantizer {
509    /// Create new mixed bit-width quantizer
510    pub fn new() -> Self {
511        Self {
512            layer_configs: HashMap::new(),
513            sensitivity_scores: HashMap::new(),
514        }
515    }
516
517    /// Set quantization configuration for a specific layer
518    pub fn set_layer_config(&mut self, layer_name: &str, config: QuantizationConfig) {
519        self.layer_configs.insert(layer_name.to_string(), config);
520    }
521
522    /// Perform sensitivity analysis to determine optimal bit allocation
523    pub fn analyze_sensitivity(
524        &mut self,
525        layer_outputs: &HashMap<String, ArrayD<f32>>,
526    ) -> Result<()> {
527        for (layer_name, output) in layer_outputs {
528            let variance = self.compute_variance(output);
529            let entropy = self.compute_entropy(output);
530            let gradient_norm = self.compute_gradient_norm(output);
531            let sensitivity = variance * 0.4 + entropy * 0.3 + gradient_norm * 0.3;
532            self.sensitivity_scores
533                .insert(layer_name.clone(), sensitivity);
534        }
535        self.assign_bit_widths()?;
536        Ok(())
537    }
538
539    /// Compute variance of activations
540    fn compute_variance(&self, tensor: &ArrayD<f32>) -> f32 {
541        let mean = tensor.mean_or(0.0);
542        let count = tensor.len() as f32;
543        if count > 0.0 {
544            tensor.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / count
545        } else {
546            0.0
547        }
548    }
549
550    /// Compute entropy of activation distribution
551    fn compute_entropy(&self, tensor: &ArrayD<f32>) -> f32 {
552        let mut histogram = vec![0u32; 256];
553        let min_val = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
554        let max_val = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
555        let range = max_val - min_val;
556        if range == 0.0 {
557            return 0.0;
558        }
559        for &val in tensor.iter() {
560            let bin = ((val - min_val) / range * 255.0).round() as usize;
561            let bin = bin.min(255);
562            histogram[bin] += 1;
563        }
564        let total = tensor.len() as f32;
565        let mut entropy = 0.0;
566        for count in histogram {
567            if count > 0 {
568                let p = count as f32 / total;
569                entropy -= p * p.ln();
570            }
571        }
572        entropy
573    }
574
575    /// Compute gradient norm (simplified approximation)
576    fn compute_gradient_norm(&self, tensor: &ArrayD<f32>) -> f32 {
577        let mut grad_norm = 0.0f32;
578        for axis in 0..tensor.ndim() {
579            if tensor.shape()[axis] > 1 {
580                for _i in 0..tensor.shape()[axis] - 1 {
581                    grad_norm += 1.0;
582                }
583            }
584        }
585        let len = tensor.len() as f32;
586        if len > 0.0 {
587            grad_norm / len
588        } else {
589            0.0
590        }
591    }
592
593    /// Assign bit-widths based on sensitivity scores
594    fn assign_bit_widths(&mut self) -> Result<()> {
595        let mut scores: Vec<(String, f32)> = self
596            .sensitivity_scores
597            .iter()
598            .map(|(name, &score)| (name.clone(), score))
599            .collect();
600
601        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
602
603        for (i, (layer_name, _)) in scores.iter().enumerate() {
604            let bits = if i < scores.len() / 3 {
605                8 // High sensitivity
606            } else if i < 2 * scores.len() / 3 {
607                6 // Medium sensitivity
608            } else {
609                4 // Low sensitivity
610            };
611            let mut config = self
612                .layer_configs
613                .get(layer_name)
614                .cloned()
615                .unwrap_or_default();
616            config.bits = bits;
617            self.layer_configs.insert(layer_name.clone(), config);
618        }
619        Ok(())
620    }
621
622    /// Get optimal configuration for a layer
623    pub fn get_layer_config(&self, layer_name: &str) -> Option<&QuantizationConfig> {
624        self.layer_configs.get(layer_name)
625    }
626
627    /// Get sensitivity score for a layer
628    pub fn get_sensitivity_score(&self, layer_name: &str) -> Option<f32> {
629        self.sensitivity_scores.get(layer_name).copied()
630    }
631}
632
633/// Dynamic quantization at runtime
634pub struct DynamicQuantizer {
635    /// Configuration for dynamic quantization
636    config: QuantizationConfig,
637    /// Cache of recently computed parameters
638    params_cache: HashMap<String, QuantizationParams>,
639    /// Cache size limit
640    cache_size_limit: usize,
641}
642
643impl DynamicQuantizer {
644    /// Create new dynamic quantizer
645    pub fn new(config: QuantizationConfig) -> Self {
646        Self {
647            config,
648            params_cache: HashMap::new(),
649            cache_size_limit: 100,
650        }
651    }
652
653    /// Dynamically quantize tensor at runtime
654    pub fn quantize(
655        &mut self,
656        tensor: &ArrayD<f32>,
657        cache_key: Option<&str>,
658    ) -> Result<QuantizedTensor> {
659        let params = if let Some(key) = cache_key {
660            if let Some(cached_params) = self.params_cache.get(key) {
661                cached_params.clone()
662            } else {
663                let params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
664                self.cache_params(key, params.clone());
665                params
666            }
667        } else {
668            QuantizationParams::from_tensor(&tensor.view(), &self.config)?
669        };
670
671        let quantized_data = QuantizedTensor::quantize_tensor(tensor, &params)?;
672        Ok(QuantizedTensor {
673            data: quantized_data,
674            params,
675            shape: tensor.shape().to_vec(),
676        })
677    }
678
679    /// Cache quantization parameters
680    fn cache_params(&mut self, key: &str, params: QuantizationParams) {
681        if self.params_cache.len() >= self.cache_size_limit {
682            if let Some(first_key) = self.params_cache.keys().next().cloned() {
683                self.params_cache.remove(&first_key);
684            }
685        }
686        self.params_cache.insert(key.to_string(), params);
687    }
688
689    /// Clear parameter cache
690    pub fn clear_cache(&mut self) {
691        self.params_cache.clear();
692    }
693
694    /// Get cache statistics
695    pub fn cache_stats(&self) -> (usize, usize) {
696        (self.params_cache.len(), self.cache_size_limit)
697    }
698}
699
700/// Quantization utilities and helper functions
701pub mod utils {
702    use super::*;
703
704    /// Compare quantized vs original tensor accuracy
705    pub fn compute_quantization_error(original: &ArrayD<f32>, quantized: &QuantizedTensor) -> f32 {
706        let dequantized = quantized.dequantize();
707        let mse = Zip::from(original)
708            .and(&dequantized)
709            .fold(0.0, |acc, &orig, &deq| acc + (orig - deq).powi(2));
710        let len = original.len() as f32;
711        if len > 0.0 {
712            mse / len
713        } else {
714            0.0
715        }
716    }
717
718    /// Estimate model size reduction from quantization
719    pub fn estimate_size_reduction(bit_width: u8) -> f32 {
720        if bit_width > 0 {
721            32.0 / bit_width as f32
722        } else {
723            1.0
724        }
725    }
726
727    /// Simulate quantization performance gains
728    pub fn estimate_performance_gain(bit_width: u8) -> f32 {
729        match bit_width {
730            8 => 2.0,
731            4 => 4.0,
732            1 => 16.0,
733            _ => 1.0,
734        }
735    }
736
737    /// Convert between different quantization schemes
738    pub fn convert_quantization_scheme(
739        tensor: &QuantizedTensor,
740        target_scheme: QuantizationScheme,
741        target_bits: u8,
742    ) -> Result<QuantizedTensor> {
743        let float_tensor = tensor.dequantize();
744        let config = QuantizationConfig {
745            scheme: target_scheme,
746            bits: target_bits,
747            ..Default::default()
748        };
749        QuantizedTensor::from_float(&float_tensor, &config)
750    }
751}
752
753/// Convenience function to quantize a model's parameters
754///
755/// Takes a map of named parameters as f32 tensors and returns
756/// quantized versions using the specified configuration.
757pub fn quantize_model(
758    parameters: &HashMap<String, ArrayD<f32>>,
759    config: &QuantizationConfig,
760) -> Result<HashMap<String, QuantizedTensor>> {
761    let mut quantized_params = HashMap::new();
762    for (name, tensor) in parameters {
763        let quantized = QuantizedTensor::from_float(tensor, config)?;
764        quantized_params.insert(name.clone(), quantized);
765    }
766    Ok(quantized_params)
767}
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772    use scirs2_core::ndarray::{array, Array2};
773
774    /// Helper to create a random f32 array for tests
775    fn random_f32_array(rows: usize, cols: usize) -> Array2<f32> {
776        let mut rng = scirs2_core::random::rng();
777        let data: Vec<f32> = (0..rows * cols)
778            .map(|_| scirs2_core::random::RngExt::random_range(&mut rng, -1.0f32..1.0f32))
779            .collect();
780        Array2::from_shape_vec((rows, cols), data).expect("Test: array creation")
781    }
782
783    #[test]
784    fn test_quantization_config_default() {
785        let config = QuantizationConfig::default();
786        assert_eq!(config.bits, 8);
787        assert!(config.signed);
788        assert_eq!(config.scheme, QuantizationScheme::Symmetric);
789    }
790
791    #[test]
792    fn test_quantization_params_creation() {
793        let params = QuantizationParams::new(8, true);
794        assert_eq!(params.bits, 8);
795        assert_eq!(params.qmin, -128);
796        assert_eq!(params.qmax, 127);
797
798        let unsigned = QuantizationParams::new(8, false);
799        assert_eq!(unsigned.qmin, 0);
800        assert_eq!(unsigned.qmax, 255);
801    }
802
803    #[test]
804    fn test_symmetric_quantization() {
805        let config = QuantizationConfig::default();
806        let tensor = array![[1.0_f32, -1.0], [2.0, -2.0]].into_dyn();
807        let quantized = QuantizedTensor::from_float(&tensor, &config).expect("Test: quantization");
808        let _dequantized = quantized.dequantize();
809        let error = utils::compute_quantization_error(&tensor, &quantized);
810        assert!(error < 0.1);
811    }
812
813    #[test]
814    fn test_asymmetric_quantization() {
815        let config = QuantizationConfig {
816            scheme: QuantizationScheme::Asymmetric,
817            ..Default::default()
818        };
819        let tensor = array![[0.0_f32, 1.0], [2.0, 3.0]].into_dyn();
820        let quantized = QuantizedTensor::from_float(&tensor, &config).expect("Test: quantization");
821        assert!(quantized.params.zero_point != 0);
822        let error = utils::compute_quantization_error(&tensor, &quantized);
823        assert!(error < 0.1);
824    }
825
826    #[test]
827    fn test_post_training_quantization() {
828        let mut ptq = PostTrainingQuantizer::new(QuantizationConfig::default());
829        let calib_data = random_f32_array(100, 50).into_dyn();
830        ptq.add_calibration_data("layer1", &calib_data);
831        let params = ptq.finalize_calibration().expect("Test: calibration");
832        assert!(params.contains_key("layer1"));
833    }
834
835    #[test]
836    fn test_quantization_aware_training() {
837        let mut qat = QuantizationAwareTraining::new(QuantizationConfig::default());
838        let tensor = Array2::<f32>::ones((10, 10)).into_dyn();
839        qat.init_layer_params("layer1", &tensor)
840            .expect("Test: init params");
841        let fake_quantized = qat
842            .fake_quantize("layer1", &tensor)
843            .expect("Test: fake quantize");
844        assert_eq!(fake_quantized.shape(), tensor.shape());
845    }
846
847    #[test]
848    fn test_mixed_bitwidth_quantization() {
849        let mut mbq = MixedBitWidthQuantizer::new();
850        let mut outputs = HashMap::new();
851        outputs.insert("layer1".to_string(), random_f32_array(50, 50).into_dyn());
852        outputs.insert(
853            "layer2".to_string(),
854            Array2::<f32>::ones((50, 50)).into_dyn(),
855        );
856        mbq.analyze_sensitivity(&outputs)
857            .expect("Test: sensitivity analysis");
858        assert!(mbq.get_sensitivity_score("layer1").is_some());
859        assert!(mbq.get_layer_config("layer1").is_some());
860    }
861
862    #[test]
863    fn test_dynamic_quantization() {
864        let mut dq = DynamicQuantizer::new(QuantizationConfig::default());
865        let tensor = random_f32_array(20, 20).into_dyn();
866        let quantized = dq
867            .quantize(&tensor, Some("test_key"))
868            .expect("Test: dynamic quantize");
869        assert_eq!(quantized.shape, tensor.shape().to_vec());
870        let (cache_size, _) = dq.cache_stats();
871        assert_eq!(cache_size, 1);
872    }
873
874    #[test]
875    fn test_quantization_utilities() {
876        let original = random_f32_array(10, 10).into_dyn();
877        let quantized = QuantizedTensor::from_float(&original, &QuantizationConfig::default())
878            .expect("Test: quantization");
879        let error = utils::compute_quantization_error(&original, &quantized);
880        assert!(error >= 0.0);
881        let size_reduction = utils::estimate_size_reduction(8);
882        assert_eq!(size_reduction, 4.0);
883        let perf_gain = utils::estimate_performance_gain(8);
884        assert_eq!(perf_gain, 2.0);
885    }
886
887    #[test]
888    fn test_compression_ratio() {
889        let tensor = Array2::<f32>::ones((100, 100)).into_dyn();
890        let quantized = QuantizedTensor::from_float(&tensor, &QuantizationConfig::default())
891            .expect("Test: quantization");
892        let ratio = quantized.compression_ratio();
893        assert!(ratio > 1.0);
894    }
895
896    #[test]
897    fn test_power_of_two_quantization() {
898        let config = QuantizationConfig {
899            scheme: QuantizationScheme::PowerOfTwo,
900            ..Default::default()
901        };
902        let tensor = random_f32_array(10, 10).into_dyn();
903        let quantized = QuantizedTensor::from_float(&tensor, &config).expect("Test: quantization");
904        let scale_log2 = quantized.params.scale.log2();
905        assert!((scale_log2.round() - scale_log2).abs() < 1e-6);
906    }
907
908    #[test]
909    fn test_quantization_scheme_conversion() {
910        let config = QuantizationConfig::default();
911        let tensor = random_f32_array(10, 10).into_dyn();
912        let quantized = QuantizedTensor::from_float(&tensor, &config).expect("Test: quantization");
913        let converted =
914            utils::convert_quantization_scheme(&quantized, QuantizationScheme::Asymmetric, 4)
915                .expect("Test: scheme conversion");
916        assert_eq!(converted.params.bits, 4);
917    }
918
919    #[test]
920    fn test_quantize_model_fn() {
921        let mut parameters = HashMap::new();
922        parameters.insert("weight".to_string(), random_f32_array(5, 5).into_dyn());
923        parameters.insert("bias".to_string(), Array2::<f32>::zeros((1, 5)).into_dyn());
924        let config = QuantizationConfig::default();
925        let quantized = quantize_model(&parameters, &config);
926        assert!(quantized.is_ok());
927        let qmap = quantized.expect("Test: quantize_model");
928        assert_eq!(qmap.len(), 2);
929        assert!(qmap.contains_key("weight"));
930        assert!(qmap.contains_key("bias"));
931    }
932
933    #[test]
934    fn test_dequantize_roundtrip() {
935        let config = QuantizationConfig::default();
936        let tensor = array![[1.0_f32, 2.0], [3.0, 4.0]].into_dyn();
937        let quantized = QuantizedTensor::from_float(&tensor, &config).expect("Test: quantization");
938        let dequantized = quantized.dequantize();
939        // Check approximate roundtrip
940        for (orig, deq) in tensor.iter().zip(dequantized.iter()) {
941            assert!((orig - deq).abs() < 0.5, "orig={}, deq={}", orig, deq);
942        }
943    }
944}