Skip to main content

torsh_backend/quantization/
calibration.rs

1//! Quantization calibration methods and utilities
2//!
3//! This module provides sophisticated calibration techniques for determining
4//! optimal quantization parameters from sample data. It supports various
5//! calibration methods including statistical approaches, entropy-based methods,
6//! and error minimization techniques.
7
8// Framework infrastructure - components designed for future use
9#![allow(dead_code)]
10use crate::quantization::{QuantizationParams, QuantizationScheme, QuantizedDType};
11use crate::{BackendResult, Device};
12use std::collections::HashMap;
13use std::sync::Arc;
14use torsh_core::error::TorshError;
15
16#[cfg(not(feature = "std"))]
17use alloc::{boxed::Box, string::String, vec::Vec};
18
19/// Quantization calibration utility
20///
21/// The calibrator analyzes sample data to determine optimal quantization
22/// parameters that balance accuracy and performance. It supports multiple
23/// calibration methods to suit different use cases and accuracy requirements.
24#[derive(Debug, Clone)]
25pub struct QuantizationCalibrator {
26    /// Sample data for calibration
27    samples: Vec<Vec<f32>>,
28    /// Calibration method to use
29    method: CalibrationMethod,
30    /// Device for calibration computations
31    device: Device,
32    /// Cache for previously computed parameters
33    parameter_cache: HashMap<String, QuantizationParams>,
34}
35
36/// Calibration methods for quantization parameter optimization
37///
38/// Different calibration methods offer trade-offs between computational cost,
39/// robustness to outliers, and final quantization accuracy.
40#[derive(Debug)]
41pub enum CalibrationMethod {
42    /// Simple min-max calibration
43    ///
44    /// Uses the minimum and maximum values in the data to set quantization
45    /// range. Fast but sensitive to outliers.
46    MinMax,
47
48    /// Percentile-based calibration
49    ///
50    /// Uses a specified percentile to clip outliers before determining range.
51    /// More robust than min-max but requires tuning the percentile parameter.
52    Percentile(f32),
53
54    /// Entropy-based calibration (KL divergence minimization)
55    ///
56    /// Minimizes the KL divergence between original and quantized distributions.
57    /// Provides good accuracy but is computationally expensive.
58    Entropy,
59
60    /// Mean squared error minimization
61    ///
62    /// Finds parameters that minimize MSE between original and quantized values.
63    /// Good balance between accuracy and computational cost.
64    MSE,
65
66    /// Adaptive method selection
67    ///
68    /// Automatically selects the best method based on data characteristics.
69    /// Uses multiple methods and picks the one with best validation score.
70    Adaptive,
71
72    /// Custom calibration with user-defined function
73    ///
74    /// Allows users to provide their own calibration logic for specialized
75    /// use cases or domain-specific optimization.
76    Custom(Arc<dyn CalibrationFunction>),
77}
78
79impl Clone for CalibrationMethod {
80    fn clone(&self) -> Self {
81        match self {
82            CalibrationMethod::MinMax => CalibrationMethod::MinMax,
83            CalibrationMethod::Percentile(percentile) => CalibrationMethod::Percentile(*percentile),
84            CalibrationMethod::Entropy => CalibrationMethod::Entropy,
85            CalibrationMethod::MSE => CalibrationMethod::MSE,
86            CalibrationMethod::Adaptive => CalibrationMethod::Adaptive,
87            CalibrationMethod::Custom(func) => CalibrationMethod::Custom(Arc::clone(func)),
88        }
89    }
90}
91
92/// Trait for custom calibration functions
93pub trait CalibrationFunction: Send + Sync + std::fmt::Debug {
94    /// Compute quantization parameters from sample data
95    fn calibrate(
96        &self,
97        samples: &[Vec<f32>],
98        dtype: QuantizedDType,
99    ) -> BackendResult<QuantizationParams>;
100}
101
102impl QuantizationCalibrator {
103    /// Create a new calibrator with the specified method
104    ///
105    /// # Arguments
106    ///
107    /// * `method` - The calibration method to use
108    /// * `device` - Device for performing calibration computations
109    ///
110    /// # Examples
111    ///
112    /// ```ignore
113    /// use torsh_backend::quantization::calibration::{QuantizationCalibrator, CalibrationMethod};
114    /// use torsh_core::DeviceType;
115    ///
116    /// let device = DeviceType::Cpu;
117    /// let calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
118    /// ```
119    pub fn new(method: CalibrationMethod, device: Device) -> Self {
120        Self {
121            samples: Vec::new(),
122            method,
123            device,
124            parameter_cache: HashMap::new(),
125        }
126    }
127
128    /// Add calibration sample
129    ///
130    /// Adds a sample of data that will be used to determine optimal
131    /// quantization parameters. More samples generally lead to better
132    /// parameter estimation.
133    ///
134    /// # Arguments
135    ///
136    /// * `data` - Sample data vector
137    pub fn add_sample(&mut self, data: Vec<f32>) {
138        self.samples.push(data);
139    }
140
141    /// Add multiple calibration samples at once
142    pub fn add_samples(&mut self, samples: Vec<Vec<f32>>) {
143        self.samples.extend(samples);
144    }
145
146    /// Clear all calibration samples
147    pub fn clear_samples(&mut self) {
148        self.samples.clear();
149        self.parameter_cache.clear();
150    }
151
152    /// Get the number of calibration samples
153    pub fn num_samples(&self) -> usize {
154        self.samples.len()
155    }
156
157    /// Set the calibration method
158    pub fn set_method(&mut self, method: CalibrationMethod) {
159        self.method = method;
160        self.parameter_cache.clear(); // Clear cache when method changes
161    }
162
163    /// Compute optimal quantization parameters
164    ///
165    /// Analyzes all collected samples to determine the best quantization
166    /// parameters for the specified data type using the configured method.
167    ///
168    /// # Arguments
169    ///
170    /// * `dtype` - Target quantization data type
171    ///
172    /// # Returns
173    ///
174    /// Optimized quantization parameters
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if no samples have been added or if calibration fails
179    pub fn calibrate(&self, dtype: QuantizedDType) -> BackendResult<QuantizationParams> {
180        if self.samples.is_empty() {
181            return Err(TorshError::BackendError(
182                "No samples available for calibration".to_string(),
183            ));
184        }
185
186        // Check cache first
187        let cache_key = format!("{:?}_{:?}", dtype, self.method);
188        if let Some(cached_params) = self.parameter_cache.get(&cache_key) {
189            return Ok(cached_params.clone());
190        }
191
192        // Perform calibration based on method
193        let params = match &self.method {
194            CalibrationMethod::MinMax => self.calibrate_minmax(dtype),
195            CalibrationMethod::Percentile(percentile) => {
196                self.calibrate_percentile(dtype, *percentile)
197            }
198            CalibrationMethod::Entropy => self.calibrate_entropy(dtype),
199            CalibrationMethod::MSE => self.calibrate_mse(dtype),
200            CalibrationMethod::Adaptive => self.calibrate_adaptive(dtype),
201            CalibrationMethod::Custom(func) => func.calibrate(&self.samples, dtype),
202        };
203
204        params
205    }
206
207    /// Min-max calibration implementation
208    fn calibrate_minmax(&self, dtype: QuantizedDType) -> BackendResult<QuantizationParams> {
209        let mut min_val = f32::INFINITY;
210        let mut max_val = f32::NEG_INFINITY;
211
212        // Find global min and max across all samples
213        for sample in &self.samples {
214            for &val in sample {
215                if val.is_finite() {
216                    min_val = min_val.min(val);
217                    max_val = max_val.max(val);
218                }
219            }
220        }
221
222        if min_val.is_infinite() || max_val.is_infinite() {
223            return Err(TorshError::BackendError(
224                "No finite values found in calibration data".to_string(),
225            ));
226        }
227
228        let mut params = QuantizationParams {
229            dtype,
230            scheme: QuantizationScheme::Asymmetric,
231            scale: vec![1.0],
232            zero_point: vec![0],
233            block_size: None,
234            min_val: Some(min_val),
235            max_val: Some(max_val),
236        };
237
238        params.from_statistics(min_val, max_val)?;
239        Ok(params)
240    }
241
242    /// Percentile-based calibration implementation
243    fn calibrate_percentile(
244        &self,
245        dtype: QuantizedDType,
246        percentile: f32,
247    ) -> BackendResult<QuantizationParams> {
248        if !(0.0..=100.0).contains(&percentile) {
249            return Err(TorshError::BackendError(
250                "Percentile must be between 0 and 100".to_string(),
251            ));
252        }
253
254        // Collect all values
255        let mut all_values = Vec::new();
256        for sample in &self.samples {
257            for &val in sample {
258                if val.is_finite() {
259                    all_values.push(val);
260                }
261            }
262        }
263
264        if all_values.is_empty() {
265            return Err(TorshError::BackendError(
266                "No finite values found in calibration data".to_string(),
267            ));
268        }
269
270        all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
271
272        // Calculate percentile bounds
273        let lower_percentile = (100.0 - percentile) / 2.0;
274        let upper_percentile = (100.0 + percentile) / 2.0;
275
276        let lower_idx = ((lower_percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
277        let upper_idx = ((upper_percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
278
279        let min_val = all_values[lower_idx];
280        let max_val = all_values[upper_idx];
281
282        let mut params = QuantizationParams {
283            dtype,
284            scheme: if min_val >= 0.0 {
285                QuantizationScheme::Asymmetric
286            } else {
287                QuantizationScheme::Symmetric
288            },
289            scale: vec![1.0],
290            zero_point: vec![0],
291            block_size: None,
292            min_val: Some(min_val),
293            max_val: Some(max_val),
294        };
295
296        params.from_statistics(min_val, max_val)?;
297        Ok(params)
298    }
299
300    /// Entropy-based calibration (KL divergence minimization)
301    fn calibrate_entropy(&self, dtype: QuantizedDType) -> BackendResult<QuantizationParams> {
302        // Collect all values for histogram computation
303        let mut all_values = Vec::new();
304        for sample in &self.samples {
305            for &val in sample {
306                if val.is_finite() {
307                    all_values.push(val);
308                }
309            }
310        }
311
312        if all_values.is_empty() {
313            return Err(TorshError::BackendError(
314                "No finite values found for entropy calibration".to_string(),
315            ));
316        }
317
318        // Find reasonable initial bounds
319        all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
320        let global_min = all_values[0];
321        let global_max = all_values[all_values.len() - 1];
322
323        // Try different clipping thresholds and find the one with minimum KL divergence
324        let mut best_kl_div = f64::INFINITY;
325        let mut best_min = global_min;
326        let mut best_max = global_max;
327
328        // Search over different percentile thresholds
329        for percentile in [90.0, 95.0, 97.0, 99.0, 99.5, 99.9, 100.0] {
330            let threshold_idx = ((percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
331            let threshold_max = all_values[threshold_idx];
332            let threshold_min = -threshold_max; // Symmetric for simplicity
333
334            // Compute KL divergence for this threshold
335            if let Ok(kl_div) =
336                self.compute_kl_divergence(&all_values, threshold_min, threshold_max, &dtype)
337            {
338                if kl_div < best_kl_div {
339                    best_kl_div = kl_div;
340                    best_min = threshold_min;
341                    best_max = threshold_max;
342                }
343            }
344        }
345
346        let mut params = QuantizationParams {
347            dtype,
348            scheme: QuantizationScheme::Symmetric,
349            scale: vec![1.0],
350            zero_point: vec![0],
351            block_size: None,
352            min_val: Some(best_min),
353            max_val: Some(best_max),
354        };
355
356        params.from_statistics(best_min, best_max)?;
357        Ok(params)
358    }
359
360    /// MSE-based calibration implementation
361    fn calibrate_mse(&self, dtype: QuantizedDType) -> BackendResult<QuantizationParams> {
362        // Collect all values
363        let mut all_values = Vec::new();
364        for sample in &self.samples {
365            for &val in sample {
366                if val.is_finite() {
367                    all_values.push(val);
368                }
369            }
370        }
371
372        if all_values.is_empty() {
373            return Err(TorshError::BackendError(
374                "No finite values found for MSE calibration".to_string(),
375            ));
376        }
377
378        all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
379        let global_min = all_values[0];
380        let global_max = all_values[all_values.len() - 1];
381
382        let mut best_mse = f64::INFINITY;
383        let mut best_min = global_min;
384        let mut best_max = global_max;
385
386        // Grid search over different clipping thresholds
387        for percentile in [95.0, 97.0, 99.0, 99.5, 99.9, 100.0] {
388            let threshold_idx = ((percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
389            let threshold_max = all_values[threshold_idx];
390            let threshold_min = if global_min >= 0.0 {
391                0.0
392            } else {
393                -threshold_max
394            };
395
396            // Compute MSE for this threshold
397            if let Ok(mse) = self.compute_mse(&all_values, threshold_min, threshold_max, &dtype) {
398                if mse < best_mse {
399                    best_mse = mse;
400                    best_min = threshold_min;
401                    best_max = threshold_max;
402                }
403            }
404        }
405
406        let mut params = QuantizationParams {
407            dtype,
408            scheme: if best_min >= 0.0 {
409                QuantizationScheme::Asymmetric
410            } else {
411                QuantizationScheme::Symmetric
412            },
413            scale: vec![1.0],
414            zero_point: vec![0],
415            block_size: None,
416            min_val: Some(best_min),
417            max_val: Some(best_max),
418        };
419
420        params.from_statistics(best_min, best_max)?;
421        Ok(params)
422    }
423
424    /// Adaptive calibration that tries multiple methods
425    fn calibrate_adaptive(&self, dtype: QuantizedDType) -> BackendResult<QuantizationParams> {
426        // Try different methods and evaluate their quality
427        let methods = vec![
428            CalibrationMethod::MinMax,
429            CalibrationMethod::Percentile(99.0),
430            CalibrationMethod::Percentile(95.0),
431            CalibrationMethod::MSE,
432        ];
433
434        let mut best_score = f64::INFINITY;
435        let mut best_params = None;
436
437        for method in methods {
438            // Create temporary calibrator with this method
439            let mut temp_calibrator = self.clone();
440            temp_calibrator.set_method(method);
441
442            if let Ok(params) = temp_calibrator.calibrate(dtype.clone()) {
443                // Evaluate quality of these parameters
444                if let Ok(score) = self.evaluate_quantization_quality(&params) {
445                    if score < best_score {
446                        best_score = score;
447                        best_params = Some(params);
448                    }
449                }
450            }
451        }
452
453        best_params.ok_or_else(|| {
454            TorshError::BackendError(
455                "No suitable quantization parameters found in adaptive mode".to_string(),
456            )
457        })
458    }
459
460    /// Compute KL divergence between original and quantized distributions
461    fn compute_kl_divergence(
462        &self,
463        values: &[f32],
464        min_val: f32,
465        max_val: f32,
466        dtype: &QuantizedDType,
467    ) -> BackendResult<f64> {
468        const NUM_BINS: usize = 256;
469
470        // Create histogram of original values
471        let mut original_hist = vec![0usize; NUM_BINS];
472        let range = max_val - min_val;
473
474        if range <= 0.0 {
475            return Ok(f64::INFINITY);
476        }
477
478        for &val in values {
479            let clipped_val = val.clamp(min_val, max_val);
480            let bin = ((clipped_val - min_val) / range * (NUM_BINS - 1) as f32) as usize;
481            let bin = bin.min(NUM_BINS - 1);
482            original_hist[bin] += 1;
483        }
484
485        // Simulate quantization and create quantized histogram
486        let mut quantized_hist = vec![0usize; NUM_BINS];
487        let (qmin, qmax) = dtype.value_range();
488        let scale = range / (qmax - qmin) as f32;
489
490        for &val in values {
491            let clipped_val = val.clamp(min_val, max_val);
492            // Simulate quantization
493            let quantized = ((clipped_val - min_val) / scale)
494                .round()
495                .clamp(qmin as f32, qmax as f32);
496            let dequantized = quantized * scale + min_val;
497
498            let bin = ((dequantized - min_val) / range * (NUM_BINS - 1) as f32) as usize;
499            let bin = bin.min(NUM_BINS - 1);
500            quantized_hist[bin] += 1;
501        }
502
503        // Compute KL divergence
504        let total_samples = values.len() as f64;
505        let mut kl_div = 0.0;
506
507        for i in 0..NUM_BINS {
508            let p = (original_hist[i] as f64 + 1e-10) / total_samples; // Add small epsilon
509            let q = (quantized_hist[i] as f64 + 1e-10) / total_samples;
510
511            if p > 0.0 && q > 0.0 {
512                kl_div += p * (p / q).ln();
513            }
514        }
515
516        Ok(kl_div)
517    }
518
519    /// Compute MSE between original and quantized values
520    fn compute_mse(
521        &self,
522        values: &[f32],
523        min_val: f32,
524        max_val: f32,
525        dtype: &QuantizedDType,
526    ) -> BackendResult<f64> {
527        let (qmin, qmax) = dtype.value_range();
528        let range = max_val - min_val;
529
530        if range <= 0.0 {
531            return Ok(f64::INFINITY);
532        }
533
534        let scale = range / (qmax - qmin) as f32;
535        let mut total_error = 0.0;
536
537        for &val in values {
538            let clipped_val = val.clamp(min_val, max_val);
539            // Simulate quantization
540            let quantized = ((clipped_val - min_val) / scale)
541                .round()
542                .clamp(qmin as f32, qmax as f32);
543            let dequantized = quantized * scale + min_val;
544
545            let error = (val - dequantized).powi(2);
546            total_error += error as f64;
547        }
548
549        Ok(total_error / values.len() as f64)
550    }
551
552    /// Evaluate the quality of quantization parameters
553    fn evaluate_quantization_quality(&self, params: &QuantizationParams) -> BackendResult<f64> {
554        // Use a subset of samples for evaluation to avoid overfitting
555        let eval_samples = if self.samples.len() > 1000 {
556            &self.samples[..1000]
557        } else {
558            &self.samples
559        };
560
561        let mut total_error = 0.0;
562        let mut total_count = 0;
563
564        for sample in eval_samples {
565            for &val in sample {
566                if !val.is_finite() {
567                    continue;
568                }
569
570                // Simulate quantization
571                let scale = params.scale[0];
572                let zero_point = params.zero_point[0] as f32;
573                let (qmin, qmax) = params.dtype.value_range();
574
575                let quantized = ((val / scale + zero_point)
576                    .round()
577                    .clamp(qmin as f32, qmax as f32)) as i32;
578                let dequantized = (quantized - params.zero_point[0]) as f32 * scale;
579
580                let error = (val - dequantized).powi(2);
581                total_error += error as f64;
582                total_count += 1;
583            }
584        }
585
586        if total_count == 0 {
587            Ok(f64::INFINITY)
588        } else {
589            Ok(total_error / total_count as f64)
590        }
591    }
592}
593
594/// Percentile-based calibration method
595///
596/// A specialized calibrator that focuses on percentile-based methods
597/// with additional features for robust outlier handling.
598#[derive(Debug, Clone)]
599pub struct PercentileCalibrator {
600    /// Percentile threshold for calibration
601    pub percentile: f32,
602    /// Whether to use symmetric clipping
603    pub symmetric: bool,
604    /// Device for calibration computations
605    device: Device,
606}
607
608impl PercentileCalibrator {
609    /// Create a new percentile calibrator
610    ///
611    /// # Arguments
612    ///
613    /// * `percentile` - Percentile threshold (0-100)
614    /// * `symmetric` - Whether to use symmetric clipping around zero
615    /// * `device` - Device for computations
616    pub fn new(percentile: f32, symmetric: bool, device: Device) -> BackendResult<Self> {
617        if !(0.0..=100.0).contains(&percentile) {
618            return Err(TorshError::BackendError(
619                "Percentile must be between 0 and 100".to_string(),
620            ));
621        }
622
623        Ok(Self {
624            percentile,
625            symmetric,
626            device,
627        })
628    }
629
630    /// Calibrate using percentile method with enhanced outlier detection
631    pub fn calibrate_percentile(
632        &self,
633        samples: &[Vec<f32>],
634        dtype: QuantizedDType,
635    ) -> BackendResult<QuantizationParams> {
636        // Collect all values from samples
637        let mut all_values = Vec::new();
638        for sample in samples {
639            for &val in sample {
640                if val.is_finite() {
641                    all_values.push(val);
642                }
643            }
644        }
645
646        if all_values.is_empty() {
647            return Err(TorshError::BackendError(
648                "No finite values found in calibration data".to_string(),
649            ));
650        }
651
652        all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
653
654        let (min_val, max_val) = if self.symmetric {
655            // Symmetric percentile clipping
656            let threshold_idx =
657                ((self.percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
658            let max_abs = all_values[threshold_idx]
659                .abs()
660                .max(all_values[all_values.len() - 1 - threshold_idx].abs());
661            (-max_abs, max_abs)
662        } else {
663            // Asymmetric percentile clipping
664            let lower_percentile = (100.0 - self.percentile) / 2.0;
665            let upper_percentile = (100.0 + self.percentile) / 2.0;
666
667            let lower_idx = ((lower_percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
668            let upper_idx = ((upper_percentile / 100.0) * (all_values.len() - 1) as f32) as usize;
669
670            (all_values[lower_idx], all_values[upper_idx])
671        };
672
673        let mut params = QuantizationParams {
674            dtype,
675            scheme: if self.symmetric {
676                QuantizationScheme::Symmetric
677            } else {
678                QuantizationScheme::Asymmetric
679            },
680            scale: vec![1.0],
681            zero_point: vec![0],
682            block_size: None,
683            min_val: Some(min_val),
684            max_val: Some(max_val),
685        };
686
687        params.from_statistics(min_val, max_val)?;
688        Ok(params)
689    }
690
691    /// Calibrate with entropy validation
692    ///
693    /// Uses percentile clipping but validates the result using entropy measures
694    /// to ensure the clipping doesn't lose too much information.
695    pub fn calibrate_entropy_validated(
696        &self,
697        samples: &[Vec<f32>],
698        dtype: QuantizedDType,
699        max_entropy_loss: f64,
700    ) -> BackendResult<QuantizationParams> {
701        // Try different percentile values and pick the highest one that
702        // doesn't exceed the entropy loss threshold
703        let mut best_params = None;
704        let mut _best_percentile = 0.0;
705
706        for test_percentile in [50.0, 70.0, 80.0, 90.0, 95.0, 97.0, 99.0, 99.5] {
707            if test_percentile > self.percentile {
708                break;
709            }
710
711            let mut temp_calibrator = self.clone();
712            temp_calibrator.percentile = test_percentile;
713
714            if let Ok(params) = temp_calibrator.calibrate_percentile(samples, dtype.clone()) {
715                // Estimate entropy loss (simplified)
716                let entropy_loss = self.estimate_entropy_loss(samples, &params)?;
717
718                if entropy_loss <= max_entropy_loss {
719                    best_params = Some(params);
720                    _best_percentile = test_percentile;
721                }
722            }
723        }
724
725        best_params.ok_or_else(|| {
726            TorshError::BackendError(format!(
727                "No percentile found that meets entropy loss requirement of {}",
728                max_entropy_loss
729            ))
730        })
731    }
732
733    /// Estimate entropy loss from quantization
734    fn estimate_entropy_loss(
735        &self,
736        samples: &[Vec<f32>],
737        params: &QuantizationParams,
738    ) -> BackendResult<f64> {
739        // Simplified entropy loss estimation
740        // In practice, would compute actual entropy of original vs quantized data
741        let min_val = params.min_val.expect("min_val should be set in params");
742        let max_val = params.max_val.expect("max_val should be set in params");
743
744        let mut clipped_count = 0;
745        let mut total_count = 0;
746
747        for sample in samples {
748            for &val in sample {
749                if val.is_finite() {
750                    total_count += 1;
751                    if val < min_val || val > max_val {
752                        clipped_count += 1;
753                    }
754                }
755            }
756        }
757
758        if total_count == 0 {
759            return Ok(0.0);
760        }
761
762        // Simple approximation: entropy loss ≈ fraction of clipped values
763        Ok(clipped_count as f64 / total_count as f64)
764    }
765}
766
767/// Calibration statistics and analysis
768#[derive(Debug, Clone)]
769pub struct CalibrationStatistics {
770    /// Total number of samples processed
771    pub num_samples: usize,
772    /// Total number of values processed
773    pub num_values: usize,
774    /// Minimum value encountered
775    pub min_value: f32,
776    /// Maximum value encountered
777    pub max_value: f32,
778    /// Mean value
779    pub mean_value: f32,
780    /// Standard deviation
781    pub std_dev: f32,
782    /// Percentage of outliers (beyond 3 standard deviations)
783    pub outlier_percentage: f32,
784    /// Recommended calibration methods
785    pub recommended_methods: Vec<CalibrationMethod>,
786}
787
788impl CalibrationStatistics {
789    /// Compute statistics from calibration samples
790    pub fn from_samples(samples: &[Vec<f32>]) -> BackendResult<Self> {
791        let mut all_values = Vec::new();
792        for sample in samples {
793            for &val in sample {
794                if val.is_finite() {
795                    all_values.push(val);
796                }
797            }
798        }
799
800        if all_values.is_empty() {
801            return Err(TorshError::BackendError(
802                "No finite values found in samples".to_string(),
803            ));
804        }
805
806        all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
807
808        let num_values = all_values.len();
809        let min_value = all_values[0];
810        let max_value = all_values[num_values - 1];
811
812        // Compute mean
813        let sum: f64 = all_values.iter().map(|&x| x as f64).sum();
814        let mean_value = (sum / num_values as f64) as f32;
815
816        // Compute standard deviation
817        let variance: f64 = all_values
818            .iter()
819            .map(|&x| (x as f64 - mean_value as f64).powi(2))
820            .sum::<f64>()
821            / num_values as f64;
822        let std_dev = variance.sqrt() as f32;
823
824        // Count outliers (beyond 3 standard deviations)
825        let outlier_threshold = 3.0 * std_dev;
826        let outlier_count = all_values
827            .iter()
828            .filter(|&&x| (x - mean_value).abs() > outlier_threshold)
829            .count();
830        let outlier_percentage = (outlier_count as f32 / num_values as f32) * 100.0;
831
832        // Recommend calibration methods based on statistics
833        let recommended_methods =
834            Self::recommend_methods(outlier_percentage, std_dev, min_value, max_value);
835
836        Ok(Self {
837            num_samples: samples.len(),
838            num_values,
839            min_value,
840            max_value,
841            mean_value,
842            std_dev,
843            outlier_percentage,
844            recommended_methods,
845        })
846    }
847
848    /// Recommend calibration methods based on data characteristics
849    fn recommend_methods(
850        outlier_percentage: f32,
851        std_dev: f32,
852        min_value: f32,
853        max_value: f32,
854    ) -> Vec<CalibrationMethod> {
855        let mut recommendations = Vec::new();
856
857        // If many outliers, recommend percentile methods
858        if outlier_percentage > 5.0 {
859            recommendations.push(CalibrationMethod::Percentile(99.0));
860            recommendations.push(CalibrationMethod::Percentile(95.0));
861        }
862
863        // If high variance, recommend entropy-based methods
864        if std_dev > (max_value - min_value) * 0.2 {
865            recommendations.push(CalibrationMethod::Entropy);
866            recommendations.push(CalibrationMethod::MSE);
867        }
868
869        // Always include adaptive as a fallback
870        recommendations.push(CalibrationMethod::Adaptive);
871
872        // If no outliers and low variance, min-max is fine
873        if outlier_percentage < 1.0 && std_dev < (max_value - min_value) * 0.1 {
874            recommendations.push(CalibrationMethod::MinMax);
875        }
876
877        recommendations
878    }
879}
880
881#[cfg(test)]
882mod tests {
883    use super::*;
884
885    fn create_test_samples() -> Vec<Vec<f32>> {
886        vec![
887            vec![1.0, 2.0, 3.0, 4.0, 5.0],
888            vec![2.0, 4.0, 6.0, 8.0, 10.0],
889            vec![-1.0, -2.0, 0.0, 1.0, 2.0],
890        ]
891    }
892
893    #[test]
894    fn test_calibrator_creation() {
895        let device = Device::cpu().unwrap();
896        let calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
897
898        assert_eq!(calibrator.num_samples(), 0);
899        assert!(matches!(calibrator.method, CalibrationMethod::MinMax));
900    }
901
902    #[test]
903    fn test_sample_management() {
904        let device = Device::cpu().unwrap();
905        let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
906
907        calibrator.add_sample(vec![1.0, 2.0, 3.0]);
908        assert_eq!(calibrator.num_samples(), 1);
909
910        calibrator.add_samples(vec![vec![4.0, 5.0], vec![6.0, 7.0]]);
911        assert_eq!(calibrator.num_samples(), 3);
912
913        calibrator.clear_samples();
914        assert_eq!(calibrator.num_samples(), 0);
915    }
916
917    #[test]
918    fn test_minmax_calibration() {
919        let device = Device::cpu().unwrap();
920        let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
921
922        let samples = create_test_samples();
923        calibrator.add_samples(samples);
924
925        let result = calibrator.calibrate(QuantizedDType::Int8);
926        assert!(result.is_ok());
927
928        let params = result.unwrap();
929        assert_eq!(params.dtype, QuantizedDType::Int8);
930        assert!(params.scale[0] > 0.0);
931        assert!(params.min_val.is_some());
932        assert!(params.max_val.is_some());
933    }
934
935    #[test]
936    fn test_percentile_calibration() {
937        let device = Device::cpu().unwrap();
938        let mut calibrator =
939            QuantizationCalibrator::new(CalibrationMethod::Percentile(95.0), device);
940
941        let samples = create_test_samples();
942        calibrator.add_samples(samples);
943
944        let result = calibrator.calibrate(QuantizedDType::UInt8);
945        assert!(result.is_ok());
946
947        let params = result.unwrap();
948        assert_eq!(params.dtype, QuantizedDType::UInt8);
949    }
950
951    #[test]
952    fn test_mse_calibration() {
953        let device = Device::cpu().unwrap();
954        let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::MSE, device);
955
956        let samples = create_test_samples();
957        calibrator.add_samples(samples);
958
959        let result = calibrator.calibrate(QuantizedDType::Int8);
960        assert!(result.is_ok());
961    }
962
963    #[test]
964    fn test_adaptive_calibration() {
965        let device = Device::cpu().unwrap();
966        let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::Adaptive, device);
967
968        let samples = create_test_samples();
969        calibrator.add_samples(samples);
970
971        let result = calibrator.calibrate(QuantizedDType::Int8);
972        assert!(result.is_ok());
973    }
974
975    #[test]
976    fn test_calibration_with_outliers() {
977        let device = Device::cpu().unwrap();
978        let mut calibrator =
979            QuantizationCalibrator::new(CalibrationMethod::Percentile(90.0), device);
980
981        // Add samples with outliers - using more samples for better percentile calculation
982        let samples = vec![
983            vec![1.0, 2.0, 3.0, 4.0, 5.0, 1000.0], // 1000.0 is an outlier
984            vec![2.0, 4.0, 6.0, 8.0, 10.0],
985            vec![-1.0, -2.0, 0.0, 1.0, 2.0, -1000.0], // -1000.0 is an outlier
986            vec![1.5, 2.5, 3.5, 4.5, 5.5],
987            vec![0.5, 1.0, 1.5, 2.0, 2.5],
988            vec![3.0, 3.5, 4.0, 4.5, 5.0],
989            vec![-0.5, -1.0, 0.5, 1.0, 1.5],
990        ];
991        calibrator.add_samples(samples);
992
993        let result = calibrator.calibrate(QuantizedDType::Int8);
994        assert!(result.is_ok());
995
996        let params = result.unwrap();
997
998        // Percentile method should handle outliers better than min-max
999        assert!(params.min_val.unwrap() > -100.0); // More reasonable bound with more data
1000        assert!(params.max_val.unwrap() < 100.0);
1001    }
1002
1003    #[test]
1004    fn test_percentile_calibrator() {
1005        let device = Device::cpu().unwrap();
1006        let calibrator = PercentileCalibrator::new(95.0, false, device);
1007        assert!(calibrator.is_ok());
1008
1009        let calibrator = calibrator.unwrap();
1010        let samples = create_test_samples();
1011
1012        let result = calibrator.calibrate_percentile(&samples, QuantizedDType::Int8);
1013        assert!(result.is_ok());
1014
1015        let params = result.unwrap();
1016        assert_eq!(params.dtype, QuantizedDType::Int8);
1017        assert_eq!(params.scheme, QuantizationScheme::Asymmetric);
1018    }
1019
1020    #[test]
1021    fn test_symmetric_percentile_calibrator() {
1022        let device = Device::cpu().unwrap();
1023        let calibrator = PercentileCalibrator::new(95.0, true, device).unwrap();
1024        let samples = create_test_samples();
1025
1026        let result = calibrator.calibrate_percentile(&samples, QuantizedDType::Int8);
1027        assert!(result.is_ok());
1028
1029        let params = result.unwrap();
1030        assert_eq!(params.scheme, QuantizationScheme::Symmetric);
1031    }
1032
1033    #[test]
1034    fn test_calibration_statistics() {
1035        let samples = create_test_samples();
1036        let stats = CalibrationStatistics::from_samples(&samples);
1037        assert!(stats.is_ok());
1038
1039        let stats = stats.unwrap();
1040        assert_eq!(stats.num_samples, 3);
1041        assert_eq!(stats.num_values, 15);
1042        assert!(stats.min_value <= stats.max_value);
1043        assert!(stats.std_dev >= 0.0);
1044        assert!(!stats.recommended_methods.is_empty());
1045    }
1046
1047    #[test]
1048    fn test_invalid_percentile() {
1049        let device = Device::cpu().unwrap();
1050
1051        // Test invalid percentile values
1052        let result = PercentileCalibrator::new(101.0, false, device.clone());
1053        assert!(result.is_err());
1054
1055        let result = PercentileCalibrator::new(-1.0, false, device);
1056        assert!(result.is_err());
1057    }
1058
1059    #[test]
1060    fn test_empty_samples_error() {
1061        let device = Device::cpu().unwrap();
1062        let calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
1063
1064        let result = calibrator.calibrate(QuantizedDType::Int8);
1065        assert!(result.is_err());
1066    }
1067
1068    #[test]
1069    fn test_method_switching() {
1070        let device = Device::cpu().unwrap();
1071        let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
1072
1073        calibrator.add_samples(create_test_samples());
1074
1075        // Test switching methods
1076        calibrator.set_method(CalibrationMethod::Percentile(95.0));
1077        let result1 = calibrator.calibrate(QuantizedDType::Int8);
1078        assert!(result1.is_ok());
1079
1080        calibrator.set_method(CalibrationMethod::MSE);
1081        let result2 = calibrator.calibrate(QuantizedDType::Int8);
1082        assert!(result2.is_ok());
1083
1084        // Results might be different due to different methods
1085        let params1 = result1.unwrap();
1086        let params2 = result2.unwrap();
1087        // Both should be valid but may have different parameters
1088        assert!(params1.scale[0] > 0.0);
1089        assert!(params2.scale[0] > 0.0);
1090    }
1091
1092    #[test]
1093    fn test_calibration_with_infinite_values() {
1094        let device = Device::cpu().unwrap();
1095        let mut calibrator = QuantizationCalibrator::new(CalibrationMethod::MinMax, device);
1096
1097        // Add samples with infinite values (should be filtered out)
1098        let samples = vec![
1099            vec![1.0, 2.0, f32::INFINITY, 4.0, 5.0],
1100            vec![2.0, f32::NEG_INFINITY, 6.0, 8.0, 10.0],
1101            vec![-1.0, -2.0, 0.0, 1.0, f32::NAN],
1102        ];
1103        calibrator.add_samples(samples);
1104
1105        let result = calibrator.calibrate(QuantizedDType::Int8);
1106        assert!(result.is_ok());
1107
1108        let params = result.unwrap();
1109        assert!(params.min_val.unwrap().is_finite());
1110        assert!(params.max_val.unwrap().is_finite());
1111    }
1112}