rustorch/quantization/
mod.rs

1//! Quantization support for RusTorch - Phase 11 Implementation
2//! RusTorch用量子化サポート - フェーズ11実装
3//!
4//! This module provides comprehensive quantization support for deep learning models,
5//! enabling efficient inference and training with reduced precision arithmetic.
6//!
7//! このモジュールは深層学習モデルの包括的な量子化サポートを提供し、
8//! 精度を下げた算術演算による効率的な推論と学習を可能にします。
9//!
10//! ## Key Features
11//!
12//! ### Dynamic Quantization
13//! - Runtime quantization of weights and activations
14//! - Automatic calibration using statistical observers
15//! - Per-tensor and per-channel quantization schemes
16//!
17//! ### Static Quantization  
18//! - Pre-calibrated quantization parameters
19//! - Optimal for deployment scenarios
20//! - Hardware-accelerated operations
21//!
22//! ### Quantization-Aware Training (QAT)
23//! - Training with quantization simulation
24//! - Straight-through estimators for gradients
25//! - Fine-tuning of quantized models
26//!
27//! ### Hardware Optimization
28//! - CPU SIMD optimizations for quantized operations
29//! - CUDA kernels for GPU acceleration
30//! - Metal Performance Shaders for Apple Silicon
31//!
32//! ## Quantization Schemes
33//!
34//! ### Symmetric Quantization
35//! ```text
36//! quantized = round(fp32_value / scale) + zero_point
37//! dequantized = (quantized - zero_point) * scale
38//! ```
39//!
40//! ### Asymmetric Quantization
41//! ```text
42//! quantized = round(fp32_value / scale)
43//! dequantized = quantized * scale  
44//! ```
45//!
46//! ## Usage Examples
47//!
48//! ```rust
49//! use rustorch::quantization::{QuantizedTensor, QuantizationScheme, StaticQuantizer, TensorQuantization};
50//! use rustorch::tensor::Tensor;
51//! # use rustorch::error::RusTorchResult;
52//! #
53//! # fn main() -> RusTorchResult<()> {
54//! // Dynamic quantization
55//! let tensor: Tensor<f32> = Tensor::randn(&[128, 256]);
56//! let quantized = tensor.quantize_dynamic(QuantizationScheme::Symmetric)?;
57//!
58//! // Static quantization with calibration
59//! let mut quantizer = StaticQuantizer::<f32>::new();
60//! quantizer.calibrate(QuantizationScheme::Symmetric)?;
61//! # Ok(())
62//! # }
63//! ```
64
65use crate::error::{RusTorchError, RusTorchResult};
66use crate::tensor::Tensor;
67use ndarray;
68use num_traits::{Float, Signed, Unsigned};
69use std::fmt;
70use std::marker::PhantomData;
71
72// Re-export public API
73pub use calibration::{HistogramObserver, MinMaxObserver, Observer, StaticQuantizer};
74pub use hardware::optimized_ops;
75pub use operations::{DequantizeOps, QuantizedOps};
76pub use qat::{FakeQuantize, QATConv2d, QATLinear, QATModule};
77pub use schemes::{AsymmetricQuantization, QuantizationScheme, SymmetricQuantization};
78pub use types::{QuantizationType, QuantizedTensor};
79
80// Export unified utilities - struct defined below
81
82/// Quantized tensor data types and core structures
83/// 量子化テンソルデータ型とコア構造
84pub mod types;
85
86/// Quantization schemes and algorithms
87/// 量子化スキームとアルゴリズム
88pub mod schemes;
89
90/// Calibration and statistical observation
91/// キャリブレーションと統計観測
92pub mod calibration;
93
94/// Quantized tensor operations
95/// 量子化テンソル演算
96pub mod operations;
97
98/// Quantization-aware training support
99/// 量子化認識学習サポート
100pub mod qat;
101
102/// Hardware-specific optimizations
103/// ハードウェア固有最適化
104pub mod hardware;
105
106/// Statistical observers for calibration
107/// キャリブレーション用統計観測器
108pub mod observers;
109
110/// Unified quantization parameter calculator
111/// 統一量子化パラメータ計算器
112#[derive(Debug, Clone)]
113pub struct QuantParamCalculator;
114
115impl QuantParamCalculator {
116    /// Compute symmetric quantization parameters
117    /// 対称量子化パラメータを計算
118    pub fn symmetric<T: Float>(data: &ndarray::ArrayD<T>, bits: u8) -> RusTorchResult<(f32, i32)> {
119        let abs_max = data.fold(T::zero(), |acc, &x| acc.max(x.abs()));
120
121        if abs_max == T::zero() {
122            return Ok((1.0, 0));
123        }
124
125        let qmax = 2.0_f32.powi(bits as i32 - 1) - 1.0;
126        let scale = abs_max.to_f32().unwrap_or(1.0) / qmax;
127
128        Ok((scale, 0))
129    }
130
131    /// Compute asymmetric quantization parameters
132    /// 非対称量子化パラメータを計算
133    pub fn asymmetric<T: Float>(data: &ndarray::ArrayD<T>, bits: u8) -> RusTorchResult<(f32, i32)> {
134        let min_val = data.fold(T::infinity(), |acc, &x| acc.min(x));
135        let max_val = data.fold(T::neg_infinity(), |acc, &x| acc.max(x));
136
137        if min_val >= max_val {
138            return Ok((1.0, 0));
139        }
140
141        let qmin = -(2.0_f32.powi(bits as i32 - 1));
142        let qmax = 2.0_f32.powi(bits as i32 - 1) - 1.0;
143
144        let scale = (max_val - min_val).to_f32().unwrap_or(1.0) / (qmax - qmin);
145        let zero_point = (qmin - min_val.to_f32().unwrap_or(0.0) / scale).round() as i32;
146        let zero_point_clamped = zero_point.clamp(qmin as i32, qmax as i32);
147
148        Ok((scale, zero_point_clamped))
149    }
150
151    /// Compute per-channel quantization parameters
152    /// チャンネル別量子化パラメータを計算
153    pub fn per_channel<T: Float>(
154        data: &ndarray::ArrayD<T>,
155        channel_axis: usize,
156        symmetric: bool,
157        bits: u8,
158    ) -> RusTorchResult<(Vec<f32>, Vec<i32>)> {
159        let channels = data.shape()[channel_axis];
160        let mut scales = Vec::with_capacity(channels);
161        let mut zero_points = Vec::with_capacity(channels);
162
163        for c in 0..channels {
164            let channel_slice =
165                data.slice_axis(ndarray::Axis(channel_axis), ndarray::Slice::from(c..=c));
166            let channel_data = channel_slice.to_owned();
167
168            let (scale, zero_point) = if symmetric {
169                Self::symmetric(&channel_data, bits)?
170            } else {
171                Self::asymmetric(&channel_data, bits)?
172            };
173
174            scales.push(scale);
175            zero_points.push(zero_point);
176        }
177
178        Ok((scales, zero_points))
179    }
180}
181
182/// Trait for quantizable data types
183/// 量子化可能なデータ型のトレイト
184pub trait Quantizable: Copy + Clone + Send + Sync + 'static {
185    /// The quantized representation type (e.g., i8, i4)
186    /// 量子化表現型(例:i8、i4)
187    type QuantizedType: Copy + Clone + Send + Sync;
188
189    /// Convert from floating point to quantized representation
190    /// 浮動小数点から量子化表現に変換
191    fn quantize(&self, scale: f32, zero_point: i32) -> Self::QuantizedType;
192
193    /// Convert from quantized representation to floating point
194    /// 量子化表現から浮動小数点に変換
195    fn dequantize(quantized: Self::QuantizedType, scale: f32, zero_point: i32) -> Self;
196}
197
198impl Quantizable for f32 {
199    type QuantizedType = i8;
200
201    fn quantize(&self, scale: f32, zero_point: i32) -> i8 {
202        let quantized = (self / scale).round() as i32 + zero_point;
203        quantized.clamp(i8::MIN as i32, i8::MAX as i32) as i8
204    }
205
206    fn dequantize(quantized: i8, scale: f32, zero_point: i32) -> f32 {
207        (quantized as i32 - zero_point) as f32 * scale
208    }
209}
210
211impl Quantizable for f64 {
212    type QuantizedType = i8;
213
214    fn quantize(&self, scale: f32, zero_point: i32) -> i8 {
215        let quantized = (*self as f32 / scale).round() as i32 + zero_point;
216        quantized.clamp(i8::MIN as i32, i8::MAX as i32) as i8
217    }
218
219    fn dequantize(quantized: i8, scale: f32, zero_point: i32) -> f64 {
220        ((quantized as i32 - zero_point) as f32 * scale) as f64
221    }
222}
223
224/// Global quantization configuration
225/// グローバル量子化設定
226#[derive(Debug, Clone)]
227pub struct QuantizationConfig {
228    /// Default quantization scheme
229    /// デフォルト量子化スキーム
230    pub default_scheme: QuantizationScheme,
231    /// Enable per-channel quantization
232    /// チャンネル別量子化を有効化
233    pub per_channel: bool,
234    /// Calibration dataset size
235    /// キャリブレーションデータセットサイズ
236    pub calibration_size: usize,
237    /// Hardware acceleration preference
238    /// ハードウェア加速設定
239    pub hardware_acceleration: bool,
240}
241
242impl Default for QuantizationConfig {
243    fn default() -> Self {
244        Self {
245            default_scheme: QuantizationScheme::Symmetric,
246            per_channel: false,
247            calibration_size: 1000,
248            hardware_acceleration: true,
249        }
250    }
251}
252
253/// Main quantization API for tensors
254/// テンソル用メイン量子化API
255pub trait TensorQuantization<T: Float> {
256    /// Perform dynamic quantization
257    /// 動的量子化を実行
258    fn quantize_dynamic(&self, scheme: QuantizationScheme) -> RusTorchResult<QuantizedTensor<i8>>;
259
260    /// Perform static quantization with pre-computed parameters
261    /// 事前計算されたパラメータでの静的量子化を実行
262    fn quantize_static(&self, scale: f32, zero_point: i32) -> RusTorchResult<QuantizedTensor<i8>>;
263
264    /// Check if tensor is suitable for quantization
265    /// テンソルが量子化に適しているかチェック
266    fn can_quantize(&self) -> bool;
267}
268
269impl<T: Float + Quantizable<QuantizedType = i8>> TensorQuantization<T> for Tensor<T> {
270    fn quantize_dynamic(&self, scheme: QuantizationScheme) -> RusTorchResult<QuantizedTensor<i8>> {
271        let (scale, zero_point) = scheme.compute_params(&self.data)?;
272        self.quantize_static(scale, zero_point)
273    }
274
275    fn quantize_static(&self, scale: f32, zero_point: i32) -> RusTorchResult<QuantizedTensor<i8>> {
276        let quantized_data = self.data.mapv(|val| val.quantize(scale, zero_point));
277
278        Ok(QuantizedTensor::new(
279            quantized_data,
280            scale,
281            zero_point,
282            self.device.clone(),
283        ))
284    }
285
286    fn can_quantize(&self) -> bool {
287        // Check if tensor has reasonable dynamic range for quantization
288        let flat_data = self.data.as_slice().unwrap_or(&[]);
289        if flat_data.is_empty() {
290            return false;
291        }
292
293        // Check for any NaN or infinite values
294        if flat_data.iter().any(|&x| x.is_nan() || x.is_infinite()) {
295            return false;
296        }
297
298        let min_val = flat_data
299            .iter()
300            .fold(T::infinity(), |a, &b| if a < b { a } else { b });
301        let max_val = flat_data
302            .iter()
303            .fold(T::neg_infinity(), |a, &b| if a > b { a } else { b });
304
305        // Ensure reasonable dynamic range
306        let range = max_val - min_val;
307        !range.is_zero() && range.is_finite()
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::tensor::Tensor;
315
316    #[test]
317    fn test_f32_quantization() {
318        let value = 3.14f32;
319        let scale = 0.1f32;
320        let zero_point = 0i32;
321
322        let quantized = value.quantize(scale, zero_point);
323        let dequantized = f32::dequantize(quantized, scale, zero_point);
324
325        // Should be close to original value
326        assert!((value - dequantized).abs() < 0.2);
327    }
328
329    #[test]
330    fn test_tensor_quantization() {
331        let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
332
333        assert!(tensor.can_quantize());
334
335        let quantized = tensor.quantize_dynamic(QuantizationScheme::Symmetric);
336        assert!(quantized.is_ok());
337    }
338
339    #[test]
340    fn test_quantization_config() {
341        let config = QuantizationConfig::default();
342        assert!(matches!(
343            config.default_scheme,
344            QuantizationScheme::Symmetric
345        ));
346        assert_eq!(config.calibration_size, 1000);
347        assert!(config.hardware_acceleration);
348    }
349
350    #[test]
351    fn test_param_calculator_symmetric() {
352        let data =
353            ndarray::Array2::from_shape_vec((2, 3), vec![1.0f32, -2.0, 3.0, -4.0, 5.0, -6.0])
354                .unwrap()
355                .into_dyn();
356        let (scale, zero_point) = QuantParamCalculator::symmetric(&data, 8).unwrap();
357
358        assert!(scale > 0.0);
359        assert_eq!(zero_point, 0);
360        assert!(scale >= 6.0 / 127.0); // Should handle max value of 6.0
361    }
362
363    #[test]
364    fn test_param_calculator_asymmetric() {
365        let data = ndarray::Array2::from_shape_vec((2, 2), vec![1.0f32, 10.0, 2.0, 8.0])
366            .unwrap()
367            .into_dyn();
368        let (scale, zero_point) = QuantParamCalculator::asymmetric(&data, 8).unwrap();
369
370        assert!(scale > 0.0);
371        assert!(zero_point >= -128 && zero_point <= 127);
372    }
373
374    #[test]
375    fn test_param_calculator_per_channel() {
376        let data = ndarray::Array3::from_shape_vec((2, 3, 4), (0..24).map(|x| x as f32).collect())
377            .unwrap()
378            .into_dyn();
379        let (scales, zero_points) = QuantParamCalculator::per_channel(&data, 1, true, 8).unwrap();
380
381        assert_eq!(scales.len(), 3);
382        assert_eq!(zero_points.len(), 3);
383        assert!(scales.iter().all(|&s| s > 0.0));
384    }
385
386    #[test]
387    fn test_quantization_edge_cases() {
388        // Test empty tensor
389        let empty_tensor = Tensor::<f32>::from_vec(vec![], vec![0]);
390        assert!(!empty_tensor.can_quantize());
391
392        // Test constant tensor
393        let constant_tensor = Tensor::<f32>::from_vec(vec![5.0; 10], vec![10]);
394        assert!(!constant_tensor.can_quantize());
395
396        // Test with infinity/NaN
397        let inf_tensor = Tensor::<f32>::from_vec(vec![f32::INFINITY, 1.0, 2.0], vec![3]);
398        assert!(!inf_tensor.can_quantize());
399
400        let nan_tensor = Tensor::<f32>::from_vec(vec![f32::NAN, 1.0, 2.0], vec![3]);
401        assert!(!nan_tensor.can_quantize());
402    }
403
404    #[test]
405    fn test_quantization_precision_bounds() {
406        // Test extreme values near quantization bounds
407        let extreme_data = ndarray::Array1::from_vec(vec![-128.0f32, 127.0, 0.0]).into_dyn();
408        let (scale, zero_point) = QuantParamCalculator::symmetric(&extreme_data, 8).unwrap();
409
410        for &value in extreme_data.iter() {
411            let quantized = value.quantize(scale, zero_point);
412            let dequantized = f32::dequantize(quantized, scale, zero_point);
413
414            // Quantization error should be bounded
415            assert!((value - dequantized).abs() <= scale);
416        }
417    }
418
419    #[test]
420    fn test_different_bit_widths() {
421        let data = ndarray::Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]).into_dyn();
422
423        for &bits in &[4u8, 8u8, 16u8] {
424            let (scale, zero_point) = QuantParamCalculator::symmetric(&data, bits).unwrap();
425            assert!(scale > 0.0);
426            assert_eq!(zero_point, 0); // Symmetric should have zero_point = 0
427        }
428    }
429}