Skip to main content

torsh_fx/quantization/
benchmark.rs

1//! Quantization benchmarking utilities
2
3use super::metrics::QuantizationMetrics;
4
5/// Quantization benchmarking utilities
6pub struct QuantizationBenchmark;
7
8impl QuantizationBenchmark {
9    /// Measure quantization accuracy
10    pub fn measure_accuracy(
11        original_outputs: &[f32],
12        quantized_outputs: &[f32],
13    ) -> QuantizationMetrics {
14        let mut total_error = 0.0f32;
15        let mut max_error = 0.0f32;
16        let mut snr_sum = 0.0f32;
17
18        for (orig, quant) in original_outputs.iter().zip(quantized_outputs.iter()) {
19            let error = (orig - quant).abs();
20            total_error += error;
21            max_error = max_error.max(error);
22
23            if orig.abs() > 1e-8 {
24                let snr = 20.0 * (orig.abs() / error).log10();
25                snr_sum += snr;
26            }
27        }
28
29        let mean_error = total_error / original_outputs.len() as f32;
30        let mean_snr = snr_sum / original_outputs.len() as f32;
31
32        QuantizationMetrics {
33            mean_absolute_error: mean_error,
34            max_absolute_error: max_error,
35            signal_to_noise_ratio: mean_snr,
36            sample_count: original_outputs.len(),
37        }
38    }
39}