Skip to main content

ruvector_cnn/quantize/
calibration.rs

1//! Calibration and Quantization Parameters (ADR-091 Phase 2)
2//!
3//! This module provides histogram-based calibration for INT8 quantization.
4
5use std::collections::HashMap;
6
7/// Histogram for calibration data collection
8#[derive(Debug, Clone)]
9pub struct CalibrationHistogram {
10    pub min_val: f32,
11    pub max_val: f32,
12    pub num_bins: usize,
13    pub bins: Vec<u64>,
14}
15
16impl CalibrationHistogram {
17    pub fn new(min_val: f32, max_val: f32, num_bins: usize) -> Self {
18        Self {
19            min_val,
20            max_val,
21            num_bins,
22            bins: vec![0; num_bins],
23        }
24    }
25
26    /// Add a value to the histogram
27    pub fn add(&mut self, value: f32) {
28        if value < self.min_val || value > self.max_val {
29            return;
30        }
31
32        let bin_width = (self.max_val - self.min_val) / self.num_bins as f32;
33        let bin_idx = ((value - self.min_val) / bin_width) as usize;
34        let bin_idx = bin_idx.min(self.num_bins - 1);
35        self.bins[bin_idx] += 1;
36    }
37
38    /// Compute quantization parameters from histogram
39    pub fn compute_quantization_params(&self) -> QuantizationParams {
40        // Use min/max for symmetric quantization
41        let abs_max = self.max_val.abs().max(self.min_val.abs());
42        let scale = abs_max / 127.0;
43        let zero_point = 0; // Symmetric quantization
44
45        QuantizationParams {
46            scale,
47            zero_point,
48            min_val: self.min_val,
49            max_val: self.max_val,
50            num_bins: self.num_bins,
51        }
52    }
53}
54
55/// Quantization parameters for a tensor
56#[derive(Debug, Clone, Copy)]
57pub struct QuantizationParams {
58    pub scale: f32,
59    pub zero_point: i32,
60    pub min_val: f32,
61    pub max_val: f32,
62    pub num_bins: usize,
63}
64
65/// Quantizer for converting FP32 to INT8 and back
66pub struct Quantizer {
67    params: HashMap<String, QuantizationParams>,
68}
69
70impl Quantizer {
71    pub fn new() -> Self {
72        Self {
73            params: HashMap::new(),
74        }
75    }
76
77    pub fn add_params(&mut self, name: String, params: QuantizationParams) {
78        self.params.insert(name, params);
79    }
80
81    pub fn quantize(&self, name: &str, values: &[f32]) -> Vec<i8> {
82        let params = self.params.get(name).expect("No params for tensor");
83        values
84            .iter()
85            .map(|&v| {
86                let q = (v / params.scale).round() as i32 + params.zero_point;
87                q.clamp(-128, 127) as i8
88            })
89            .collect()
90    }
91
92    pub fn dequantize(&self, name: &str, values: &[i8]) -> Vec<f32> {
93        let params = self.params.get(name).expect("No params for tensor");
94        values
95            .iter()
96            .map(|&v| (v as i32 - params.zero_point) as f32 * params.scale)
97            .collect()
98    }
99}
100
101impl Default for Quantizer {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_histogram_calibration() {
113        let mut hist = CalibrationHistogram::new(-10.0, 10.0, 100);
114
115        // Add some values
116        for _ in 0..100 {
117            hist.add(5.0);
118        }
119        for _ in 0..50 {
120            hist.add(-5.0);
121        }
122
123        let params = hist.compute_quantization_params();
124        assert!((params.scale - 10.0 / 127.0).abs() < 0.01);
125        assert_eq!(params.zero_point, 0);
126    }
127
128    #[test]
129    fn test_quantizer() {
130        let mut quantizer = Quantizer::new();
131        quantizer.add_params(
132            "test".to_string(),
133            QuantizationParams {
134                scale: 0.1,
135                zero_point: 0,
136                min_val: -12.8,
137                max_val: 12.7,
138                num_bins: 256,
139            },
140        );
141
142        let values = vec![0.0, 1.0, -1.0, 12.7, -12.8];
143        let quantized = quantizer.quantize("test", &values);
144        let dequantized = quantizer.dequantize("test", &quantized);
145
146        for (orig, deq) in values.iter().zip(dequantized.iter()) {
147            assert!((orig - deq).abs() < 0.2); // Allow some quantization error
148        }
149    }
150}