Skip to main content

torsh_fx/quantization/
precision.rs

1//! Automatic precision selection for quantization
2
3use super::context::QuantizationContext;
4use super::types::{QuantizationAnnotation, QuantizationParams, QuantizationScheme};
5use crate::{FxGraph, Node, TorshResult};
6use petgraph::graph::NodeIndex;
7use petgraph::visit::IntoNodeReferences;
8use std::collections::HashMap;
9
10/// Automatic precision selection criteria
11#[derive(Debug, Clone, Copy)]
12pub enum PrecisionCriteria {
13    /// Maximize performance, minimal accuracy loss
14    Performance,
15    /// Balance performance and accuracy
16    Balanced,
17    /// Maximize accuracy, minimal performance loss
18    Accuracy,
19    /// Custom criteria with specified thresholds
20    Custom {
21        max_accuracy_loss: f32,
22        min_speedup: f32,
23    },
24}
25
26/// Precision selection result for a node
27#[derive(Debug, Clone)]
28pub struct PrecisionRecommendation {
29    /// Recommended quantization scheme
30    pub scheme: QuantizationScheme,
31    /// Expected accuracy loss (0.0 = no loss, 1.0 = total loss)
32    pub accuracy_loss: f32,
33    /// Expected speedup ratio (1.0 = no speedup, 2.0 = 2x faster)
34    pub speedup_ratio: f32,
35    /// Confidence in recommendation (0.0 = low, 1.0 = high)
36    pub confidence: f32,
37    /// Reasoning for the recommendation
38    pub reasoning: String,
39}
40
41/// Precision selection strategy
42#[derive(Debug, Clone)]
43pub struct PrecisionStrategy {
44    /// Priority for different data types
45    pub int8_priority: f32,
46    /// Priority for int16 (usually lower than int8)
47    pub int16_priority: f32,
48    /// Priority for dynamic quantization
49    pub dynamic_priority: f32,
50    /// Priority for keeping full precision
51    pub fp32_priority: f32,
52    /// Performance sensitivity factor
53    pub performance_weight: f32,
54    /// Accuracy sensitivity factor
55    pub accuracy_weight: f32,
56}
57
58impl Default for PrecisionStrategy {
59    fn default() -> Self {
60        Self {
61            int8_priority: 0.8,
62            int16_priority: 0.6,
63            dynamic_priority: 0.4,
64            fp32_priority: 0.2,
65            performance_weight: 0.5,
66            accuracy_weight: 0.5,
67        }
68    }
69}
70
71/// Automatic precision selector
72pub struct AutomaticPrecisionSelector {
73    /// Selection criteria
74    pub criteria: PrecisionCriteria,
75    /// Selection strategy
76    pub strategy: PrecisionStrategy,
77    /// Operation-specific precision profiles
78    pub operation_profiles: HashMap<String, PrecisionProfile>,
79}
80
81/// Precision profile for specific operations
82#[derive(Debug, Clone)]
83pub struct PrecisionProfile {
84    /// Recommended scheme for this operation
85    pub recommended_scheme: QuantizationScheme,
86    /// Expected accuracy impact for each scheme
87    pub accuracy_impact: HashMap<QuantizationScheme, f32>,
88    /// Expected performance gain for each scheme
89    pub performance_gain: HashMap<QuantizationScheme, f32>,
90    /// Whether this operation is quantization-sensitive
91    pub quantization_sensitive: bool,
92}
93
94impl AutomaticPrecisionSelector {
95    /// Create new precision selector
96    pub fn new(criteria: PrecisionCriteria) -> Self {
97        Self {
98            criteria,
99            strategy: PrecisionStrategy::default(),
100            operation_profiles: Self::create_default_profiles(),
101        }
102    }
103
104    /// Create precision selector with custom strategy
105    pub fn with_strategy(criteria: PrecisionCriteria, strategy: PrecisionStrategy) -> Self {
106        Self {
107            criteria,
108            strategy,
109            operation_profiles: Self::create_default_profiles(),
110        }
111    }
112
113    /// Analyze graph and recommend precision for each operation
114    pub fn analyze_graph(
115        &self,
116        graph: &FxGraph,
117    ) -> TorshResult<HashMap<NodeIndex, PrecisionRecommendation>> {
118        let mut recommendations = HashMap::new();
119
120        // Analyze each node in the graph
121        for (node_idx, node) in graph.graph.node_references() {
122            if let Node::Call(op_name, _args) = node {
123                let recommendation = self.analyze_operation(&op_name, node_idx, graph)?;
124                recommendations.insert(node_idx, recommendation);
125            }
126        }
127
128        // Post-process recommendations to ensure graph-level consistency
129        self.optimize_graph_precision(&mut recommendations, graph)?;
130
131        Ok(recommendations)
132    }
133
134    /// Analyze a specific operation and recommend precision
135    fn analyze_operation(
136        &self,
137        op_name: &str,
138        node_idx: NodeIndex,
139        graph: &FxGraph,
140    ) -> TorshResult<PrecisionRecommendation> {
141        let profile = self
142            .operation_profiles
143            .get(op_name)
144            .cloned()
145            .unwrap_or_else(|| self.create_generic_profile(op_name));
146
147        // Calculate scores for different precision schemes
148        let mut best_score = f32::NEG_INFINITY;
149        let mut best_scheme = None;
150        let mut best_reasoning = String::new();
151
152        for &scheme in &[
153            QuantizationScheme::Int8,
154            QuantizationScheme::Int16,
155            QuantizationScheme::Dynamic,
156        ] {
157            let score = self.calculate_precision_score(&profile, scheme, node_idx, graph)?;
158
159            if score > best_score && score != f32::NEG_INFINITY {
160                best_score = score;
161                best_scheme = Some(scheme);
162                best_reasoning = self.generate_reasoning(op_name, scheme, &profile);
163            }
164        }
165
166        // Use the best scheme or fallback to a conservative default
167        let selected_scheme = best_scheme.unwrap_or_else(|| {
168            // If no scheme meets the criteria, use the most conservative option
169            if matches!(self.criteria, PrecisionCriteria::Custom { .. }) {
170                // For custom criteria, try to find a scheme that at least meets accuracy requirements
171                for &scheme in &[
172                    QuantizationScheme::Int16,
173                    QuantizationScheme::Dynamic,
174                    QuantizationScheme::Int8,
175                ] {
176                    let accuracy_loss =
177                        profile.accuracy_impact.get(&scheme).copied().unwrap_or(0.1);
178                    if let PrecisionCriteria::Custom {
179                        max_accuracy_loss, ..
180                    } = self.criteria
181                    {
182                        if accuracy_loss <= max_accuracy_loss {
183                            return scheme;
184                        }
185                    }
186                }
187            }
188            QuantizationScheme::Int16 // Most conservative fallback
189        });
190
191        // Get metrics for the selected scheme
192        let accuracy_loss = profile
193            .accuracy_impact
194            .get(&selected_scheme)
195            .copied()
196            .unwrap_or(0.1);
197        let speedup_ratio = profile
198            .performance_gain
199            .get(&selected_scheme)
200            .copied()
201            .unwrap_or(1.2);
202        let confidence = self.calculate_confidence(&profile, selected_scheme);
203
204        Ok(PrecisionRecommendation {
205            scheme: selected_scheme,
206            accuracy_loss,
207            speedup_ratio,
208            confidence,
209            reasoning: if best_scheme.is_some() {
210                best_reasoning
211            } else {
212                format!(
213                    "Fallback selection for '{}' due to constraint violations",
214                    op_name
215                )
216            },
217        })
218    }
219
220    /// Calculate precision score for a specific scheme
221    fn calculate_precision_score(
222        &self,
223        profile: &PrecisionProfile,
224        scheme: QuantizationScheme,
225        _node_idx: NodeIndex,
226        _graph: &FxGraph,
227    ) -> TorshResult<f32> {
228        let accuracy_loss = profile.accuracy_impact.get(&scheme).copied().unwrap_or(0.1);
229        let performance_gain = profile
230            .performance_gain
231            .get(&scheme)
232            .copied()
233            .unwrap_or(1.1);
234
235        // Calculate base score
236        let accuracy_score = (1.0 - accuracy_loss) * self.strategy.accuracy_weight;
237        let performance_score = (performance_gain - 1.0) * self.strategy.performance_weight;
238
239        // Apply criteria-specific adjustments
240        let adjusted_score = match self.criteria {
241            PrecisionCriteria::Performance => performance_score * 2.0 + accuracy_score,
242            PrecisionCriteria::Accuracy => {
243                // For accuracy-focused selection, heavily penalize accuracy loss
244                // and favor schemes that preserve accuracy, especially for sensitive operations
245                if profile.quantization_sensitive {
246                    // For sensitive operations, strongly prefer Int16 over Int8
247                    let sensitivity_bonus = match scheme {
248                        QuantizationScheme::Int16 => 2.0,
249                        QuantizationScheme::Int8 => -1.0,
250                        _ => 0.0,
251                    };
252                    accuracy_score * 3.0 + performance_score * 0.5 + sensitivity_bonus
253                } else {
254                    accuracy_score * 2.0 + performance_score
255                }
256            }
257            PrecisionCriteria::Balanced => {
258                // For balanced selection, also consider operation sensitivity
259                if profile.quantization_sensitive {
260                    // For sensitive operations, prefer Int16 over Int8
261                    let sensitivity_bonus = match scheme {
262                        QuantizationScheme::Int16 => 1.0,
263                        QuantizationScheme::Int8 => -0.5,
264                        _ => 0.0,
265                    };
266                    accuracy_score + performance_score + sensitivity_bonus
267                } else {
268                    accuracy_score + performance_score
269                }
270            }
271            PrecisionCriteria::Custom {
272                max_accuracy_loss,
273                min_speedup,
274            } => {
275                if accuracy_loss > max_accuracy_loss || performance_gain < min_speedup {
276                    return Ok(f32::NEG_INFINITY);
277                }
278                accuracy_score + performance_score
279            }
280        };
281
282        // Apply scheme-specific priority
283        let priority = match scheme {
284            QuantizationScheme::Int8 => self.strategy.int8_priority,
285            QuantizationScheme::Int16 => self.strategy.int16_priority,
286            QuantizationScheme::Dynamic => self.strategy.dynamic_priority,
287            QuantizationScheme::Fake => self.strategy.fp32_priority,
288        };
289
290        Ok(adjusted_score * priority)
291    }
292
293    /// Generate reasoning for precision recommendation
294    fn generate_reasoning(
295        &self,
296        op_name: &str,
297        scheme: QuantizationScheme,
298        profile: &PrecisionProfile,
299    ) -> String {
300        let scheme_name = match scheme {
301            QuantizationScheme::Int8 => "INT8",
302            QuantizationScheme::Int16 => "INT16",
303            QuantizationScheme::Dynamic => "Dynamic",
304            QuantizationScheme::Fake => "FP32",
305        };
306
307        if profile.quantization_sensitive {
308            format!("Operation '{op_name}' is quantization-sensitive. {scheme_name} provides good balance of performance and accuracy.")
309        } else {
310            format!("Operation '{op_name}' is quantization-friendly. {scheme_name} offers optimal performance with minimal accuracy loss.")
311        }
312    }
313
314    /// Calculate confidence in recommendation
315    fn calculate_confidence(&self, profile: &PrecisionProfile, scheme: QuantizationScheme) -> f32 {
316        let base_confidence = if profile.quantization_sensitive {
317            0.75
318        } else {
319            0.9
320        };
321
322        // Higher confidence for well-supported schemes
323        let scheme_confidence = match scheme {
324            QuantizationScheme::Int8 => 0.9,
325            QuantizationScheme::Int16 => 0.85,
326            QuantizationScheme::Dynamic => 0.7,
327            QuantizationScheme::Fake => 0.6,
328        };
329
330        // Additional bonus for schemes that match the operation's recommended scheme
331        let recommendation_bonus = if scheme == profile.recommended_scheme {
332            1.1
333        } else {
334            1.0
335        };
336
337        let confidence: f32 = base_confidence * scheme_confidence * recommendation_bonus;
338        confidence.min(1.0)
339    }
340
341    /// Optimize precision recommendations at graph level
342    fn optimize_graph_precision(
343        &self,
344        recommendations: &mut HashMap<NodeIndex, PrecisionRecommendation>,
345        _graph: &FxGraph,
346    ) -> TorshResult<()> {
347        // For now, just ensure consistency - in practice, this would:
348        // 1. Minimize precision mismatches between connected operations
349        // 2. Avoid unnecessary conversions
350        // 3. Consider memory bandwidth and cache efficiency
351
352        // Simple optimization: prefer consistent precision in chains
353        for recommendation in recommendations.values_mut() {
354            if recommendation.confidence < 0.5 {
355                // Low confidence - use conservative INT16
356                recommendation.scheme = QuantizationScheme::Int16;
357                recommendation.reasoning = format!(
358                    "Conservative choice due to low confidence: {}",
359                    recommendation.reasoning
360                );
361            }
362        }
363
364        Ok(())
365    }
366
367    /// Create default precision profiles for common operations
368    fn create_default_profiles() -> HashMap<String, PrecisionProfile> {
369        let mut profiles = HashMap::new();
370
371        // Matrix operations - generally quantization-friendly
372        profiles.insert(
373            "matmul".to_string(),
374            PrecisionProfile {
375                recommended_scheme: QuantizationScheme::Int8,
376                accuracy_impact: [
377                    (QuantizationScheme::Int8, 0.015),
378                    (QuantizationScheme::Int16, 0.005),
379                    (QuantizationScheme::Dynamic, 0.008),
380                    (QuantizationScheme::Fake, 0.0),
381                ]
382                .iter()
383                .cloned()
384                .collect(),
385                performance_gain: [
386                    (QuantizationScheme::Int8, 2.5),
387                    (QuantizationScheme::Int16, 2.2),
388                    (QuantizationScheme::Dynamic, 2.1),
389                    (QuantizationScheme::Fake, 1.0),
390                ]
391                .iter()
392                .cloned()
393                .collect(),
394                quantization_sensitive: false,
395            },
396        );
397
398        // Convolution operations - quantization-friendly
399        profiles.insert(
400            "conv2d".to_string(),
401            PrecisionProfile {
402                recommended_scheme: QuantizationScheme::Int8,
403                accuracy_impact: [
404                    (QuantizationScheme::Int8, 0.03),
405                    (QuantizationScheme::Int16, 0.008),
406                    (QuantizationScheme::Dynamic, 0.015),
407                    (QuantizationScheme::Fake, 0.0),
408                ]
409                .iter()
410                .cloned()
411                .collect(),
412                performance_gain: [
413                    (QuantizationScheme::Int8, 3.0),
414                    (QuantizationScheme::Int16, 2.0),
415                    (QuantizationScheme::Dynamic, 1.5),
416                    (QuantizationScheme::Fake, 1.0),
417                ]
418                .iter()
419                .cloned()
420                .collect(),
421                quantization_sensitive: false,
422            },
423        );
424
425        // Attention operations - more quantization-sensitive
426        profiles.insert(
427            "attention".to_string(),
428            PrecisionProfile {
429                recommended_scheme: QuantizationScheme::Int16,
430                accuracy_impact: [
431                    (QuantizationScheme::Int8, 0.08),
432                    (QuantizationScheme::Int16, 0.02),
433                    (QuantizationScheme::Dynamic, 0.04),
434                    (QuantizationScheme::Fake, 0.0),
435                ]
436                .iter()
437                .cloned()
438                .collect(),
439                performance_gain: [
440                    (QuantizationScheme::Int8, 2.0),
441                    (QuantizationScheme::Int16, 1.6),
442                    (QuantizationScheme::Dynamic, 1.3),
443                    (QuantizationScheme::Fake, 1.0),
444                ]
445                .iter()
446                .cloned()
447                .collect(),
448                quantization_sensitive: true,
449            },
450        );
451
452        // Activation functions - generally quantization-friendly
453        profiles.insert(
454            "relu".to_string(),
455            PrecisionProfile {
456                recommended_scheme: QuantizationScheme::Int8,
457                accuracy_impact: [
458                    (QuantizationScheme::Int8, 0.001),
459                    (QuantizationScheme::Int16, 0.0005),
460                    (QuantizationScheme::Dynamic, 0.0008),
461                    (QuantizationScheme::Fake, 0.0),
462                ]
463                .iter()
464                .cloned()
465                .collect(),
466                performance_gain: [
467                    (QuantizationScheme::Int8, 1.8),
468                    (QuantizationScheme::Int16, 1.4),
469                    (QuantizationScheme::Dynamic, 1.2),
470                    (QuantizationScheme::Fake, 1.0),
471                ]
472                .iter()
473                .cloned()
474                .collect(),
475                quantization_sensitive: false,
476            },
477        );
478
479        profiles
480    }
481
482    /// Create generic profile for unknown operations
483    fn create_generic_profile(&self, _op_name: &str) -> PrecisionProfile {
484        PrecisionProfile {
485            recommended_scheme: QuantizationScheme::Int16, // Conservative default
486            accuracy_impact: [
487                (QuantizationScheme::Int8, 0.015),
488                (QuantizationScheme::Int16, 0.005),
489                (QuantizationScheme::Dynamic, 0.01),
490                (QuantizationScheme::Fake, 0.0),
491            ]
492            .iter()
493            .cloned()
494            .collect(),
495            performance_gain: [
496                (QuantizationScheme::Int8, 2.0),
497                (QuantizationScheme::Int16, 1.5),
498                (QuantizationScheme::Dynamic, 1.3),
499                (QuantizationScheme::Fake, 1.0),
500            ]
501            .iter()
502            .cloned()
503            .collect(),
504            quantization_sensitive: true, // Conservative default
505        }
506    }
507}
508
509/// Convenience function to perform automatic precision selection
510pub fn select_automatic_precision(
511    graph: &FxGraph,
512    criteria: PrecisionCriteria,
513) -> TorshResult<HashMap<NodeIndex, PrecisionRecommendation>> {
514    let selector = AutomaticPrecisionSelector::new(criteria);
515    selector.analyze_graph(graph)
516}
517
518/// Apply automatic precision selection to a graph
519pub fn apply_automatic_precision(
520    graph: &mut FxGraph,
521    criteria: PrecisionCriteria,
522) -> TorshResult<QuantizationContext> {
523    let recommendations = select_automatic_precision(graph, criteria)?;
524
525    let mut context = QuantizationContext::new(QuantizationScheme::Int8);
526
527    // Apply recommendations to the graph
528    for (node_idx, recommendation) in recommendations {
529        let params = QuantizationParams::symmetric(recommendation.scheme, 0.1);
530        let annotation = QuantizationAnnotation {
531            input_params: vec![Some(params.clone())],
532            output_params: Some(params),
533            calibration_data: None,
534        };
535
536        context.annotate_node(node_idx, annotation);
537    }
538
539    Ok(context)
540}