torsh_fx/quantization/
context.rs1use super::types::{
4 CalibrationData, QuantizationAnnotation, QuantizationParams, QuantizationScheme,
5};
6use crate::{FxGraph, Node, TorshResult};
7use petgraph::graph::NodeIndex;
8use std::collections::HashMap;
9use torsh_core::error::TorshError;
10
11pub struct QuantizationContext {
13 annotations: HashMap<NodeIndex, QuantizationAnnotation>,
14 calibration_data: HashMap<NodeIndex, CalibrationData>,
15 global_scheme: QuantizationScheme,
16}
17
18impl QuantizationContext {
19 pub fn new(scheme: QuantizationScheme) -> Self {
21 Self {
22 annotations: HashMap::new(),
23 calibration_data: HashMap::new(),
24 global_scheme: scheme,
25 }
26 }
27
28 pub fn annotate_node(&mut self, node: NodeIndex, annotation: QuantizationAnnotation) {
30 self.annotations.insert(node, annotation);
31 }
32
33 #[cfg(test)]
35 pub fn annotations(&self) -> &HashMap<NodeIndex, QuantizationAnnotation> {
36 &self.annotations
37 }
38
39 pub fn get_annotation(&self, node: NodeIndex) -> Option<&QuantizationAnnotation> {
41 self.annotations.get(&node)
42 }
43
44 pub fn start_calibration(&mut self, node: NodeIndex) {
46 self.calibration_data.insert(node, CalibrationData::new());
47 }
48
49 pub fn update_calibration(&mut self, node: NodeIndex, values: &[f32]) -> TorshResult<()> {
51 if let Some(data) = self.calibration_data.get_mut(&node) {
52 data.update(values);
53 Ok(())
54 } else {
55 Err(TorshError::InvalidArgument(format!(
56 "Calibration not started for node {:?}",
57 node
58 )))
59 }
60 }
61
62 pub fn finalize_calibration(&mut self, node: NodeIndex) -> TorshResult<QuantizationParams> {
64 if let Some(data) = self.calibration_data.remove(&node) {
65 Ok(data.compute_params(self.global_scheme))
66 } else {
67 Err(TorshError::InvalidArgument(format!(
68 "No calibration data for node {:?}",
69 node
70 )))
71 }
72 }
73
74 pub fn prepare_qat(&mut self, graph: &mut FxGraph) -> TorshResult<()> {
76 let mut insertions = Vec::new();
78
79 for (idx, node) in graph.nodes() {
80 if let Node::Call(op_name, _) = node {
81 if self.should_quantize_operation(op_name) {
82 insertions.push(idx);
83 }
84 }
85 }
86
87 for node_idx in insertions {
89 self.insert_fake_quantize_node(graph, node_idx)?;
90 }
91
92 Ok(())
93 }
94
95 pub fn quantize_graph(&self, graph: &mut FxGraph) -> TorshResult<()> {
97 let mut replacements = Vec::new();
99
100 for (idx, node) in graph.nodes() {
101 if let Node::Call(op_name, args) = node {
102 if self.should_quantize_operation(op_name) {
103 let quantized_op = self.get_quantized_operation(op_name);
104 replacements.push((idx, quantized_op, args.clone()));
105 }
106 }
107 }
108
109 for (idx, quantized_op, args) in replacements {
111 if let Some(node) = graph.graph.node_weight_mut(idx) {
112 *node = Node::Call(quantized_op, args);
113 }
114 }
115
116 Ok(())
117 }
118
119 fn should_quantize_operation(&self, op_name: &str) -> bool {
121 matches!(op_name, "linear" | "conv2d" | "matmul" | "add" | "mul")
122 }
123
124 fn get_quantized_operation(&self, op_name: &str) -> String {
126 match op_name {
127 "linear" => "quantized_linear".to_string(),
128 "conv2d" => "quantized_conv2d".to_string(),
129 "matmul" => "quantized_matmul".to_string(),
130 "add" => "quantized_add".to_string(),
131 "mul" => "quantized_mul".to_string(),
132 _ => format!("quantized_{op_name}"),
133 }
134 }
135
136 fn insert_fake_quantize_node(
138 &mut self,
139 _graph: &mut FxGraph,
140 target_idx: NodeIndex,
141 ) -> TorshResult<()> {
142 let annotation = QuantizationAnnotation {
146 input_params: vec![Some(QuantizationParams::symmetric(self.global_scheme, 1.0))],
147 output_params: Some(QuantizationParams::symmetric(self.global_scheme, 1.0)),
148 calibration_data: None,
149 };
150
151 self.annotate_node(target_idx, annotation);
152 Ok(())
153 }
154}