Skip to main content

voirs_cloning/
quantization.rs

1//! Model quantization for edge deployment
2//!
3//! This module provides INT8 and FP16 quantization support for model deployment on edge devices.
4//! Quantization reduces memory usage and computational requirements while maintaining quality.
5
6use candle_core::{Device, Result as CandleResult, Tensor};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Quantization configuration
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub struct QuantizationConfig {
13    /// Target quantization precision
14    pub precision: QuantizationPrecision,
15    /// Quantization method to use
16    pub method: QuantizationMethod,
17    /// Calibration dataset size for post-training quantization
18    pub calibration_samples: usize,
19    /// Enable dynamic quantization
20    pub dynamic_quantization: bool,
21    /// Percentile for outlier clipping (0.01 = 99th percentile)
22    pub outlier_percentile: f32,
23    /// Layer-specific quantization settings
24    pub layer_configs: HashMap<String, LayerQuantizationConfig>,
25    /// Enable quantization-aware training mode
26    pub quantization_aware_training: bool,
27}
28
29impl Default for QuantizationConfig {
30    fn default() -> Self {
31        Self {
32            precision: QuantizationPrecision::Int8,
33            method: QuantizationMethod::PostTrainingQuantization,
34            calibration_samples: 100,
35            dynamic_quantization: false,
36            outlier_percentile: 0.01,
37            layer_configs: HashMap::new(),
38            quantization_aware_training: false,
39        }
40    }
41}
42
43impl QuantizationConfig {
44    /// Create a new quantization config
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Create a config optimized for mobile deployment
50    pub fn mobile_optimized() -> Self {
51        Self {
52            precision: QuantizationPrecision::Int8,
53            method: QuantizationMethod::PostTrainingQuantization,
54            calibration_samples: 50,
55            dynamic_quantization: true,
56            outlier_percentile: 0.005,
57            layer_configs: HashMap::new(),
58            quantization_aware_training: false,
59        }
60    }
61
62    /// Create a config optimized for edge devices with severe memory constraints
63    pub fn edge_optimized() -> Self {
64        let mut layer_configs = HashMap::new();
65        // Quantize embedding layers more aggressively
66        layer_configs.insert(
67            "embedding".to_string(),
68            LayerQuantizationConfig {
69                precision: QuantizationPrecision::Int4,
70                quantize_weights: true,
71                quantize_activations: true,
72                symmetric: true,
73            },
74        );
75
76        Self {
77            precision: QuantizationPrecision::Int8,
78            method: QuantizationMethod::PostTrainingQuantization,
79            calibration_samples: 25,
80            dynamic_quantization: true,
81            outlier_percentile: 0.001,
82            layer_configs,
83            quantization_aware_training: false,
84        }
85    }
86
87    /// Validate the configuration
88    pub fn validate(&self) -> crate::Result<()> {
89        if self.calibration_samples == 0 {
90            return Err(crate::Error::Config(
91                "Calibration samples must be greater than 0".to_string(),
92            ));
93        }
94
95        if !(0.0..0.1).contains(&self.outlier_percentile) {
96            return Err(crate::Error::Config(
97                "Outlier percentile must be between 0.0 and 0.1".to_string(),
98            ));
99        }
100
101        Ok(())
102    }
103}
104
105/// Quantization precision levels
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
107pub enum QuantizationPrecision {
108    /// 4-bit integer (extreme compression)
109    Int4,
110    /// 8-bit integer (standard quantization)
111    Int8,
112    /// 16-bit integer
113    Int16,
114    /// 16-bit floating point (half precision)
115    Float16,
116    /// Mixed precision (different precisions for different layers)
117    Mixed,
118}
119
120impl QuantizationPrecision {
121    /// Get the bits per parameter for this precision
122    pub fn bits_per_param(&self) -> u8 {
123        match self {
124            QuantizationPrecision::Int4 => 4,
125            QuantizationPrecision::Int8 => 8,
126            QuantizationPrecision::Int16 => 16,
127            QuantizationPrecision::Float16 => 16,
128            QuantizationPrecision::Mixed => 8, // Average estimate
129        }
130    }
131
132    /// Get memory reduction ratio compared to FP32
133    pub fn memory_reduction_ratio(&self) -> f32 {
134        32.0 / self.bits_per_param() as f32
135    }
136}
137
138/// Quantization methods
139#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
140pub enum QuantizationMethod {
141    /// Post-training quantization (no retraining required)
142    PostTrainingQuantization,
143    /// Quantization-aware training (requires fine-tuning)
144    QuantizationAwareTraining,
145    /// Dynamic quantization (runtime quantization)
146    DynamicQuantization,
147    /// Knowledge distillation with quantization
148    KnowledgeDistillation,
149}
150
151/// Layer-specific quantization configuration
152#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
153pub struct LayerQuantizationConfig {
154    /// Precision for this layer
155    pub precision: QuantizationPrecision,
156    /// Whether to quantize weights
157    pub quantize_weights: bool,
158    /// Whether to quantize activations
159    pub quantize_activations: bool,
160    /// Use symmetric quantization
161    pub symmetric: bool,
162}
163
164impl Default for LayerQuantizationConfig {
165    fn default() -> Self {
166        Self {
167            precision: QuantizationPrecision::Int8,
168            quantize_weights: true,
169            quantize_activations: true,
170            symmetric: false,
171        }
172    }
173}
174
175/// Quantization statistics for calibration
176#[derive(Debug, Clone)]
177pub struct QuantizationStats {
178    /// Minimum value observed
179    pub min_val: f32,
180    /// Maximum value observed
181    pub max_val: f32,
182    /// Mean value
183    pub mean: f32,
184    /// Standard deviation
185    pub std: f32,
186    /// Number of samples
187    pub num_samples: usize,
188}
189
190impl QuantizationStats {
191    /// Create new empty stats
192    pub fn new() -> Self {
193        Self {
194            min_val: f32::INFINITY,
195            max_val: f32::NEG_INFINITY,
196            mean: 0.0,
197            std: 0.0,
198            num_samples: 0,
199        }
200    }
201
202    /// Update stats with new tensor values
203    pub fn update(&mut self, tensor: &Tensor) -> CandleResult<()> {
204        let flat_tensor = tensor.flatten_all()?;
205        let values: Vec<f32> = flat_tensor.to_vec1()?;
206
207        for &val in &values {
208            self.min_val = self.min_val.min(val);
209            self.max_val = self.max_val.max(val);
210        }
211
212        // Update running statistics
213        let old_count = self.num_samples;
214        self.num_samples += values.len();
215
216        // Update mean using Welford's algorithm
217        let old_mean = self.mean;
218        let sum: f32 = values.iter().sum();
219        self.mean = (old_mean * old_count as f32 + sum) / self.num_samples as f32;
220
221        // Update standard deviation
222        let sum_sq_diff: f32 = values.iter().map(|&x| (x - self.mean).powi(2)).sum();
223        let old_sum_sq = self.std.powi(2) * old_count as f32;
224        self.std = ((old_sum_sq + sum_sq_diff) / self.num_samples as f32).sqrt();
225
226        Ok(())
227    }
228
229    /// Get quantization scale and zero point for given precision
230    pub fn get_quantization_params(
231        &self,
232        precision: QuantizationPrecision,
233        symmetric: bool,
234    ) -> (f32, i32) {
235        let (min_quant, max_quant) = match precision {
236            QuantizationPrecision::Int4 => {
237                if symmetric {
238                    (-8, 7)
239                } else {
240                    (0, 15)
241                }
242            }
243            QuantizationPrecision::Int8 => {
244                if symmetric {
245                    (-128, 127)
246                } else {
247                    (0, 255)
248                }
249            }
250            QuantizationPrecision::Int16 => {
251                if symmetric {
252                    (-32768, 32767)
253                } else {
254                    (0, 65535)
255                }
256            }
257            _ => (0, 255), // Fallback to Int8
258        };
259
260        if symmetric {
261            let abs_max = self.max_val.abs().max(self.min_val.abs());
262            let scale = abs_max / max_quant as f32;
263            (scale, 0) // Zero point is 0 for symmetric quantization
264        } else {
265            let scale = (self.max_val - self.min_val) / (max_quant - min_quant) as f32;
266            let zero_point = (min_quant as f32 - self.min_val / scale).round() as i32;
267            (scale, zero_point.clamp(min_quant, max_quant))
268        }
269    }
270}
271
272impl Default for QuantizationStats {
273    fn default() -> Self {
274        Self::new()
275    }
276}
277
278/// Model quantizer
279#[derive(Debug)]
280pub struct ModelQuantizer {
281    /// Quantization configuration
282    config: QuantizationConfig,
283    /// Statistics collector for calibration
284    stats_collector: HashMap<String, QuantizationStats>,
285    /// Device for computations
286    device: Device,
287    /// Whether calibration is active
288    calibration_active: bool,
289}
290
291impl ModelQuantizer {
292    /// Create a new model quantizer
293    pub fn new(config: QuantizationConfig, device: Device) -> crate::Result<Self> {
294        config.validate()?;
295
296        Ok(Self {
297            config,
298            stats_collector: HashMap::new(),
299            device,
300            calibration_active: false,
301        })
302    }
303
304    /// Get quantization config
305    pub fn config(&self) -> &QuantizationConfig {
306        &self.config
307    }
308
309    /// Start calibration phase
310    pub fn start_calibration(&mut self) {
311        self.calibration_active = true;
312        self.stats_collector.clear();
313    }
314
315    /// Finish calibration phase
316    pub fn finish_calibration(&mut self) {
317        self.calibration_active = false;
318    }
319
320    /// Calibrate with a tensor (collect statistics)
321    pub fn calibrate(&mut self, layer_name: &str, tensor: &Tensor) -> CandleResult<()> {
322        if !self.calibration_active {
323            return Ok(());
324        }
325
326        let stats = self
327            .stats_collector
328            .entry(layer_name.to_string())
329            .or_default();
330
331        stats.update(tensor)?;
332        Ok(())
333    }
334
335    /// Quantize a tensor to specified precision
336    pub fn quantize_tensor(
337        &self,
338        tensor: &Tensor,
339        layer_name: &str,
340        precision: QuantizationPrecision,
341    ) -> CandleResult<QuantizedTensor> {
342        let layer_config = self
343            .config
344            .layer_configs
345            .get(layer_name)
346            .cloned()
347            .unwrap_or_default();
348
349        let stats = self.stats_collector.get(layer_name);
350
351        match precision {
352            QuantizationPrecision::Int8 => {
353                self.quantize_int8(tensor, stats, layer_config.symmetric)
354            }
355            QuantizationPrecision::Int4 => {
356                self.quantize_int4(tensor, stats, layer_config.symmetric)
357            }
358            QuantizationPrecision::Float16 => self.quantize_float16(tensor),
359            QuantizationPrecision::Int16 => {
360                self.quantize_int16(tensor, stats, layer_config.symmetric)
361            }
362            QuantizationPrecision::Mixed => {
363                // Default to Int8 for mixed precision
364                self.quantize_int8(tensor, stats, layer_config.symmetric)
365            }
366        }
367    }
368
369    /// Quantize to INT8
370    fn quantize_int8(
371        &self,
372        tensor: &Tensor,
373        stats: Option<&QuantizationStats>,
374        symmetric: bool,
375    ) -> CandleResult<QuantizedTensor> {
376        let (scale, zero_point) = if let Some(stats) = stats {
377            stats.get_quantization_params(QuantizationPrecision::Int8, symmetric)
378        } else {
379            // Fallback to dynamic quantization
380            self.compute_dynamic_quantization_params(
381                tensor,
382                QuantizationPrecision::Int8,
383                symmetric,
384            )?
385        };
386
387        let scale_tensor = Tensor::new(&[scale], tensor.device())?.broadcast_as(tensor.shape())?;
388        let quantized = if symmetric {
389            ((tensor / scale_tensor)?.round()?.clamp(-128.0, 127.0)?)
390                .to_dtype(candle_core::DType::I64)?
391        } else {
392            let zero_tensor =
393                Tensor::new(&[zero_point as f64], tensor.device())?.broadcast_as(tensor.shape())?;
394            (((tensor / scale_tensor)? + zero_tensor)?
395                .round()?
396                .clamp(0.0, 255.0)?)
397            .to_dtype(candle_core::DType::I64)?
398        };
399
400        Ok(QuantizedTensor {
401            data: quantized,
402            scale,
403            zero_point,
404            precision: QuantizationPrecision::Int8,
405            symmetric,
406            original_shape: tensor.shape().clone(),
407        })
408    }
409
410    /// Quantize to INT4
411    fn quantize_int4(
412        &self,
413        tensor: &Tensor,
414        stats: Option<&QuantizationStats>,
415        symmetric: bool,
416    ) -> CandleResult<QuantizedTensor> {
417        let (scale, zero_point) = if let Some(stats) = stats {
418            stats.get_quantization_params(QuantizationPrecision::Int4, symmetric)
419        } else {
420            self.compute_dynamic_quantization_params(
421                tensor,
422                QuantizationPrecision::Int4,
423                symmetric,
424            )?
425        };
426
427        let scale_tensor = Tensor::new(&[scale], tensor.device())?.broadcast_as(tensor.shape())?;
428        let quantized = if symmetric {
429            ((tensor / scale_tensor)?.round()?.clamp(-8.0, 7.0)?)
430                .to_dtype(candle_core::DType::I64)?
431        } else {
432            let zero_tensor =
433                Tensor::new(&[zero_point as f64], tensor.device())?.broadcast_as(tensor.shape())?;
434            (((tensor / scale_tensor)? + zero_tensor)?
435                .round()?
436                .clamp(0.0, 15.0)?)
437            .to_dtype(candle_core::DType::I64)?
438        };
439
440        Ok(QuantizedTensor {
441            data: quantized,
442            scale,
443            zero_point,
444            precision: QuantizationPrecision::Int4,
445            symmetric,
446            original_shape: tensor.shape().clone(),
447        })
448    }
449
450    /// Quantize to INT16
451    fn quantize_int16(
452        &self,
453        tensor: &Tensor,
454        stats: Option<&QuantizationStats>,
455        symmetric: bool,
456    ) -> CandleResult<QuantizedTensor> {
457        let (scale, zero_point) = if let Some(stats) = stats {
458            stats.get_quantization_params(QuantizationPrecision::Int16, symmetric)
459        } else {
460            self.compute_dynamic_quantization_params(
461                tensor,
462                QuantizationPrecision::Int16,
463                symmetric,
464            )?
465        };
466
467        let scale_tensor = Tensor::new(&[scale], tensor.device())?.broadcast_as(tensor.shape())?;
468        let quantized = if symmetric {
469            ((tensor / scale_tensor)?.round()?.clamp(-32768.0, 32767.0)?)
470                .to_dtype(candle_core::DType::I64)?
471        } else {
472            let zero_tensor =
473                Tensor::new(&[zero_point as f64], tensor.device())?.broadcast_as(tensor.shape())?;
474            (((tensor / scale_tensor)? + zero_tensor)?
475                .round()?
476                .clamp(0.0, 65535.0)?)
477            .to_dtype(candle_core::DType::I64)?
478        };
479
480        Ok(QuantizedTensor {
481            data: quantized,
482            scale,
483            zero_point,
484            precision: QuantizationPrecision::Int16,
485            symmetric,
486            original_shape: tensor.shape().clone(),
487        })
488    }
489
490    /// Quantize to Float16
491    fn quantize_float16(&self, tensor: &Tensor) -> CandleResult<QuantizedTensor> {
492        let quantized = tensor.to_dtype(candle_core::DType::F16)?;
493
494        Ok(QuantizedTensor {
495            data: quantized,
496            scale: 1.0,
497            zero_point: 0,
498            precision: QuantizationPrecision::Float16,
499            symmetric: true,
500            original_shape: tensor.shape().clone(),
501        })
502    }
503
504    /// Compute dynamic quantization parameters
505    fn compute_dynamic_quantization_params(
506        &self,
507        tensor: &Tensor,
508        precision: QuantizationPrecision,
509        symmetric: bool,
510    ) -> CandleResult<(f32, i32)> {
511        let min_val = tensor.min(0)?.to_vec0::<f32>()?;
512        let max_val = tensor.max(0)?.to_vec0::<f32>()?;
513
514        let mut temp_stats = QuantizationStats::new();
515        temp_stats.min_val = min_val;
516        temp_stats.max_val = max_val;
517
518        Ok(temp_stats.get_quantization_params(precision, symmetric))
519    }
520
521    /// Get quantization statistics summary
522    pub fn get_stats_summary(&self) -> HashMap<String, QuantizationStatsSummary> {
523        self.stats_collector
524            .iter()
525            .map(|(layer, stats)| {
526                let summary = QuantizationStatsSummary {
527                    layer_name: layer.clone(),
528                    min_val: stats.min_val,
529                    max_val: stats.max_val,
530                    mean: stats.mean,
531                    std: stats.std,
532                    dynamic_range: stats.max_val - stats.min_val,
533                    num_samples: stats.num_samples,
534                };
535                (layer.clone(), summary)
536            })
537            .collect()
538    }
539
540    /// Estimate memory savings from quantization
541    pub fn estimate_memory_savings(
542        &self,
543        original_model_size_mb: f32,
544    ) -> QuantizationMemoryAnalysis {
545        let reduction_ratio = self.config.precision.memory_reduction_ratio();
546        let quantized_size_mb = original_model_size_mb / reduction_ratio;
547        let savings_mb = original_model_size_mb - quantized_size_mb;
548        let savings_percent = (savings_mb / original_model_size_mb) * 100.0;
549
550        QuantizationMemoryAnalysis {
551            original_size_mb: original_model_size_mb,
552            quantized_size_mb,
553            savings_mb,
554            savings_percent,
555            compression_ratio: reduction_ratio,
556            precision: self.config.precision,
557        }
558    }
559}
560
561/// Quantized tensor representation
562#[derive(Debug, Clone)]
563pub struct QuantizedTensor {
564    /// Quantized data
565    pub data: Tensor,
566    /// Quantization scale
567    pub scale: f32,
568    /// Zero point for asymmetric quantization
569    pub zero_point: i32,
570    /// Quantization precision
571    pub precision: QuantizationPrecision,
572    /// Whether quantization is symmetric
573    pub symmetric: bool,
574    /// Original tensor shape
575    pub original_shape: candle_core::Shape,
576}
577
578impl QuantizedTensor {
579    /// Dequantize back to float32
580    pub fn dequantize(&self) -> CandleResult<Tensor> {
581        match self.precision {
582            QuantizationPrecision::Float16 => {
583                // For FP16, just convert back to F32
584                self.data.to_dtype(candle_core::DType::F32)
585            }
586            _ => {
587                // For integer quantization, apply scale and zero point
588                let float_data = self.data.to_dtype(candle_core::DType::F32)?;
589                let scale_tensor = Tensor::new(&[self.scale], self.data.device())?
590                    .broadcast_as(float_data.shape())?;
591                if self.symmetric {
592                    Ok((&float_data * scale_tensor)?)
593                } else {
594                    let zero_tensor = Tensor::new(&[self.zero_point as f64], self.data.device())?
595                        .broadcast_as(float_data.shape())?;
596                    Ok(((&float_data - zero_tensor)? * scale_tensor)?)
597                }
598            }
599        }
600    }
601
602    /// Get memory usage of quantized tensor in bytes
603    pub fn memory_usage_bytes(&self) -> usize {
604        let num_elements = self.data.elem_count();
605        let bytes_per_element = match self.precision {
606            QuantizationPrecision::Int4 => 1, // Packed 2 elements per byte
607            QuantizationPrecision::Int8 => 1,
608            QuantizationPrecision::Int16 | QuantizationPrecision::Float16 => 2,
609            QuantizationPrecision::Mixed => 1, // Average estimate
610        };
611
612        if self.precision == QuantizationPrecision::Int4 {
613            num_elements.div_ceil(2) // Ceiling division for packed 4-bit
614        } else {
615            num_elements * bytes_per_element
616        }
617    }
618}
619
620/// Summary statistics for quantization
621#[derive(Debug, Clone, Serialize, Deserialize)]
622pub struct QuantizationStatsSummary {
623    /// Layer name
624    pub layer_name: String,
625    /// Minimum value
626    pub min_val: f32,
627    /// Maximum value
628    pub max_val: f32,
629    /// Mean value
630    pub mean: f32,
631    /// Standard deviation
632    pub std: f32,
633    /// Dynamic range
634    pub dynamic_range: f32,
635    /// Number of samples used for calibration
636    pub num_samples: usize,
637}
638
639/// Memory analysis for quantization
640#[derive(Debug, Clone, Serialize, Deserialize)]
641pub struct QuantizationMemoryAnalysis {
642    /// Original model size in MB
643    pub original_size_mb: f32,
644    /// Quantized model size in MB
645    pub quantized_size_mb: f32,
646    /// Memory savings in MB
647    pub savings_mb: f32,
648    /// Percentage savings
649    pub savings_percent: f32,
650    /// Compression ratio
651    pub compression_ratio: f32,
652    /// Quantization precision used
653    pub precision: QuantizationPrecision,
654}
655
656/// Quantization result with performance metrics
657#[derive(Debug, Clone)]
658pub struct QuantizationResult {
659    /// Quantized tensors by layer name
660    pub quantized_tensors: HashMap<String, QuantizedTensor>,
661    /// Memory analysis
662    pub memory_analysis: QuantizationMemoryAnalysis,
663    /// Statistics summary
664    pub stats_summary: HashMap<String, QuantizationStatsSummary>,
665    /// Quantization configuration used
666    pub config: QuantizationConfig,
667    /// Processing time in milliseconds
668    pub processing_time_ms: u64,
669}
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674    use candle_core::{DType, Device, Shape, Tensor};
675
676    #[test]
677    fn test_quantization_config_default() {
678        let config = QuantizationConfig::default();
679        assert_eq!(config.precision, QuantizationPrecision::Int8);
680        assert_eq!(config.calibration_samples, 100);
681        assert!(config.validate().is_ok());
682    }
683
684    #[test]
685    fn test_quantization_config_mobile() {
686        let config = QuantizationConfig::mobile_optimized();
687        assert_eq!(config.precision, QuantizationPrecision::Int8);
688        assert_eq!(config.calibration_samples, 50);
689        assert!(config.dynamic_quantization);
690    }
691
692    #[test]
693    fn test_quantization_config_edge() {
694        let config = QuantizationConfig::edge_optimized();
695        assert_eq!(config.precision, QuantizationPrecision::Int8);
696        assert_eq!(config.calibration_samples, 25);
697        assert!(config.layer_configs.contains_key("embedding"));
698    }
699
700    #[test]
701    fn test_quantization_precision_bits() {
702        assert_eq!(QuantizationPrecision::Int4.bits_per_param(), 4);
703        assert_eq!(QuantizationPrecision::Int8.bits_per_param(), 8);
704        assert_eq!(QuantizationPrecision::Int16.bits_per_param(), 16);
705        assert_eq!(QuantizationPrecision::Float16.bits_per_param(), 16);
706    }
707
708    #[test]
709    fn test_quantization_precision_memory_reduction() {
710        // FP32 to INT8 should give 4x reduction
711        assert_eq!(QuantizationPrecision::Int8.memory_reduction_ratio(), 4.0);
712        // FP32 to INT4 should give 8x reduction
713        assert_eq!(QuantizationPrecision::Int4.memory_reduction_ratio(), 8.0);
714        // FP32 to FP16 should give 2x reduction
715        assert_eq!(QuantizationPrecision::Float16.memory_reduction_ratio(), 2.0);
716    }
717
718    #[test]
719    fn test_quantization_stats() {
720        let device = Device::Cpu;
721        let data = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], (5,), &device).unwrap();
722
723        let mut stats = QuantizationStats::new();
724        stats.update(&data).unwrap();
725
726        assert_eq!(stats.min_val, 1.0);
727        assert_eq!(stats.max_val, 5.0);
728        assert_eq!(stats.num_samples, 5);
729
730        let (scale, zero_point) = stats.get_quantization_params(QuantizationPrecision::Int8, false);
731        assert!(scale > 0.0);
732        assert!(zero_point >= 0 && zero_point <= 255);
733    }
734
735    #[test]
736    fn test_model_quantizer_creation() {
737        let config = QuantizationConfig::default();
738        let device = Device::Cpu;
739
740        let quantizer = ModelQuantizer::new(config, device);
741        assert!(quantizer.is_ok());
742    }
743
744    #[test]
745    fn test_quantized_tensor_memory_usage() {
746        let device = Device::Cpu;
747        let data = Tensor::zeros((100,), DType::I64, &device).unwrap();
748
749        let quantized = QuantizedTensor {
750            data,
751            scale: 1.0,
752            zero_point: 0,
753            precision: QuantizationPrecision::Int8,
754            symmetric: true,
755            original_shape: Shape::from_dims(&[100]),
756        };
757
758        // 100 elements * 1 byte per element = 100 bytes
759        assert_eq!(quantized.memory_usage_bytes(), 100);
760    }
761
762    #[test]
763    fn test_quantized_tensor_memory_usage_int4() {
764        let device = Device::Cpu;
765        let data = Tensor::zeros((100,), DType::I64, &device).unwrap();
766
767        let quantized = QuantizedTensor {
768            data,
769            scale: 1.0,
770            zero_point: 0,
771            precision: QuantizationPrecision::Int4,
772            symmetric: true,
773            original_shape: Shape::from_dims(&[100]),
774        };
775
776        // 100 elements, packed 2 per byte = 50 bytes
777        assert_eq!(quantized.memory_usage_bytes(), 50);
778    }
779
780    #[test]
781    fn test_layer_quantization_config_default() {
782        let config = LayerQuantizationConfig::default();
783        assert_eq!(config.precision, QuantizationPrecision::Int8);
784        assert!(config.quantize_weights);
785        assert!(config.quantize_activations);
786        assert!(!config.symmetric);
787    }
788}