1use candle_core::{Device, Result as CandleResult, Tensor};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub struct QuantizationConfig {
13 pub precision: QuantizationPrecision,
15 pub method: QuantizationMethod,
17 pub calibration_samples: usize,
19 pub dynamic_quantization: bool,
21 pub outlier_percentile: f32,
23 pub layer_configs: HashMap<String, LayerQuantizationConfig>,
25 pub quantization_aware_training: bool,
27}
28
29impl Default for QuantizationConfig {
30 fn default() -> Self {
31 Self {
32 precision: QuantizationPrecision::Int8,
33 method: QuantizationMethod::PostTrainingQuantization,
34 calibration_samples: 100,
35 dynamic_quantization: false,
36 outlier_percentile: 0.01,
37 layer_configs: HashMap::new(),
38 quantization_aware_training: false,
39 }
40 }
41}
42
43impl QuantizationConfig {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn mobile_optimized() -> Self {
51 Self {
52 precision: QuantizationPrecision::Int8,
53 method: QuantizationMethod::PostTrainingQuantization,
54 calibration_samples: 50,
55 dynamic_quantization: true,
56 outlier_percentile: 0.005,
57 layer_configs: HashMap::new(),
58 quantization_aware_training: false,
59 }
60 }
61
62 pub fn edge_optimized() -> Self {
64 let mut layer_configs = HashMap::new();
65 layer_configs.insert(
67 "embedding".to_string(),
68 LayerQuantizationConfig {
69 precision: QuantizationPrecision::Int4,
70 quantize_weights: true,
71 quantize_activations: true,
72 symmetric: true,
73 },
74 );
75
76 Self {
77 precision: QuantizationPrecision::Int8,
78 method: QuantizationMethod::PostTrainingQuantization,
79 calibration_samples: 25,
80 dynamic_quantization: true,
81 outlier_percentile: 0.001,
82 layer_configs,
83 quantization_aware_training: false,
84 }
85 }
86
87 pub fn validate(&self) -> crate::Result<()> {
89 if self.calibration_samples == 0 {
90 return Err(crate::Error::Config(
91 "Calibration samples must be greater than 0".to_string(),
92 ));
93 }
94
95 if !(0.0..0.1).contains(&self.outlier_percentile) {
96 return Err(crate::Error::Config(
97 "Outlier percentile must be between 0.0 and 0.1".to_string(),
98 ));
99 }
100
101 Ok(())
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
107pub enum QuantizationPrecision {
108 Int4,
110 Int8,
112 Int16,
114 Float16,
116 Mixed,
118}
119
120impl QuantizationPrecision {
121 pub fn bits_per_param(&self) -> u8 {
123 match self {
124 QuantizationPrecision::Int4 => 4,
125 QuantizationPrecision::Int8 => 8,
126 QuantizationPrecision::Int16 => 16,
127 QuantizationPrecision::Float16 => 16,
128 QuantizationPrecision::Mixed => 8, }
130 }
131
132 pub fn memory_reduction_ratio(&self) -> f32 {
134 32.0 / self.bits_per_param() as f32
135 }
136}
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
140pub enum QuantizationMethod {
141 PostTrainingQuantization,
143 QuantizationAwareTraining,
145 DynamicQuantization,
147 KnowledgeDistillation,
149}
150
151#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
153pub struct LayerQuantizationConfig {
154 pub precision: QuantizationPrecision,
156 pub quantize_weights: bool,
158 pub quantize_activations: bool,
160 pub symmetric: bool,
162}
163
164impl Default for LayerQuantizationConfig {
165 fn default() -> Self {
166 Self {
167 precision: QuantizationPrecision::Int8,
168 quantize_weights: true,
169 quantize_activations: true,
170 symmetric: false,
171 }
172 }
173}
174
175#[derive(Debug, Clone)]
177pub struct QuantizationStats {
178 pub min_val: f32,
180 pub max_val: f32,
182 pub mean: f32,
184 pub std: f32,
186 pub num_samples: usize,
188}
189
190impl QuantizationStats {
191 pub fn new() -> Self {
193 Self {
194 min_val: f32::INFINITY,
195 max_val: f32::NEG_INFINITY,
196 mean: 0.0,
197 std: 0.0,
198 num_samples: 0,
199 }
200 }
201
202 pub fn update(&mut self, tensor: &Tensor) -> CandleResult<()> {
204 let flat_tensor = tensor.flatten_all()?;
205 let values: Vec<f32> = flat_tensor.to_vec1()?;
206
207 for &val in &values {
208 self.min_val = self.min_val.min(val);
209 self.max_val = self.max_val.max(val);
210 }
211
212 let old_count = self.num_samples;
214 self.num_samples += values.len();
215
216 let old_mean = self.mean;
218 let sum: f32 = values.iter().sum();
219 self.mean = (old_mean * old_count as f32 + sum) / self.num_samples as f32;
220
221 let sum_sq_diff: f32 = values.iter().map(|&x| (x - self.mean).powi(2)).sum();
223 let old_sum_sq = self.std.powi(2) * old_count as f32;
224 self.std = ((old_sum_sq + sum_sq_diff) / self.num_samples as f32).sqrt();
225
226 Ok(())
227 }
228
229 pub fn get_quantization_params(
231 &self,
232 precision: QuantizationPrecision,
233 symmetric: bool,
234 ) -> (f32, i32) {
235 let (min_quant, max_quant) = match precision {
236 QuantizationPrecision::Int4 => {
237 if symmetric {
238 (-8, 7)
239 } else {
240 (0, 15)
241 }
242 }
243 QuantizationPrecision::Int8 => {
244 if symmetric {
245 (-128, 127)
246 } else {
247 (0, 255)
248 }
249 }
250 QuantizationPrecision::Int16 => {
251 if symmetric {
252 (-32768, 32767)
253 } else {
254 (0, 65535)
255 }
256 }
257 _ => (0, 255), };
259
260 if symmetric {
261 let abs_max = self.max_val.abs().max(self.min_val.abs());
262 let scale = abs_max / max_quant as f32;
263 (scale, 0) } else {
265 let scale = (self.max_val - self.min_val) / (max_quant - min_quant) as f32;
266 let zero_point = (min_quant as f32 - self.min_val / scale).round() as i32;
267 (scale, zero_point.clamp(min_quant, max_quant))
268 }
269 }
270}
271
272impl Default for QuantizationStats {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277
278#[derive(Debug)]
280pub struct ModelQuantizer {
281 config: QuantizationConfig,
283 stats_collector: HashMap<String, QuantizationStats>,
285 device: Device,
287 calibration_active: bool,
289}
290
291impl ModelQuantizer {
292 pub fn new(config: QuantizationConfig, device: Device) -> crate::Result<Self> {
294 config.validate()?;
295
296 Ok(Self {
297 config,
298 stats_collector: HashMap::new(),
299 device,
300 calibration_active: false,
301 })
302 }
303
304 pub fn config(&self) -> &QuantizationConfig {
306 &self.config
307 }
308
309 pub fn start_calibration(&mut self) {
311 self.calibration_active = true;
312 self.stats_collector.clear();
313 }
314
315 pub fn finish_calibration(&mut self) {
317 self.calibration_active = false;
318 }
319
320 pub fn calibrate(&mut self, layer_name: &str, tensor: &Tensor) -> CandleResult<()> {
322 if !self.calibration_active {
323 return Ok(());
324 }
325
326 let stats = self
327 .stats_collector
328 .entry(layer_name.to_string())
329 .or_default();
330
331 stats.update(tensor)?;
332 Ok(())
333 }
334
335 pub fn quantize_tensor(
337 &self,
338 tensor: &Tensor,
339 layer_name: &str,
340 precision: QuantizationPrecision,
341 ) -> CandleResult<QuantizedTensor> {
342 let layer_config = self
343 .config
344 .layer_configs
345 .get(layer_name)
346 .cloned()
347 .unwrap_or_default();
348
349 let stats = self.stats_collector.get(layer_name);
350
351 match precision {
352 QuantizationPrecision::Int8 => {
353 self.quantize_int8(tensor, stats, layer_config.symmetric)
354 }
355 QuantizationPrecision::Int4 => {
356 self.quantize_int4(tensor, stats, layer_config.symmetric)
357 }
358 QuantizationPrecision::Float16 => self.quantize_float16(tensor),
359 QuantizationPrecision::Int16 => {
360 self.quantize_int16(tensor, stats, layer_config.symmetric)
361 }
362 QuantizationPrecision::Mixed => {
363 self.quantize_int8(tensor, stats, layer_config.symmetric)
365 }
366 }
367 }
368
369 fn quantize_int8(
371 &self,
372 tensor: &Tensor,
373 stats: Option<&QuantizationStats>,
374 symmetric: bool,
375 ) -> CandleResult<QuantizedTensor> {
376 let (scale, zero_point) = if let Some(stats) = stats {
377 stats.get_quantization_params(QuantizationPrecision::Int8, symmetric)
378 } else {
379 self.compute_dynamic_quantization_params(
381 tensor,
382 QuantizationPrecision::Int8,
383 symmetric,
384 )?
385 };
386
387 let scale_tensor = Tensor::new(&[scale], tensor.device())?.broadcast_as(tensor.shape())?;
388 let quantized = if symmetric {
389 ((tensor / scale_tensor)?.round()?.clamp(-128.0, 127.0)?)
390 .to_dtype(candle_core::DType::I64)?
391 } else {
392 let zero_tensor =
393 Tensor::new(&[zero_point as f64], tensor.device())?.broadcast_as(tensor.shape())?;
394 (((tensor / scale_tensor)? + zero_tensor)?
395 .round()?
396 .clamp(0.0, 255.0)?)
397 .to_dtype(candle_core::DType::I64)?
398 };
399
400 Ok(QuantizedTensor {
401 data: quantized,
402 scale,
403 zero_point,
404 precision: QuantizationPrecision::Int8,
405 symmetric,
406 original_shape: tensor.shape().clone(),
407 })
408 }
409
410 fn quantize_int4(
412 &self,
413 tensor: &Tensor,
414 stats: Option<&QuantizationStats>,
415 symmetric: bool,
416 ) -> CandleResult<QuantizedTensor> {
417 let (scale, zero_point) = if let Some(stats) = stats {
418 stats.get_quantization_params(QuantizationPrecision::Int4, symmetric)
419 } else {
420 self.compute_dynamic_quantization_params(
421 tensor,
422 QuantizationPrecision::Int4,
423 symmetric,
424 )?
425 };
426
427 let scale_tensor = Tensor::new(&[scale], tensor.device())?.broadcast_as(tensor.shape())?;
428 let quantized = if symmetric {
429 ((tensor / scale_tensor)?.round()?.clamp(-8.0, 7.0)?)
430 .to_dtype(candle_core::DType::I64)?
431 } else {
432 let zero_tensor =
433 Tensor::new(&[zero_point as f64], tensor.device())?.broadcast_as(tensor.shape())?;
434 (((tensor / scale_tensor)? + zero_tensor)?
435 .round()?
436 .clamp(0.0, 15.0)?)
437 .to_dtype(candle_core::DType::I64)?
438 };
439
440 Ok(QuantizedTensor {
441 data: quantized,
442 scale,
443 zero_point,
444 precision: QuantizationPrecision::Int4,
445 symmetric,
446 original_shape: tensor.shape().clone(),
447 })
448 }
449
450 fn quantize_int16(
452 &self,
453 tensor: &Tensor,
454 stats: Option<&QuantizationStats>,
455 symmetric: bool,
456 ) -> CandleResult<QuantizedTensor> {
457 let (scale, zero_point) = if let Some(stats) = stats {
458 stats.get_quantization_params(QuantizationPrecision::Int16, symmetric)
459 } else {
460 self.compute_dynamic_quantization_params(
461 tensor,
462 QuantizationPrecision::Int16,
463 symmetric,
464 )?
465 };
466
467 let scale_tensor = Tensor::new(&[scale], tensor.device())?.broadcast_as(tensor.shape())?;
468 let quantized = if symmetric {
469 ((tensor / scale_tensor)?.round()?.clamp(-32768.0, 32767.0)?)
470 .to_dtype(candle_core::DType::I64)?
471 } else {
472 let zero_tensor =
473 Tensor::new(&[zero_point as f64], tensor.device())?.broadcast_as(tensor.shape())?;
474 (((tensor / scale_tensor)? + zero_tensor)?
475 .round()?
476 .clamp(0.0, 65535.0)?)
477 .to_dtype(candle_core::DType::I64)?
478 };
479
480 Ok(QuantizedTensor {
481 data: quantized,
482 scale,
483 zero_point,
484 precision: QuantizationPrecision::Int16,
485 symmetric,
486 original_shape: tensor.shape().clone(),
487 })
488 }
489
490 fn quantize_float16(&self, tensor: &Tensor) -> CandleResult<QuantizedTensor> {
492 let quantized = tensor.to_dtype(candle_core::DType::F16)?;
493
494 Ok(QuantizedTensor {
495 data: quantized,
496 scale: 1.0,
497 zero_point: 0,
498 precision: QuantizationPrecision::Float16,
499 symmetric: true,
500 original_shape: tensor.shape().clone(),
501 })
502 }
503
504 fn compute_dynamic_quantization_params(
506 &self,
507 tensor: &Tensor,
508 precision: QuantizationPrecision,
509 symmetric: bool,
510 ) -> CandleResult<(f32, i32)> {
511 let min_val = tensor.min(0)?.to_vec0::<f32>()?;
512 let max_val = tensor.max(0)?.to_vec0::<f32>()?;
513
514 let mut temp_stats = QuantizationStats::new();
515 temp_stats.min_val = min_val;
516 temp_stats.max_val = max_val;
517
518 Ok(temp_stats.get_quantization_params(precision, symmetric))
519 }
520
521 pub fn get_stats_summary(&self) -> HashMap<String, QuantizationStatsSummary> {
523 self.stats_collector
524 .iter()
525 .map(|(layer, stats)| {
526 let summary = QuantizationStatsSummary {
527 layer_name: layer.clone(),
528 min_val: stats.min_val,
529 max_val: stats.max_val,
530 mean: stats.mean,
531 std: stats.std,
532 dynamic_range: stats.max_val - stats.min_val,
533 num_samples: stats.num_samples,
534 };
535 (layer.clone(), summary)
536 })
537 .collect()
538 }
539
540 pub fn estimate_memory_savings(
542 &self,
543 original_model_size_mb: f32,
544 ) -> QuantizationMemoryAnalysis {
545 let reduction_ratio = self.config.precision.memory_reduction_ratio();
546 let quantized_size_mb = original_model_size_mb / reduction_ratio;
547 let savings_mb = original_model_size_mb - quantized_size_mb;
548 let savings_percent = (savings_mb / original_model_size_mb) * 100.0;
549
550 QuantizationMemoryAnalysis {
551 original_size_mb: original_model_size_mb,
552 quantized_size_mb,
553 savings_mb,
554 savings_percent,
555 compression_ratio: reduction_ratio,
556 precision: self.config.precision,
557 }
558 }
559}
560
561#[derive(Debug, Clone)]
563pub struct QuantizedTensor {
564 pub data: Tensor,
566 pub scale: f32,
568 pub zero_point: i32,
570 pub precision: QuantizationPrecision,
572 pub symmetric: bool,
574 pub original_shape: candle_core::Shape,
576}
577
578impl QuantizedTensor {
579 pub fn dequantize(&self) -> CandleResult<Tensor> {
581 match self.precision {
582 QuantizationPrecision::Float16 => {
583 self.data.to_dtype(candle_core::DType::F32)
585 }
586 _ => {
587 let float_data = self.data.to_dtype(candle_core::DType::F32)?;
589 let scale_tensor = Tensor::new(&[self.scale], self.data.device())?
590 .broadcast_as(float_data.shape())?;
591 if self.symmetric {
592 Ok((&float_data * scale_tensor)?)
593 } else {
594 let zero_tensor = Tensor::new(&[self.zero_point as f64], self.data.device())?
595 .broadcast_as(float_data.shape())?;
596 Ok(((&float_data - zero_tensor)? * scale_tensor)?)
597 }
598 }
599 }
600 }
601
602 pub fn memory_usage_bytes(&self) -> usize {
604 let num_elements = self.data.elem_count();
605 let bytes_per_element = match self.precision {
606 QuantizationPrecision::Int4 => 1, QuantizationPrecision::Int8 => 1,
608 QuantizationPrecision::Int16 | QuantizationPrecision::Float16 => 2,
609 QuantizationPrecision::Mixed => 1, };
611
612 if self.precision == QuantizationPrecision::Int4 {
613 num_elements.div_ceil(2) } else {
615 num_elements * bytes_per_element
616 }
617 }
618}
619
620#[derive(Debug, Clone, Serialize, Deserialize)]
622pub struct QuantizationStatsSummary {
623 pub layer_name: String,
625 pub min_val: f32,
627 pub max_val: f32,
629 pub mean: f32,
631 pub std: f32,
633 pub dynamic_range: f32,
635 pub num_samples: usize,
637}
638
639#[derive(Debug, Clone, Serialize, Deserialize)]
641pub struct QuantizationMemoryAnalysis {
642 pub original_size_mb: f32,
644 pub quantized_size_mb: f32,
646 pub savings_mb: f32,
648 pub savings_percent: f32,
650 pub compression_ratio: f32,
652 pub precision: QuantizationPrecision,
654}
655
656#[derive(Debug, Clone)]
658pub struct QuantizationResult {
659 pub quantized_tensors: HashMap<String, QuantizedTensor>,
661 pub memory_analysis: QuantizationMemoryAnalysis,
663 pub stats_summary: HashMap<String, QuantizationStatsSummary>,
665 pub config: QuantizationConfig,
667 pub processing_time_ms: u64,
669}
670
671#[cfg(test)]
672mod tests {
673 use super::*;
674 use candle_core::{DType, Device, Shape, Tensor};
675
676 #[test]
677 fn test_quantization_config_default() {
678 let config = QuantizationConfig::default();
679 assert_eq!(config.precision, QuantizationPrecision::Int8);
680 assert_eq!(config.calibration_samples, 100);
681 assert!(config.validate().is_ok());
682 }
683
684 #[test]
685 fn test_quantization_config_mobile() {
686 let config = QuantizationConfig::mobile_optimized();
687 assert_eq!(config.precision, QuantizationPrecision::Int8);
688 assert_eq!(config.calibration_samples, 50);
689 assert!(config.dynamic_quantization);
690 }
691
692 #[test]
693 fn test_quantization_config_edge() {
694 let config = QuantizationConfig::edge_optimized();
695 assert_eq!(config.precision, QuantizationPrecision::Int8);
696 assert_eq!(config.calibration_samples, 25);
697 assert!(config.layer_configs.contains_key("embedding"));
698 }
699
700 #[test]
701 fn test_quantization_precision_bits() {
702 assert_eq!(QuantizationPrecision::Int4.bits_per_param(), 4);
703 assert_eq!(QuantizationPrecision::Int8.bits_per_param(), 8);
704 assert_eq!(QuantizationPrecision::Int16.bits_per_param(), 16);
705 assert_eq!(QuantizationPrecision::Float16.bits_per_param(), 16);
706 }
707
708 #[test]
709 fn test_quantization_precision_memory_reduction() {
710 assert_eq!(QuantizationPrecision::Int8.memory_reduction_ratio(), 4.0);
712 assert_eq!(QuantizationPrecision::Int4.memory_reduction_ratio(), 8.0);
714 assert_eq!(QuantizationPrecision::Float16.memory_reduction_ratio(), 2.0);
716 }
717
718 #[test]
719 fn test_quantization_stats() {
720 let device = Device::Cpu;
721 let data = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], (5,), &device).unwrap();
722
723 let mut stats = QuantizationStats::new();
724 stats.update(&data).unwrap();
725
726 assert_eq!(stats.min_val, 1.0);
727 assert_eq!(stats.max_val, 5.0);
728 assert_eq!(stats.num_samples, 5);
729
730 let (scale, zero_point) = stats.get_quantization_params(QuantizationPrecision::Int8, false);
731 assert!(scale > 0.0);
732 assert!(zero_point >= 0 && zero_point <= 255);
733 }
734
735 #[test]
736 fn test_model_quantizer_creation() {
737 let config = QuantizationConfig::default();
738 let device = Device::Cpu;
739
740 let quantizer = ModelQuantizer::new(config, device);
741 assert!(quantizer.is_ok());
742 }
743
744 #[test]
745 fn test_quantized_tensor_memory_usage() {
746 let device = Device::Cpu;
747 let data = Tensor::zeros((100,), DType::I64, &device).unwrap();
748
749 let quantized = QuantizedTensor {
750 data,
751 scale: 1.0,
752 zero_point: 0,
753 precision: QuantizationPrecision::Int8,
754 symmetric: true,
755 original_shape: Shape::from_dims(&[100]),
756 };
757
758 assert_eq!(quantized.memory_usage_bytes(), 100);
760 }
761
762 #[test]
763 fn test_quantized_tensor_memory_usage_int4() {
764 let device = Device::Cpu;
765 let data = Tensor::zeros((100,), DType::I64, &device).unwrap();
766
767 let quantized = QuantizedTensor {
768 data,
769 scale: 1.0,
770 zero_point: 0,
771 precision: QuantizationPrecision::Int4,
772 symmetric: true,
773 original_shape: Shape::from_dims(&[100]),
774 };
775
776 assert_eq!(quantized.memory_usage_bytes(), 50);
778 }
779
780 #[test]
781 fn test_layer_quantization_config_default() {
782 let config = LayerQuantizationConfig::default();
783 assert_eq!(config.precision, QuantizationPrecision::Int8);
784 assert!(config.quantize_weights);
785 assert!(config.quantize_activations);
786 assert!(!config.symmetric);
787 }
788}