Skip to main content

torsh_quantization/
utils.rs

1//! Quantization utilities and helper functions
2//!
3//! This module provides a comprehensive set of utility functions for quantization operations
4//! including configuration validation, batch processing, error diagnostics, performance
5//! optimization, benchmarking, and reporting tools.
6
7use crate::{config::QuantConfig, observers::Observer};
8use torsh_core::{error::Result as TorshResult, TorshError};
9use torsh_tensor::Tensor;
10// Note: ops module with dequantize and quantize_per_tensor not found, will implement locally
11
12/// Implementation of quantize_per_tensor using algorithms module
13fn quantize_per_tensor(
14    tensor: &Tensor,
15    scale: f32,
16    zero_point: i32,
17    _dtype: torsh_core::DType,
18) -> TorshResult<Tensor> {
19    let (quantized, _, _) =
20        crate::algorithms::quantize_per_tensor_affine(tensor, scale, zero_point)?;
21    Ok(quantized)
22}
23
24/// Implementation of dequantize using algorithms module
25#[allow(dead_code)]
26fn dequantize(tensor: &Tensor, scale: f32, zero_point: i32) -> TorshResult<Tensor> {
27    crate::algorithms::dequantize_per_tensor_affine(tensor, scale, zero_point)
28}
29
30/// Enhanced configuration validator with helpful suggestions
31///
32/// Validates a quantization configuration and provides performance and accuracy suggestions
33/// based on the configuration parameters.
34///
35/// # Arguments
36/// * `config` - The quantization configuration to validate
37///
38/// # Returns
39/// A vector of suggestion strings for optimization
40pub fn validate_config_with_suggestions(config: &QuantConfig) -> TorshResult<Vec<String>> {
41    use crate::config::{ObserverType, QScheme, QuantBackend};
42
43    let mut suggestions = Vec::new();
44
45    // Run basic validation first
46    config.validate()?;
47
48    // Add performance suggestions
49    match config.scheme {
50        QScheme::PerChannelAffine | QScheme::PerChannelSymmetric => {
51            if config.observer_type == ObserverType::MinMax {
52                suggestions.push("Consider using Histogram observer for per-channel quantization for better accuracy".to_string());
53            }
54        }
55        QScheme::GroupWise => {
56            if let Some(group_size) = config.group_size {
57                if group_size < 8 {
58                    suggestions.push("Very small group sizes may not provide significant benefits over per-channel quantization".to_string());
59                } else if group_size > 128 {
60                    suggestions.push(
61                        "Large group sizes may reduce the benefits of group-wise quantization"
62                            .to_string(),
63                    );
64                }
65            }
66        }
67        QScheme::Int4PerTensor | QScheme::Int4PerChannel => {
68            if config.observer_type == ObserverType::MinMax {
69                suggestions.push("Consider using Histogram observer for INT4 quantization to handle outliers better".to_string());
70            }
71        }
72        QScheme::Binary | QScheme::Ternary => {
73            if config.observer_type != ObserverType::MinMax {
74                suggestions.push(
75                    "MinMax observer is typically sufficient for binary/ternary quantization"
76                        .to_string(),
77                );
78            }
79        }
80        _ => {}
81    }
82
83    // Backend suggestions
84    if config.backend == QuantBackend::Native {
85        suggestions.push(
86            "Consider using FBGEMM or QNNPACK backends for better performance in production"
87                .to_string(),
88        );
89    }
90
91    // Observer suggestions
92    if config.enable_fake_quant && config.observer_type != ObserverType::MovingAverage {
93        suggestions
94            .push("MovingAverage observer is recommended for QAT (fake quantization)".to_string());
95    }
96
97    Ok(suggestions)
98}
99
100/// Create optimized configuration for common use cases
101///
102/// Generates optimized quantization configurations for specific use cases and target platforms.
103///
104/// # Arguments
105/// * `use_case` - The target use case ("inference_cpu", "inference_mobile", "training", etc.)
106/// * `target_platform` - The target platform ("x86", "arm", "gpu", etc.)
107///
108/// # Returns
109/// An optimized quantization configuration
110pub fn create_optimized_config(use_case: &str, target_platform: &str) -> TorshResult<QuantConfig> {
111    use crate::config::{ObserverType, QuantBackend, ReduceRange};
112
113    let base_config = match use_case.to_lowercase().as_str() {
114        "inference_cpu" => QuantConfig::int8()
115            .with_backend(QuantBackend::Fbgemm)
116            .with_observer(ObserverType::Histogram),
117        "inference_mobile" => QuantConfig::int8()
118            .with_backend(QuantBackend::Qnnpack)
119            .with_observer(ObserverType::MinMax)
120            .with_reduce_range(ReduceRange::Reduce),
121        "training" => QuantConfig::qat().with_observer(ObserverType::MovingAverage),
122        "extreme_compression" => QuantConfig::int4().with_observer(ObserverType::Histogram),
123        "transformers" => QuantConfig::group_wise(0, 64).with_observer(ObserverType::Histogram),
124        "edge_device" => QuantConfig::binary().with_observer(ObserverType::MinMax),
125        _ => {
126            return Err(TorshError::InvalidArgument(format!(
127                "Unknown use case: {use_case}"
128            )))
129        }
130    };
131
132    let optimized_config = match target_platform.to_lowercase().as_str() {
133        "x86" | "x64" => base_config.with_backend(QuantBackend::Fbgemm),
134        "arm" | "mobile" => base_config.with_backend(QuantBackend::Qnnpack),
135        "gpu" => base_config.with_backend(QuantBackend::Native),
136        _ => base_config,
137    };
138
139    Ok(optimized_config)
140}
141
142/// Batch quantization utility for multiple tensors with consistent parameters
143///
144/// Quantizes multiple tensors using globally consistent parameters calculated across all tensors.
145/// This ensures that all tensors use the same scale and zero point for consistent quantization.
146///
147/// # Arguments
148/// * `tensors` - Slice of tensor references to quantize
149/// * `config` - Quantization configuration to use
150///
151/// # Returns
152/// Vector of quantized tensors with their scale and zero point parameters
153pub fn quantize_batch_consistent(
154    tensors: &[&Tensor],
155    config: &QuantConfig,
156) -> TorshResult<Vec<(Tensor, f32, i32)>> {
157    if tensors.is_empty() {
158        return Ok(Vec::new());
159    }
160
161    // Calculate global statistics across all tensors for consistency
162    let mut global_observer = Observer::new(config.observer_type);
163
164    for tensor in tensors {
165        global_observer.update(tensor)?;
166    }
167
168    let (global_scale, global_zero_point) = global_observer.calculate_qparams(config.dtype)?;
169
170    // Quantize all tensors using the same parameters
171    let mut results = Vec::new();
172    for tensor in tensors {
173        let quantized = quantize_per_tensor(tensor, global_scale, global_zero_point, config.dtype)?;
174        results.push((quantized, global_scale, global_zero_point));
175    }
176
177    Ok(results)
178}
179
180/// Error recovery utility that provides detailed diagnostics
181///
182/// Analyzes a failed quantization operation and provides detailed diagnostics
183/// about the tensor properties, configuration issues, and recovery suggestions.
184///
185/// # Arguments
186/// * `tensor` - The tensor that failed to quantize
187/// * `config` - The quantization configuration that was used
188/// * `error` - The error that occurred during quantization
189///
190/// # Returns
191/// A detailed diagnostic string with analysis and recovery suggestions
192pub fn diagnose_quantization_failure(
193    tensor: &Tensor,
194    config: &QuantConfig,
195    error: &TorshError,
196) -> String {
197    let mut diagnosis = format!("Quantization failed with error: {error}\n\n");
198
199    // Analyze tensor properties
200    let shape = tensor.shape();
201    let data_result = tensor.data();
202
203    diagnosis.push_str("Tensor Analysis:\n");
204    diagnosis.push_str(&format!("  Shape: {:?}\n", shape.dims()));
205    diagnosis.push_str(&format!("  Total elements: {}\n", shape.numel()));
206
207    if let Ok(data) = data_result {
208        let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
209        let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
210        let has_nan = data.iter().any(|&x| x.is_nan());
211        let has_inf = data.iter().any(|&x| x.is_infinite());
212
213        diagnosis.push_str(&format!("  Data range: [{min_val:.6}, {max_val:.6}]\n"));
214        diagnosis.push_str(&format!("  Contains NaN: {has_nan}\n"));
215        diagnosis.push_str(&format!("  Contains Inf: {has_inf}\n"));
216
217        if has_nan || has_inf {
218            diagnosis.push_str(
219                "\nSuggestion: Clean tensor data to remove NaN/Inf values before quantization.\n",
220            );
221        }
222
223        if max_val - min_val < 1e-6 {
224            diagnosis.push_str("\nSuggestion: Tensor has very small dynamic range. Consider using a different tensor or adjusting the quantization scheme.\n");
225        }
226    }
227
228    // Analyze configuration
229    diagnosis.push_str("\nConfiguration Analysis:\n");
230    diagnosis.push_str(&format!("  Scheme: {:?}\n", config.scheme));
231    diagnosis.push_str(&format!("  Observer: {:?}\n", config.observer_type));
232    diagnosis.push_str(&format!("  Backend: {:?}\n", config.backend));
233
234    match config.validate() {
235        Ok(_) => diagnosis.push_str("  Configuration is valid\n"),
236        Err(e) => diagnosis.push_str(&format!("  Configuration error: {e}\n")),
237    }
238
239    // Provide recovery suggestions
240    diagnosis.push_str("\nRecovery Suggestions:\n");
241    diagnosis.push_str(
242        "1. Try a simpler quantization scheme (e.g., PerTensorAffine with MinMax observer)\n",
243    );
244    diagnosis.push_str("2. Use quantize_with_fallback() for automatic error recovery\n");
245    diagnosis.push_str("3. Check tensor data for NaN/Inf values\n");
246    diagnosis.push_str("4. Ensure tensor has sufficient dynamic range\n");
247    diagnosis
248        .push_str("5. Try a different observer type (Histogram for outlier-robust quantization)\n");
249
250    diagnosis
251}
252
253/// Performance optimization hints based on tensor characteristics
254///
255/// Analyzes tensor properties and provides optimization hints for better quantization performance.
256///
257/// # Arguments
258/// * `tensor` - The tensor to analyze
259/// * `config` - The quantization configuration
260///
261/// # Returns
262/// Vector of optimization hint strings
263pub fn get_optimization_hints(tensor: &Tensor, config: &QuantConfig) -> Vec<String> {
264    use crate::config::{ObserverType, QScheme};
265
266    let mut hints = Vec::new();
267    let shape = tensor.shape();
268    let numel = shape.numel();
269
270    // Size-based hints
271    if numel > 1_000_000 {
272        hints.push("Large tensor detected. Consider using parallel processing with Rayon for better performance.".to_string());
273        if config.observer_type == ObserverType::Percentile {
274            hints.push("For large tensors, Histogram observer may be more memory-efficient than Percentile observer.".to_string());
275        }
276    }
277
278    // Shape-based hints
279    if shape.dims().len() >= 2 && shape.dims().iter().any(|&dim| dim > 16) {
280        hints.push("Multi-channel tensor detected. Per-channel or group-wise quantization may provide better accuracy.".to_string());
281    }
282
283    // Scheme-specific hints
284    match config.scheme {
285        QScheme::PerTensorAffine | QScheme::PerTensorSymmetric => {
286            if shape.dims().len() > 2 {
287                hints.push("Consider per-channel quantization for better accuracy with multi-dimensional tensors.".to_string());
288            }
289        }
290        QScheme::GroupWise => {
291            if let Some(group_size) = config.group_size {
292                let total_elements = shape.dims().iter().product::<usize>();
293                if total_elements / group_size < 4 {
294                    hints.push("Too few groups for group-wise quantization. Consider per-tensor quantization instead.".to_string());
295                }
296            }
297        }
298        QScheme::Int4PerTensor | QScheme::Int4PerChannel => {
299            hints.push("INT4 quantization detected. Ensure your inference backend supports INT4 operations.".to_string());
300        }
301        QScheme::Binary | QScheme::Ternary => {
302            hints.push(
303                "Extreme quantization scheme detected. Verify accuracy requirements are met."
304                    .to_string(),
305            );
306        }
307        _ => {}
308    }
309
310    hints
311}
312
313/// Export quantization configuration to JSON string
314///
315/// Serializes a quantization configuration to a JSON string for persistence or transfer.
316///
317/// # Arguments
318/// * `config` - The quantization configuration to export
319///
320/// # Returns
321/// A JSON string representation of the configuration
322pub fn export_config_to_json(config: &QuantConfig) -> TorshResult<String> {
323    match serde_json::to_string_pretty(config) {
324        Ok(json) => Ok(json),
325        Err(e) => Err(TorshError::InvalidArgument(format!(
326            "Failed to serialize config: {e}"
327        ))),
328    }
329}
330
331/// Import quantization configuration from JSON string
332///
333/// Deserializes a quantization configuration from a JSON string.
334///
335/// # Arguments
336/// * `json` - The JSON string to deserialize
337///
338/// # Returns
339/// The deserialized quantization configuration
340pub fn import_config_from_json(json: &str) -> TorshResult<QuantConfig> {
341    match serde_json::from_str(json) {
342        Ok(config) => Ok(config),
343        Err(e) => Err(TorshError::InvalidArgument(format!(
344            "Failed to deserialize config: {e}"
345        ))),
346    }
347}