ruvector_scipix/optimize/
quantize.rs1use std::f32;
7
8#[derive(Debug, Clone, Copy)]
10pub struct QuantParams {
11 pub scale: f32,
12 pub zero_point: i8,
13}
14
15impl QuantParams {
16 pub fn from_range(min: f32, max: f32) -> Self {
18 let qmin = i8::MIN as f32;
19 let qmax = i8::MAX as f32;
20
21 let scale = (max - min) / (qmax - qmin);
22 let zero_point = (qmin - min / scale).round() as i8;
23
24 Self { scale, zero_point }
25 }
26
27 pub fn from_data(data: &[f32]) -> Self {
29 let min = data.iter().copied().fold(f32::INFINITY, f32::min);
30 let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
31 Self::from_range(min, max)
32 }
33
34 pub fn symmetric(abs_max: f32) -> Self {
36 let scale = abs_max / 127.0;
37 Self {
38 scale,
39 zero_point: 0,
40 }
41 }
42}
43
44pub fn quantize_weights(weights: &[f32]) -> (Vec<i8>, QuantParams) {
46 let params = QuantParams::from_data(weights);
47 let quantized = quantize_with_params(weights, params);
48 (quantized, params)
49}
50
51pub fn quantize_with_params(weights: &[f32], params: QuantParams) -> Vec<i8> {
53 weights
54 .iter()
55 .map(|&w| quantize_value(w, params))
56 .collect()
57}
58
59#[inline]
61pub fn quantize_value(value: f32, params: QuantParams) -> i8 {
62 let scaled = value / params.scale + params.zero_point as f32;
63 scaled.round().clamp(i8::MIN as f32, i8::MAX as f32) as i8
64}
65
66pub fn dequantize(quantized: &[i8], params: QuantParams) -> Vec<f32> {
68 quantized
69 .iter()
70 .map(|&q| dequantize_value(q, params))
71 .collect()
72}
73
74#[inline]
76pub fn dequantize_value(quantized: i8, params: QuantParams) -> f32 {
77 (quantized as f32 - params.zero_point as f32) * params.scale
78}
79
80pub struct QuantizedTensor {
82 pub data: Vec<i8>,
83 pub params: QuantParams,
84 pub shape: Vec<usize>,
85}
86
87impl QuantizedTensor {
88 pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Self {
90 let (quantized, params) = quantize_weights(data);
91 Self {
92 data: quantized,
93 params,
94 shape,
95 }
96 }
97
98 pub fn from_f32_symmetric(data: &[f32], shape: Vec<usize>) -> Self {
100 let abs_max = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
101 let params = QuantParams::symmetric(abs_max);
102 let quantized = quantize_with_params(data, params);
103
104 Self {
105 data: quantized,
106 params,
107 shape,
108 }
109 }
110
111 pub fn to_f32(&self) -> Vec<f32> {
113 dequantize(&self.data, self.params)
114 }
115
116 pub fn size_bytes(&self) -> usize {
118 self.data.len() + std::mem::size_of::<QuantParams>() + self.shape.len() * std::mem::size_of::<usize>()
119 }
120
121 pub fn compression_ratio(&self) -> f32 {
123 let f32_size = self.data.len() * std::mem::size_of::<f32>();
124 let quantized_size = self.size_bytes();
125 f32_size as f32 / quantized_size as f32
126 }
127}
128
129pub struct PerChannelQuant {
131 pub data: Vec<i8>,
132 pub params: Vec<QuantParams>,
133 pub shape: Vec<usize>,
134}
135
136impl PerChannelQuant {
137 pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Self {
141 if shape.is_empty() {
142 panic!("Shape cannot be empty");
143 }
144
145 let out_channels = shape[0];
146 let channel_size = data.len() / out_channels;
147
148 let mut all_quantized = Vec::with_capacity(data.len());
149 let mut params = Vec::with_capacity(out_channels);
150
151 for ch in 0..out_channels {
152 let start = ch * channel_size;
153 let end = start + channel_size;
154 let channel_data = &data[start..end];
155
156 let ch_params = QuantParams::from_data(channel_data);
157 let ch_quantized = quantize_with_params(channel_data, ch_params);
158
159 all_quantized.extend(ch_quantized);
160 params.push(ch_params);
161 }
162
163 Self {
164 data: all_quantized,
165 params,
166 shape,
167 }
168 }
169
170 pub fn to_f32(&self) -> Vec<f32> {
172 let out_channels = self.shape[0];
173 let channel_size = self.data.len() / out_channels;
174
175 let mut result = Vec::with_capacity(self.data.len());
176
177 for ch in 0..out_channels {
178 let start = ch * channel_size;
179 let end = start + channel_size;
180 let channel_data = &self.data[start..end];
181 let ch_params = self.params[ch];
182
183 result.extend(dequantize(channel_data, ch_params));
184 }
185
186 result
187 }
188}
189
190pub struct DynamicQuantizer {
192 percentile: f32,
193}
194
195impl DynamicQuantizer {
196 pub fn new(percentile: f32) -> Self {
199 Self { percentile }
200 }
201
202 pub fn quantize(&self, data: &[f32]) -> (Vec<i8>, QuantParams) {
204 let mut sorted: Vec<f32> = data.iter().copied().collect();
205 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
206
207 let idx = ((sorted.len() as f32 * self.percentile / 100.0) as usize)
208 .min(sorted.len() - 1);
209
210 let min = -sorted[sorted.len() - idx];
211 let max = sorted[idx];
212
213 let params = QuantParams::from_range(min, max);
214 let quantized = quantize_with_params(data, params);
215
216 (quantized, params)
217 }
218}
219
220pub fn quantization_error(original: &[f32], quantized: &[i8], params: QuantParams) -> f32 {
222 let dequantized = dequantize(quantized, params);
223
224 let mse: f32 = original
225 .iter()
226 .zip(dequantized.iter())
227 .map(|(o, d)| (o - d).powi(2))
228 .sum::<f32>() / original.len() as f32;
229
230 mse
231}
232
233pub fn sqnr(original: &[f32], quantized: &[i8], params: QuantParams) -> f32 {
235 let dequantized = dequantize(quantized, params);
236
237 let signal_power: f32 = original.iter().map(|x| x.powi(2)).sum::<f32>() / original.len() as f32;
238 let noise_power: f32 = original
239 .iter()
240 .zip(dequantized.iter())
241 .map(|(o, d)| (o - d).powi(2))
242 .sum::<f32>() / original.len() as f32;
243
244 10.0 * (signal_power / noise_power).log10()
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_quantize_dequantize() {
253 let weights = vec![0.0, 0.5, 1.0, -0.5, -1.0];
254 let (quantized, params) = quantize_weights(&weights);
255 let dequantized = dequantize(&quantized, params);
256
257 for (orig, deq) in weights.iter().zip(dequantized.iter()) {
259 assert!((orig - deq).abs() < 0.01, "orig: {}, deq: {}", orig, deq);
260 }
261 }
262
263 #[test]
264 fn test_symmetric_quantization() {
265 let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
266 let params = QuantParams::symmetric(1.0);
267
268 assert_eq!(params.zero_point, 0);
269 assert!((params.scale - 1.0 / 127.0).abs() < 1e-6);
270
271 let quantized = quantize_with_params(&data, params);
272 assert_eq!(quantized[2], 0); }
274
275 #[test]
276 fn test_quantized_tensor() {
277 let data = vec![1.0, 2.0, 3.0, 4.0];
278 let tensor = QuantizedTensor::from_f32(&data, vec![2, 2]);
279
280 assert_eq!(tensor.shape, vec![2, 2]);
281 assert_eq!(tensor.data.len(), 4);
282
283 let dequantized = tensor.to_f32();
284 for (orig, deq) in data.iter().zip(dequantized.iter()) {
285 assert!((orig - deq).abs() < 0.1);
286 }
287 }
288
289 #[test]
290 fn test_per_channel_quant() {
291 let data = vec![
293 1.0, 2.0, 3.0, 10.0, 20.0, 30.0, ];
296
297 let quant = PerChannelQuant::from_f32(&data, vec![2, 3]);
298 assert_eq!(quant.params.len(), 2);
299
300 let dequantized = quant.to_f32();
301 for (orig, deq) in data.iter().zip(dequantized.iter()) {
302 assert!((orig - deq).abs() < 1.0);
303 }
304 }
305
306 #[test]
307 fn test_quantization_error() {
308 let original = vec![1.0, 2.0, 3.0, 4.0, 5.0];
309 let (quantized, params) = quantize_weights(&original);
310
311 let error = quantization_error(&original, &quantized, params);
312 assert!(error < 0.1); let snr = sqnr(&original, &quantized, params);
315 assert!(snr > 30.0); }
317
318 #[test]
319 fn test_compression_ratio() {
320 let data: Vec<f32> = (0..1000).map(|i| i as f32 / 1000.0).collect();
321 let tensor = QuantizedTensor::from_f32(&data, vec![1000]);
322
323 let ratio = tensor.compression_ratio();
324 assert!(ratio > 3.5); }
326
327 #[test]
328 fn test_dynamic_quantizer() {
329 let mut data: Vec<f32> = (0..100).map(|i| i as f32).collect();
330 data.push(1000.0); let quantizer = DynamicQuantizer::new(99.0);
333 let (quantized, params) = quantizer.quantize(&data);
334
335 assert_eq!(quantized.len(), 101);
336 assert!(params.scale > 0.0);
338 }
339}