1use half::f16;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::BufReader;
14use std::path::{Path, PathBuf};
15use trustformers_core::errors::{invalid_config, runtime_error, tensor_op_error, Result};
16use trustformers_core::Tensor;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[allow(non_camel_case_types)]
21pub enum QuantizationScheme {
22 Int4,
23 Int8,
24 FP16,
25 Dynamic,
26 GGUF_Q2_K,
28 GGUF_Q3_K,
30 GGUF_Q4_K,
32 GGUF_Q5_0,
34 GGUF_Q6_K,
36}
37
38impl std::fmt::Display for QuantizationScheme {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match self {
41 QuantizationScheme::Int4 => write!(f, "INT4"),
42 QuantizationScheme::Int8 => write!(f, "INT8"),
43 QuantizationScheme::FP16 => write!(f, "FP16"),
44 QuantizationScheme::Dynamic => write!(f, "Dynamic"),
45 QuantizationScheme::GGUF_Q2_K => write!(f, "GGUF_Q2_K"),
46 QuantizationScheme::GGUF_Q3_K => write!(f, "GGUF_Q3_K"),
47 QuantizationScheme::GGUF_Q4_K => write!(f, "GGUF_Q4_K"),
48 QuantizationScheme::GGUF_Q5_0 => write!(f, "GGUF_Q5_0"),
49 QuantizationScheme::GGUF_Q6_K => write!(f, "GGUF_Q6_K"),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum CalibrationMethod {
57 MinMax,
58 Percentile,
59 MovingAverage,
60 KLDivergence,
61}
62
63#[derive(Debug, Clone)]
65pub struct QuantizationContext {
66 pub method: CalibrationMethod,
67 pub num_calibration_samples: usize,
68 pub percentile: f32, pub smooth_factor: f32, }
71
72impl Default for QuantizationContext {
73 fn default() -> Self {
74 Self {
75 method: CalibrationMethod::MinMax,
76 num_calibration_samples: 100,
77 percentile: 99.9,
78 smooth_factor: 0.999,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Default)]
85pub struct QuantizationCalibration {
86 pub min_values: HashMap<String, f32>,
87 pub max_values: HashMap<String, f32>,
88 pub scales: HashMap<String, f32>,
89 pub zero_points: HashMap<String, i32>,
90 pub histogram_bins: HashMap<String, Vec<f32>>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct QuantizationSchemeConfig {
96 pub default_scheme: QuantizationScheme,
98 pub layer_schemes: HashMap<String, QuantizationScheme>,
100 pub tensor_schemes: HashMap<String, QuantizationScheme>,
102 pub model_schemes: HashMap<String, QuantizationScheme>,
104 pub performance_schemes: HashMap<String, QuantizationScheme>,
106}
107
108impl Default for QuantizationSchemeConfig {
109 fn default() -> Self {
110 Self {
111 default_scheme: QuantizationScheme::Int8,
112 layer_schemes: HashMap::new(),
113 tensor_schemes: HashMap::new(),
114 model_schemes: HashMap::new(),
115 performance_schemes: HashMap::new(),
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct QuantizationSchemeStorage {
123 pub config_path: Option<PathBuf>,
125 pub config: QuantizationSchemeConfig,
127 pub scheme_cache: HashMap<String, QuantizationScheme>,
129}
130
131impl Default for QuantizationSchemeStorage {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137impl QuantizationSchemeStorage {
138 pub fn new() -> Self {
140 Self {
141 config_path: None,
142 config: QuantizationSchemeConfig::default(),
143 scheme_cache: HashMap::new(),
144 }
145 }
146
147 pub fn with_config_file<P: AsRef<Path>>(path: P) -> Result<Self> {
149 let config_path = path.as_ref().to_path_buf();
150 let config = Self::load_config(&config_path)?;
151
152 Ok(Self {
153 config_path: Some(config_path),
154 config,
155 scheme_cache: HashMap::new(),
156 })
157 }
158
159 pub fn load_config<P: AsRef<Path>>(path: P) -> Result<QuantizationSchemeConfig> {
161 let file = File::open(path.as_ref())
162 .map_err(|e| runtime_error(format!("Failed to open config file: {}", e)))?;
163 let reader = BufReader::new(file);
164
165 serde_json::from_reader(reader)
166 .map_err(|e| invalid_config("load_config", format!("Failed to parse config: {}", e)))
167 }
168
169 pub fn save_config(&self) -> Result<()> {
171 if let Some(ref path) = self.config_path {
172 let file = File::create(path)
173 .map_err(|e| runtime_error(format!("Failed to create config file: {}", e)))?;
174
175 serde_json::to_writer_pretty(file, &self.config)
176 .map_err(|e| runtime_error(format!("Failed to write config: {}", e)))?;
177 }
178 Ok(())
179 }
180
181 pub fn determine_scheme(
183 &mut self,
184 tensor_id: &str,
185 layer_name: Option<&str>,
186 model_name: Option<&str>,
187 ) -> QuantizationScheme {
188 if let Some(&scheme) = self.scheme_cache.get(tensor_id) {
190 return scheme;
191 }
192
193 if let Some(&scheme) = self.config.tensor_schemes.get(tensor_id) {
195 self.scheme_cache.insert(tensor_id.to_string(), scheme);
196 return scheme;
197 }
198
199 if let Some(layer) = layer_name {
201 if let Some(&scheme) = self.config.layer_schemes.get(layer) {
202 self.scheme_cache.insert(tensor_id.to_string(), scheme);
203 return scheme;
204 }
205 }
206
207 if let Some(model) = model_name {
209 if let Some(&scheme) = self.config.model_schemes.get(model) {
210 self.scheme_cache.insert(tensor_id.to_string(), scheme);
211 return scheme;
212 }
213 }
214
215 let default_scheme = self.config.default_scheme;
217 self.scheme_cache.insert(tensor_id.to_string(), default_scheme);
218 default_scheme
219 }
220
221 pub fn set_tensor_scheme(&mut self, tensor_id: String, scheme: QuantizationScheme) {
223 self.config.tensor_schemes.insert(tensor_id.clone(), scheme);
224 self.scheme_cache.insert(tensor_id, scheme);
225 }
226
227 pub fn set_layer_scheme(&mut self, layer_name: String, scheme: QuantizationScheme) {
229 self.config.layer_schemes.insert(layer_name, scheme);
230 }
231
232 pub fn set_model_scheme(&mut self, model_name: String, scheme: QuantizationScheme) {
234 self.config.model_schemes.insert(model_name, scheme);
235 }
236
237 pub fn clear_cache(&mut self) {
239 self.scheme_cache.clear();
240 }
241
242 pub fn generate_tensor_id(tensor: &Tensor, layer_name: Option<&str>) -> String {
244 let shape_str = tensor.shape().iter().map(|&s| s.to_string()).collect::<Vec<_>>().join("x");
245
246 let data_hash = {
247 if let Ok(data) = tensor.data() {
248 let sample_size = (data.len() / 100).max(1).min(1000); let mut hash = 0u64;
250 for i in (0..data.len()).step_by(sample_size) {
251 hash = hash.wrapping_mul(31).wrapping_add(data[i].to_bits() as u64);
252 }
253 hash
254 } else {
255 0u64 }
257 };
258
259 match layer_name {
260 Some(layer) => format!("{}:{}:{:x}", layer, shape_str, data_hash),
261 None => format!("tensor:{}:{:x}", shape_str, data_hash),
262 }
263 }
264}
265
266pub trait MobileQuantizer: Send + Sync {
268 fn get_scheme(&self) -> QuantizationScheme;
270
271 fn requires_calibration(&self) -> bool;
273
274 fn calibrate(&self, data: &[Tensor]) -> Result<()>;
276
277 fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor>;
279
280 fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor>;
282}
283
284pub struct Int4Quantizer {
286 context: QuantizationContext,
287 calibration: std::sync::RwLock<QuantizationCalibration>,
288}
289
290impl Default for Int4Quantizer {
291 fn default() -> Self {
292 Self::new()
293 }
294}
295
296impl Int4Quantizer {
297 pub fn new() -> Self {
298 Self {
299 context: QuantizationContext::default(),
300 calibration: std::sync::RwLock::new(QuantizationCalibration::default()),
301 }
302 }
303
304 fn compute_scale_zero_point(&self, min_val: f32, max_val: f32) -> (f32, i32) {
305 let qmin = -8.0; let qmax = 7.0;
307
308 let scale = (max_val - min_val) / (qmax - qmin);
309 let zero_point = ((qmin - min_val / scale).round() as i32).clamp(-8, 7);
310
311 (scale, zero_point)
312 }
313
314 fn quantize_value(&self, value: f32, scale: f32, zero_point: i32) -> i8 {
315 let quantized = (value / scale).round() as i32 + zero_point;
316 quantized.clamp(-8, 7) as i8
317 }
318
319 fn dequantize_value(&self, quantized: i8, scale: f32, zero_point: i32) -> f32 {
320 (quantized as i32 - zero_point) as f32 * scale
321 }
322}
323
324impl MobileQuantizer for Int4Quantizer {
325 fn get_scheme(&self) -> QuantizationScheme {
326 QuantizationScheme::Int4
327 }
328
329 fn requires_calibration(&self) -> bool {
330 true
331 }
332
333 fn calibrate(&self, data: &[Tensor]) -> Result<()> {
334 let mut calibration = self.calibration.write().expect("RwLock poisoned");
335
336 for tensor in data {
337 let tensor_data = tensor.data()?;
338 let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
339 let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
340
341 let (scale, zero_point) = self.compute_scale_zero_point(min_val, max_val);
342
343 calibration.min_values.insert("global".to_string(), min_val);
345 calibration.max_values.insert("global".to_string(), max_val);
346 calibration.scales.insert("global".to_string(), scale);
347 calibration.zero_points.insert("global".to_string(), zero_point);
348 }
349
350 Ok(())
351 }
352
353 fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
354 let calibration = self.calibration.read().expect("RwLock poisoned");
355 let tensor_data = tensor.data()?;
356
357 let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
359 (
360 scale,
361 *calibration.zero_points.get("global").expect("No global zero point"),
362 )
363 } else {
364 let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
366 let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
367 self.compute_scale_zero_point(min_val, max_val)
368 };
369
370 let quantized_data: Vec<i8> =
372 tensor_data.iter().map(|&x| self.quantize_value(x, scale, zero_point)).collect();
373
374 let quantized_f32: Vec<f32> = quantized_data.iter().map(|&x| x as f32).collect();
376
377 let quantized_tensor = Tensor::from_vec(quantized_f32, &tensor.shape())?;
379
380 Ok(quantized_tensor)
383 }
384
385 fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
386 let calibration = self.calibration.read().expect("RwLock poisoned");
387 let tensor_data = tensor.data()?;
388
389 let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
391 (
392 scale,
393 *calibration.zero_points.get("global").expect("No global zero point"),
394 )
395 } else {
396 let min_q = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b)) as i8;
398 let max_q = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) as i8;
399 let range = (max_q - min_q) as f32;
400 let scale = if range > 0.0 { 15.0 / range } else { 1.0 }; (scale, 0)
402 };
403
404 let dequantized_data: Vec<f32> = tensor_data
406 .iter()
407 .map(|&x| self.dequantize_value(x as i8, scale, zero_point))
408 .collect();
409
410 Tensor::from_vec(dequantized_data, &tensor.shape())
411 }
412}
413
414pub struct Int8Quantizer {
416 context: QuantizationContext,
417 calibration: std::sync::RwLock<QuantizationCalibration>,
418 symmetric: bool,
419}
420
421impl Default for Int8Quantizer {
422 fn default() -> Self {
423 Self::new()
424 }
425}
426
427impl Int8Quantizer {
428 pub fn new() -> Self {
429 Self {
430 context: QuantizationContext::default(),
431 calibration: std::sync::RwLock::new(QuantizationCalibration::default()),
432 symmetric: true, }
434 }
435
436 fn compute_scale_zero_point(&self, min_val: f32, max_val: f32) -> (f32, i32) {
437 if self.symmetric {
438 let abs_max = min_val.abs().max(max_val.abs());
440 let scale = abs_max / 127.0;
441 (scale, 0)
442 } else {
443 let qmin = -128.0;
445 let qmax = 127.0;
446 let scale = (max_val - min_val) / (qmax - qmin);
447 let zero_point = ((qmin - min_val / scale).round() as i32).clamp(-128, 127);
448 (scale, zero_point)
449 }
450 }
451}
452
453impl MobileQuantizer for Int8Quantizer {
454 fn get_scheme(&self) -> QuantizationScheme {
455 QuantizationScheme::Int8
456 }
457
458 fn requires_calibration(&self) -> bool {
459 true
460 }
461
462 fn calibrate(&self, data: &[Tensor]) -> Result<()> {
463 let mut calibration = self.calibration.write().expect("RwLock poisoned");
464
465 for tensor in data {
466 let tensor_data = tensor.data()?;
467
468 let (min_val, max_val) = match self.context.method {
469 CalibrationMethod::MinMax => {
470 let min = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
471 let max = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
472 (min, max)
473 },
474 CalibrationMethod::Percentile => {
475 let mut sorted = tensor_data.to_vec();
476 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
477 let percentile_idx =
478 (sorted.len() as f32 * self.context.percentile / 100.0) as usize;
479 let min_idx =
480 (sorted.len() as f32 * (100.0 - self.context.percentile) / 100.0) as usize;
481 (
482 sorted[min_idx],
483 sorted[percentile_idx.min(sorted.len() - 1)],
484 )
485 },
486 _ => {
487 let min = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
489 let max = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
490 (min, max)
491 },
492 };
493
494 let (scale, zero_point) = self.compute_scale_zero_point(min_val, max_val);
495
496 calibration.min_values.insert("global".to_string(), min_val);
498 calibration.max_values.insert("global".to_string(), max_val);
499 calibration.scales.insert("global".to_string(), scale);
500 calibration.zero_points.insert("global".to_string(), zero_point);
501 }
502
503 Ok(())
504 }
505
506 fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
507 let calibration = self.calibration.read().expect("RwLock poisoned");
508 let tensor_data = tensor.data()?;
509
510 let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
511 (
512 scale,
513 *calibration.zero_points.get("global").expect("No global zero point"),
514 )
515 } else {
516 let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
517 let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
518 self.compute_scale_zero_point(min_val, max_val)
519 };
520
521 let quantized_data: Vec<i8> = tensor_data
522 .iter()
523 .map(|&x| {
524 let q = (x / scale).round() as i32 + zero_point;
525 q.clamp(-128, 127) as i8
526 })
527 .collect();
528
529 let quantized_f32: Vec<f32> = quantized_data.iter().map(|&x| x as f32).collect();
531
532 let quantized_tensor = Tensor::from_vec(quantized_f32, &tensor.shape())?;
533 Ok(quantized_tensor)
536 }
537
538 fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
539 let calibration = self.calibration.read().expect("RwLock poisoned");
540 let tensor_data = tensor.data()?;
541
542 let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
544 (
545 scale,
546 *calibration.zero_points.get("global").expect("No global zero point"),
547 )
548 } else {
549 let min_q = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b)) as i32;
551 let max_q = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) as i32;
552 let range = (max_q - min_q) as f32;
553 let scale = if range > 0.0 { 255.0 / range } else { 1.0 }; (scale, 0)
555 };
556
557 let dequantized_data: Vec<f32> =
558 tensor_data.iter().map(|&x| ((x as i32) - zero_point) as f32 * scale).collect();
559
560 Tensor::from_vec(dequantized_data, &tensor.shape())
561 }
562}
563
564pub struct FP16Quantizer;
566
567impl Default for FP16Quantizer {
568 fn default() -> Self {
569 Self::new()
570 }
571}
572
573impl FP16Quantizer {
574 pub fn new() -> Self {
575 Self
576 }
577}
578
579impl MobileQuantizer for FP16Quantizer {
580 fn get_scheme(&self) -> QuantizationScheme {
581 QuantizationScheme::FP16
582 }
583
584 fn requires_calibration(&self) -> bool {
585 false }
587
588 fn calibrate(&self, _data: &[Tensor]) -> Result<()> {
589 Ok(()) }
591
592 fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
593 let tensor_data = tensor.data()?;
594
595 let fp16_data: Vec<f16> = tensor_data.iter().map(|&x| f16::from_f32(x)).collect();
597
598 let quantized_data: Vec<f32> = fp16_data.iter().map(|&x| f32::from(x)).collect();
600
601 let quantized_tensor = Tensor::from_vec(quantized_data, &tensor.shape())?;
602 Ok(quantized_tensor)
605 }
606
607 fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
608 Ok(tensor.clone())
610 }
611}
612
613pub struct DynamicQuantizer {
615 int8_quantizer: Int8Quantizer,
616 fp16_quantizer: FP16Quantizer,
617 selection_threshold: f32,
618 scheme_storage: QuantizationSchemeStorage,
619 layer_context: Option<String>,
620 model_context: Option<String>,
621}
622
623impl Default for DynamicQuantizer {
624 fn default() -> Self {
625 Self::new()
626 }
627}
628
629impl DynamicQuantizer {
630 pub fn new() -> Self {
631 Self {
632 int8_quantizer: Int8Quantizer::new(),
633 fp16_quantizer: FP16Quantizer::new(),
634 selection_threshold: 0.1, scheme_storage: QuantizationSchemeStorage::new(),
636 layer_context: None,
637 model_context: None,
638 }
639 }
640
641 pub fn with_config_file<P: AsRef<Path>>(path: P) -> Result<Self> {
642 let scheme_storage = QuantizationSchemeStorage::with_config_file(path)?;
643 Ok(Self {
644 int8_quantizer: Int8Quantizer::new(),
645 fp16_quantizer: FP16Quantizer::new(),
646 selection_threshold: 0.1,
647 scheme_storage,
648 layer_context: None,
649 model_context: None,
650 })
651 }
652
653 pub fn set_layer_context(&mut self, layer_name: String) {
654 self.layer_context = Some(layer_name);
655 }
656
657 pub fn set_model_context(&mut self, model_name: String) {
658 self.model_context = Some(model_name);
659 }
660
661 pub fn scheme_storage_mut(&mut self) -> &mut QuantizationSchemeStorage {
663 &mut self.scheme_storage
664 }
665
666 pub fn scheme_storage(&self) -> &QuantizationSchemeStorage {
668 &self.scheme_storage
669 }
670
671 fn select_quantization_scheme(&self, tensor: &Tensor) -> Result<QuantizationScheme> {
672 let tensor_data = tensor.data()?;
673
674 let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
676 let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
677 let range = max_val - min_val;
678
679 let mean = tensor_data.iter().sum::<f32>() / tensor_data.len() as f32;
681 let variance =
682 tensor_data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / tensor_data.len() as f32;
683
684 if range < 1.0 && variance < 0.01 {
686 Ok(QuantizationScheme::Int8)
688 } else {
689 Ok(QuantizationScheme::FP16)
691 }
692 }
693}
694
695impl MobileQuantizer for DynamicQuantizer {
696 fn get_scheme(&self) -> QuantizationScheme {
697 QuantizationScheme::Dynamic
698 }
699
700 fn requires_calibration(&self) -> bool {
701 true }
703
704 fn calibrate(&self, data: &[Tensor]) -> Result<()> {
705 self.int8_quantizer.calibrate(data)
707 }
708
709 fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
710 let tensor_id =
712 QuantizationSchemeStorage::generate_tensor_id(tensor, self.layer_context.as_deref());
713
714 let mut storage = self.scheme_storage.clone();
716 let scheme = storage.determine_scheme(
717 &tensor_id,
718 self.layer_context.as_deref(),
719 self.model_context.as_deref(),
720 );
721
722 let final_scheme = if scheme == QuantizationScheme::Dynamic {
724 self.select_quantization_scheme(tensor)?
725 } else {
726 scheme
727 };
728
729 match final_scheme {
730 QuantizationScheme::Int4 => {
731 let int4_quantizer = Int4Quantizer::new();
733 int4_quantizer.quantize_tensor(tensor)
734 },
735 QuantizationScheme::Int8 => self.int8_quantizer.quantize_tensor(tensor),
736 QuantizationScheme::FP16 => self.fp16_quantizer.quantize_tensor(tensor),
737 QuantizationScheme::GGUF_Q2_K
739 | QuantizationScheme::GGUF_Q3_K
740 | QuantizationScheme::GGUF_Q4_K
741 | QuantizationScheme::GGUF_Q5_0
742 | QuantizationScheme::GGUF_Q6_K => {
743 self.int8_quantizer.quantize_tensor(tensor)
746 },
747 QuantizationScheme::Dynamic => {
748 let selected_scheme = self.select_quantization_scheme(tensor)?;
750 match selected_scheme {
751 QuantizationScheme::Int4 => {
752 let int4_quantizer = Int4Quantizer::new();
753 int4_quantizer.quantize_tensor(tensor)
754 },
755 QuantizationScheme::Int8 => self.int8_quantizer.quantize_tensor(tensor),
756 QuantizationScheme::FP16 => self.fp16_quantizer.quantize_tensor(tensor),
757 QuantizationScheme::GGUF_Q2_K
759 | QuantizationScheme::GGUF_Q3_K
760 | QuantizationScheme::GGUF_Q4_K
761 | QuantizationScheme::GGUF_Q5_0
762 | QuantizationScheme::GGUF_Q6_K => self.int8_quantizer.quantize_tensor(tensor),
763 QuantizationScheme::Dynamic => {
764 self.int8_quantizer.quantize_tensor(tensor)
766 },
767 }
768 },
769 }
770 }
771
772 fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
773 let tensor_id =
775 QuantizationSchemeStorage::generate_tensor_id(tensor, self.layer_context.as_deref());
776
777 let mut storage = self.scheme_storage.clone();
779 let scheme = storage.determine_scheme(
780 &tensor_id,
781 self.layer_context.as_deref(),
782 self.model_context.as_deref(),
783 );
784
785 match scheme {
786 QuantizationScheme::Int8 => self.int8_quantizer.dequantize_tensor(tensor),
787 QuantizationScheme::FP16 => self.fp16_quantizer.dequantize_tensor(tensor),
788 QuantizationScheme::Int4 => {
789 let int4_quantizer = Int4Quantizer::new();
791 int4_quantizer.dequantize_tensor(tensor)
792 },
793 QuantizationScheme::GGUF_Q2_K
795 | QuantizationScheme::GGUF_Q3_K
796 | QuantizationScheme::GGUF_Q4_K
797 | QuantizationScheme::GGUF_Q5_0
798 | QuantizationScheme::GGUF_Q6_K => {
799 self.int8_quantizer.dequantize_tensor(tensor)
802 },
803 QuantizationScheme::Dynamic => {
804 let selected_scheme = self.select_quantization_scheme(tensor)?;
806 match selected_scheme {
807 QuantizationScheme::Int8 => self.int8_quantizer.dequantize_tensor(tensor),
808 QuantizationScheme::FP16 => self.fp16_quantizer.dequantize_tensor(tensor),
809 _ => self.int8_quantizer.dequantize_tensor(tensor), }
811 },
812 }
813 }
814}
815
816pub struct QuantizationUtils;
818
819impl QuantizationUtils {
820 pub fn compute_error(original: &Tensor, quantized: &Tensor) -> Result<f32> {
822 let orig_data = original.data()?;
823 let quant_data = quantized.data()?;
824
825 if orig_data.len() != quant_data.len() {
826 return Err(tensor_op_error(
827 "compute_error",
828 "Tensors must have same size for error computation",
829 ));
830 }
831
832 let mse = orig_data
833 .iter()
834 .zip(quant_data.iter())
835 .map(|(&o, &q)| (o - q).powi(2))
836 .sum::<f32>()
837 / orig_data.len() as f32;
838
839 Ok(mse.sqrt())
840 }
841
842 pub fn compression_ratio(scheme: QuantizationScheme) -> f32 {
844 match scheme {
845 QuantizationScheme::Int4 => 8.0, QuantizationScheme::Int8 => 4.0, QuantizationScheme::FP16 => 2.0, QuantizationScheme::Dynamic => 3.0, QuantizationScheme::GGUF_Q2_K => 32.0 / 2.5625, QuantizationScheme::GGUF_Q3_K => 32.0 / 3.4375, QuantizationScheme::GGUF_Q4_K => 32.0 / 4.5, QuantizationScheme::GGUF_Q5_0 => 32.0 / 5.5, QuantizationScheme::GGUF_Q6_K => 32.0 / 6.5, }
855 }
856
857 pub fn memory_savings_percent(scheme: QuantizationScheme) -> f32 {
859 let ratio = Self::compression_ratio(scheme);
860 (1.0 - 1.0 / ratio) * 100.0
861 }
862}
863
864#[cfg(test)]
865mod tests {
866 use super::*;
867
868 #[test]
869 fn test_int4_quantization() {
870 let quantizer = Int4Quantizer::new();
871 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])
872 .expect("Failed to create tensor");
873
874 quantizer.calibrate(std::slice::from_ref(&tensor)).expect("Calibration failed");
876
877 let quantized = quantizer.quantize_tensor(&tensor).expect("Quantization failed");
879 assert_eq!(quantized.shape(), tensor.shape());
880
881 let dequantized = quantizer.dequantize_tensor(&quantized).expect("Dequantization failed");
883 assert_eq!(dequantized.shape(), tensor.shape());
884
885 let error = QuantizationUtils::compute_error(&tensor, &dequantized)
887 .expect("Error computation failed");
888 assert!(error < 1.0); }
890
891 #[test]
892 fn test_int8_quantization() {
893 let quantizer = Int8Quantizer::new();
894 let tensor = Tensor::from_vec(vec![-10.0, -5.0, 0.0, 5.0, 10.0], &[5])
895 .expect("Failed to create tensor");
896
897 quantizer.calibrate(std::slice::from_ref(&tensor)).expect("Calibration failed");
898
899 let quantized = quantizer.quantize_tensor(&tensor).expect("Quantization failed");
900 let dequantized = quantizer.dequantize_tensor(&quantized).expect("Dequantization failed");
901
902 let error = QuantizationUtils::compute_error(&tensor, &dequantized)
903 .expect("Error computation failed");
904 assert!(error < 0.1); }
906
907 #[test]
908 fn test_fp16_quantization() {
909 let quantizer = FP16Quantizer::new();
910 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("Operation failed");
911
912 assert!(!quantizer.requires_calibration());
914
915 let quantized = quantizer.quantize_tensor(&tensor).expect("Quantization failed");
916 let dequantized = quantizer.dequantize_tensor(&quantized).expect("Dequantization failed");
917
918 let error = QuantizationUtils::compute_error(&tensor, &dequantized)
920 .expect("Error computation failed");
921 assert!(error < 0.001);
922 }
923
924 #[test]
925 fn test_dynamic_quantization() {
926 let mut quantizer = DynamicQuantizer::new();
927
928 let small_range =
930 Tensor::from_vec(vec![0.1, 0.2, 0.3, 0.4], &[4]).expect("Operation failed");
931
932 quantizer
933 .calibrate(std::slice::from_ref(&small_range))
934 .expect("Operation failed");
935 let quantized = quantizer.quantize_tensor(&small_range).expect("Operation failed");
936
937 let tensor_id = QuantizationSchemeStorage::generate_tensor_id(&small_range, None);
939
940 quantizer
942 .scheme_storage_mut()
943 .set_tensor_scheme(tensor_id.clone(), QuantizationScheme::FP16);
944
945 let quantized_fp16 = quantizer.quantize_tensor(&small_range).expect("Operation failed");
947
948 let stored_scheme = quantizer.scheme_storage_mut().determine_scheme(&tensor_id, None, None);
950 assert_eq!(stored_scheme, QuantizationScheme::FP16);
951
952 let unknown_tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("Operation failed");
954 let unknown_id = QuantizationSchemeStorage::generate_tensor_id(&unknown_tensor, None);
955 let default_scheme =
956 quantizer.scheme_storage_mut().determine_scheme(&unknown_id, None, None);
957 assert_eq!(default_scheme, QuantizationScheme::Int8); }
959
960 #[test]
961 fn test_compression_ratios() {
962 assert_eq!(
963 QuantizationUtils::compression_ratio(QuantizationScheme::Int4),
964 8.0
965 );
966 assert_eq!(
967 QuantizationUtils::compression_ratio(QuantizationScheme::Int8),
968 4.0
969 );
970 assert_eq!(
971 QuantizationUtils::compression_ratio(QuantizationScheme::FP16),
972 2.0
973 );
974
975 assert_eq!(
976 QuantizationUtils::memory_savings_percent(QuantizationScheme::Int4),
977 87.5
978 );
979 assert_eq!(
980 QuantizationUtils::memory_savings_percent(QuantizationScheme::Int8),
981 75.0
982 );
983 assert_eq!(
984 QuantizationUtils::memory_savings_percent(QuantizationScheme::FP16),
985 50.0
986 );
987 }
988}