1#![allow(dead_code)]
10use crate::quantization::{QuantizationParams, QuantizationScheme, QuantizedDType};
11use crate::{BackendResult, Device};
12use std::collections::HashMap;
13use std::sync::Arc;
14use torsh_core::error::TorshError;
15
16#[cfg(not(feature = "std"))]
17use alloc::{boxed::Box, string::String, vec::Vec};
18
19#[derive(Debug, Clone)]
25pub struct QuantizationCalibrator {
26 samples: Vec<Vec<f32>>,
28 method: CalibrationMethod,
30 device: Device,
32 parameter_cache: HashMap<String, QuantizationParams>,
34}
35
36#[derive(Debug)]
41pub enum CalibrationMethod {
42 MinMax,
47
48 Percentile(f32),
53
54 Entropy,
59
60 MSE,
65
66 Adaptive,
71
72 Custom(Arc<dyn CalibrationFunction>),
77}
78
79impl Clone for CalibrationMethod {
80 fn clone(&self) -> Self {
81 match self {
82 CalibrationMethod::MinMax => CalibrationMethod::MinMax,
83 CalibrationMethod::Percentile(percentile) => CalibrationMethod::Percentile(*percentile),
84 CalibrationMethod::Entropy => CalibrationMethod::Entropy,
85 CalibrationMethod::MSE => CalibrationMethod::MSE,
86 CalibrationMethod::Adaptive => CalibrationMethod::Adaptive,
87 CalibrationMethod::Custom(func) => CalibrationMethod::Custom(Arc::clone(func)),
88 }
89 }
90}
91
92pub trait CalibrationFunction: Send + Sync + std::fmt::Debug {
94 fn calibrate(
96 &self,
97 samples: &[Vec<f32>],
98 dtype: QuantizedDType,
99 ) -> BackendResult<QuantizationParams>;
100}
101
102impl QuantizationCalibrator {
103 pub fn new(method: CalibrationMethod, device: Device) -> Self {
120 Self {
121 samples: Vec::new(),
122 method,
123 device,
124 parameter_cache: HashMap::new(),
125 }
126 }
127
128 pub fn add_sample(&mut self, data: Vec<f32>) {
138 self.samples.push(data);
139 }
140
141 pub fn add_samples(&mut self, samples: Vec<Vec<f32>>) {
143 self.samples.extend(samples);
144 }
145
146 pub fn clear_samples(&mut self) {
148 self.samples.clear();
149 self.parameter_cache.clear();
150 }
151
152 pub fn num_samples(&self) -> usize {
154 self.samples.len()
155 }
156
157 pub fn set_method(&mut self, method: CalibrationMethod) {
159 self.method = method;
160 self.parameter_cache.clear(); }
162
163 pub fn calibrate(&self, dtype: QuantizedDType) -> BackendResult<QuantizationParams> {
180 if self.samples.is_empty() {
181 return Err(TorshError::BackendError(
182 "No samples available for calibration".to_string(),
183 ));
184 }
185
186 let cache_key = format!("{:?}_{:?}", dtype, self.method);
188 if let Some(cached_params) = self.parameter_cache.get(&cache_key) {
189 return Ok(cached_params.clone());
190 }
191
192 let params = match &self.method {
194 CalibrationMethod::MinMax => self.calibrate_minmax(dtype),
195 CalibrationMethod::Percentile(percentile) => {
196 self.calibrate_percentile(dtype, *percentile)
197 }
198 CalibrationMethod::Entropy => self.calibrate_entropy(dtype),
199 CalibrationMethod::MSE => self.calibrate_mse(dtype),
200 CalibrationMethod::Adaptive => self.calibrate_adaptive(dtype),
201 CalibrationMethod::Custom(func) => func.calibrate(&self.samples, dtype),
202 };
203
204 params
205 }
206
207 fn calibrate_minmax(&self, dtype: QuantizedDType) -> BackendResult<QuantizationParams> {
209 let mut min_val = f32::INFINITY;
210 let mut max_val = f32::NEG_INFINITY;
211
212 for sample in &self.samples {
214 for &val in sample {
215 if val.is_finite() {
216 min_val = min_val.min(val);
217 max_val = max_val.max(val);
218 }
219 }
220 }
221
222 if min_val.is_infinite() || max_val.is_infinite() {
223 return Err(TorshError::BackendError(
224 "No finite values found in calibration data".to_string(),
225 ));
226 }
227
228 let mut params = QuantizationParams {
229 dtype,
230 scheme: QuantizationScheme::Asymmetric,
231 scale: vec![1.0],
232 zero_point: vec![0],
233 block_size: None,
234 min_val: Some(min_val),
235 max_val: Some(max_val),
236 };
237
238 params.from_statistics(min_val, max_val)?;
239 Ok(params)
240 }
241
242 fn calibrate_percentile(
244 &self,
245 dtype: QuantizedDType,
246 percentile: f32,
247 ) -> BackendResult<QuantizationParams> {
248 if !(0.0..=100.0).contains(&percentile) {
249 return Err(TorshError::BackendError(
250 "Percentile must be between 0 and 100".to_string(),
251 ));
252 }
253
254 let mut all_values = Vec::new();
256 for sample in &self.samples {
257 for &val in sample {
258 if val.is_finite() {
259 all_values.push(val);
260 }
261 }
262 }
263
264 if all_values.is_empty() {
265 return Err(TorshError::BackendError(
266 "No finite values found in calibration data".to_string(),
267 ));
268 }
269
270 all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
271
272 let lower_percentile = (100.0 - percentile) / 2.0;
274 let upper_percentile = (100.0 + percentile) / 2.0;
275
276 let lower_idx = ((lower_percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
277 let upper_idx = ((upper_percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
278
279 let min_val = all_values[lower_idx];
280 let max_val = all_values[upper_idx];
281
282 let mut params = QuantizationParams {
283 dtype,
284 scheme: if min_val >= 0.0 {
285 QuantizationScheme::Asymmetric
286 } else {
287 QuantizationScheme::Symmetric
288 },
289 scale: vec![1.0],
290 zero_point: vec![0],
291 block_size: None,
292 min_val: Some(min_val),
293 max_val: Some(max_val),
294 };
295
296 params.from_statistics(min_val, max_val)?;
297 Ok(params)
298 }
299
300 fn calibrate_entropy(&self, dtype: QuantizedDType) -> BackendResult<QuantizationParams> {
302 let mut all_values = Vec::new();
304 for sample in &self.samples {
305 for &val in sample {
306 if val.is_finite() {
307 all_values.push(val);
308 }
309 }
310 }
311
312 if all_values.is_empty() {
313 return Err(TorshError::BackendError(
314 "No finite values found for entropy calibration".to_string(),
315 ));
316 }
317
318 all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
320 let global_min = all_values[0];
321 let global_max = all_values[all_values.len() - 1];
322
323 let mut best_kl_div = f64::INFINITY;
325 let mut best_min = global_min;
326 let mut best_max = global_max;
327
328 for percentile in [90.0, 95.0, 97.0, 99.0, 99.5, 99.9, 100.0] {
330 let threshold_idx = ((percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
331 let threshold_max = all_values[threshold_idx];
332 let threshold_min = -threshold_max; if let Ok(kl_div) =
336 self.compute_kl_divergence(&all_values, threshold_min, threshold_max, &dtype)
337 {
338 if kl_div < best_kl_div {
339 best_kl_div = kl_div;
340 best_min = threshold_min;
341 best_max = threshold_max;
342 }
343 }
344 }
345
346 let mut params = QuantizationParams {
347 dtype,
348 scheme: QuantizationScheme::Symmetric,
349 scale: vec![1.0],
350 zero_point: vec![0],
351 block_size: None,
352 min_val: Some(best_min),
353 max_val: Some(best_max),
354 };
355
356 params.from_statistics(best_min, best_max)?;
357 Ok(params)
358 }
359
360 fn calibrate_mse(&self, dtype: QuantizedDType) -> BackendResult<QuantizationParams> {
362 let mut all_values = Vec::new();
364 for sample in &self.samples {
365 for &val in sample {
366 if val.is_finite() {
367 all_values.push(val);
368 }
369 }
370 }
371
372 if all_values.is_empty() {
373 return Err(TorshError::BackendError(
374 "No finite values found for MSE calibration".to_string(),
375 ));
376 }
377
378 all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
379 let global_min = all_values[0];
380 let global_max = all_values[all_values.len() - 1];
381
382 let mut best_mse = f64::INFINITY;
383 let mut best_min = global_min;
384 let mut best_max = global_max;
385
386 for percentile in [95.0, 97.0, 99.0, 99.5, 99.9, 100.0] {
388 let threshold_idx = ((percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
389 let threshold_max = all_values[threshold_idx];
390 let threshold_min = if global_min >= 0.0 {
391 0.0
392 } else {
393 -threshold_max
394 };
395
396 if let Ok(mse) = self.compute_mse(&all_values, threshold_min, threshold_max, &dtype) {
398 if mse < best_mse {
399 best_mse = mse;
400 best_min = threshold_min;
401 best_max = threshold_max;
402 }
403 }
404 }
405
406 let mut params = QuantizationParams {
407 dtype,
408 scheme: if best_min >= 0.0 {
409 QuantizationScheme::Asymmetric
410 } else {
411 QuantizationScheme::Symmetric
412 },
413 scale: vec![1.0],
414 zero_point: vec![0],
415 block_size: None,
416 min_val: Some(best_min),
417 max_val: Some(best_max),
418 };
419
420 params.from_statistics(best_min, best_max)?;
421 Ok(params)
422 }
423
424 fn calibrate_adaptive(&self, dtype: QuantizedDType) -> BackendResult<QuantizationParams> {
426 let methods = vec![
428 CalibrationMethod::MinMax,
429 CalibrationMethod::Percentile(99.0),
430 CalibrationMethod::Percentile(95.0),
431 CalibrationMethod::MSE,
432 ];
433
434 let mut best_score = f64::INFINITY;
435 let mut best_params = None;
436
437 for method in methods {
438 let mut temp_calibrator = self.clone();
440 temp_calibrator.set_method(method);
441
442 if let Ok(params) = temp_calibrator.calibrate(dtype.clone()) {
443 if let Ok(score) = self.evaluate_quantization_quality(¶ms) {
445 if score < best_score {
446 best_score = score;
447 best_params = Some(params);
448 }
449 }
450 }
451 }
452
453 best_params.ok_or_else(|| {
454 TorshError::BackendError(
455 "No suitable quantization parameters found in adaptive mode".to_string(),
456 )
457 })
458 }
459
460 fn compute_kl_divergence(
462 &self,
463 values: &[f32],
464 min_val: f32,
465 max_val: f32,
466 dtype: &QuantizedDType,
467 ) -> BackendResult<f64> {
468 const NUM_BINS: usize = 256;
469
470 let mut original_hist = vec![0usize; NUM_BINS];
472 let range = max_val - min_val;
473
474 if range <= 0.0 {
475 return Ok(f64::INFINITY);
476 }
477
478 for &val in values {
479 let clipped_val = val.clamp(min_val, max_val);
480 let bin = ((clipped_val - min_val) / range * (NUM_BINS - 1) as f32) as usize;
481 let bin = bin.min(NUM_BINS - 1);
482 original_hist[bin] += 1;
483 }
484
485 let mut quantized_hist = vec![0usize; NUM_BINS];
487 let (qmin, qmax) = dtype.value_range();
488 let scale = range / (qmax - qmin) as f32;
489
490 for &val in values {
491 let clipped_val = val.clamp(min_val, max_val);
492 let quantized = ((clipped_val - min_val) / scale)
494 .round()
495 .clamp(qmin as f32, qmax as f32);
496 let dequantized = quantized * scale + min_val;
497
498 let bin = ((dequantized - min_val) / range * (NUM_BINS - 1) as f32) as usize;
499 let bin = bin.min(NUM_BINS - 1);
500 quantized_hist[bin] += 1;
501 }
502
503 let total_samples = values.len() as f64;
505 let mut kl_div = 0.0;
506
507 for i in 0..NUM_BINS {
508 let p = (original_hist[i] as f64 + 1e-10) / total_samples; let q = (quantized_hist[i] as f64 + 1e-10) / total_samples;
510
511 if p > 0.0 && q > 0.0 {
512 kl_div += p * (p / q).ln();
513 }
514 }
515
516 Ok(kl_div)
517 }
518
519 fn compute_mse(
521 &self,
522 values: &[f32],
523 min_val: f32,
524 max_val: f32,
525 dtype: &QuantizedDType,
526 ) -> BackendResult<f64> {
527 let (qmin, qmax) = dtype.value_range();
528 let range = max_val - min_val;
529
530 if range <= 0.0 {
531 return Ok(f64::INFINITY);
532 }
533
534 let scale = range / (qmax - qmin) as f32;
535 let mut total_error = 0.0;
536
537 for &val in values {
538 let clipped_val = val.clamp(min_val, max_val);
539 let quantized = ((clipped_val - min_val) / scale)
541 .round()
542 .clamp(qmin as f32, qmax as f32);
543 let dequantized = quantized * scale + min_val;
544
545 let error = (val - dequantized).powi(2);
546 total_error += error as f64;
547 }
548
549 Ok(total_error / values.len() as f64)
550 }
551
552 fn evaluate_quantization_quality(&self, params: &QuantizationParams) -> BackendResult<f64> {
554 let eval_samples = if self.samples.len() > 1000 {
556 &self.samples[..1000]
557 } else {
558 &self.samples
559 };
560
561 let mut total_error = 0.0;
562 let mut total_count = 0;
563
564 for sample in eval_samples {
565 for &val in sample {
566 if !val.is_finite() {
567 continue;
568 }
569
570 let scale = params.scale[0];
572 let zero_point = params.zero_point[0] as f32;
573 let (qmin, qmax) = params.dtype.value_range();
574
575 let quantized = ((val / scale + zero_point)
576 .round()
577 .clamp(qmin as f32, qmax as f32)) as i32;
578 let dequantized = (quantized - params.zero_point[0]) as f32 * scale;
579
580 let error = (val - dequantized).powi(2);
581 total_error += error as f64;
582 total_count += 1;
583 }
584 }
585
586 if total_count == 0 {
587 Ok(f64::INFINITY)
588 } else {
589 Ok(total_error / total_count as f64)
590 }
591 }
592}
593
594#[derive(Debug, Clone)]
599pub struct PercentileCalibrator {
600 pub percentile: f32,
602 pub symmetric: bool,
604 device: Device,
606}
607
608impl PercentileCalibrator {
609 pub fn new(percentile: f32, symmetric: bool, device: Device) -> BackendResult<Self> {
617 if !(0.0..=100.0).contains(&percentile) {
618 return Err(TorshError::BackendError(
619 "Percentile must be between 0 and 100".to_string(),
620 ));
621 }
622
623 Ok(Self {
624 percentile,
625 symmetric,
626 device,
627 })
628 }
629
630 pub fn calibrate_percentile(
632 &self,
633 samples: &[Vec<f32>],
634 dtype: QuantizedDType,
635 ) -> BackendResult<QuantizationParams> {
636 let mut all_values = Vec::new();
638 for sample in samples {
639 for &val in sample {
640 if val.is_finite() {
641 all_values.push(val);
642 }
643 }
644 }
645
646 if all_values.is_empty() {
647 return Err(TorshError::BackendError(
648 "No finite values found in calibration data".to_string(),
649 ));
650 }
651
652 all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
653
654 let (min_val, max_val) = if self.symmetric {
655 let threshold_idx =
657 ((self.percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
658 let max_abs = all_values[threshold_idx]
659 .abs()
660 .max(all_values[all_values.len() - 1 - threshold_idx].abs());
661 (-max_abs, max_abs)
662 } else {
663 let lower_percentile = (100.0 - self.percentile) / 2.0;
665 let upper_percentile = (100.0 + self.percentile) / 2.0;
666
667 let lower_idx = ((lower_percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
668 let upper_idx = ((upper_percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
669
670 (all_values[lower_idx], all_values[upper_idx])
671 };
672
673 let mut params = QuantizationParams {
674 dtype,
675 scheme: if self.symmetric {
676 QuantizationScheme::Symmetric
677 } else {
678 QuantizationScheme::Asymmetric
679 },
680 scale: vec![1.0],
681 zero_point: vec![0],
682 block_size: None,
683 min_val: Some(min_val),
684 max_val: Some(max_val),
685 };
686
687 params.from_statistics(min_val, max_val)?;
688 Ok(params)
689 }
690
691 pub fn calibrate_entropy_validated(
696 &self,
697 samples: &[Vec<f32>],
698 dtype: QuantizedDType,
699 max_entropy_loss: f64,
700 ) -> BackendResult<QuantizationParams> {
701 let mut best_params = None;
704 let mut _best_percentile = 0.0;
705
706 for test_percentile in [50.0, 70.0, 80.0, 90.0, 95.0, 97.0, 99.0, 99.5] {
707 if test_percentile > self.percentile {
708 break;
709 }
710
711 let mut temp_calibrator = self.clone();
712 temp_calibrator.percentile = test_percentile;
713
714 if let Ok(params) = temp_calibrator.calibrate_percentile(samples, dtype.clone()) {
715 let entropy_loss = self.estimate_entropy_loss(samples, ¶ms)?;
717
718 if entropy_loss <= max_entropy_loss {
719 best_params = Some(params);
720 _best_percentile = test_percentile;
721 }
722 }
723 }
724
725 best_params.ok_or_else(|| {
726 TorshError::BackendError(format!(
727 "No percentile found that meets entropy loss requirement of {}",
728 max_entropy_loss
729 ))
730 })
731 }
732
733 fn estimate_entropy_loss(
735 &self,
736 samples: &[Vec<f32>],
737 params: &QuantizationParams,
738 ) -> BackendResult<f64> {
739 let min_val = params.min_val.expect("min_val should be set in params");
742 let max_val = params.max_val.expect("max_val should be set in params");
743
744 let mut clipped_count = 0;
745 let mut total_count = 0;
746
747 for sample in samples {
748 for &val in sample {
749 if val.is_finite() {
750 total_count += 1;
751 if val < min_val || val > max_val {
752 clipped_count += 1;
753 }
754 }
755 }
756 }
757
758 if total_count == 0 {
759 return Ok(0.0);
760 }
761
762 Ok(clipped_count as f64 / total_count as f64)
764 }
765}
766
767#[derive(Debug, Clone)]
769pub struct CalibrationStatistics {
770 pub num_samples: usize,
772 pub num_values: usize,
774 pub min_value: f32,
776 pub max_value: f32,
778 pub mean_value: f32,
780 pub std_dev: f32,
782 pub outlier_percentage: f32,
784 pub recommended_methods: Vec<CalibrationMethod>,
786}
787
788impl CalibrationStatistics {
789 pub fn from_samples(samples: &[Vec<f32>]) -> BackendResult<Self> {
791 let mut all_values = Vec::new();
792 for sample in samples {
793 for &val in sample {
794 if val.is_finite() {
795 all_values.push(val);
796 }
797 }
798 }
799
800 if all_values.is_empty() {
801 return Err(TorshError::BackendError(
802 "No finite values found in samples".to_string(),
803 ));
804 }
805
806 all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
807
808 let num_values = all_values.len();
809 let min_value = all_values[0];
810 let max_value = all_values[num_values - 1];
811
812 let sum: f64 = all_values.iter().map(|&x| x as f64).sum();
814 let mean_value = (sum / num_values as f64) as f32;
815
816 let variance: f64 = all_values
818 .iter()
819 .map(|&x| (x as f64 - mean_value as f64).powi(2))
820 .sum::<f64>()
821 / num_values as f64;
822 let std_dev = variance.sqrt() as f32;
823
824 let outlier_threshold = 3.0 * std_dev;
826 let outlier_count = all_values
827 .iter()
828 .filter(|&&x| (x - mean_value).abs() > outlier_threshold)
829 .count();
830 let outlier_percentage = (outlier_count as f32 / num_values as f32) * 100.0;
831
832 let recommended_methods =
834 Self::recommend_methods(outlier_percentage, std_dev, min_value, max_value);
835
836 Ok(Self {
837 num_samples: samples.len(),
838 num_values,
839 min_value,
840 max_value,
841 mean_value,
842 std_dev,
843 outlier_percentage,
844 recommended_methods,
845 })
846 }
847
848 fn recommend_methods(
850 outlier_percentage: f32,
851 std_dev: f32,
852 min_value: f32,
853 max_value: f32,
854 ) -> Vec<CalibrationMethod> {
855 let mut recommendations = Vec::new();
856
857 if outlier_percentage > 5.0 {
859 recommendations.push(CalibrationMethod::Percentile(99.0));
860 recommendations.push(CalibrationMethod::Percentile(95.0));
861 }
862
863 if std_dev > (max_value - min_value) * 0.2 {
865 recommendations.push(CalibrationMethod::Entropy);
866 recommendations.push(CalibrationMethod::MSE);
867 }
868
869 recommendations.push(CalibrationMethod::Adaptive);
871
872 if outlier_percentage < 1.0 && std_dev < (max_value - min_value) * 0.1 {
874 recommendations.push(CalibrationMethod::MinMax);
875 }
876
877 recommendations
878 }
879}
880
881#[cfg(test)]
882mod tests {
883 use super::*;
884
885 fn create_test_samples() -> Vec<Vec<f32>> {
886 vec![
887 vec![1.0, 2.0, 3.0, 4.0, 5.0],
888 vec![2.0, 4.0, 6.0, 8.0, 10.0],
889 vec![-1.0, -2.0, 0.0, 1.0, 2.0],
890 ]
891 }
892
893 #[test]
894 fn test_calibrator_creation() {
895 let device = Device::cpu().unwrap();
896 let calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
897
898 assert_eq!(calibrator.num_samples(), 0);
899 assert!(matches!(calibrator.method, CalibrationMethod::MinMax));
900 }
901
902 #[test]
903 fn test_sample_management() {
904 let device = Device::cpu().unwrap();
905 let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
906
907 calibrator.add_sample(vec![1.0, 2.0, 3.0]);
908 assert_eq!(calibrator.num_samples(), 1);
909
910 calibrator.add_samples(vec![vec![4.0, 5.0], vec![6.0, 7.0]]);
911 assert_eq!(calibrator.num_samples(), 3);
912
913 calibrator.clear_samples();
914 assert_eq!(calibrator.num_samples(), 0);
915 }
916
917 #[test]
918 fn test_minmax_calibration() {
919 let device = Device::cpu().unwrap();
920 let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
921
922 let samples = create_test_samples();
923 calibrator.add_samples(samples);
924
925 let result = calibrator.calibrate(QuantizedDType::Int8);
926 assert!(result.is_ok());
927
928 let params = result.unwrap();
929 assert_eq!(params.dtype, QuantizedDType::Int8);
930 assert!(params.scale[0] > 0.0);
931 assert!(params.min_val.is_some());
932 assert!(params.max_val.is_some());
933 }
934
935 #[test]
936 fn test_percentile_calibration() {
937 let device = Device::cpu().unwrap();
938 let mut calibrator =
939 QuantizationCalibrator::new(CalibrationMethod::Percentile(95.0), device);
940
941 let samples = create_test_samples();
942 calibrator.add_samples(samples);
943
944 let result = calibrator.calibrate(QuantizedDType::UInt8);
945 assert!(result.is_ok());
946
947 let params = result.unwrap();
948 assert_eq!(params.dtype, QuantizedDType::UInt8);
949 }
950
951 #[test]
952 fn test_mse_calibration() {
953 let device = Device::cpu().unwrap();
954 let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::MSE, device);
955
956 let samples = create_test_samples();
957 calibrator.add_samples(samples);
958
959 let result = calibrator.calibrate(QuantizedDType::Int8);
960 assert!(result.is_ok());
961 }
962
963 #[test]
964 fn test_adaptive_calibration() {
965 let device = Device::cpu().unwrap();
966 let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::Adaptive, device);
967
968 let samples = create_test_samples();
969 calibrator.add_samples(samples);
970
971 let result = calibrator.calibrate(QuantizedDType::Int8);
972 assert!(result.is_ok());
973 }
974
975 #[test]
976 fn test_calibration_with_outliers() {
977 let device = Device::cpu().unwrap();
978 let mut calibrator =
979 QuantizationCalibrator::new(CalibrationMethod::Percentile(90.0), device);
980
981 let samples = vec![
983 vec![1.0, 2.0, 3.0, 4.0, 5.0, 1000.0], vec![2.0, 4.0, 6.0, 8.0, 10.0],
985 vec![-1.0, -2.0, 0.0, 1.0, 2.0, -1000.0], vec![1.5, 2.5, 3.5, 4.5, 5.5],
987 vec![0.5, 1.0, 1.5, 2.0, 2.5],
988 vec![3.0, 3.5, 4.0, 4.5, 5.0],
989 vec![-0.5, -1.0, 0.5, 1.0, 1.5],
990 ];
991 calibrator.add_samples(samples);
992
993 let result = calibrator.calibrate(QuantizedDType::Int8);
994 assert!(result.is_ok());
995
996 let params = result.unwrap();
997
998 assert!(params.min_val.unwrap() > -100.0); assert!(params.max_val.unwrap() < 100.0);
1001 }
1002
1003 #[test]
1004 fn test_percentile_calibrator() {
1005 let device = Device::cpu().unwrap();
1006 let calibrator = PercentileCalibrator::new(95.0, false, device);
1007 assert!(calibrator.is_ok());
1008
1009 let calibrator = calibrator.unwrap();
1010 let samples = create_test_samples();
1011
1012 let result = calibrator.calibrate_percentile(&samples, QuantizedDType::Int8);
1013 assert!(result.is_ok());
1014
1015 let params = result.unwrap();
1016 assert_eq!(params.dtype, QuantizedDType::Int8);
1017 assert_eq!(params.scheme, QuantizationScheme::Asymmetric);
1018 }
1019
1020 #[test]
1021 fn test_symmetric_percentile_calibrator() {
1022 let device = Device::cpu().unwrap();
1023 let calibrator = PercentileCalibrator::new(95.0, true, device).unwrap();
1024 let samples = create_test_samples();
1025
1026 let result = calibrator.calibrate_percentile(&samples, QuantizedDType::Int8);
1027 assert!(result.is_ok());
1028
1029 let params = result.unwrap();
1030 assert_eq!(params.scheme, QuantizationScheme::Symmetric);
1031 }
1032
1033 #[test]
1034 fn test_calibration_statistics() {
1035 let samples = create_test_samples();
1036 let stats = CalibrationStatistics::from_samples(&samples);
1037 assert!(stats.is_ok());
1038
1039 let stats = stats.unwrap();
1040 assert_eq!(stats.num_samples, 3);
1041 assert_eq!(stats.num_values, 15);
1042 assert!(stats.min_value <= stats.max_value);
1043 assert!(stats.std_dev >= 0.0);
1044 assert!(!stats.recommended_methods.is_empty());
1045 }
1046
1047 #[test]
1048 fn test_invalid_percentile() {
1049 let device = Device::cpu().unwrap();
1050
1051 let result = PercentileCalibrator::new(101.0, false, device.clone());
1053 assert!(result.is_err());
1054
1055 let result = PercentileCalibrator::new(-1.0, false, device);
1056 assert!(result.is_err());
1057 }
1058
1059 #[test]
1060 fn test_empty_samples_error() {
1061 let device = Device::cpu().unwrap();
1062 let calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
1063
1064 let result = calibrator.calibrate(QuantizedDType::Int8);
1065 assert!(result.is_err());
1066 }
1067
1068 #[test]
1069 fn test_method_switching() {
1070 let device = Device::cpu().unwrap();
1071 let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
1072
1073 calibrator.add_samples(create_test_samples());
1074
1075 calibrator.set_method(CalibrationMethod::Percentile(95.0));
1077 let result1 = calibrator.calibrate(QuantizedDType::Int8);
1078 assert!(result1.is_ok());
1079
1080 calibrator.set_method(CalibrationMethod::MSE);
1081 let result2 = calibrator.calibrate(QuantizedDType::Int8);
1082 assert!(result2.is_ok());
1083
1084 let params1 = result1.unwrap();
1086 let params2 = result2.unwrap();
1087 assert!(params1.scale[0] > 0.0);
1089 assert!(params2.scale[0] > 0.0);
1090 }
1091
1092 #[test]
1093 fn test_calibration_with_infinite_values() {
1094 let device = Device::cpu().unwrap();
1095 let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
1096
1097 let samples = vec![
1099 vec![1.0, 2.0, f32::INFINITY, 4.0, 5.0],
1100 vec![2.0, f32::NEG_INFINITY, 6.0, 8.0, 10.0],
1101 vec![-1.0, -2.0, 0.0, 1.0, f32::NAN],
1102 ];
1103 calibrator.add_samples(samples);
1104
1105 let result = calibrator.calibrate(QuantizedDType::Int8);
1106 assert!(result.is_ok());
1107
1108 let params = result.unwrap();
1109 assert!(params.min_val.unwrap().is_finite());
1110 assert!(params.max_val.unwrap().is_finite());
1111 }
1112}