torsh_quantization/analysis/
benchmarking.rs1use crate::{QScheme, TorshResult};
4use std::time::{Duration, Instant};
5
6#[derive(Debug, Clone)]
8pub struct QuantizationBenchmarker {
9 pub config: BenchmarkConfig,
11 pub metrics: Vec<BenchmarkResult>,
13}
14
15#[derive(Debug, Clone)]
17pub struct BenchmarkConfig {
18 pub warmup_iterations: usize,
20 pub measurement_iterations: usize,
22 pub batch_size: usize,
24 pub measure_memory: bool,
26 pub measure_accuracy: bool,
28}
29
30impl Default for BenchmarkConfig {
31 fn default() -> Self {
32 Self {
33 warmup_iterations: 10,
34 measurement_iterations: 100,
35 batch_size: 32,
36 measure_memory: true,
37 measure_accuracy: true,
38 }
39 }
40}
41
42impl Default for QuantizationBenchmarker {
43 fn default() -> Self {
44 Self::new(BenchmarkConfig::default())
45 }
46}
47
48impl QuantizationBenchmarker {
49 pub fn new(config: BenchmarkConfig) -> Self {
51 Self {
52 config,
53 metrics: Vec::new(),
54 }
55 }
56
57 pub fn benchmark_scheme(
59 &mut self,
60 scheme: QScheme,
61 operation: impl Fn() -> TorshResult<()>,
62 ) -> TorshResult<BenchmarkResult> {
63 for _ in 0..self.config.warmup_iterations {
65 operation()?;
66 }
67
68 let start = Instant::now();
70 for _ in 0..self.config.measurement_iterations {
71 operation()?;
72 }
73 let duration = start.elapsed();
74
75 let avg_duration = duration / self.config.measurement_iterations as u32;
76 let throughput = self.calculate_throughput(avg_duration);
77
78 let result = BenchmarkResult {
79 scheme,
80 avg_latency_ms: avg_duration.as_millis() as f32,
81 throughput_ops_per_sec: throughput,
82 memory_usage_mb: self.estimate_memory_usage(scheme),
83 accuracy_preservation: self.estimate_accuracy_preservation(scheme),
84 compression_ratio: self.estimate_compression_ratio(scheme),
85 };
86
87 self.metrics.push(result.clone());
88 Ok(result)
89 }
90
91 pub fn benchmark_comparison(
93 &mut self,
94 schemes: &[QScheme],
95 operation_factory: impl Fn(QScheme) -> Box<dyn Fn() -> TorshResult<()>>,
96 ) -> TorshResult<Vec<BenchmarkResult>> {
97 let mut results = Vec::new();
98
99 for &scheme in schemes {
100 let operation = operation_factory(scheme);
101 let result = self.benchmark_scheme(scheme, || operation())?;
102 results.push(result);
103 }
104
105 Ok(results)
106 }
107
108 pub fn generate_report(&self) -> String {
110 let mut report = String::new();
111 report.push_str("Quantization Benchmarking Report\n");
112 report.push_str(&"=".repeat(80));
113 report.push('\n');
114
115 report.push_str(&format!(
116 "{:<20} | {:>12} | {:>12} | {:>10} | {:>10}\n",
117 "Scheme", "Latency (ms)", "Throughput", "Memory", "Accuracy"
118 ));
119 report.push_str(&"-".repeat(80));
120 report.push('\n');
121
122 for metric in &self.metrics {
123 report.push_str(&format!(
124 "{:<20} | {:>10.2} | {:>10.0} | {:>8.1}MB | {:>8.3}\n",
125 format!("{:?}", metric.scheme),
126 metric.avg_latency_ms,
127 metric.throughput_ops_per_sec,
128 metric.memory_usage_mb,
129 metric.accuracy_preservation
130 ));
131 }
132
133 report.push('\n');
134 report.push_str(&format!(
135 "Benchmark Configuration:\n\
136 - Warmup iterations: {}\n\
137 - Measurement iterations: {}\n\
138 - Batch size: {}",
139 self.config.warmup_iterations,
140 self.config.measurement_iterations,
141 self.config.batch_size
142 ));
143
144 report
145 }
146
147 pub fn find_best_scheme(&self, criteria: OptimizationCriteria) -> Option<QScheme> {
149 if self.metrics.is_empty() {
150 return None;
151 }
152
153 let mut best_score = f32::NEG_INFINITY;
154 let mut best_scheme = None;
155
156 for metric in &self.metrics {
157 let score = criteria.calculate_score(metric);
158 if score > best_score {
159 best_score = score;
160 best_scheme = Some(metric.scheme);
161 }
162 }
163
164 best_scheme
165 }
166
167 fn calculate_throughput(&self, avg_duration: Duration) -> f32 {
169 self.config.batch_size as f32 / avg_duration.as_secs_f32()
170 }
171
172 fn estimate_memory_usage(&self, scheme: QScheme) -> f32 {
173 match scheme {
175 QScheme::Binary => 0.5,
176 QScheme::Ternary => 1.0,
177 QScheme::Int4PerTensor | QScheme::Int4PerChannel => 2.0,
178 QScheme::PerTensorAffine | QScheme::PerChannelAffine => 4.0,
179 QScheme::PerTensorSymmetric | QScheme::PerChannelSymmetric => 4.0,
180 QScheme::MixedPrecision => 8.0,
181 QScheme::GroupWise => 3.0,
182 }
183 }
184
185 fn estimate_accuracy_preservation(&self, scheme: QScheme) -> f32 {
186 match scheme {
187 QScheme::PerTensorAffine => 0.98,
188 QScheme::PerChannelAffine => 0.99,
189 QScheme::PerTensorSymmetric => 0.97,
190 QScheme::PerChannelSymmetric => 0.98,
191 QScheme::Int4PerTensor => 0.93,
192 QScheme::Int4PerChannel => 0.95,
193 QScheme::MixedPrecision => 0.99,
194 QScheme::Binary => 0.75,
195 QScheme::Ternary => 0.85,
196 QScheme::GroupWise => 0.96,
197 }
198 }
199
200 fn estimate_compression_ratio(&self, scheme: QScheme) -> f32 {
201 match scheme {
202 QScheme::PerTensorAffine => 4.0,
203 QScheme::PerChannelAffine => 3.8,
204 QScheme::PerTensorSymmetric => 4.0,
205 QScheme::PerChannelSymmetric => 3.8,
206 QScheme::Int4PerTensor => 8.0,
207 QScheme::Int4PerChannel => 7.5,
208 QScheme::MixedPrecision => 5.0,
209 QScheme::Binary => 32.0,
210 QScheme::Ternary => 16.0,
211 QScheme::GroupWise => 6.0,
212 }
213 }
214
215 pub fn clear_metrics(&mut self) {
217 self.metrics.clear();
218 }
219
220 pub fn get_metrics(&self) -> &[BenchmarkResult] {
222 &self.metrics
223 }
224}
225
226#[derive(Debug, Clone)]
228pub struct BenchmarkResult {
229 pub scheme: QScheme,
231 pub avg_latency_ms: f32,
233 pub throughput_ops_per_sec: f32,
235 pub memory_usage_mb: f32,
237 pub accuracy_preservation: f32,
239 pub compression_ratio: f32,
241}
242
243#[derive(Debug, Clone)]
245pub struct OptimizationCriteria {
246 pub latency_weight: f32,
248 pub throughput_weight: f32,
250 pub memory_weight: f32,
252 pub accuracy_weight: f32,
254 pub compression_weight: f32,
256}
257
258impl OptimizationCriteria {
259 pub fn optimize_for_speed() -> Self {
261 Self {
262 latency_weight: 0.4,
263 throughput_weight: 0.4,
264 memory_weight: 0.1,
265 accuracy_weight: 0.1,
266 compression_weight: 0.0,
267 }
268 }
269
270 pub fn optimize_for_accuracy() -> Self {
272 Self {
273 latency_weight: 0.1,
274 throughput_weight: 0.1,
275 memory_weight: 0.1,
276 accuracy_weight: 0.7,
277 compression_weight: 0.0,
278 }
279 }
280
281 pub fn optimize_for_size() -> Self {
283 Self {
284 latency_weight: 0.1,
285 throughput_weight: 0.1,
286 memory_weight: 0.3,
287 accuracy_weight: 0.2,
288 compression_weight: 0.3,
289 }
290 }
291
292 pub fn calculate_score(&self, result: &BenchmarkResult) -> f32 {
294 let latency_score = (1.0 / result.avg_latency_ms.max(0.001)) * self.latency_weight;
296 let throughput_score =
297 (result.throughput_ops_per_sec / 10000.0).min(1.0) * self.throughput_weight;
298 let memory_score = (1.0 / result.memory_usage_mb.max(0.1)) * self.memory_weight;
299 let accuracy_score = result.accuracy_preservation * self.accuracy_weight;
300 let compression_score =
301 (result.compression_ratio / 32.0).min(1.0) * self.compression_weight;
302
303 latency_score + throughput_score + memory_score + accuracy_score + compression_score
304 }
305}