1pub mod kernels;
7
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
12pub struct QuantParams {
13 pub scale: f32,
15 pub zero_point: i8,
17}
18
19impl QuantParams {
20 pub fn from_tensor(tensor: &[f32]) -> Self {
22 if tensor.is_empty() {
23 return Self {
24 scale: 1.0,
25 zero_point: 0,
26 };
27 }
28
29 let min_val = tensor.iter().copied().fold(f32::INFINITY, f32::min);
30 let max_val = tensor.iter().copied().fold(f32::NEG_INFINITY, f32::max);
31
32 if (max_val - min_val).abs() < 1e-10 {
34 return Self {
35 scale: 1.0,
36 zero_point: 0,
37 };
38 }
39
40 let scale = (max_val - min_val) / 255.0;
43 let zero_point = (-128.0 - min_val / scale).round() as i8;
44
45 Self { scale, zero_point }
46 }
47}
48
49pub fn quantize_tensor(fp32: &[f32], params: &QuantParams) -> Vec<i8> {
51 fp32.iter()
52 .map(|&x| {
53 let quantized = (x / params.scale + params.zero_point as f32).round();
54 quantized.clamp(-128.0, 127.0) as i8
55 })
56 .collect()
57}
58
59pub fn dequantize_tensor(int8: &[i8], params: &QuantParams) -> Vec<f32> {
61 int8.iter()
62 .map(|&x| (x as f32 - params.zero_point as f32) * params.scale)
63 .collect()
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69
70 #[test]
71 fn test_quant_params_from_tensor() {
72 let tensor = vec![-1.0, 0.0, 1.0, 2.0];
73 let params = QuantParams::from_tensor(&tensor);
74
75 assert!(params.scale > 0.0);
76 assert!(params.scale.is_finite());
77 assert!(params.zero_point >= -128 && params.zero_point <= 127);
78 }
79
80 #[test]
81 fn test_quantize_dequantize_roundtrip() {
82 let fp32 = vec![0.5, -0.3, 0.8, -0.1];
83 let params = QuantParams::from_tensor(&fp32);
84 let int8 = quantize_tensor(&fp32, ¶ms);
85 let dequant = dequantize_tensor(&int8, ¶ms);
86
87 for (orig, recovered) in fp32.iter().zip(dequant.iter()) {
88 let error = (orig - recovered).abs();
89 assert!(error < 0.1, "Roundtrip error too large: {}", error);
90 }
91 }
92
93 #[test]
94 fn test_empty_tensor() {
95 let empty: Vec<f32> = vec![];
96 let params = QuantParams::from_tensor(&empty);
97 assert_eq!(params.scale, 1.0);
98 assert_eq!(params.zero_point, 0);
99 }
100
101 #[test]
102 fn test_constant_tensor() {
103 let constant = vec![0.5; 100];
104 let params = QuantParams::from_tensor(&constant);
105 assert_eq!(params.scale, 1.0);
106 assert_eq!(params.zero_point, 0);
107 }
108}