Skip to main content

ruvector_cnn/int8/
mod.rs

1//! INT8 quantization module for ADR-091
2//!
3//! This module provides INT8 quantization primitives for testing.
4//! Full implementation will be added in subsequent phases.
5
6pub mod kernels;
7
8use serde::{Deserialize, Serialize};
9
10/// Quantization parameters for INT8 conversion
11#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
12pub struct QuantParams {
13    /// Scale factor for quantization
14    pub scale: f32,
15    /// Zero point offset
16    pub zero_point: i8,
17}
18
19impl QuantParams {
20    /// Compute quantization parameters from a tensor
21    pub fn from_tensor(tensor: &[f32]) -> Self {
22        if tensor.is_empty() {
23            return Self {
24                scale: 1.0,
25                zero_point: 0,
26            };
27        }
28
29        let min_val = tensor.iter().copied().fold(f32::INFINITY, f32::min);
30        let max_val = tensor.iter().copied().fold(f32::NEG_INFINITY, f32::max);
31
32        // Handle all zeros or constant tensors
33        if (max_val - min_val).abs() < 1e-10 {
34            return Self {
35                scale: 1.0,
36                zero_point: 0,
37            };
38        }
39
40        // Compute scale and zero_point for asymmetric quantization
41        // Map [min_val, max_val] to [-128, 127]
42        let scale = (max_val - min_val) / 255.0;
43        let zero_point = (-128.0 - min_val / scale).round() as i8;
44
45        Self { scale, zero_point }
46    }
47}
48
49/// Quantize FP32 tensor to INT8
50pub fn quantize_tensor(fp32: &[f32], params: &QuantParams) -> Vec<i8> {
51    fp32.iter()
52        .map(|&x| {
53            let quantized = (x / params.scale + params.zero_point as f32).round();
54            quantized.clamp(-128.0, 127.0) as i8
55        })
56        .collect()
57}
58
59/// Dequantize INT8 tensor to FP32
60pub fn dequantize_tensor(int8: &[i8], params: &QuantParams) -> Vec<f32> {
61    int8.iter()
62        .map(|&x| (x as f32 - params.zero_point as f32) * params.scale)
63        .collect()
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn test_quant_params_from_tensor() {
72        let tensor = vec![-1.0, 0.0, 1.0, 2.0];
73        let params = QuantParams::from_tensor(&tensor);
74
75        assert!(params.scale > 0.0);
76        assert!(params.scale.is_finite());
77        assert!(params.zero_point >= -128 && params.zero_point <= 127);
78    }
79
80    #[test]
81    fn test_quantize_dequantize_roundtrip() {
82        let fp32 = vec![0.5, -0.3, 0.8, -0.1];
83        let params = QuantParams::from_tensor(&fp32);
84        let int8 = quantize_tensor(&fp32, &params);
85        let dequant = dequantize_tensor(&int8, &params);
86
87        for (orig, recovered) in fp32.iter().zip(dequant.iter()) {
88            let error = (orig - recovered).abs();
89            assert!(error < 0.1, "Roundtrip error too large: {}", error);
90        }
91    }
92
93    #[test]
94    fn test_empty_tensor() {
95        let empty: Vec<f32> = vec![];
96        let params = QuantParams::from_tensor(&empty);
97        assert_eq!(params.scale, 1.0);
98        assert_eq!(params.zero_point, 0);
99    }
100
101    #[test]
102    fn test_constant_tensor() {
103        let constant = vec![0.5; 100];
104        let params = QuantParams::from_tensor(&constant);
105        assert_eq!(params.scale, 1.0);
106        assert_eq!(params.zero_point, 0);
107    }
108}