1#[cfg(target_arch = "x86_64")]
31use std::arch::x86_64::*;
32
33pub mod pi_constants {
35 use std::f32::consts::PI;
36
37 pub const PI_FRAC: f32 = PI - 3.0; pub const PI_SCALE: f32 = PI / 4.0; pub const PHI_APPROX: f32 = 2.0 / (PI - 1.0); pub const PI_DIGITS: [u8; 16] = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3];
48
49 #[inline]
51 pub fn anti_resonance(bits: u8) -> f32 {
52 PI_FRAC / (1u32 << bits) as f32
53 }
54
55 #[inline]
57 pub fn jitter(index: usize) -> f32 {
58 let digit = PI_DIGITS[index % 16];
59 (digit as f32) * 0.001 * PI_FRAC
60 }
61}
62
63#[derive(Debug, Clone, Copy)]
65pub struct QuantParams {
66 pub scale: f32,
68 pub zero_point: i8,
70 pub anti_resonance: f32,
72 pub bits: u8,
74}
75
76impl QuantParams {
77 pub fn symmetric(min_val: f32, max_val: f32) -> Self {
81 let abs_max = min_val.abs().max(max_val.abs());
82
83 let bits = 7u8;
85 let qmax = 127.0f32;
86
87 let anti_resonance = pi_constants::anti_resonance(bits);
89 let scale = (abs_max + anti_resonance) / qmax;
90
91 Self {
92 scale: scale.max(1e-10), zero_point: 0,
94 anti_resonance,
95 bits,
96 }
97 }
98
99 pub fn asymmetric(min_val: f32, max_val: f32) -> Self {
103 let bits = 8u8;
104 let qmin = -128.0f32;
105 let qmax = 127.0f32;
106
107 let anti_resonance = pi_constants::anti_resonance(bits);
108 let range = (max_val - min_val).max(1e-10) + anti_resonance;
109 let scale = range / (qmax - qmin);
110
111 let zero_point_float = qmin - min_val / scale + pi_constants::jitter(0);
113 let zero_point = zero_point_float.round().clamp(-128.0, 127.0) as i8;
114
115 Self {
116 scale: scale.max(1e-10),
117 zero_point,
118 anti_resonance,
119 bits,
120 }
121 }
122
123 #[inline]
125 pub fn quantize(&self, value: f32) -> i8 {
126 let scaled = value / self.scale + self.zero_point as f32;
127 let rounded = (scaled + self.anti_resonance * 0.5).round();
129 rounded.clamp(-128.0, 127.0) as i8
130 }
131
132 #[inline]
134 pub fn dequantize(&self, quantized: i8) -> f32 {
135 (quantized as f32 - self.zero_point as f32) * self.scale
136 }
137}
138
139impl Default for QuantParams {
140 fn default() -> Self {
141 Self::symmetric(-1.0, 1.0)
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct PerChannelQuantParams {
148 pub scales: Vec<f32>,
150 pub zero_points: Vec<i8>,
152 pub num_channels: usize,
154}
155
156impl PerChannelQuantParams {
157 pub fn symmetric_per_channel(weights: &[f32], out_channels: usize, in_channels: usize) -> Self {
159 let kernel_size = weights.len() / (out_channels * in_channels);
160 let mut scales = Vec::with_capacity(out_channels);
161 let zero_points = vec![0i8; out_channels];
162
163 for oc in 0..out_channels {
164 let start = oc * in_channels * kernel_size;
165 let end = start + in_channels * kernel_size;
166 let channel_weights = &weights[start..end];
167
168 let abs_max = channel_weights
169 .iter()
170 .map(|x| x.abs())
171 .fold(0.0f32, |a, b| a.max(b));
172
173 let anti_res = pi_constants::anti_resonance(7);
174 let scale = (abs_max + anti_res) / 127.0;
175 scales.push(scale.max(1e-10));
176 }
177
178 Self {
179 scales,
180 zero_points,
181 num_channels: out_channels,
182 }
183 }
184
185 #[inline]
187 pub fn channel_params(&self, channel: usize) -> QuantParams {
188 QuantParams {
189 scale: self.scales[channel],
190 zero_point: self.zero_points[channel],
191 anti_resonance: pi_constants::anti_resonance(7),
192 bits: 7,
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct QuantizedTensor {
200 pub data: Vec<i8>,
202 pub shape: Vec<usize>,
204 pub params: QuantizationType,
206}
207
208#[derive(Debug, Clone)]
210pub enum QuantizationType {
211 PerTensor(QuantParams),
213 PerChannel(PerChannelQuantParams),
215}
216
217impl QuantizedTensor {
218 pub fn from_float_symmetric(data: &[f32], shape: &[usize]) -> Self {
220 let min_val = data.iter().fold(f32::MAX, |a, &b| a.min(b));
221 let max_val = data.iter().fold(f32::MIN, |a, &b| a.max(b));
222 let params = QuantParams::symmetric(min_val, max_val);
223
224 let quantized: Vec<i8> = data.iter().map(|&v| params.quantize(v)).collect();
225
226 Self {
227 data: quantized,
228 shape: shape.to_vec(),
229 params: QuantizationType::PerTensor(params),
230 }
231 }
232
233 pub fn from_weights_per_channel(
235 weights: &[f32],
236 out_channels: usize,
237 in_channels: usize,
238 kernel_h: usize,
239 kernel_w: usize,
240 ) -> Self {
241 let per_channel = PerChannelQuantParams::symmetric_per_channel(weights, out_channels, in_channels);
242 let kernel_size = kernel_h * kernel_w;
243
244 let mut quantized = Vec::with_capacity(weights.len());
245
246 for oc in 0..out_channels {
247 let params = per_channel.channel_params(oc);
248 let start = oc * in_channels * kernel_size;
249 let end = start + in_channels * kernel_size;
250
251 for &w in &weights[start..end] {
252 quantized.push(params.quantize(w));
253 }
254 }
255
256 Self {
257 data: quantized,
258 shape: vec![out_channels, in_channels, kernel_h, kernel_w],
259 params: QuantizationType::PerChannel(per_channel),
260 }
261 }
262
263 pub fn dequantize(&self) -> Vec<f32> {
265 match &self.params {
266 QuantizationType::PerTensor(params) => {
267 self.data.iter().map(|&q| params.dequantize(q)).collect()
268 }
269 QuantizationType::PerChannel(per_channel) => {
270 let out_channels = self.shape[0];
271 let channel_size = self.data.len() / out_channels;
272 let mut output = Vec::with_capacity(self.data.len());
273
274 for oc in 0..out_channels {
275 let params = per_channel.channel_params(oc);
276 let start = oc * channel_size;
277 let end = start + channel_size;
278
279 for &q in &self.data[start..end] {
280 output.push(params.dequantize(q));
281 }
282 }
283 output
284 }
285 }
286 }
287
288 pub fn len(&self) -> usize {
290 self.data.len()
291 }
292
293 pub fn is_empty(&self) -> bool {
295 self.data.is_empty()
296 }
297}
298
299pub fn quantize_batch(input: &[f32], output: &mut [i8], params: &QuantParams) {
303 debug_assert_eq!(input.len(), output.len());
304
305 let inv_scale = 1.0 / params.scale;
306 let zp = params.zero_point as f32;
307 let anti_res = params.anti_resonance * 0.5;
308
309 for (i, &val) in input.iter().enumerate() {
310 let scaled = val * inv_scale + zp + anti_res;
311 output[i] = scaled.round().clamp(-128.0, 127.0) as i8;
312 }
313}
314
315pub fn dequantize_batch(input: &[i8], output: &mut [f32], params: &QuantParams) {
317 debug_assert_eq!(input.len(), output.len());
318
319 let zp = params.zero_point as f32;
320
321 for (i, &q) in input.iter().enumerate() {
322 output[i] = (q as f32 - zp) * params.scale;
323 }
324}
325
326#[cfg(target_arch = "x86_64")]
328#[target_feature(enable = "avx2")]
329pub unsafe fn quantize_batch_avx2(input: &[f32], output: &mut [i8], params: &QuantParams) {
330 let len = input.len();
331 let chunks = len / 8;
332
333 let inv_scale = _mm256_set1_ps(1.0 / params.scale);
334 let zp = _mm256_set1_ps(params.zero_point as f32);
335 let anti_res = _mm256_set1_ps(params.anti_resonance * 0.5);
336 let half = _mm256_set1_ps(0.5);
337 let min_val = _mm256_set1_ps(-128.0);
338 let max_val = _mm256_set1_ps(127.0);
339
340 for i in 0..chunks {
341 let offset = i * 8;
342
343 let v = _mm256_loadu_ps(input.as_ptr().add(offset));
345
346 let scaled = _mm256_add_ps(_mm256_mul_ps(v, inv_scale), zp);
348 let adjusted = _mm256_add_ps(scaled, anti_res);
349
350 let rounded = _mm256_round_ps(adjusted, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
352
353 let clamped = _mm256_min_ps(_mm256_max_ps(rounded, min_val), max_val);
355
356 let i32_vals = _mm256_cvtps_epi32(clamped);
358
359 let i32_array: [i32; 8] = std::mem::transmute(i32_vals);
361 for j in 0..8 {
362 output[offset + j] = i32_array[j] as i8;
363 }
364 }
365
366 let remainder_start = chunks * 8;
368 for i in remainder_start..len {
369 let scaled = input[i] / params.scale + params.zero_point as f32 + params.anti_resonance * 0.5;
370 output[i] = scaled.round().clamp(-128.0, 127.0) as i8;
371 }
372}
373
374#[cfg(target_arch = "x86_64")]
376#[target_feature(enable = "avx2")]
377pub unsafe fn dequantize_batch_avx2(input: &[i8], output: &mut [f32], params: &QuantParams) {
378 let len = input.len();
379 let chunks = len / 8;
380
381 let scale = _mm256_set1_ps(params.scale);
382 let zp = _mm256_set1_ps(params.zero_point as f32);
383
384 for i in 0..chunks {
385 let offset = i * 8;
386
387 let mut i32_array = [0i32; 8];
389 for j in 0..8 {
390 i32_array[j] = input[offset + j] as i32;
391 }
392 let i32_vals: __m256i = std::mem::transmute(i32_array);
393 let f32_vals = _mm256_cvtepi32_ps(i32_vals);
394
395 let shifted = _mm256_sub_ps(f32_vals, zp);
397 let result = _mm256_mul_ps(shifted, scale);
398
399 _mm256_storeu_ps(output.as_mut_ptr().add(offset), result);
400 }
401
402 let remainder_start = chunks * 8;
404 for i in remainder_start..len {
405 output[i] = (input[i] as f32 - params.zero_point as f32) * params.scale;
406 }
407}
408
409#[cfg(not(target_arch = "x86_64"))]
411pub unsafe fn quantize_batch_avx2(_input: &[f32], _output: &mut [i8], _params: &QuantParams) {}
412
413#[cfg(not(target_arch = "x86_64"))]
414pub unsafe fn dequantize_batch_avx2(_input: &[i8], _output: &mut [f32], _params: &QuantParams) {}
415
416#[inline(always)]
418pub fn quantize_simd(input: &[f32], output: &mut [i8], params: &QuantParams) {
419 #[cfg(target_arch = "x86_64")]
420 {
421 if is_x86_feature_detected!("avx2") {
422 unsafe {
423 quantize_batch_avx2(input, output, params);
424 }
425 return;
426 }
427 }
428 quantize_batch(input, output, params);
429}
430
431#[inline(always)]
433pub fn dequantize_simd(input: &[i8], output: &mut [f32], params: &QuantParams) {
434 #[cfg(target_arch = "x86_64")]
435 {
436 if is_x86_feature_detected!("avx2") {
437 unsafe {
438 dequantize_batch_avx2(input, output, params);
439 }
440 return;
441 }
442 }
443 dequantize_batch(input, output, params);
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 #[test]
451 fn test_symmetric_quantization() {
452 let params = QuantParams::symmetric(-1.0, 1.0);
453
454 let q = params.quantize(0.5);
455 let dq = params.dequantize(q);
456
457 assert!((0.5 - dq).abs() < 0.02);
459 }
460
461 #[test]
462 fn test_asymmetric_quantization() {
463 let params = QuantParams::asymmetric(0.0, 1.0);
464
465 let q = params.quantize(0.5);
466 let dq = params.dequantize(q);
467
468 assert!((0.5 - dq).abs() < 0.02);
469 }
470
471 #[test]
472 fn test_pi_anti_resonance() {
473 let anti_res = pi_constants::anti_resonance(8);
474 assert!(anti_res > 0.0);
475 assert!(anti_res < 0.001);
476
477 let expected = (PI - 3.0) / 256.0;
479 assert!((anti_res - expected).abs() < 1e-10);
480 }
481
482 #[test]
483 fn test_quantized_tensor_roundtrip() {
484 let data = vec![0.1, 0.2, 0.3, 0.4, -0.1, -0.2, -0.3, -0.4];
485 let shape = vec![2, 4];
486
487 let quantized = QuantizedTensor::from_float_symmetric(&data, &shape);
488 let dequantized = quantized.dequantize();
489
490 for (original, recovered) in data.iter().zip(dequantized.iter()) {
492 assert!((original - recovered).abs() < 0.02);
493 }
494 }
495
496 #[test]
497 fn test_per_channel_quantization() {
498 let weights: Vec<f32> = (0..36).map(|i| (i as f32 - 18.0) * 0.1).collect();
500
501 let quantized = QuantizedTensor::from_weights_per_channel(&weights, 2, 2, 3, 3);
502 let dequantized = quantized.dequantize();
503
504 let max_error: f32 = weights
506 .iter()
507 .zip(dequantized.iter())
508 .map(|(a, b)| (a - b).abs())
509 .fold(0.0f32, |a, b| a.max(b));
510
511 assert!(max_error < 0.05);
512 }
513
514 #[test]
515 fn test_batch_quantize() {
516 let input = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
517 let mut output = vec![0i8; 8];
518 let params = QuantParams::symmetric(-1.0, 1.0);
519
520 quantize_batch(&input, &mut output, ¶ms);
521
522 for &q in &output {
524 assert!(q >= -128 && q <= 127);
525 }
526 }
527
528 #[test]
529 fn test_batch_dequantize() {
530 let input = vec![10i8, 20, 30, 40, -10, -20, -30, -40];
531 let mut output = vec![0.0f32; 8];
532 let params = QuantParams::symmetric(-1.0, 1.0);
533
534 dequantize_batch(&input, &mut output, ¶ms);
535
536 assert!(output[0] > 0.0);
538 assert!(output[4] < 0.0);
539 }
540
541 #[test]
542 fn test_simd_dispatch() {
543 let input = vec![0.1f32; 16];
544 let mut output = vec![0i8; 16];
545 let params = QuantParams::symmetric(-1.0, 1.0);
546
547 quantize_simd(&input, &mut output, ¶ms);
548
549 let first = output[0];
551 for &q in &output {
552 assert_eq!(q, first);
553 }
554 }
555}