ruvector_cnn/quantize/
calibration.rs1use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
9pub struct CalibrationHistogram {
10 pub min_val: f32,
11 pub max_val: f32,
12 pub num_bins: usize,
13 pub bins: Vec<u64>,
14}
15
16impl CalibrationHistogram {
17 pub fn new(min_val: f32, max_val: f32, num_bins: usize) -> Self {
18 Self {
19 min_val,
20 max_val,
21 num_bins,
22 bins: vec![0; num_bins],
23 }
24 }
25
26 pub fn add(&mut self, value: f32) {
28 if value < self.min_val || value > self.max_val {
29 return;
30 }
31
32 let bin_width = (self.max_val - self.min_val) / self.num_bins as f32;
33 let bin_idx = ((value - self.min_val) / bin_width) as usize;
34 let bin_idx = bin_idx.min(self.num_bins - 1);
35 self.bins[bin_idx] += 1;
36 }
37
38 pub fn compute_quantization_params(&self) -> QuantizationParams {
40 let abs_max = self.max_val.abs().max(self.min_val.abs());
42 let scale = abs_max / 127.0;
43 let zero_point = 0; QuantizationParams {
46 scale,
47 zero_point,
48 min_val: self.min_val,
49 max_val: self.max_val,
50 num_bins: self.num_bins,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy)]
57pub struct QuantizationParams {
58 pub scale: f32,
59 pub zero_point: i32,
60 pub min_val: f32,
61 pub max_val: f32,
62 pub num_bins: usize,
63}
64
65pub struct Quantizer {
67 params: HashMap<String, QuantizationParams>,
68}
69
70impl Quantizer {
71 pub fn new() -> Self {
72 Self {
73 params: HashMap::new(),
74 }
75 }
76
77 pub fn add_params(&mut self, name: String, params: QuantizationParams) {
78 self.params.insert(name, params);
79 }
80
81 pub fn quantize(&self, name: &str, values: &[f32]) -> Vec<i8> {
82 let params = self.params.get(name).expect("No params for tensor");
83 values
84 .iter()
85 .map(|&v| {
86 let q = (v / params.scale).round() as i32 + params.zero_point;
87 q.clamp(-128, 127) as i8
88 })
89 .collect()
90 }
91
92 pub fn dequantize(&self, name: &str, values: &[i8]) -> Vec<f32> {
93 let params = self.params.get(name).expect("No params for tensor");
94 values
95 .iter()
96 .map(|&v| (v as i32 - params.zero_point) as f32 * params.scale)
97 .collect()
98 }
99}
100
101impl Default for Quantizer {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn test_histogram_calibration() {
113 let mut hist = CalibrationHistogram::new(-10.0, 10.0, 100);
114
115 for _ in 0..100 {
117 hist.add(5.0);
118 }
119 for _ in 0..50 {
120 hist.add(-5.0);
121 }
122
123 let params = hist.compute_quantization_params();
124 assert!((params.scale - 10.0 / 127.0).abs() < 0.01);
125 assert_eq!(params.zero_point, 0);
126 }
127
128 #[test]
129 fn test_quantizer() {
130 let mut quantizer = Quantizer::new();
131 quantizer.add_params(
132 "test".to_string(),
133 QuantizationParams {
134 scale: 0.1,
135 zero_point: 0,
136 min_val: -12.8,
137 max_val: 12.7,
138 num_bins: 256,
139 },
140 );
141
142 let values = vec![0.0, 1.0, -1.0, 12.7, -12.8];
143 let quantized = quantizer.quantize("test", &values);
144 let dequantized = quantizer.dequantize("test", &quantized);
145
146 for (orig, deq) in values.iter().zip(dequantized.iter()) {
147 assert!((orig - deq).abs() < 0.2); }
149 }
150}