Skip to main content

torsh_quantization/analysis/
benchmarking.rs

1//! Performance benchmarking utilities for quantization analysis
2
3use crate::{QScheme, TorshResult};
4use std::time::{Duration, Instant};
5
6/// Performance benchmarking utilities for quantization analysis
7#[derive(Debug, Clone)]
8pub struct QuantizationBenchmarker {
9    /// Configuration for benchmarking
10    pub config: BenchmarkConfig,
11    /// Collected benchmark metrics
12    pub metrics: Vec<BenchmarkResult>,
13}
14
15/// Configuration for benchmarking
16#[derive(Debug, Clone)]
17pub struct BenchmarkConfig {
18    /// Number of warmup iterations
19    pub warmup_iterations: usize,
20    /// Number of measurement iterations
21    pub measurement_iterations: usize,
22    /// Target batch size for benchmarking
23    pub batch_size: usize,
24    /// Include memory usage measurements
25    pub measure_memory: bool,
26    /// Include accuracy measurements
27    pub measure_accuracy: bool,
28}
29
30impl Default for BenchmarkConfig {
31    fn default() -> Self {
32        Self {
33            warmup_iterations: 10,
34            measurement_iterations: 100,
35            batch_size: 32,
36            measure_memory: true,
37            measure_accuracy: true,
38        }
39    }
40}
41
42impl Default for QuantizationBenchmarker {
43    fn default() -> Self {
44        Self::new(BenchmarkConfig::default())
45    }
46}
47
48impl QuantizationBenchmarker {
49    /// Create a new benchmarker with configuration
50    pub fn new(config: BenchmarkConfig) -> Self {
51        Self {
52            config,
53            metrics: Vec::new(),
54        }
55    }
56
57    /// Benchmark a quantization scheme
58    pub fn benchmark_scheme(
59        &mut self,
60        scheme: QScheme,
61        operation: impl Fn() -> TorshResult<()>,
62    ) -> TorshResult<BenchmarkResult> {
63        // Warmup phase
64        for _ in 0..self.config.warmup_iterations {
65            operation()?;
66        }
67
68        // Measurement phase
69        let start = Instant::now();
70        for _ in 0..self.config.measurement_iterations {
71            operation()?;
72        }
73        let duration = start.elapsed();
74
75        let avg_duration = duration / self.config.measurement_iterations as u32;
76        let throughput = self.calculate_throughput(avg_duration);
77
78        let result = BenchmarkResult {
79            scheme,
80            avg_latency_ms: avg_duration.as_millis() as f32,
81            throughput_ops_per_sec: throughput,
82            memory_usage_mb: self.estimate_memory_usage(scheme),
83            accuracy_preservation: self.estimate_accuracy_preservation(scheme),
84            compression_ratio: self.estimate_compression_ratio(scheme),
85        };
86
87        self.metrics.push(result.clone());
88        Ok(result)
89    }
90
91    /// Benchmark multiple schemes and compare
92    pub fn benchmark_comparison(
93        &mut self,
94        schemes: &[QScheme],
95        operation_factory: impl Fn(QScheme) -> Box<dyn Fn() -> TorshResult<()>>,
96    ) -> TorshResult<Vec<BenchmarkResult>> {
97        let mut results = Vec::new();
98
99        for &scheme in schemes {
100            let operation = operation_factory(scheme);
101            let result = self.benchmark_scheme(scheme, || operation())?;
102            results.push(result);
103        }
104
105        Ok(results)
106    }
107
108    /// Generate benchmark report
109    pub fn generate_report(&self) -> String {
110        let mut report = String::new();
111        report.push_str("Quantization Benchmarking Report\n");
112        report.push_str(&"=".repeat(80));
113        report.push('\n');
114
115        report.push_str(&format!(
116            "{:<20} | {:>12} | {:>12} | {:>10} | {:>10}\n",
117            "Scheme", "Latency (ms)", "Throughput", "Memory", "Accuracy"
118        ));
119        report.push_str(&"-".repeat(80));
120        report.push('\n');
121
122        for metric in &self.metrics {
123            report.push_str(&format!(
124                "{:<20} | {:>10.2} | {:>10.0} | {:>8.1}MB | {:>8.3}\n",
125                format!("{:?}", metric.scheme),
126                metric.avg_latency_ms,
127                metric.throughput_ops_per_sec,
128                metric.memory_usage_mb,
129                metric.accuracy_preservation
130            ));
131        }
132
133        report.push('\n');
134        report.push_str(&format!(
135            "Benchmark Configuration:\n\
136             - Warmup iterations: {}\n\
137             - Measurement iterations: {}\n\
138             - Batch size: {}",
139            self.config.warmup_iterations,
140            self.config.measurement_iterations,
141            self.config.batch_size
142        ));
143
144        report
145    }
146
147    /// Find the best scheme based on criteria
148    pub fn find_best_scheme(&self, criteria: OptimizationCriteria) -> Option<QScheme> {
149        if self.metrics.is_empty() {
150            return None;
151        }
152
153        let mut best_score = f32::NEG_INFINITY;
154        let mut best_scheme = None;
155
156        for metric in &self.metrics {
157            let score = criteria.calculate_score(metric);
158            if score > best_score {
159                best_score = score;
160                best_scheme = Some(metric.scheme);
161            }
162        }
163
164        best_scheme
165    }
166
167    // Helper methods
168    fn calculate_throughput(&self, avg_duration: Duration) -> f32 {
169        self.config.batch_size as f32 / avg_duration.as_secs_f32()
170    }
171
172    fn estimate_memory_usage(&self, scheme: QScheme) -> f32 {
173        // Simplified memory estimation
174        match scheme {
175            QScheme::Binary => 0.5,
176            QScheme::Ternary => 1.0,
177            QScheme::Int4PerTensor | QScheme::Int4PerChannel => 2.0,
178            QScheme::PerTensorAffine | QScheme::PerChannelAffine => 4.0,
179            QScheme::PerTensorSymmetric | QScheme::PerChannelSymmetric => 4.0,
180            QScheme::MixedPrecision => 8.0,
181            QScheme::GroupWise => 3.0,
182        }
183    }
184
185    fn estimate_accuracy_preservation(&self, scheme: QScheme) -> f32 {
186        match scheme {
187            QScheme::PerTensorAffine => 0.98,
188            QScheme::PerChannelAffine => 0.99,
189            QScheme::PerTensorSymmetric => 0.97,
190            QScheme::PerChannelSymmetric => 0.98,
191            QScheme::Int4PerTensor => 0.93,
192            QScheme::Int4PerChannel => 0.95,
193            QScheme::MixedPrecision => 0.99,
194            QScheme::Binary => 0.75,
195            QScheme::Ternary => 0.85,
196            QScheme::GroupWise => 0.96,
197        }
198    }
199
200    fn estimate_compression_ratio(&self, scheme: QScheme) -> f32 {
201        match scheme {
202            QScheme::PerTensorAffine => 4.0,
203            QScheme::PerChannelAffine => 3.8,
204            QScheme::PerTensorSymmetric => 4.0,
205            QScheme::PerChannelSymmetric => 3.8,
206            QScheme::Int4PerTensor => 8.0,
207            QScheme::Int4PerChannel => 7.5,
208            QScheme::MixedPrecision => 5.0,
209            QScheme::Binary => 32.0,
210            QScheme::Ternary => 16.0,
211            QScheme::GroupWise => 6.0,
212        }
213    }
214
215    /// Clear all collected metrics
216    pub fn clear_metrics(&mut self) {
217        self.metrics.clear();
218    }
219
220    /// Get all collected metrics
221    pub fn get_metrics(&self) -> &[BenchmarkResult] {
222        &self.metrics
223    }
224}
225
226/// Benchmark result for a specific quantization scheme
227#[derive(Debug, Clone)]
228pub struct BenchmarkResult {
229    /// Quantization scheme tested
230    pub scheme: QScheme,
231    /// Average latency in milliseconds
232    pub avg_latency_ms: f32,
233    /// Throughput in operations per second
234    pub throughput_ops_per_sec: f32,
235    /// Memory usage in megabytes
236    pub memory_usage_mb: f32,
237    /// Accuracy preservation ratio (0.0 to 1.0)
238    pub accuracy_preservation: f32,
239    /// Compression ratio compared to FP32
240    pub compression_ratio: f32,
241}
242
243/// Optimization criteria for selecting best quantization scheme
244#[derive(Debug, Clone)]
245pub struct OptimizationCriteria {
246    /// Weight for latency optimization (lower is better)
247    pub latency_weight: f32,
248    /// Weight for throughput optimization (higher is better)
249    pub throughput_weight: f32,
250    /// Weight for memory optimization (lower is better)
251    pub memory_weight: f32,
252    /// Weight for accuracy preservation (higher is better)
253    pub accuracy_weight: f32,
254    /// Weight for compression ratio (higher is better)
255    pub compression_weight: f32,
256}
257
258impl OptimizationCriteria {
259    /// Create criteria optimized for speed
260    pub fn optimize_for_speed() -> Self {
261        Self {
262            latency_weight: 0.4,
263            throughput_weight: 0.4,
264            memory_weight: 0.1,
265            accuracy_weight: 0.1,
266            compression_weight: 0.0,
267        }
268    }
269
270    /// Create criteria optimized for accuracy
271    pub fn optimize_for_accuracy() -> Self {
272        Self {
273            latency_weight: 0.1,
274            throughput_weight: 0.1,
275            memory_weight: 0.1,
276            accuracy_weight: 0.7,
277            compression_weight: 0.0,
278        }
279    }
280
281    /// Create criteria optimized for size
282    pub fn optimize_for_size() -> Self {
283        Self {
284            latency_weight: 0.1,
285            throughput_weight: 0.1,
286            memory_weight: 0.3,
287            accuracy_weight: 0.2,
288            compression_weight: 0.3,
289        }
290    }
291
292    /// Calculate weighted score for a benchmark result
293    pub fn calculate_score(&self, result: &BenchmarkResult) -> f32 {
294        // Normalize metrics to 0-1 range and apply weights
295        let latency_score = (1.0 / result.avg_latency_ms.max(0.001)) * self.latency_weight;
296        let throughput_score =
297            (result.throughput_ops_per_sec / 10000.0).min(1.0) * self.throughput_weight;
298        let memory_score = (1.0 / result.memory_usage_mb.max(0.1)) * self.memory_weight;
299        let accuracy_score = result.accuracy_preservation * self.accuracy_weight;
300        let compression_score =
301            (result.compression_ratio / 32.0).min(1.0) * self.compression_weight;
302
303        latency_score + throughput_score + memory_score + accuracy_score + compression_score
304    }
305}