Skip to main content

torsh_fx/quantization/
context.rs

1//! Quantization context for managing quantization state
2
3use super::types::{
4    CalibrationData, QuantizationAnnotation, QuantizationParams, QuantizationScheme,
5};
6use crate::{FxGraph, Node, TorshResult};
7use petgraph::graph::NodeIndex;
8use std::collections::HashMap;
9use torsh_core::error::TorshError;
10
11/// Quantization context for managing quantization state
12pub struct QuantizationContext {
13    annotations: HashMap<NodeIndex, QuantizationAnnotation>,
14    calibration_data: HashMap<NodeIndex, CalibrationData>,
15    global_scheme: QuantizationScheme,
16}
17
18impl QuantizationContext {
19    /// Create new quantization context
20    pub fn new(scheme: QuantizationScheme) -> Self {
21        Self {
22            annotations: HashMap::new(),
23            calibration_data: HashMap::new(),
24            global_scheme: scheme,
25        }
26    }
27
28    /// Add quantization annotation for a node
29    pub fn annotate_node(&mut self, node: NodeIndex, annotation: QuantizationAnnotation) {
30        self.annotations.insert(node, annotation);
31    }
32
33    /// Get all annotations (for testing)
34    #[cfg(test)]
35    pub fn annotations(&self) -> &HashMap<NodeIndex, QuantizationAnnotation> {
36        &self.annotations
37    }
38
39    /// Get quantization annotation for a node
40    pub fn get_annotation(&self, node: NodeIndex) -> Option<&QuantizationAnnotation> {
41        self.annotations.get(&node)
42    }
43
44    /// Start calibration for a node
45    pub fn start_calibration(&mut self, node: NodeIndex) {
46        self.calibration_data.insert(node, CalibrationData::new());
47    }
48
49    /// Update calibration data for a node
50    pub fn update_calibration(&mut self, node: NodeIndex, values: &[f32]) -> TorshResult<()> {
51        if let Some(data) = self.calibration_data.get_mut(&node) {
52            data.update(values);
53            Ok(())
54        } else {
55            Err(TorshError::InvalidArgument(format!(
56                "Calibration not started for node {:?}",
57                node
58            )))
59        }
60    }
61
62    /// Finalize calibration and compute quantization parameters
63    pub fn finalize_calibration(&mut self, node: NodeIndex) -> TorshResult<QuantizationParams> {
64        if let Some(data) = self.calibration_data.remove(&node) {
65            Ok(data.compute_params(self.global_scheme))
66        } else {
67            Err(TorshError::InvalidArgument(format!(
68                "No calibration data for node {:?}",
69                node
70            )))
71        }
72    }
73
74    /// Prepare graph for quantization-aware training (QAT)
75    pub fn prepare_qat(&mut self, graph: &mut FxGraph) -> TorshResult<()> {
76        // Insert fake quantize nodes before operations that benefit from quantization
77        let mut insertions = Vec::new();
78
79        for (idx, node) in graph.nodes() {
80            if let Node::Call(op_name, _) = node {
81                if self.should_quantize_operation(op_name) {
82                    insertions.push(idx);
83                }
84            }
85        }
86
87        // Insert fake quantize nodes
88        for node_idx in insertions {
89            self.insert_fake_quantize_node(graph, node_idx)?;
90        }
91
92        Ok(())
93    }
94
95    /// Convert graph to quantized format
96    pub fn quantize_graph(&self, graph: &mut FxGraph) -> TorshResult<()> {
97        // Replace operations with quantized versions
98        let mut replacements = Vec::new();
99
100        for (idx, node) in graph.nodes() {
101            if let Node::Call(op_name, args) = node {
102                if self.should_quantize_operation(op_name) {
103                    let quantized_op = self.get_quantized_operation(op_name);
104                    replacements.push((idx, quantized_op, args.clone()));
105                }
106            }
107        }
108
109        // Apply replacements
110        for (idx, quantized_op, args) in replacements {
111            if let Some(node) = graph.graph.node_weight_mut(idx) {
112                *node = Node::Call(quantized_op, args);
113            }
114        }
115
116        Ok(())
117    }
118
119    /// Check if an operation should be quantized
120    fn should_quantize_operation(&self, op_name: &str) -> bool {
121        matches!(op_name, "linear" | "conv2d" | "matmul" | "add" | "mul")
122    }
123
124    /// Get quantized version of an operation
125    fn get_quantized_operation(&self, op_name: &str) -> String {
126        match op_name {
127            "linear" => "quantized_linear".to_string(),
128            "conv2d" => "quantized_conv2d".to_string(),
129            "matmul" => "quantized_matmul".to_string(),
130            "add" => "quantized_add".to_string(),
131            "mul" => "quantized_mul".to_string(),
132            _ => format!("quantized_{op_name}"),
133        }
134    }
135
136    /// Insert fake quantize node for QAT
137    fn insert_fake_quantize_node(
138        &mut self,
139        _graph: &mut FxGraph,
140        target_idx: NodeIndex,
141    ) -> TorshResult<()> {
142        // In a full implementation, this would insert a FakeQuantize node
143        // before the target operation. For now, we'll just annotate the node.
144
145        let annotation = QuantizationAnnotation {
146            input_params: vec![Some(QuantizationParams::symmetric(self.global_scheme, 1.0))],
147            output_params: Some(QuantizationParams::symmetric(self.global_scheme, 1.0)),
148            calibration_data: None,
149        };
150
151        self.annotate_node(target_idx, annotation);
152        Ok(())
153    }
154}