torsh_fx/quantization/
benchmark.rs1use super::metrics::QuantizationMetrics;
4
5pub struct QuantizationBenchmark;
7
8impl QuantizationBenchmark {
9 pub fn measure_accuracy(
11 original_outputs: &[f32],
12 quantized_outputs: &[f32],
13 ) -> QuantizationMetrics {
14 let mut total_error = 0.0f32;
15 let mut max_error = 0.0f32;
16 let mut snr_sum = 0.0f32;
17
18 for (orig, quant) in original_outputs.iter().zip(quantized_outputs.iter()) {
19 let error = (orig - quant).abs();
20 total_error += error;
21 max_error = max_error.max(error);
22
23 if orig.abs() > 1e-8 {
24 let snr = 20.0 * (orig.abs() / error).log10();
25 snr_sum += snr;
26 }
27 }
28
29 let mean_error = total_error / original_outputs.len() as f32;
30 let mean_snr = snr_sum / original_outputs.len() as f32;
31
32 QuantizationMetrics {
33 mean_absolute_error: mean_error,
34 max_absolute_error: max_error,
35 signal_to_noise_ratio: mean_snr,
36 sample_count: original_outputs.len(),
37 }
38 }
39}