Skip to main content

ruvector_cnn/simd/
quantize.rs

1//! INT8 Quantization with π-Based Calibration
2//!
3//! Implements efficient INT8 quantization for CNN inference using π-derived
4//! constants to avoid quantization boundary resonance artifacts.
5//!
6//! # Why π?
7//!
8//! In low-precision quantization, values tend to collapse into repeating buckets
9//! when scale factors align with powers of two. Using π-derived constants
10//! breaks this symmetry:
11//!
12//! - π is irrational (non-repeating, infinite structure)
13//! - Avoids power-of-2 boundary alignment
14//! - Provides deterministic anti-resonance offsets
15//!
16//! # Quantization Schemes
17//!
18//! - **Symmetric**: For weights (zero-centered distributions)
19//! - **Asymmetric**: For activations (ReLU outputs are non-negative)
20//! - **Per-channel**: Different scale per output channel (higher accuracy)
21//! - **Per-tensor**: Single scale for entire tensor (faster)
22//!
23//! # Performance
24//!
25//! INT8 inference provides:
26//! - 4x memory reduction vs FP32
27//! - 2-3x speedup on AVX2/AVX-512 (VNNI)
28//! - 2-4x speedup on ARM NEON (SDOT)
29
30#[cfg(target_arch = "x86_64")]
31use std::arch::x86_64::*;
32
33/// π-based scale factors to avoid power-of-2 resonance
34pub mod pi_constants {
35    use std::f32::consts::PI;
36
37    /// Anti-resonance offset derived from π fractional part
38    pub const PI_FRAC: f32 = PI - 3.0; // 0.14159...
39
40    /// Scale factor that avoids 2^n boundaries
41    pub const PI_SCALE: f32 = PI / 4.0; // ~0.785
42
43    /// Golden ratio approximation from π
44    pub const PHI_APPROX: f32 = 2.0 / (PI - 1.0); // ~0.934
45
46    /// First 16 digits of π for deterministic seeding
47    pub const PI_DIGITS: [u8; 16] = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3];
48
49    /// Compute anti-resonance offset for n-bit quantization
50    #[inline]
51    pub fn anti_resonance(bits: u8) -> f32 {
52        PI_FRAC / (1u32 << bits) as f32
53    }
54
55    /// π-based jitter for tie-breaking in rounding
56    #[inline]
57    pub fn jitter(index: usize) -> f32 {
58        let digit = PI_DIGITS[index % 16];
59        (digit as f32) * 0.001 * PI_FRAC
60    }
61}
62
63/// Quantization parameters for a tensor or channel
64#[derive(Debug, Clone, Copy)]
65pub struct QuantParams {
66    /// Scale factor (float = quant * scale + zero_point)
67    pub scale: f32,
68    /// Zero point offset (for asymmetric quantization)
69    pub zero_point: i8,
70    /// Anti-resonance offset from π
71    pub anti_resonance: f32,
72    /// Quantization bits (7 for signed int8)
73    pub bits: u8,
74}
75
76impl QuantParams {
77    /// Create symmetric quantization params (for weights)
78    ///
79    /// Uses π-based anti-resonance to avoid boundary collapse.
80    pub fn symmetric(min_val: f32, max_val: f32) -> Self {
81        let abs_max = min_val.abs().max(max_val.abs());
82
83        // 7 bits for signed int8 (-127 to 127)
84        let bits = 7u8;
85        let qmax = 127.0f32;
86
87        // π-based scale with anti-resonance
88        let anti_resonance = pi_constants::anti_resonance(bits);
89        let scale = (abs_max + anti_resonance) / qmax;
90
91        Self {
92            scale: scale.max(1e-10), // Avoid division by zero
93            zero_point: 0,
94            anti_resonance,
95            bits,
96        }
97    }
98
99    /// Create asymmetric quantization params (for activations)
100    ///
101    /// Maps [min_val, max_val] to [-128, 127] with π-based calibration.
102    pub fn asymmetric(min_val: f32, max_val: f32) -> Self {
103        let bits = 8u8;
104        let qmin = -128.0f32;
105        let qmax = 127.0f32;
106
107        let anti_resonance = pi_constants::anti_resonance(bits);
108        let range = (max_val - min_val).max(1e-10) + anti_resonance;
109        let scale = range / (qmax - qmin);
110
111        // Compute zero point with π-jitter for tie-breaking
112        let zero_point_float = qmin - min_val / scale + pi_constants::jitter(0);
113        let zero_point = zero_point_float.round().clamp(-128.0, 127.0) as i8;
114
115        Self {
116            scale: scale.max(1e-10),
117            zero_point,
118            anti_resonance,
119            bits,
120        }
121    }
122
123    /// Quantize a single f32 value to i8
124    #[inline]
125    pub fn quantize(&self, value: f32) -> i8 {
126        let scaled = value / self.scale + self.zero_point as f32;
127        // Add small π-based offset for better rounding distribution
128        let rounded = (scaled + self.anti_resonance * 0.5).round();
129        rounded.clamp(-128.0, 127.0) as i8
130    }
131
132    /// Dequantize a single i8 value to f32
133    #[inline]
134    pub fn dequantize(&self, quantized: i8) -> f32 {
135        (quantized as f32 - self.zero_point as f32) * self.scale
136    }
137}
138
139impl Default for QuantParams {
140    fn default() -> Self {
141        Self::symmetric(-1.0, 1.0)
142    }
143}
144
145/// Per-channel quantization parameters
146#[derive(Debug, Clone)]
147pub struct PerChannelQuantParams {
148    /// Per-channel scales
149    pub scales: Vec<f32>,
150    /// Per-channel zero points
151    pub zero_points: Vec<i8>,
152    /// Number of channels
153    pub num_channels: usize,
154}
155
156impl PerChannelQuantParams {
157    /// Compute per-channel symmetric quantization params
158    pub fn symmetric_per_channel(weights: &[f32], out_channels: usize, in_channels: usize) -> Self {
159        let kernel_size = weights.len() / (out_channels * in_channels);
160        let mut scales = Vec::with_capacity(out_channels);
161        let zero_points = vec![0i8; out_channels];
162
163        for oc in 0..out_channels {
164            let start = oc * in_channels * kernel_size;
165            let end = start + in_channels * kernel_size;
166            let channel_weights = &weights[start..end];
167
168            let abs_max = channel_weights
169                .iter()
170                .map(|x| x.abs())
171                .fold(0.0f32, |a, b| a.max(b));
172
173            let anti_res = pi_constants::anti_resonance(7);
174            let scale = (abs_max + anti_res) / 127.0;
175            scales.push(scale.max(1e-10));
176        }
177
178        Self {
179            scales,
180            zero_points,
181            num_channels: out_channels,
182        }
183    }
184
185    /// Get params for a specific channel
186    #[inline]
187    pub fn channel_params(&self, channel: usize) -> QuantParams {
188        QuantParams {
189            scale: self.scales[channel],
190            zero_point: self.zero_points[channel],
191            anti_resonance: pi_constants::anti_resonance(7),
192            bits: 7,
193        }
194    }
195}
196
197/// Quantized INT8 tensor storage
198#[derive(Debug, Clone)]
199pub struct QuantizedTensor {
200    /// INT8 data
201    pub data: Vec<i8>,
202    /// Shape
203    pub shape: Vec<usize>,
204    /// Per-tensor or per-channel quantization
205    pub params: QuantizationType,
206}
207
208/// Quantization type
209#[derive(Debug, Clone)]
210pub enum QuantizationType {
211    /// Single scale for entire tensor
212    PerTensor(QuantParams),
213    /// Different scale per output channel
214    PerChannel(PerChannelQuantParams),
215}
216
217impl QuantizedTensor {
218    /// Quantize a float tensor with per-tensor symmetric quantization
219    pub fn from_float_symmetric(data: &[f32], shape: &[usize]) -> Self {
220        let min_val = data.iter().fold(f32::MAX, |a, &b| a.min(b));
221        let max_val = data.iter().fold(f32::MIN, |a, &b| a.max(b));
222        let params = QuantParams::symmetric(min_val, max_val);
223
224        let quantized: Vec<i8> = data.iter().map(|&v| params.quantize(v)).collect();
225
226        Self {
227            data: quantized,
228            shape: shape.to_vec(),
229            params: QuantizationType::PerTensor(params),
230        }
231    }
232
233    /// Quantize weights with per-channel quantization
234    pub fn from_weights_per_channel(
235        weights: &[f32],
236        out_channels: usize,
237        in_channels: usize,
238        kernel_h: usize,
239        kernel_w: usize,
240    ) -> Self {
241        let per_channel = PerChannelQuantParams::symmetric_per_channel(weights, out_channels, in_channels);
242        let kernel_size = kernel_h * kernel_w;
243
244        let mut quantized = Vec::with_capacity(weights.len());
245
246        for oc in 0..out_channels {
247            let params = per_channel.channel_params(oc);
248            let start = oc * in_channels * kernel_size;
249            let end = start + in_channels * kernel_size;
250
251            for &w in &weights[start..end] {
252                quantized.push(params.quantize(w));
253            }
254        }
255
256        Self {
257            data: quantized,
258            shape: vec![out_channels, in_channels, kernel_h, kernel_w],
259            params: QuantizationType::PerChannel(per_channel),
260        }
261    }
262
263    /// Dequantize back to float32
264    pub fn dequantize(&self) -> Vec<f32> {
265        match &self.params {
266            QuantizationType::PerTensor(params) => {
267                self.data.iter().map(|&q| params.dequantize(q)).collect()
268            }
269            QuantizationType::PerChannel(per_channel) => {
270                let out_channels = self.shape[0];
271                let channel_size = self.data.len() / out_channels;
272                let mut output = Vec::with_capacity(self.data.len());
273
274                for oc in 0..out_channels {
275                    let params = per_channel.channel_params(oc);
276                    let start = oc * channel_size;
277                    let end = start + channel_size;
278
279                    for &q in &self.data[start..end] {
280                        output.push(params.dequantize(q));
281                    }
282                }
283                output
284            }
285        }
286    }
287
288    /// Get the number of elements
289    pub fn len(&self) -> usize {
290        self.data.len()
291    }
292
293    /// Check if empty
294    pub fn is_empty(&self) -> bool {
295        self.data.is_empty()
296    }
297}
298
299/// Batch quantize f32 to i8 using π-calibration
300///
301/// Faster than per-element quantization using SIMD.
302pub fn quantize_batch(input: &[f32], output: &mut [i8], params: &QuantParams) {
303    debug_assert_eq!(input.len(), output.len());
304
305    let inv_scale = 1.0 / params.scale;
306    let zp = params.zero_point as f32;
307    let anti_res = params.anti_resonance * 0.5;
308
309    for (i, &val) in input.iter().enumerate() {
310        let scaled = val * inv_scale + zp + anti_res;
311        output[i] = scaled.round().clamp(-128.0, 127.0) as i8;
312    }
313}
314
315/// Batch dequantize i8 to f32
316pub fn dequantize_batch(input: &[i8], output: &mut [f32], params: &QuantParams) {
317    debug_assert_eq!(input.len(), output.len());
318
319    let zp = params.zero_point as f32;
320
321    for (i, &q) in input.iter().enumerate() {
322        output[i] = (q as f32 - zp) * params.scale;
323    }
324}
325
326/// AVX2 batch quantization (8 values at a time)
327#[cfg(target_arch = "x86_64")]
328#[target_feature(enable = "avx2")]
329pub unsafe fn quantize_batch_avx2(input: &[f32], output: &mut [i8], params: &QuantParams) {
330    let len = input.len();
331    let chunks = len / 8;
332
333    let inv_scale = _mm256_set1_ps(1.0 / params.scale);
334    let zp = _mm256_set1_ps(params.zero_point as f32);
335    let anti_res = _mm256_set1_ps(params.anti_resonance * 0.5);
336    let half = _mm256_set1_ps(0.5);
337    let min_val = _mm256_set1_ps(-128.0);
338    let max_val = _mm256_set1_ps(127.0);
339
340    for i in 0..chunks {
341        let offset = i * 8;
342
343        // Load 8 floats
344        let v = _mm256_loadu_ps(input.as_ptr().add(offset));
345
346        // Scale and offset: v * inv_scale + zp + anti_res
347        let scaled = _mm256_add_ps(_mm256_mul_ps(v, inv_scale), zp);
348        let adjusted = _mm256_add_ps(scaled, anti_res);
349
350        // Round (add 0.5 and floor for positive, subtract 0.5 for negative)
351        let rounded = _mm256_round_ps(adjusted, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
352
353        // Clamp to [-128, 127]
354        let clamped = _mm256_min_ps(_mm256_max_ps(rounded, min_val), max_val);
355
356        // Convert to i32 then pack to i8
357        let i32_vals = _mm256_cvtps_epi32(clamped);
358
359        // Extract and pack to i8 (need to do this manually for AVX2)
360        let i32_array: [i32; 8] = std::mem::transmute(i32_vals);
361        for j in 0..8 {
362            output[offset + j] = i32_array[j] as i8;
363        }
364    }
365
366    // Handle remainder
367    let remainder_start = chunks * 8;
368    for i in remainder_start..len {
369        let scaled = input[i] / params.scale + params.zero_point as f32 + params.anti_resonance * 0.5;
370        output[i] = scaled.round().clamp(-128.0, 127.0) as i8;
371    }
372}
373
374/// AVX2 batch dequantization
375#[cfg(target_arch = "x86_64")]
376#[target_feature(enable = "avx2")]
377pub unsafe fn dequantize_batch_avx2(input: &[i8], output: &mut [f32], params: &QuantParams) {
378    let len = input.len();
379    let chunks = len / 8;
380
381    let scale = _mm256_set1_ps(params.scale);
382    let zp = _mm256_set1_ps(params.zero_point as f32);
383
384    for i in 0..chunks {
385        let offset = i * 8;
386
387        // Load 8 i8 values and convert to f32
388        let mut i32_array = [0i32; 8];
389        for j in 0..8 {
390            i32_array[j] = input[offset + j] as i32;
391        }
392        let i32_vals: __m256i = std::mem::transmute(i32_array);
393        let f32_vals = _mm256_cvtepi32_ps(i32_vals);
394
395        // Dequantize: (val - zp) * scale
396        let shifted = _mm256_sub_ps(f32_vals, zp);
397        let result = _mm256_mul_ps(shifted, scale);
398
399        _mm256_storeu_ps(output.as_mut_ptr().add(offset), result);
400    }
401
402    // Handle remainder
403    let remainder_start = chunks * 8;
404    for i in remainder_start..len {
405        output[i] = (input[i] as f32 - params.zero_point as f32) * params.scale;
406    }
407}
408
409// Non-x86_64 stubs
410#[cfg(not(target_arch = "x86_64"))]
411pub unsafe fn quantize_batch_avx2(_input: &[f32], _output: &mut [i8], _params: &QuantParams) {}
412
413#[cfg(not(target_arch = "x86_64"))]
414pub unsafe fn dequantize_batch_avx2(_input: &[i8], _output: &mut [f32], _params: &QuantParams) {}
415
416/// SIMD-dispatched quantization
417#[inline(always)]
418pub fn quantize_simd(input: &[f32], output: &mut [i8], params: &QuantParams) {
419    #[cfg(target_arch = "x86_64")]
420    {
421        if is_x86_feature_detected!("avx2") {
422            unsafe {
423                quantize_batch_avx2(input, output, params);
424            }
425            return;
426        }
427    }
428    quantize_batch(input, output, params);
429}
430
431/// SIMD-dispatched dequantization
432#[inline(always)]
433pub fn dequantize_simd(input: &[i8], output: &mut [f32], params: &QuantParams) {
434    #[cfg(target_arch = "x86_64")]
435    {
436        if is_x86_feature_detected!("avx2") {
437            unsafe {
438                dequantize_batch_avx2(input, output, params);
439            }
440            return;
441        }
442    }
443    dequantize_batch(input, output, params);
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[test]
451    fn test_symmetric_quantization() {
452        let params = QuantParams::symmetric(-1.0, 1.0);
453
454        let q = params.quantize(0.5);
455        let dq = params.dequantize(q);
456
457        // Should round-trip with small error
458        assert!((0.5 - dq).abs() < 0.02);
459    }
460
461    #[test]
462    fn test_asymmetric_quantization() {
463        let params = QuantParams::asymmetric(0.0, 1.0);
464
465        let q = params.quantize(0.5);
466        let dq = params.dequantize(q);
467
468        assert!((0.5 - dq).abs() < 0.02);
469    }
470
471    #[test]
472    fn test_pi_anti_resonance() {
473        let anti_res = pi_constants::anti_resonance(8);
474        assert!(anti_res > 0.0);
475        assert!(anti_res < 0.001);
476
477        // Check it's π-derived
478        let expected = (PI - 3.0) / 256.0;
479        assert!((anti_res - expected).abs() < 1e-10);
480    }
481
482    #[test]
483    fn test_quantized_tensor_roundtrip() {
484        let data = vec![0.1, 0.2, 0.3, 0.4, -0.1, -0.2, -0.3, -0.4];
485        let shape = vec![2, 4];
486
487        let quantized = QuantizedTensor::from_float_symmetric(&data, &shape);
488        let dequantized = quantized.dequantize();
489
490        // Check all values round-trip within tolerance
491        for (original, recovered) in data.iter().zip(dequantized.iter()) {
492            assert!((original - recovered).abs() < 0.02);
493        }
494    }
495
496    #[test]
497    fn test_per_channel_quantization() {
498        // 2 output channels, 2 input channels, 3x3 kernel
499        let weights: Vec<f32> = (0..36).map(|i| (i as f32 - 18.0) * 0.1).collect();
500
501        let quantized = QuantizedTensor::from_weights_per_channel(&weights, 2, 2, 3, 3);
502        let dequantized = quantized.dequantize();
503
504        // Per-channel should have better accuracy than per-tensor for diverse channels
505        let max_error: f32 = weights
506            .iter()
507            .zip(dequantized.iter())
508            .map(|(a, b)| (a - b).abs())
509            .fold(0.0f32, |a, b| a.max(b));
510
511        assert!(max_error < 0.05);
512    }
513
514    #[test]
515    fn test_batch_quantize() {
516        let input = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
517        let mut output = vec![0i8; 8];
518        let params = QuantParams::symmetric(-1.0, 1.0);
519
520        quantize_batch(&input, &mut output, &params);
521
522        // All values should be non-zero and in valid range
523        for &q in &output {
524            assert!(q >= -128 && q <= 127);
525        }
526    }
527
528    #[test]
529    fn test_batch_dequantize() {
530        let input = vec![10i8, 20, 30, 40, -10, -20, -30, -40];
531        let mut output = vec![0.0f32; 8];
532        let params = QuantParams::symmetric(-1.0, 1.0);
533
534        dequantize_batch(&input, &mut output, &params);
535
536        // Positive quantized values should give positive floats
537        assert!(output[0] > 0.0);
538        assert!(output[4] < 0.0);
539    }
540
541    #[test]
542    fn test_simd_dispatch() {
543        let input = vec![0.1f32; 16];
544        let mut output = vec![0i8; 16];
545        let params = QuantParams::symmetric(-1.0, 1.0);
546
547        quantize_simd(&input, &mut output, &params);
548
549        // All should be same value
550        let first = output[0];
551        for &q in &output {
552            assert_eq!(q, first);
553        }
554    }
555}