ruvector_scipix/optimize/
quantize.rs

1//! Model quantization utilities
2//!
3//! Provides INT8 quantization for model weights and activations to reduce
4//! memory usage and improve inference speed.
5
6use std::f32;
7
8/// Quantization parameters
9#[derive(Debug, Clone, Copy)]
10pub struct QuantParams {
11    pub scale: f32,
12    pub zero_point: i8,
13}
14
15impl QuantParams {
16    /// Calculate quantization parameters from min/max values
17    pub fn from_range(min: f32, max: f32) -> Self {
18        let qmin = i8::MIN as f32;
19        let qmax = i8::MAX as f32;
20
21        let scale = (max - min) / (qmax - qmin);
22        let zero_point = (qmin - min / scale).round() as i8;
23
24        Self { scale, zero_point }
25    }
26
27    /// Calculate from data statistics
28    pub fn from_data(data: &[f32]) -> Self {
29        let min = data.iter().copied().fold(f32::INFINITY, f32::min);
30        let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
31        Self::from_range(min, max)
32    }
33
34    /// Symmetric quantization (zero_point = 0)
35    pub fn symmetric(abs_max: f32) -> Self {
36        let scale = abs_max / 127.0;
37        Self {
38            scale,
39            zero_point: 0,
40        }
41    }
42}
43
44/// Quantize f32 weights to i8
45pub fn quantize_weights(weights: &[f32]) -> (Vec<i8>, QuantParams) {
46    let params = QuantParams::from_data(weights);
47    let quantized = quantize_with_params(weights, params);
48    (quantized, params)
49}
50
51/// Quantize with given parameters
52pub fn quantize_with_params(weights: &[f32], params: QuantParams) -> Vec<i8> {
53    weights
54        .iter()
55        .map(|&w| quantize_value(w, params))
56        .collect()
57}
58
59/// Quantize single value
60#[inline]
61pub fn quantize_value(value: f32, params: QuantParams) -> i8 {
62    let scaled = value / params.scale + params.zero_point as f32;
63    scaled.round().clamp(i8::MIN as f32, i8::MAX as f32) as i8
64}
65
66/// Dequantize i8 to f32
67pub fn dequantize(quantized: &[i8], params: QuantParams) -> Vec<f32> {
68    quantized
69        .iter()
70        .map(|&q| dequantize_value(q, params))
71        .collect()
72}
73
74/// Dequantize single value
75#[inline]
76pub fn dequantize_value(quantized: i8, params: QuantParams) -> f32 {
77    (quantized as f32 - params.zero_point as f32) * params.scale
78}
79
80/// Quantized tensor representation
81pub struct QuantizedTensor {
82    pub data: Vec<i8>,
83    pub params: QuantParams,
84    pub shape: Vec<usize>,
85}
86
87impl QuantizedTensor {
88    /// Create from f32 tensor
89    pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Self {
90        let (quantized, params) = quantize_weights(data);
91        Self {
92            data: quantized,
93            params,
94            shape,
95        }
96    }
97
98    /// Create with symmetric quantization
99    pub fn from_f32_symmetric(data: &[f32], shape: Vec<usize>) -> Self {
100        let abs_max = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
101        let params = QuantParams::symmetric(abs_max);
102        let quantized = quantize_with_params(data, params);
103
104        Self {
105            data: quantized,
106            params,
107            shape,
108        }
109    }
110
111    /// Dequantize to f32
112    pub fn to_f32(&self) -> Vec<f32> {
113        dequantize(&self.data, self.params)
114    }
115
116    /// Get size in bytes
117    pub fn size_bytes(&self) -> usize {
118        self.data.len() + std::mem::size_of::<QuantParams>() + self.shape.len() * std::mem::size_of::<usize>()
119    }
120
121    /// Calculate memory savings vs f32
122    pub fn compression_ratio(&self) -> f32 {
123        let f32_size = self.data.len() * std::mem::size_of::<f32>();
124        let quantized_size = self.size_bytes();
125        f32_size as f32 / quantized_size as f32
126    }
127}
128
129/// Per-channel quantization for conv/linear layers
130pub struct PerChannelQuant {
131    pub data: Vec<i8>,
132    pub params: Vec<QuantParams>,
133    pub shape: Vec<usize>,
134}
135
136impl PerChannelQuant {
137    /// Quantize with per-channel parameters
138    /// For a weight tensor of shape [out_channels, in_channels, ...],
139    /// use separate params for each output channel
140    pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Self {
141        if shape.is_empty() {
142            panic!("Shape cannot be empty");
143        }
144
145        let out_channels = shape[0];
146        let channel_size = data.len() / out_channels;
147
148        let mut all_quantized = Vec::with_capacity(data.len());
149        let mut params = Vec::with_capacity(out_channels);
150
151        for ch in 0..out_channels {
152            let start = ch * channel_size;
153            let end = start + channel_size;
154            let channel_data = &data[start..end];
155
156            let ch_params = QuantParams::from_data(channel_data);
157            let ch_quantized = quantize_with_params(channel_data, ch_params);
158
159            all_quantized.extend(ch_quantized);
160            params.push(ch_params);
161        }
162
163        Self {
164            data: all_quantized,
165            params,
166            shape,
167        }
168    }
169
170    /// Dequantize to f32
171    pub fn to_f32(&self) -> Vec<f32> {
172        let out_channels = self.shape[0];
173        let channel_size = self.data.len() / out_channels;
174
175        let mut result = Vec::with_capacity(self.data.len());
176
177        for ch in 0..out_channels {
178            let start = ch * channel_size;
179            let end = start + channel_size;
180            let channel_data = &self.data[start..end];
181            let ch_params = self.params[ch];
182
183            result.extend(dequantize(channel_data, ch_params));
184        }
185
186        result
187    }
188}
189
190/// Dynamic quantization - quantize at runtime
191pub struct DynamicQuantizer {
192    percentile: f32,
193}
194
195impl DynamicQuantizer {
196    /// Create quantizer with calibration percentile
197    /// percentile: clip values beyond this percentile (e.g., 99.9)
198    pub fn new(percentile: f32) -> Self {
199        Self { percentile }
200    }
201
202    /// Quantize with calibration
203    pub fn quantize(&self, data: &[f32]) -> (Vec<i8>, QuantParams) {
204        let mut sorted: Vec<f32> = data.iter().copied().collect();
205        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
206
207        let idx = ((sorted.len() as f32 * self.percentile / 100.0) as usize)
208            .min(sorted.len() - 1);
209
210        let min = -sorted[sorted.len() - idx];
211        let max = sorted[idx];
212
213        let params = QuantParams::from_range(min, max);
214        let quantized = quantize_with_params(data, params);
215
216        (quantized, params)
217    }
218}
219
220/// Calculate quantization error (MSE)
221pub fn quantization_error(original: &[f32], quantized: &[i8], params: QuantParams) -> f32 {
222    let dequantized = dequantize(quantized, params);
223
224    let mse: f32 = original
225        .iter()
226        .zip(dequantized.iter())
227        .map(|(o, d)| (o - d).powi(2))
228        .sum::<f32>() / original.len() as f32;
229
230    mse
231}
232
233/// Calculate signal-to-quantization-noise ratio (SQNR) in dB
234pub fn sqnr(original: &[f32], quantized: &[i8], params: QuantParams) -> f32 {
235    let dequantized = dequantize(quantized, params);
236
237    let signal_power: f32 = original.iter().map(|x| x.powi(2)).sum::<f32>() / original.len() as f32;
238    let noise_power: f32 = original
239        .iter()
240        .zip(dequantized.iter())
241        .map(|(o, d)| (o - d).powi(2))
242        .sum::<f32>() / original.len() as f32;
243
244    10.0 * (signal_power / noise_power).log10()
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_quantize_dequantize() {
253        let weights = vec![0.0, 0.5, 1.0, -0.5, -1.0];
254        let (quantized, params) = quantize_weights(&weights);
255        let dequantized = dequantize(&quantized, params);
256
257        // Check approximate equality
258        for (orig, deq) in weights.iter().zip(dequantized.iter()) {
259            assert!((orig - deq).abs() < 0.01, "orig: {}, deq: {}", orig, deq);
260        }
261    }
262
263    #[test]
264    fn test_symmetric_quantization() {
265        let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
266        let params = QuantParams::symmetric(1.0);
267
268        assert_eq!(params.zero_point, 0);
269        assert!((params.scale - 1.0 / 127.0).abs() < 1e-6);
270
271        let quantized = quantize_with_params(&data, params);
272        assert_eq!(quantized[2], 0); // 0.0 should map to 0
273    }
274
275    #[test]
276    fn test_quantized_tensor() {
277        let data = vec![1.0, 2.0, 3.0, 4.0];
278        let tensor = QuantizedTensor::from_f32(&data, vec![2, 2]);
279
280        assert_eq!(tensor.shape, vec![2, 2]);
281        assert_eq!(tensor.data.len(), 4);
282
283        let dequantized = tensor.to_f32();
284        for (orig, deq) in data.iter().zip(dequantized.iter()) {
285            assert!((orig - deq).abs() < 0.1);
286        }
287    }
288
289    #[test]
290    fn test_per_channel_quant() {
291        // 2 channels, 3 values each
292        let data = vec![
293            1.0, 2.0, 3.0,  // Channel 0
294            10.0, 20.0, 30.0, // Channel 1
295        ];
296
297        let quant = PerChannelQuant::from_f32(&data, vec![2, 3]);
298        assert_eq!(quant.params.len(), 2);
299
300        let dequantized = quant.to_f32();
301        for (orig, deq) in data.iter().zip(dequantized.iter()) {
302            assert!((orig - deq).abs() < 1.0);
303        }
304    }
305
306    #[test]
307    fn test_quantization_error() {
308        let original = vec![1.0, 2.0, 3.0, 4.0, 5.0];
309        let (quantized, params) = quantize_weights(&original);
310
311        let error = quantization_error(&original, &quantized, params);
312        assert!(error < 0.1); // Should be small for simple data
313
314        let snr = sqnr(&original, &quantized, params);
315        assert!(snr > 30.0); // Should have good SNR
316    }
317
318    #[test]
319    fn test_compression_ratio() {
320        let data: Vec<f32> = (0..1000).map(|i| i as f32 / 1000.0).collect();
321        let tensor = QuantizedTensor::from_f32(&data, vec![1000]);
322
323        let ratio = tensor.compression_ratio();
324        assert!(ratio > 3.5); // Should be ~4x compression
325    }
326
327    #[test]
328    fn test_dynamic_quantizer() {
329        let mut data: Vec<f32> = (0..100).map(|i| i as f32).collect();
330        data.push(1000.0); // Outlier
331
332        let quantizer = DynamicQuantizer::new(99.0);
333        let (quantized, params) = quantizer.quantize(&data);
334
335        assert_eq!(quantized.len(), 101);
336        // The outlier should be clipped
337        assert!(params.scale > 0.0);
338    }
339}