scirs2_neural/
quantization.rs

1//! Quantization support for neural networks
2//!
3//! This module provides comprehensive quantization capabilities including:
4//! - Post-training quantization (PTQ)
5//! - Quantization-aware training (QAT)
6//! - Mixed bit-width operations
7//! - Dynamic and static quantization schemes
8
9use crate::error::{Error, Result};
10use ndarray::{ArrayD, ArrayView, Zip};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// Quantization configuration
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct QuantizationConfig {
17    /// Number of bits for quantization
18    pub bits: u8,
19    /// Whether to use signed quantization
20    pub signed: bool,
21    /// Quantization scheme
22    pub scheme: QuantizationScheme,
23    /// Calibration dataset size for PTQ
24    pub calibration_size: usize,
25    /// Quantization mode
26    pub mode: QuantizationMode,
27    /// Per-channel quantization for weights
28    pub per_channel: bool,
29    /// Quantization range clipping
30    pub range_clipping: f32,
31}
32
33impl Default for QuantizationConfig {
34    fn default() -> Self {
35        Self {
36            bits: 8,
37            signed: true,
38            scheme: QuantizationScheme::Symmetric,
39            calibration_size: 1000,
40            mode: QuantizationMode::Static,
41            per_channel: false,
42            range_clipping: 0.999,
43        }
44    }
45}
46
47/// Quantization scheme
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
49pub enum QuantizationScheme {
50    /// Symmetric quantization around zero
51    Symmetric,
52    /// Asymmetric quantization with zero-point offset
53    Asymmetric,
54    /// Power-of-two quantization for hardware efficiency
55    PowerOfTwo,
56}
57
58/// Quantization mode
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60pub enum QuantizationMode {
61    /// Static quantization with fixed parameters
62    Static,
63    /// Dynamic quantization computed at runtime
64    Dynamic,
65    /// QAT (Quantization-Aware Training)
66    QAT,
67}
68
69/// Quantization parameters for a tensor
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct QuantizationParams {
72    /// Scale factor for quantization
73    pub scale: f32,
74    /// Zero point for asymmetric quantization
75    pub zero_point: i32,
76    /// Number of quantization bits
77    pub bits: u8,
78    /// Minimum quantization value
79    pub qmin: i32,
80    /// Maximum quantization value
81    pub qmax: i32,
82}
83
84impl QuantizationParams {
85    /// Create new quantization parameters
86    pub fn new(bits: u8, signed: bool) -> Self {
87        let (qmin, qmax) = if signed {
88            (-(1 << (bits - 1)), (1 << (bits - 1)) - 1)
89        } else {
90            (0, (1 << bits) - 1)
91        };
92
93        Self {
94            scale: 1.0,
95            zero_point: 0,
96            bits,
97            qmin,
98            qmax,
99        }
100    }
101
102    /// Calculate quantization parameters from tensor statistics
103    pub fn from_tensor(
104        tensor: &ArrayView<f32, ndarray::IxDyn>,
105        config: &QuantizationConfig,
106    ) -> Result<Self> {
107        let mut params = Self::new(config.bits, config.signed);
108
109        // Calculate tensor statistics
110        let min_val = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
111        let max_val = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
112
113        // Apply range clipping
114        let range = max_val - min_val;
115        let clipped_range = range * config.range_clipping;
116        let center = (max_val + min_val) / 2.0;
117        let clipped_min = center - clipped_range / 2.0;
118        let clipped_max = center + clipped_range / 2.0;
119
120        match config.scheme {
121            QuantizationScheme::Symmetric => {
122                let abs_max = clipped_max.abs().max(clipped_min.abs());
123                params.scale = (2.0 * abs_max) / (params.qmax - params.qmin) as f32;
124                params.zero_point = 0;
125            }
126            QuantizationScheme::Asymmetric => {
127                params.scale = (clipped_max - clipped_min) / (params.qmax - params.qmin) as f32;
128                params.zero_point = params.qmin - (clipped_min / params.scale).round() as i32;
129            }
130            QuantizationScheme::PowerOfTwo => {
131                let abs_max = clipped_max.abs().max(clipped_min.abs());
132                let scale_log2 = (abs_max / (1 << (config.bits - 1)) as f32).log2().ceil();
133                params.scale = 2.0_f32.powf(scale_log2);
134                params.zero_point = 0;
135            }
136        }
137
138        Ok(params)
139    }
140}
141
142/// Quantized tensor representation
143#[derive(Debug, Clone)]
144pub struct QuantizedTensor {
145    /// Quantized integer data
146    pub data: ArrayD<i8>,
147    /// Quantization parameters
148    pub params: QuantizationParams,
149    /// Original tensor shape
150    pub shape: Vec<usize>,
151}
152
153impl QuantizedTensor {
154    /// Create new quantized tensor from float tensor
155    pub fn from_float(tensor: &ArrayD<f32>, config: &QuantizationConfig) -> Result<Self> {
156        let params = QuantizationParams::from_tensor(&tensor.view(), config)?;
157        let quantized_data = Self::quantize_tensor(tensor, &params)?;
158
159        Ok(Self {
160            data: quantized_data,
161            params,
162            shape: tensor.shape().to_vec(),
163        })
164    }
165
166    /// Quantize a float tensor to integers
167    fn quantize_tensor(tensor: &ArrayD<f32>, params: &QuantizationParams) -> Result<ArrayD<i8>> {
168        let quantized = tensor.mapv(|x| {
169            let q_val = (x / params.scale).round() + params.zero_point as f32;
170            let clamped = q_val.max(params.qmin as f32).min(params.qmax as f32);
171            clamped as i8
172        });
173
174        Ok(quantized)
175    }
176
177    /// Dequantize back to float tensor
178    pub fn dequantize(&self) -> ArrayD<f32> {
179        self.data
180            .mapv(|q| (q as f32 - self.params.zero_point as f32) * self.params.scale)
181    }
182
183    /// Get quantized tensor size in bytes
184    pub fn size_bytes(&self) -> usize {
185        self.data.len() + std::mem::size_of::<QuantizationParams>()
186    }
187
188    /// Get compression ratio compared to f32
189    pub fn compression_ratio(&self) -> f32 {
190        let original_size = self.data.len() * std::mem::size_of::<f32>();
191        let quantized_size = self.size_bytes();
192        original_size as f32 / quantized_size as f32
193    }
194}
195
196/// Post-training quantization (PTQ) implementation
197#[derive(Debug)]
198pub struct PostTrainingQuantizer {
199    /// Quantization configuration
200    config: QuantizationConfig,
201    /// Calibration statistics
202    calibration_stats: HashMap<String, TensorStats>,
203}
204
205/// Tensor statistics for calibration
206#[derive(Debug, Clone)]
207struct TensorStats {
208    min: f32,
209    max: f32,
210    mean: f32,
211    std: f32,
212    histogram: Vec<u32>,
213}
214
215impl TensorStats {
216    fn new() -> Self {
217        Self {
218            min: f32::INFINITY,
219            max: f32::NEG_INFINITY,
220            mean: 0.0,
221            std: 0.0,
222            histogram: vec![0; 256],
223        }
224    }
225
226    fn update(&mut self, tensor: &ArrayView<f32, ndarray::IxDyn>) {
227        self.min = self.min.min(
228            *tensor
229                .iter()
230                .min_by(|a, b| a.partial_cmp(b).unwrap())
231                .unwrap(),
232        );
233        self.max = self.max.max(
234            *tensor
235                .iter()
236                .max_by(|a, b| a.partial_cmp(b).unwrap())
237                .unwrap(),
238        );
239
240        let sum: f32 = tensor.sum();
241        let count = tensor.len() as f32;
242        self.mean = sum / count;
243
244        let variance: f32 = tensor.iter().map(|&x| (x - self.mean).powi(2)).sum::<f32>() / count;
245        self.std = variance.sqrt();
246
247        // Update histogram
248        for &val in tensor.iter() {
249            let normalized = ((val - self.min) / (self.max - self.min) * 255.0).round() as usize;
250            let bin = normalized.min(255);
251            self.histogram[bin] += 1;
252        }
253    }
254}
255
256impl PostTrainingQuantizer {
257    /// Create new post-training quantizer
258    pub fn new(config: QuantizationConfig) -> Self {
259        Self {
260            config,
261            calibration_stats: HashMap::new(),
262        }
263    }
264
265    /// Add calibration data for a named tensor
266    pub fn add_calibration_data(&mut self, name: &str, tensor: &ArrayD<f32>) {
267        let stats = self
268            .calibration_stats
269            .entry(name.to_string())
270            .or_insert_with(TensorStats::new);
271        stats.update(&tensor.view());
272    }
273
274    /// Finalize calibration and compute optimal quantization parameters
275    pub fn finalize_calibration(&mut self) -> Result<HashMap<String, QuantizationParams>> {
276        let mut params_map = HashMap::new();
277
278        for (name, stats) in &self.calibration_stats {
279            // Use the stats directly for optimal parameter computation
280
281            // Use KL divergence for optimal quantization range selection
282            let optimal_params = self.compute_optimal_params(stats)?;
283            params_map.insert(name.clone(), optimal_params);
284        }
285
286        Ok(params_map)
287    }
288
289    /// Compute optimal quantization parameters using KL divergence
290    fn compute_optimal_params(&self, stats: &TensorStats) -> Result<QuantizationParams> {
291        let mut best_params = QuantizationParams::new(self.config.bits, self.config.signed);
292        let mut best_kl_div = f32::INFINITY;
293
294        // Try different threshold values
295        for threshold_idx in 128..=255 {
296            let threshold = stats.min + (threshold_idx as f32 / 255.0) * (stats.max - stats.min);
297
298            // Compute quantization parameters for this threshold
299            let mut params = QuantizationParams::new(self.config.bits, self.config.signed);
300
301            match self.config.scheme {
302                QuantizationScheme::Symmetric => {
303                    params.scale = (2.0 * threshold) / (params.qmax - params.qmin) as f32;
304                    params.zero_point = 0;
305                }
306                QuantizationScheme::Asymmetric => {
307                    params.scale = (threshold - stats.min) / (params.qmax - params.qmin) as f32;
308                    params.zero_point = params.qmin - (stats.min / params.scale).round() as i32;
309                }
310                QuantizationScheme::PowerOfTwo => {
311                    let scale_log2 = (threshold / (1 << (self.config.bits - 1)) as f32)
312                        .log2()
313                        .ceil();
314                    params.scale = 2.0_f32.powf(scale_log2);
315                    params.zero_point = 0;
316                }
317            }
318
319            // Compute KL divergence (simplified approximation)
320            let kl_div = self.compute_kl_divergence(&stats.histogram, &params);
321
322            if kl_div < best_kl_div {
323                best_kl_div = kl_div;
324                best_params = params;
325            }
326        }
327
328        Ok(best_params)
329    }
330
331    /// Compute KL divergence between original and quantized distributions
332    fn compute_kl_divergence(&self, histogram: &[u32], params: &QuantizationParams) -> f32 {
333        let total_count: u32 = histogram.iter().sum();
334        if total_count == 0 {
335            return 0.0;
336        }
337
338        let mut kl_div = 0.0;
339        for (i, &count) in histogram.iter().enumerate() {
340            if count > 0 {
341                let p = count as f32 / total_count as f32;
342
343                // Simulate quantization effect
344                let bin_value = i as f32 / 255.0;
345                let quantized = (bin_value / params.scale)
346                    .round()
347                    .max(params.qmin as f32)
348                    .min(params.qmax as f32);
349                let dequantized = quantized * params.scale;
350
351                // Approximate quantized distribution
352                let q = (dequantized * 255.0).round() as usize;
353                let q_count = if q < histogram.len() { histogram[q] } else { 1 };
354                let q_prob = (q_count as f32 / total_count as f32).max(1e-8);
355
356                kl_div += p * (p / q_prob).ln();
357            }
358        }
359
360        kl_div
361    }
362
363    /// Quantize a tensor using computed parameters
364    pub fn quantize_tensor(
365        &self,
366        tensor: &ArrayD<f32>,
367        params: &QuantizationParams,
368    ) -> Result<QuantizedTensor> {
369        let quantized_data = QuantizedTensor::quantize_tensor(tensor, params)?;
370
371        Ok(QuantizedTensor {
372            data: quantized_data,
373            params: params.clone(),
374            shape: tensor.shape().to_vec(),
375        })
376    }
377}
378
379/// Quantization-aware training (QAT) support
380#[derive(Debug)]
381pub struct QuantizationAwareTraining {
382    /// QAT configuration
383    config: QuantizationConfig,
384    /// Fake quantization parameters for layers
385    layer_params: HashMap<String, QuantizationParams>,
386    /// Training step counter
387    step_count: usize,
388    /// Warmup steps before quantization
389    warmup_steps: usize,
390}
391
392impl QuantizationAwareTraining {
393    /// Create new QAT instance
394    pub fn new(config: QuantizationConfig) -> Self {
395        Self {
396            config,
397            layer_params: HashMap::new(),
398            step_count: 0,
399            warmup_steps: 1000,
400        }
401    }
402
403    /// Set warmup steps
404    pub fn set_warmup_steps(&mut self, steps: usize) {
405        self.warmup_steps = steps;
406    }
407
408    /// Initialize quantization parameters for a layer
409    pub fn init_layer_params(&mut self, layer_name: &str, tensor: &ArrayD<f32>) -> Result<()> {
410        let params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
411        self.layer_params.insert(layer_name.to_string(), params);
412        Ok(())
413    }
414
415    /// Apply fake quantization during training
416    pub fn fake_quantize(&mut self, layer_name: &str, tensor: &ArrayD<f32>) -> Result<ArrayD<f32>> {
417        self.step_count += 1;
418
419        // Skip quantization during warmup
420        if self.step_count < self.warmup_steps {
421            return Ok(tensor.clone());
422        }
423
424        let params = self.layer_params.get_mut(layer_name).ok_or_else(|| {
425            Error::InvalidArgument(format!("Layer {} not initialized", layer_name))
426        })?;
427
428        // Update parameters with exponential moving average
429        let new_params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
430        let alpha = 0.01; // EMA factor
431
432        params.scale = params.scale * (1.0 - alpha) + new_params.scale * alpha;
433        if self.config.scheme == QuantizationScheme::Asymmetric {
434            params.zero_point = ((params.zero_point as f32) * (1.0 - alpha)
435                + (new_params.zero_point as f32) * alpha)
436                .round() as i32;
437        }
438
439        // Apply fake quantization (quantize then dequantize)
440        let quantized = QuantizedTensor::quantize_tensor(tensor, params)?;
441        let dequantized = quantized.mapv(|q| (q as f32 - params.zero_point as f32) * params.scale);
442
443        Ok(dequantized)
444    }
445
446    /// Get final quantization parameters for deployment
447    pub fn get_quantization_params(&self) -> &HashMap<String, QuantizationParams> {
448        &self.layer_params
449    }
450
451    /// Simulate quantization noise for better training
452    pub fn add_quantization_noise(&self, tensor: &ArrayD<f32>, noise_scale: f32) -> ArrayD<f32> {
453        use rand::Rng;
454        let mut rng = rand::rng();
455
456        tensor.mapv(|x| {
457            let noise = rng.random::<f32>() - 0.5; // Uniform noise [-0.5, 0.5]
458            x + noise * noise_scale
459        })
460    }
461}
462
463/// Mixed bit-width quantization support
464#[derive(Debug)]
465pub struct MixedBitWidthQuantizer {
466    /// Per-layer bit configurations
467    layer_configs: HashMap<String, QuantizationConfig>,
468    /// Sensitivity analysis results
469    sensitivity_scores: HashMap<String, f32>,
470}
471
472impl Default for MixedBitWidthQuantizer {
473    fn default() -> Self {
474        Self::new()
475    }
476}
477
478impl MixedBitWidthQuantizer {
479    /// Create new mixed bit-width quantizer
480    pub fn new() -> Self {
481        Self {
482            layer_configs: HashMap::new(),
483            sensitivity_scores: HashMap::new(),
484        }
485    }
486
487    /// Set quantization configuration for a specific layer
488    pub fn set_layer_config(&mut self, layer_name: &str, config: QuantizationConfig) {
489        self.layer_configs.insert(layer_name.to_string(), config);
490    }
491
492    /// Perform sensitivity analysis to determine optimal bit allocation
493    pub fn analyze_sensitivity(
494        &mut self,
495        layer_outputs: &HashMap<String, ArrayD<f32>>,
496    ) -> Result<()> {
497        for (layer_name, output) in layer_outputs {
498            // Compute sensitivity score based on activation distribution
499            let variance = self.compute_variance(output);
500            let entropy = self.compute_entropy(output);
501            let gradient_norm = self.compute_gradient_norm(output);
502
503            // Combined sensitivity score
504            let sensitivity = variance * 0.4 + entropy * 0.3 + gradient_norm * 0.3;
505            self.sensitivity_scores
506                .insert(layer_name.clone(), sensitivity);
507        }
508
509        // Assign bit-widths based on sensitivity scores
510        self.assign_bit_widths()?;
511
512        Ok(())
513    }
514
515    /// Compute variance of activations
516    fn compute_variance(&self, tensor: &ArrayD<f32>) -> f32 {
517        let mean = tensor.mean().unwrap_or(0.0);
518        let variance =
519            tensor.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / tensor.len() as f32;
520        variance
521    }
522
523    /// Compute entropy of activation distribution
524    fn compute_entropy(&self, tensor: &ArrayD<f32>) -> f32 {
525        let mut histogram = vec![0; 256];
526        let min_val = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
527        let max_val = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
528        let range = max_val - min_val;
529
530        if range == 0.0 {
531            return 0.0;
532        }
533
534        for &val in tensor.iter() {
535            let bin = ((val - min_val) / range * 255.0).round() as usize;
536            let bin = bin.min(255);
537            histogram[bin] += 1;
538        }
539
540        let total = tensor.len() as f32;
541        let mut entropy = 0.0;
542        for count in histogram {
543            if count > 0 {
544                let p = count as f32 / total;
545                entropy -= p * p.ln();
546            }
547        }
548
549        entropy
550    }
551
552    /// Compute gradient norm (simplified approximation)
553    fn compute_gradient_norm(&self, tensor: &ArrayD<f32>) -> f32 {
554        // Approximate gradient as the standard deviation of adjacent differences
555        let mut grad_norm = 0.0;
556        for axis in 0..tensor.ndim() {
557            if tensor.shape()[axis] > 1 {
558                for _i in 0..tensor.shape()[axis] - 1 {
559                    // Simplified gradient computation along each axis
560                    grad_norm += 1.0; // Placeholder - would compute actual gradients in real implementation
561                }
562            }
563        }
564        grad_norm / tensor.len() as f32
565    }
566
567    /// Assign bit-widths based on sensitivity scores
568    fn assign_bit_widths(&mut self) -> Result<()> {
569        let mut scores: Vec<(String, f32)> = self
570            .sensitivity_scores
571            .iter()
572            .map(|(name, &score)| (name.clone(), score))
573            .collect();
574
575        // Sort by sensitivity (higher sensitivity gets more bits)
576        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
577
578        // Assign bit-widths: high sensitivity layers get 8 bits, others get 4-6 bits
579        for (i, (layer_name, _)) in scores.iter().enumerate() {
580            let bits = if i < scores.len() / 3 {
581                8 // High sensitivity
582            } else if i < 2 * scores.len() / 3 {
583                6 // Medium sensitivity
584            } else {
585                4 // Low sensitivity
586            };
587
588            let mut config = self
589                .layer_configs
590                .get(layer_name)
591                .cloned()
592                .unwrap_or_default();
593            config.bits = bits;
594            self.layer_configs.insert(layer_name.clone(), config);
595        }
596
597        Ok(())
598    }
599
600    /// Get optimal configuration for a layer
601    pub fn get_layer_config(&self, layer_name: &str) -> Option<&QuantizationConfig> {
602        self.layer_configs.get(layer_name)
603    }
604
605    /// Get sensitivity score for a layer
606    pub fn get_sensitivity_score(&self, layer_name: &str) -> Option<f32> {
607        self.sensitivity_scores.get(layer_name).copied()
608    }
609}
610
611/// Dynamic quantization at runtime
612#[derive(Debug)]
613pub struct DynamicQuantizer {
614    /// Configuration for dynamic quantization
615    config: QuantizationConfig,
616    /// Cache of recently computed parameters
617    params_cache: HashMap<String, QuantizationParams>,
618    /// Cache size limit
619    cache_size_limit: usize,
620}
621
622impl DynamicQuantizer {
623    /// Create new dynamic quantizer
624    pub fn new(config: QuantizationConfig) -> Self {
625        Self {
626            config,
627            params_cache: HashMap::new(),
628            cache_size_limit: 100,
629        }
630    }
631
632    /// Dynamically quantize tensor at runtime
633    pub fn quantize(
634        &mut self,
635        tensor: &ArrayD<f32>,
636        cache_key: Option<&str>,
637    ) -> Result<QuantizedTensor> {
638        let params = if let Some(key) = cache_key {
639            if let Some(cached_params) = self.params_cache.get(key) {
640                cached_params.clone()
641            } else {
642                let params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
643                self.cache_params(key, params.clone());
644                params
645            }
646        } else {
647            QuantizationParams::from_tensor(&tensor.view(), &self.config)?
648        };
649
650        let quantized_data = QuantizedTensor::quantize_tensor(tensor, &params)?;
651
652        Ok(QuantizedTensor {
653            data: quantized_data,
654            params,
655            shape: tensor.shape().to_vec(),
656        })
657    }
658
659    /// Cache quantization parameters
660    fn cache_params(&mut self, key: &str, params: QuantizationParams) {
661        if self.params_cache.len() >= self.cache_size_limit {
662            // Simple LRU eviction - remove first entry
663            if let Some(first_key) = self.params_cache.keys().next().cloned() {
664                self.params_cache.remove(&first_key);
665            }
666        }
667        self.params_cache.insert(key.to_string(), params);
668    }
669
670    /// Clear parameter cache
671    pub fn clear_cache(&mut self) {
672        self.params_cache.clear();
673    }
674
675    /// Get cache statistics
676    pub fn cache_stats(&self) -> (usize, usize) {
677        (self.params_cache.len(), self.cache_size_limit)
678    }
679}
680
681/// Quantization utilities and helper functions
682pub mod utils {
683    use super::*;
684
685    /// Compare quantized vs original tensor accuracy
686    pub fn compute_quantization_error(original: &ArrayD<f32>, quantized: &QuantizedTensor) -> f32 {
687        let dequantized = quantized.dequantize();
688        let mse = Zip::from(original)
689            .and(&dequantized)
690            .fold(0.0, |acc, &orig, &deq| acc + (orig - deq).powi(2));
691        mse / original.len() as f32
692    }
693
694    /// Estimate model size reduction from quantization
695    pub fn estimate_size_reduction(bit_width: u8) -> f32 {
696        32.0 / bit_width as f32
697    }
698
699    /// Simulate quantization performance gains
700    pub fn estimate_performance_gain(bit_width: u8) -> f32 {
701        // Empirical approximation based on common hardware
702        match bit_width {
703            8 => 2.0,  // ~2x speedup with INT8
704            4 => 4.0,  // ~4x speedup with INT4
705            1 => 16.0, // ~16x speedup with binary
706            _ => 1.0,
707        }
708    }
709
710    /// Convert between different quantization schemes
711    pub fn convert_quantization_scheme(
712        tensor: &QuantizedTensor,
713        target_scheme: QuantizationScheme,
714        target_bits: u8,
715    ) -> Result<QuantizedTensor> {
716        // First dequantize to float
717        let float_tensor = tensor.dequantize();
718
719        // Create new config with target scheme
720        let config = QuantizationConfig {
721            scheme: target_scheme,
722            bits: target_bits,
723            ..Default::default()
724        };
725
726        // Re-quantize with new scheme
727        QuantizedTensor::from_float(&float_tensor, &config)
728    }
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734    use ndarray::{array, Array2};
735    use ndarray_rand::rand::distributions::Standard;
736    use ndarray_rand::RandomExt;
737
738    #[test]
739    fn test_quantization_config_default() {
740        let config = QuantizationConfig::default();
741        assert_eq!(config.bits, 8);
742        assert!(config.signed);
743        assert_eq!(config.scheme, QuantizationScheme::Symmetric);
744    }
745
746    #[test]
747    fn test_quantization_params_creation() {
748        let params = QuantizationParams::new(8, true);
749        assert_eq!(params.bits, 8);
750        assert_eq!(params.qmin, -128);
751        assert_eq!(params.qmax, 127);
752    }
753
754    #[test]
755    fn test_symmetric_quantization() {
756        let tensor = array![[1.0, -1.0], [2.0, -2.0]].into_dyn();
757        let config = QuantizationConfig {
758            scheme: QuantizationScheme::Symmetric,
759            ..Default::default()
760        };
761
762        let quantized = QuantizedTensor::from_float(&tensor, &config).unwrap();
763        let _dequantized = quantized.dequantize();
764
765        // Check that quantization preserves approximate values
766        let error = utils::compute_quantization_error(&tensor, &quantized);
767        assert!(error < 0.1); // Small quantization error
768    }
769
770    #[test]
771    fn test_asymmetric_quantization() {
772        let tensor = array![[0.0, 1.0], [2.0, 3.0]].into_dyn();
773        let config = QuantizationConfig {
774            scheme: QuantizationScheme::Asymmetric,
775            ..Default::default()
776        };
777
778        let quantized = QuantizedTensor::from_float(&tensor, &config).unwrap();
779        let _dequantized = quantized.dequantize();
780
781        assert!(quantized.params.zero_point != 0); // Should have non-zero zero-point
782
783        let error = utils::compute_quantization_error(&tensor, &quantized);
784        assert!(error < 0.1);
785    }
786
787    #[test]
788    fn test_post_training_quantization() {
789        let mut ptq = PostTrainingQuantizer::new(QuantizationConfig::default());
790
791        // Add calibration data
792        let calib_data = Array2::random((100, 50), Standard).into_dyn();
793        ptq.add_calibration_data("layer1", &calib_data);
794
795        let params = ptq.finalize_calibration().unwrap();
796        assert!(params.contains_key("layer1"));
797    }
798
799    #[test]
800    fn test_quantization_aware_training() {
801        let mut qat = QuantizationAwareTraining::new(QuantizationConfig::default());
802        let tensor = Array2::ones((10, 10)).into_dyn();
803
804        qat.init_layer_params("layer1", &tensor).unwrap();
805        let fake_quantized = qat.fake_quantize("layer1", &tensor).unwrap();
806
807        assert_eq!(fake_quantized.shape(), tensor.shape());
808    }
809
810    #[test]
811    fn test_mixed_bitwidth_quantization() {
812        let mut mbq = MixedBitWidthQuantizer::new();
813
814        let mut outputs = HashMap::new();
815        outputs.insert(
816            "layer1".to_string(),
817            Array2::random((50, 50), Standard).into_dyn(),
818        );
819        outputs.insert("layer2".to_string(), Array2::ones((50, 50)).into_dyn());
820
821        mbq.analyze_sensitivity(&outputs).unwrap();
822
823        assert!(mbq.get_sensitivity_score("layer1").is_some());
824        assert!(mbq.get_layer_config("layer1").is_some());
825    }
826
827    #[test]
828    fn test_dynamic_quantization() {
829        let mut dq = DynamicQuantizer::new(QuantizationConfig::default());
830        let tensor = Array2::random((20, 20), Standard).into_dyn();
831
832        let quantized = dq.quantize(&tensor, Some("test_key")).unwrap();
833        assert_eq!(quantized.shape, tensor.shape().to_vec());
834
835        let (cache_size, _) = dq.cache_stats();
836        assert_eq!(cache_size, 1);
837    }
838
839    #[test]
840    fn test_quantization_utilities() {
841        let original = Array2::random((10, 10), Standard).into_dyn();
842        let quantized =
843            QuantizedTensor::from_float(&original, &QuantizationConfig::default()).unwrap();
844
845        let error = utils::compute_quantization_error(&original, &quantized);
846        assert!(error >= 0.0);
847
848        let size_reduction = utils::estimate_size_reduction(8);
849        assert_eq!(size_reduction, 4.0);
850
851        let perf_gain = utils::estimate_performance_gain(8);
852        assert_eq!(perf_gain, 2.0);
853    }
854
855    #[test]
856    fn test_compression_ratio() {
857        let tensor = Array2::ones((100, 100)).into_dyn();
858        let quantized =
859            QuantizedTensor::from_float(&tensor, &QuantizationConfig::default()).unwrap();
860
861        let ratio = quantized.compression_ratio();
862        assert!(ratio > 1.0); // Should be compressed
863    }
864
865    #[test]
866    fn test_power_of_two_quantization() {
867        let tensor = Array2::random((10, 10), Standard).into_dyn();
868        let config = QuantizationConfig {
869            scheme: QuantizationScheme::PowerOfTwo,
870            ..Default::default()
871        };
872
873        let quantized = QuantizedTensor::from_float(&tensor, &config).unwrap();
874
875        // Scale should be a power of 2
876        let scale_log2 = quantized.params.scale.log2();
877        assert!((scale_log2.round() - scale_log2).abs() < 1e-6);
878    }
879
880    #[test]
881    fn test_quantization_scheme_conversion() {
882        let tensor = Array2::random((10, 10), Standard).into_dyn();
883        let quantized =
884            QuantizedTensor::from_float(&tensor, &QuantizationConfig::default()).unwrap();
885
886        let converted =
887            utils::convert_quantization_scheme(&quantized, QuantizationScheme::Asymmetric, 4)
888                .unwrap();
889
890        assert_eq!(converted.params.bits, 4);
891    }
892}