1use crate::error::{RusTorchError, RusTorchResult};
66use crate::tensor::Tensor;
67use ndarray;
68use num_traits::{Float, Signed, Unsigned};
69use std::fmt;
70use std::marker::PhantomData;
71
72pub use calibration::{HistogramObserver, MinMaxObserver, Observer, StaticQuantizer};
74pub use hardware::optimized_ops;
75pub use operations::{DequantizeOps, QuantizedOps};
76pub use qat::{FakeQuantize, QATConv2d, QATLinear, QATModule};
77pub use schemes::{AsymmetricQuantization, QuantizationScheme, SymmetricQuantization};
78pub use types::{QuantizationType, QuantizedTensor};
79
80pub mod types;
85
86pub mod schemes;
89
90pub mod calibration;
93
94pub mod operations;
97
98pub mod qat;
101
102pub mod hardware;
105
106pub mod observers;
109
110#[derive(Debug, Clone)]
113pub struct QuantParamCalculator;
114
115impl QuantParamCalculator {
116 pub fn symmetric<T: Float>(data: &ndarray::ArrayD<T>, bits: u8) -> RusTorchResult<(f32, i32)> {
119 let abs_max = data.fold(T::zero(), |acc, &x| acc.max(x.abs()));
120
121 if abs_max == T::zero() {
122 return Ok((1.0, 0));
123 }
124
125 let qmax = 2.0_f32.powi(bits as i32 - 1) - 1.0;
126 let scale = abs_max.to_f32().unwrap_or(1.0) / qmax;
127
128 Ok((scale, 0))
129 }
130
131 pub fn asymmetric<T: Float>(data: &ndarray::ArrayD<T>, bits: u8) -> RusTorchResult<(f32, i32)> {
134 let min_val = data.fold(T::infinity(), |acc, &x| acc.min(x));
135 let max_val = data.fold(T::neg_infinity(), |acc, &x| acc.max(x));
136
137 if min_val >= max_val {
138 return Ok((1.0, 0));
139 }
140
141 let qmin = -(2.0_f32.powi(bits as i32 - 1));
142 let qmax = 2.0_f32.powi(bits as i32 - 1) - 1.0;
143
144 let scale = (max_val - min_val).to_f32().unwrap_or(1.0) / (qmax - qmin);
145 let zero_point = (qmin - min_val.to_f32().unwrap_or(0.0) / scale).round() as i32;
146 let zero_point_clamped = zero_point.clamp(qmin as i32, qmax as i32);
147
148 Ok((scale, zero_point_clamped))
149 }
150
151 pub fn per_channel<T: Float>(
154 data: &ndarray::ArrayD<T>,
155 channel_axis: usize,
156 symmetric: bool,
157 bits: u8,
158 ) -> RusTorchResult<(Vec<f32>, Vec<i32>)> {
159 let channels = data.shape()[channel_axis];
160 let mut scales = Vec::with_capacity(channels);
161 let mut zero_points = Vec::with_capacity(channels);
162
163 for c in 0..channels {
164 let channel_slice =
165 data.slice_axis(ndarray::Axis(channel_axis), ndarray::Slice::from(c..=c));
166 let channel_data = channel_slice.to_owned();
167
168 let (scale, zero_point) = if symmetric {
169 Self::symmetric(&channel_data, bits)?
170 } else {
171 Self::asymmetric(&channel_data, bits)?
172 };
173
174 scales.push(scale);
175 zero_points.push(zero_point);
176 }
177
178 Ok((scales, zero_points))
179 }
180}
181
182pub trait Quantizable: Copy + Clone + Send + Sync + 'static {
185 type QuantizedType: Copy + Clone + Send + Sync;
188
189 fn quantize(&self, scale: f32, zero_point: i32) -> Self::QuantizedType;
192
193 fn dequantize(quantized: Self::QuantizedType, scale: f32, zero_point: i32) -> Self;
196}
197
198impl Quantizable for f32 {
199 type QuantizedType = i8;
200
201 fn quantize(&self, scale: f32, zero_point: i32) -> i8 {
202 let quantized = (self / scale).round() as i32 + zero_point;
203 quantized.clamp(i8::MIN as i32, i8::MAX as i32) as i8
204 }
205
206 fn dequantize(quantized: i8, scale: f32, zero_point: i32) -> f32 {
207 (quantized as i32 - zero_point) as f32 * scale
208 }
209}
210
211impl Quantizable for f64 {
212 type QuantizedType = i8;
213
214 fn quantize(&self, scale: f32, zero_point: i32) -> i8 {
215 let quantized = (*self as f32 / scale).round() as i32 + zero_point;
216 quantized.clamp(i8::MIN as i32, i8::MAX as i32) as i8
217 }
218
219 fn dequantize(quantized: i8, scale: f32, zero_point: i32) -> f64 {
220 ((quantized as i32 - zero_point) as f32 * scale) as f64
221 }
222}
223
224#[derive(Debug, Clone)]
227pub struct QuantizationConfig {
228 pub default_scheme: QuantizationScheme,
231 pub per_channel: bool,
234 pub calibration_size: usize,
237 pub hardware_acceleration: bool,
240}
241
242impl Default for QuantizationConfig {
243 fn default() -> Self {
244 Self {
245 default_scheme: QuantizationScheme::Symmetric,
246 per_channel: false,
247 calibration_size: 1000,
248 hardware_acceleration: true,
249 }
250 }
251}
252
253pub trait TensorQuantization<T: Float> {
256 fn quantize_dynamic(&self, scheme: QuantizationScheme) -> RusTorchResult<QuantizedTensor<i8>>;
259
260 fn quantize_static(&self, scale: f32, zero_point: i32) -> RusTorchResult<QuantizedTensor<i8>>;
263
264 fn can_quantize(&self) -> bool;
267}
268
269impl<T: Float + Quantizable<QuantizedType = i8>> TensorQuantization<T> for Tensor<T> {
270 fn quantize_dynamic(&self, scheme: QuantizationScheme) -> RusTorchResult<QuantizedTensor<i8>> {
271 let (scale, zero_point) = scheme.compute_params(&self.data)?;
272 self.quantize_static(scale, zero_point)
273 }
274
275 fn quantize_static(&self, scale: f32, zero_point: i32) -> RusTorchResult<QuantizedTensor<i8>> {
276 let quantized_data = self.data.mapv(|val| val.quantize(scale, zero_point));
277
278 Ok(QuantizedTensor::new(
279 quantized_data,
280 scale,
281 zero_point,
282 self.device.clone(),
283 ))
284 }
285
286 fn can_quantize(&self) -> bool {
287 let flat_data = self.data.as_slice().unwrap_or(&[]);
289 if flat_data.is_empty() {
290 return false;
291 }
292
293 if flat_data.iter().any(|&x| x.is_nan() || x.is_infinite()) {
295 return false;
296 }
297
298 let min_val = flat_data
299 .iter()
300 .fold(T::infinity(), |a, &b| if a < b { a } else { b });
301 let max_val = flat_data
302 .iter()
303 .fold(T::neg_infinity(), |a, &b| if a > b { a } else { b });
304
305 let range = max_val - min_val;
307 !range.is_zero() && range.is_finite()
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::tensor::Tensor;
315
316 #[test]
317 fn test_f32_quantization() {
318 let value = 3.14f32;
319 let scale = 0.1f32;
320 let zero_point = 0i32;
321
322 let quantized = value.quantize(scale, zero_point);
323 let dequantized = f32::dequantize(quantized, scale, zero_point);
324
325 assert!((value - dequantized).abs() < 0.2);
327 }
328
329 #[test]
330 fn test_tensor_quantization() {
331 let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
332
333 assert!(tensor.can_quantize());
334
335 let quantized = tensor.quantize_dynamic(QuantizationScheme::Symmetric);
336 assert!(quantized.is_ok());
337 }
338
339 #[test]
340 fn test_quantization_config() {
341 let config = QuantizationConfig::default();
342 assert!(matches!(
343 config.default_scheme,
344 QuantizationScheme::Symmetric
345 ));
346 assert_eq!(config.calibration_size, 1000);
347 assert!(config.hardware_acceleration);
348 }
349
350 #[test]
351 fn test_param_calculator_symmetric() {
352 let data =
353 ndarray::Array2::from_shape_vec((2, 3), vec![1.0f32, -2.0, 3.0, -4.0, 5.0, -6.0])
354 .unwrap()
355 .into_dyn();
356 let (scale, zero_point) = QuantParamCalculator::symmetric(&data, 8).unwrap();
357
358 assert!(scale > 0.0);
359 assert_eq!(zero_point, 0);
360 assert!(scale >= 6.0 / 127.0); }
362
363 #[test]
364 fn test_param_calculator_asymmetric() {
365 let data = ndarray::Array2::from_shape_vec((2, 2), vec![1.0f32, 10.0, 2.0, 8.0])
366 .unwrap()
367 .into_dyn();
368 let (scale, zero_point) = QuantParamCalculator::asymmetric(&data, 8).unwrap();
369
370 assert!(scale > 0.0);
371 assert!(zero_point >= -128 && zero_point <= 127);
372 }
373
374 #[test]
375 fn test_param_calculator_per_channel() {
376 let data = ndarray::Array3::from_shape_vec((2, 3, 4), (0..24).map(|x| x as f32).collect())
377 .unwrap()
378 .into_dyn();
379 let (scales, zero_points) = QuantParamCalculator::per_channel(&data, 1, true, 8).unwrap();
380
381 assert_eq!(scales.len(), 3);
382 assert_eq!(zero_points.len(), 3);
383 assert!(scales.iter().all(|&s| s > 0.0));
384 }
385
386 #[test]
387 fn test_quantization_edge_cases() {
388 let empty_tensor = Tensor::<f32>::from_vec(vec![], vec![0]);
390 assert!(!empty_tensor.can_quantize());
391
392 let constant_tensor = Tensor::<f32>::from_vec(vec![5.0; 10], vec![10]);
394 assert!(!constant_tensor.can_quantize());
395
396 let inf_tensor = Tensor::<f32>::from_vec(vec![f32::INFINITY, 1.0, 2.0], vec![3]);
398 assert!(!inf_tensor.can_quantize());
399
400 let nan_tensor = Tensor::<f32>::from_vec(vec![f32::NAN, 1.0, 2.0], vec![3]);
401 assert!(!nan_tensor.can_quantize());
402 }
403
404 #[test]
405 fn test_quantization_precision_bounds() {
406 let extreme_data = ndarray::Array1::from_vec(vec![-128.0f32, 127.0, 0.0]).into_dyn();
408 let (scale, zero_point) = QuantParamCalculator::symmetric(&extreme_data, 8).unwrap();
409
410 for &value in extreme_data.iter() {
411 let quantized = value.quantize(scale, zero_point);
412 let dequantized = f32::dequantize(quantized, scale, zero_point);
413
414 assert!((value - dequantized).abs() <= scale);
416 }
417 }
418
419 #[test]
420 fn test_different_bit_widths() {
421 let data = ndarray::Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]).into_dyn();
422
423 for &bits in &[4u8, 8u8, 16u8] {
424 let (scale, zero_point) = QuantParamCalculator::symmetric(&data, bits).unwrap();
425 assert!(scale > 0.0);
426 assert_eq!(zero_point, 0); }
428 }
429}