torsh_quantization/
utils.rs1use crate::{config::QuantConfig, observers::Observer};
8use torsh_core::{error::Result as TorshResult, TorshError};
9use torsh_tensor::Tensor;
10fn quantize_per_tensor(
14 tensor: &Tensor,
15 scale: f32,
16 zero_point: i32,
17 _dtype: torsh_core::DType,
18) -> TorshResult<Tensor> {
19 let (quantized, _, _) =
20 crate::algorithms::quantize_per_tensor_affine(tensor, scale, zero_point)?;
21 Ok(quantized)
22}
23
24#[allow(dead_code)]
26fn dequantize(tensor: &Tensor, scale: f32, zero_point: i32) -> TorshResult<Tensor> {
27 crate::algorithms::dequantize_per_tensor_affine(tensor, scale, zero_point)
28}
29
30pub fn validate_config_with_suggestions(config: &QuantConfig) -> TorshResult<Vec<String>> {
41 use crate::config::{ObserverType, QScheme, QuantBackend};
42
43 let mut suggestions = Vec::new();
44
45 config.validate()?;
47
48 match config.scheme {
50 QScheme::PerChannelAffine | QScheme::PerChannelSymmetric => {
51 if config.observer_type == ObserverType::MinMax {
52 suggestions.push("Consider using Histogram observer for per-channel quantization for better accuracy".to_string());
53 }
54 }
55 QScheme::GroupWise => {
56 if let Some(group_size) = config.group_size {
57 if group_size < 8 {
58 suggestions.push("Very small group sizes may not provide significant benefits over per-channel quantization".to_string());
59 } else if group_size > 128 {
60 suggestions.push(
61 "Large group sizes may reduce the benefits of group-wise quantization"
62 .to_string(),
63 );
64 }
65 }
66 }
67 QScheme::Int4PerTensor | QScheme::Int4PerChannel => {
68 if config.observer_type == ObserverType::MinMax {
69 suggestions.push("Consider using Histogram observer for INT4 quantization to handle outliers better".to_string());
70 }
71 }
72 QScheme::Binary | QScheme::Ternary => {
73 if config.observer_type != ObserverType::MinMax {
74 suggestions.push(
75 "MinMax observer is typically sufficient for binary/ternary quantization"
76 .to_string(),
77 );
78 }
79 }
80 _ => {}
81 }
82
83 if config.backend == QuantBackend::Native {
85 suggestions.push(
86 "Consider using FBGEMM or QNNPACK backends for better performance in production"
87 .to_string(),
88 );
89 }
90
91 if config.enable_fake_quant && config.observer_type != ObserverType::MovingAverage {
93 suggestions
94 .push("MovingAverage observer is recommended for QAT (fake quantization)".to_string());
95 }
96
97 Ok(suggestions)
98}
99
100pub fn create_optimized_config(use_case: &str, target_platform: &str) -> TorshResult<QuantConfig> {
111 use crate::config::{ObserverType, QuantBackend, ReduceRange};
112
113 let base_config = match use_case.to_lowercase().as_str() {
114 "inference_cpu" => QuantConfig::int8()
115 .with_backend(QuantBackend::Fbgemm)
116 .with_observer(ObserverType::Histogram),
117 "inference_mobile" => QuantConfig::int8()
118 .with_backend(QuantBackend::Qnnpack)
119 .with_observer(ObserverType::MinMax)
120 .with_reduce_range(ReduceRange::Reduce),
121 "training" => QuantConfig::qat().with_observer(ObserverType::MovingAverage),
122 "extreme_compression" => QuantConfig::int4().with_observer(ObserverType::Histogram),
123 "transformers" => QuantConfig::group_wise(0, 64).with_observer(ObserverType::Histogram),
124 "edge_device" => QuantConfig::binary().with_observer(ObserverType::MinMax),
125 _ => {
126 return Err(TorshError::InvalidArgument(format!(
127 "Unknown use case: {use_case}"
128 )))
129 }
130 };
131
132 let optimized_config = match target_platform.to_lowercase().as_str() {
133 "x86" | "x64" => base_config.with_backend(QuantBackend::Fbgemm),
134 "arm" | "mobile" => base_config.with_backend(QuantBackend::Qnnpack),
135 "gpu" => base_config.with_backend(QuantBackend::Native),
136 _ => base_config,
137 };
138
139 Ok(optimized_config)
140}
141
142pub fn quantize_batch_consistent(
154 tensors: &[&Tensor],
155 config: &QuantConfig,
156) -> TorshResult<Vec<(Tensor, f32, i32)>> {
157 if tensors.is_empty() {
158 return Ok(Vec::new());
159 }
160
161 let mut global_observer = Observer::new(config.observer_type);
163
164 for tensor in tensors {
165 global_observer.update(tensor)?;
166 }
167
168 let (global_scale, global_zero_point) = global_observer.calculate_qparams(config.dtype)?;
169
170 let mut results = Vec::new();
172 for tensor in tensors {
173 let quantized = quantize_per_tensor(tensor, global_scale, global_zero_point, config.dtype)?;
174 results.push((quantized, global_scale, global_zero_point));
175 }
176
177 Ok(results)
178}
179
180pub fn diagnose_quantization_failure(
193 tensor: &Tensor,
194 config: &QuantConfig,
195 error: &TorshError,
196) -> String {
197 let mut diagnosis = format!("Quantization failed with error: {error}\n\n");
198
199 let shape = tensor.shape();
201 let data_result = tensor.data();
202
203 diagnosis.push_str("Tensor Analysis:\n");
204 diagnosis.push_str(&format!(" Shape: {:?}\n", shape.dims()));
205 diagnosis.push_str(&format!(" Total elements: {}\n", shape.numel()));
206
207 if let Ok(data) = data_result {
208 let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
209 let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
210 let has_nan = data.iter().any(|&x| x.is_nan());
211 let has_inf = data.iter().any(|&x| x.is_infinite());
212
213 diagnosis.push_str(&format!(" Data range: [{min_val:.6}, {max_val:.6}]\n"));
214 diagnosis.push_str(&format!(" Contains NaN: {has_nan}\n"));
215 diagnosis.push_str(&format!(" Contains Inf: {has_inf}\n"));
216
217 if has_nan || has_inf {
218 diagnosis.push_str(
219 "\nSuggestion: Clean tensor data to remove NaN/Inf values before quantization.\n",
220 );
221 }
222
223 if max_val - min_val < 1e-6 {
224 diagnosis.push_str("\nSuggestion: Tensor has very small dynamic range. Consider using a different tensor or adjusting the quantization scheme.\n");
225 }
226 }
227
228 diagnosis.push_str("\nConfiguration Analysis:\n");
230 diagnosis.push_str(&format!(" Scheme: {:?}\n", config.scheme));
231 diagnosis.push_str(&format!(" Observer: {:?}\n", config.observer_type));
232 diagnosis.push_str(&format!(" Backend: {:?}\n", config.backend));
233
234 match config.validate() {
235 Ok(_) => diagnosis.push_str(" Configuration is valid\n"),
236 Err(e) => diagnosis.push_str(&format!(" Configuration error: {e}\n")),
237 }
238
239 diagnosis.push_str("\nRecovery Suggestions:\n");
241 diagnosis.push_str(
242 "1. Try a simpler quantization scheme (e.g., PerTensorAffine with MinMax observer)\n",
243 );
244 diagnosis.push_str("2. Use quantize_with_fallback() for automatic error recovery\n");
245 diagnosis.push_str("3. Check tensor data for NaN/Inf values\n");
246 diagnosis.push_str("4. Ensure tensor has sufficient dynamic range\n");
247 diagnosis
248 .push_str("5. Try a different observer type (Histogram for outlier-robust quantization)\n");
249
250 diagnosis
251}
252
253pub fn get_optimization_hints(tensor: &Tensor, config: &QuantConfig) -> Vec<String> {
264 use crate::config::{ObserverType, QScheme};
265
266 let mut hints = Vec::new();
267 let shape = tensor.shape();
268 let numel = shape.numel();
269
270 if numel > 1_000_000 {
272 hints.push("Large tensor detected. Consider using parallel processing with Rayon for better performance.".to_string());
273 if config.observer_type == ObserverType::Percentile {
274 hints.push("For large tensors, Histogram observer may be more memory-efficient than Percentile observer.".to_string());
275 }
276 }
277
278 if shape.dims().len() >= 2 && shape.dims().iter().any(|&dim| dim > 16) {
280 hints.push("Multi-channel tensor detected. Per-channel or group-wise quantization may provide better accuracy.".to_string());
281 }
282
283 match config.scheme {
285 QScheme::PerTensorAffine | QScheme::PerTensorSymmetric => {
286 if shape.dims().len() > 2 {
287 hints.push("Consider per-channel quantization for better accuracy with multi-dimensional tensors.".to_string());
288 }
289 }
290 QScheme::GroupWise => {
291 if let Some(group_size) = config.group_size {
292 let total_elements = shape.dims().iter().product::<usize>();
293 if total_elements / group_size < 4 {
294 hints.push("Too few groups for group-wise quantization. Consider per-tensor quantization instead.".to_string());
295 }
296 }
297 }
298 QScheme::Int4PerTensor | QScheme::Int4PerChannel => {
299 hints.push("INT4 quantization detected. Ensure your inference backend supports INT4 operations.".to_string());
300 }
301 QScheme::Binary | QScheme::Ternary => {
302 hints.push(
303 "Extreme quantization scheme detected. Verify accuracy requirements are met."
304 .to_string(),
305 );
306 }
307 _ => {}
308 }
309
310 hints
311}
312
313pub fn export_config_to_json(config: &QuantConfig) -> TorshResult<String> {
323 match serde_json::to_string_pretty(config) {
324 Ok(json) => Ok(json),
325 Err(e) => Err(TorshError::InvalidArgument(format!(
326 "Failed to serialize config: {e}"
327 ))),
328 }
329}
330
331pub fn import_config_from_json(json: &str) -> TorshResult<QuantConfig> {
341 match serde_json::from_str(json) {
342 Ok(config) => Ok(config),
343 Err(e) => Err(TorshError::InvalidArgument(format!(
344 "Failed to deserialize config: {e}"
345 ))),
346 }
347}