1use super::context::QuantizationContext;
4use super::types::{QuantizationAnnotation, QuantizationParams, QuantizationScheme};
5use crate::{FxGraph, Node, TorshResult};
6use petgraph::graph::NodeIndex;
7use petgraph::visit::IntoNodeReferences;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy)]
12pub enum PrecisionCriteria {
13 Performance,
15 Balanced,
17 Accuracy,
19 Custom {
21 max_accuracy_loss: f32,
22 min_speedup: f32,
23 },
24}
25
26#[derive(Debug, Clone)]
28pub struct PrecisionRecommendation {
29 pub scheme: QuantizationScheme,
31 pub accuracy_loss: f32,
33 pub speedup_ratio: f32,
35 pub confidence: f32,
37 pub reasoning: String,
39}
40
41#[derive(Debug, Clone)]
43pub struct PrecisionStrategy {
44 pub int8_priority: f32,
46 pub int16_priority: f32,
48 pub dynamic_priority: f32,
50 pub fp32_priority: f32,
52 pub performance_weight: f32,
54 pub accuracy_weight: f32,
56}
57
58impl Default for PrecisionStrategy {
59 fn default() -> Self {
60 Self {
61 int8_priority: 0.8,
62 int16_priority: 0.6,
63 dynamic_priority: 0.4,
64 fp32_priority: 0.2,
65 performance_weight: 0.5,
66 accuracy_weight: 0.5,
67 }
68 }
69}
70
71pub struct AutomaticPrecisionSelector {
73 pub criteria: PrecisionCriteria,
75 pub strategy: PrecisionStrategy,
77 pub operation_profiles: HashMap<String, PrecisionProfile>,
79}
80
81#[derive(Debug, Clone)]
83pub struct PrecisionProfile {
84 pub recommended_scheme: QuantizationScheme,
86 pub accuracy_impact: HashMap<QuantizationScheme, f32>,
88 pub performance_gain: HashMap<QuantizationScheme, f32>,
90 pub quantization_sensitive: bool,
92}
93
94impl AutomaticPrecisionSelector {
95 pub fn new(criteria: PrecisionCriteria) -> Self {
97 Self {
98 criteria,
99 strategy: PrecisionStrategy::default(),
100 operation_profiles: Self::create_default_profiles(),
101 }
102 }
103
104 pub fn with_strategy(criteria: PrecisionCriteria, strategy: PrecisionStrategy) -> Self {
106 Self {
107 criteria,
108 strategy,
109 operation_profiles: Self::create_default_profiles(),
110 }
111 }
112
113 pub fn analyze_graph(
115 &self,
116 graph: &FxGraph,
117 ) -> TorshResult<HashMap<NodeIndex, PrecisionRecommendation>> {
118 let mut recommendations = HashMap::new();
119
120 for (node_idx, node) in graph.graph.node_references() {
122 if let Node::Call(op_name, _args) = node {
123 let recommendation = self.analyze_operation(&op_name, node_idx, graph)?;
124 recommendations.insert(node_idx, recommendation);
125 }
126 }
127
128 self.optimize_graph_precision(&mut recommendations, graph)?;
130
131 Ok(recommendations)
132 }
133
134 fn analyze_operation(
136 &self,
137 op_name: &str,
138 node_idx: NodeIndex,
139 graph: &FxGraph,
140 ) -> TorshResult<PrecisionRecommendation> {
141 let profile = self
142 .operation_profiles
143 .get(op_name)
144 .cloned()
145 .unwrap_or_else(|| self.create_generic_profile(op_name));
146
147 let mut best_score = f32::NEG_INFINITY;
149 let mut best_scheme = None;
150 let mut best_reasoning = String::new();
151
152 for &scheme in &[
153 QuantizationScheme::Int8,
154 QuantizationScheme::Int16,
155 QuantizationScheme::Dynamic,
156 ] {
157 let score = self.calculate_precision_score(&profile, scheme, node_idx, graph)?;
158
159 if score > best_score && score != f32::NEG_INFINITY {
160 best_score = score;
161 best_scheme = Some(scheme);
162 best_reasoning = self.generate_reasoning(op_name, scheme, &profile);
163 }
164 }
165
166 let selected_scheme = best_scheme.unwrap_or_else(|| {
168 if matches!(self.criteria, PrecisionCriteria::Custom { .. }) {
170 for &scheme in &[
172 QuantizationScheme::Int16,
173 QuantizationScheme::Dynamic,
174 QuantizationScheme::Int8,
175 ] {
176 let accuracy_loss =
177 profile.accuracy_impact.get(&scheme).copied().unwrap_or(0.1);
178 if let PrecisionCriteria::Custom {
179 max_accuracy_loss, ..
180 } = self.criteria
181 {
182 if accuracy_loss <= max_accuracy_loss {
183 return scheme;
184 }
185 }
186 }
187 }
188 QuantizationScheme::Int16 });
190
191 let accuracy_loss = profile
193 .accuracy_impact
194 .get(&selected_scheme)
195 .copied()
196 .unwrap_or(0.1);
197 let speedup_ratio = profile
198 .performance_gain
199 .get(&selected_scheme)
200 .copied()
201 .unwrap_or(1.2);
202 let confidence = self.calculate_confidence(&profile, selected_scheme);
203
204 Ok(PrecisionRecommendation {
205 scheme: selected_scheme,
206 accuracy_loss,
207 speedup_ratio,
208 confidence,
209 reasoning: if best_scheme.is_some() {
210 best_reasoning
211 } else {
212 format!(
213 "Fallback selection for '{}' due to constraint violations",
214 op_name
215 )
216 },
217 })
218 }
219
220 fn calculate_precision_score(
222 &self,
223 profile: &PrecisionProfile,
224 scheme: QuantizationScheme,
225 _node_idx: NodeIndex,
226 _graph: &FxGraph,
227 ) -> TorshResult<f32> {
228 let accuracy_loss = profile.accuracy_impact.get(&scheme).copied().unwrap_or(0.1);
229 let performance_gain = profile
230 .performance_gain
231 .get(&scheme)
232 .copied()
233 .unwrap_or(1.1);
234
235 let accuracy_score = (1.0 - accuracy_loss) * self.strategy.accuracy_weight;
237 let performance_score = (performance_gain - 1.0) * self.strategy.performance_weight;
238
239 let adjusted_score = match self.criteria {
241 PrecisionCriteria::Performance => performance_score * 2.0 + accuracy_score,
242 PrecisionCriteria::Accuracy => {
243 if profile.quantization_sensitive {
246 let sensitivity_bonus = match scheme {
248 QuantizationScheme::Int16 => 2.0,
249 QuantizationScheme::Int8 => -1.0,
250 _ => 0.0,
251 };
252 accuracy_score * 3.0 + performance_score * 0.5 + sensitivity_bonus
253 } else {
254 accuracy_score * 2.0 + performance_score
255 }
256 }
257 PrecisionCriteria::Balanced => {
258 if profile.quantization_sensitive {
260 let sensitivity_bonus = match scheme {
262 QuantizationScheme::Int16 => 1.0,
263 QuantizationScheme::Int8 => -0.5,
264 _ => 0.0,
265 };
266 accuracy_score + performance_score + sensitivity_bonus
267 } else {
268 accuracy_score + performance_score
269 }
270 }
271 PrecisionCriteria::Custom {
272 max_accuracy_loss,
273 min_speedup,
274 } => {
275 if accuracy_loss > max_accuracy_loss || performance_gain < min_speedup {
276 return Ok(f32::NEG_INFINITY);
277 }
278 accuracy_score + performance_score
279 }
280 };
281
282 let priority = match scheme {
284 QuantizationScheme::Int8 => self.strategy.int8_priority,
285 QuantizationScheme::Int16 => self.strategy.int16_priority,
286 QuantizationScheme::Dynamic => self.strategy.dynamic_priority,
287 QuantizationScheme::Fake => self.strategy.fp32_priority,
288 };
289
290 Ok(adjusted_score * priority)
291 }
292
293 fn generate_reasoning(
295 &self,
296 op_name: &str,
297 scheme: QuantizationScheme,
298 profile: &PrecisionProfile,
299 ) -> String {
300 let scheme_name = match scheme {
301 QuantizationScheme::Int8 => "INT8",
302 QuantizationScheme::Int16 => "INT16",
303 QuantizationScheme::Dynamic => "Dynamic",
304 QuantizationScheme::Fake => "FP32",
305 };
306
307 if profile.quantization_sensitive {
308 format!("Operation '{op_name}' is quantization-sensitive. {scheme_name} provides good balance of performance and accuracy.")
309 } else {
310 format!("Operation '{op_name}' is quantization-friendly. {scheme_name} offers optimal performance with minimal accuracy loss.")
311 }
312 }
313
314 fn calculate_confidence(&self, profile: &PrecisionProfile, scheme: QuantizationScheme) -> f32 {
316 let base_confidence = if profile.quantization_sensitive {
317 0.75
318 } else {
319 0.9
320 };
321
322 let scheme_confidence = match scheme {
324 QuantizationScheme::Int8 => 0.9,
325 QuantizationScheme::Int16 => 0.85,
326 QuantizationScheme::Dynamic => 0.7,
327 QuantizationScheme::Fake => 0.6,
328 };
329
330 let recommendation_bonus = if scheme == profile.recommended_scheme {
332 1.1
333 } else {
334 1.0
335 };
336
337 let confidence: f32 = base_confidence * scheme_confidence * recommendation_bonus;
338 confidence.min(1.0)
339 }
340
341 fn optimize_graph_precision(
343 &self,
344 recommendations: &mut HashMap<NodeIndex, PrecisionRecommendation>,
345 _graph: &FxGraph,
346 ) -> TorshResult<()> {
347 for recommendation in recommendations.values_mut() {
354 if recommendation.confidence < 0.5 {
355 recommendation.scheme = QuantizationScheme::Int16;
357 recommendation.reasoning = format!(
358 "Conservative choice due to low confidence: {}",
359 recommendation.reasoning
360 );
361 }
362 }
363
364 Ok(())
365 }
366
367 fn create_default_profiles() -> HashMap<String, PrecisionProfile> {
369 let mut profiles = HashMap::new();
370
371 profiles.insert(
373 "matmul".to_string(),
374 PrecisionProfile {
375 recommended_scheme: QuantizationScheme::Int8,
376 accuracy_impact: [
377 (QuantizationScheme::Int8, 0.015),
378 (QuantizationScheme::Int16, 0.005),
379 (QuantizationScheme::Dynamic, 0.008),
380 (QuantizationScheme::Fake, 0.0),
381 ]
382 .iter()
383 .cloned()
384 .collect(),
385 performance_gain: [
386 (QuantizationScheme::Int8, 2.5),
387 (QuantizationScheme::Int16, 2.2),
388 (QuantizationScheme::Dynamic, 2.1),
389 (QuantizationScheme::Fake, 1.0),
390 ]
391 .iter()
392 .cloned()
393 .collect(),
394 quantization_sensitive: false,
395 },
396 );
397
398 profiles.insert(
400 "conv2d".to_string(),
401 PrecisionProfile {
402 recommended_scheme: QuantizationScheme::Int8,
403 accuracy_impact: [
404 (QuantizationScheme::Int8, 0.03),
405 (QuantizationScheme::Int16, 0.008),
406 (QuantizationScheme::Dynamic, 0.015),
407 (QuantizationScheme::Fake, 0.0),
408 ]
409 .iter()
410 .cloned()
411 .collect(),
412 performance_gain: [
413 (QuantizationScheme::Int8, 3.0),
414 (QuantizationScheme::Int16, 2.0),
415 (QuantizationScheme::Dynamic, 1.5),
416 (QuantizationScheme::Fake, 1.0),
417 ]
418 .iter()
419 .cloned()
420 .collect(),
421 quantization_sensitive: false,
422 },
423 );
424
425 profiles.insert(
427 "attention".to_string(),
428 PrecisionProfile {
429 recommended_scheme: QuantizationScheme::Int16,
430 accuracy_impact: [
431 (QuantizationScheme::Int8, 0.08),
432 (QuantizationScheme::Int16, 0.02),
433 (QuantizationScheme::Dynamic, 0.04),
434 (QuantizationScheme::Fake, 0.0),
435 ]
436 .iter()
437 .cloned()
438 .collect(),
439 performance_gain: [
440 (QuantizationScheme::Int8, 2.0),
441 (QuantizationScheme::Int16, 1.6),
442 (QuantizationScheme::Dynamic, 1.3),
443 (QuantizationScheme::Fake, 1.0),
444 ]
445 .iter()
446 .cloned()
447 .collect(),
448 quantization_sensitive: true,
449 },
450 );
451
452 profiles.insert(
454 "relu".to_string(),
455 PrecisionProfile {
456 recommended_scheme: QuantizationScheme::Int8,
457 accuracy_impact: [
458 (QuantizationScheme::Int8, 0.001),
459 (QuantizationScheme::Int16, 0.0005),
460 (QuantizationScheme::Dynamic, 0.0008),
461 (QuantizationScheme::Fake, 0.0),
462 ]
463 .iter()
464 .cloned()
465 .collect(),
466 performance_gain: [
467 (QuantizationScheme::Int8, 1.8),
468 (QuantizationScheme::Int16, 1.4),
469 (QuantizationScheme::Dynamic, 1.2),
470 (QuantizationScheme::Fake, 1.0),
471 ]
472 .iter()
473 .cloned()
474 .collect(),
475 quantization_sensitive: false,
476 },
477 );
478
479 profiles
480 }
481
482 fn create_generic_profile(&self, _op_name: &str) -> PrecisionProfile {
484 PrecisionProfile {
485 recommended_scheme: QuantizationScheme::Int16, accuracy_impact: [
487 (QuantizationScheme::Int8, 0.015),
488 (QuantizationScheme::Int16, 0.005),
489 (QuantizationScheme::Dynamic, 0.01),
490 (QuantizationScheme::Fake, 0.0),
491 ]
492 .iter()
493 .cloned()
494 .collect(),
495 performance_gain: [
496 (QuantizationScheme::Int8, 2.0),
497 (QuantizationScheme::Int16, 1.5),
498 (QuantizationScheme::Dynamic, 1.3),
499 (QuantizationScheme::Fake, 1.0),
500 ]
501 .iter()
502 .cloned()
503 .collect(),
504 quantization_sensitive: true, }
506 }
507}
508
509pub fn select_automatic_precision(
511 graph: &FxGraph,
512 criteria: PrecisionCriteria,
513) -> TorshResult<HashMap<NodeIndex, PrecisionRecommendation>> {
514 let selector = AutomaticPrecisionSelector::new(criteria);
515 selector.analyze_graph(graph)
516}
517
518pub fn apply_automatic_precision(
520 graph: &mut FxGraph,
521 criteria: PrecisionCriteria,
522) -> TorshResult<QuantizationContext> {
523 let recommendations = select_automatic_precision(graph, criteria)?;
524
525 let mut context = QuantizationContext::new(QuantizationScheme::Int8);
526
527 for (node_idx, recommendation) in recommendations {
529 let params = QuantizationParams::symmetric(recommendation.scheme, 0.1);
530 let annotation = QuantizationAnnotation {
531 input_params: vec![Some(params.clone())],
532 output_params: Some(params),
533 calibration_data: None,
534 };
535
536 context.annotate_node(node_idx, annotation);
537 }
538
539 Ok(context)
540}