Skip to main content

xz_embed/quantize/
scalar.rs

1use crate::quantize::VectorQuantizer;
2
3/// 标量量化(Scalar Quantization)
4///
5/// 将 float32 压缩到 u8(每维度 1 byte,压缩比 4:1)
6#[derive(Debug)]
7pub struct ScalarQuantizer {
8    /// 每个维度的 min/max
9    ranges: Vec<(f32, f32)>,
10    /// 量化位数 (8 = u8)
11    bits: usize,
12}
13
14impl ScalarQuantizer {
15    pub fn new(ranges: Vec<(f32, f32)>, bits: usize) -> Self {
16        Self { ranges, bits }
17    }
18
19    /// 从样本计算各维度的 min/max
20    pub fn from_samples(samples: &[Vec<f32>], bits: usize) -> Self {
21        if samples.is_empty() {
22            return Self { ranges: vec![], bits };
23        }
24
25        let dim = samples[0].len();
26        let mut ranges = Vec::with_capacity(dim);
27
28        for d in 0..dim {
29            let mut min = f32::MAX;
30            let mut max = f32::MIN;
31            for sample in samples {
32                let val = sample[d];
33                if val < min { min = val; }
34                if val > max { max = val; }
35            }
36            ranges.push((min, max));
37        }
38
39        Self { ranges, bits }
40    }
41
42    fn quantize_value(&self, value: f32, min: f32, max: f32) -> u8 {
43        let range = max - min;
44        if range < f32::EPSILON {
45            return 0;
46        }
47        let normalized = (value - min) / range;
48        let max_val = (1u32 << self.bits) - 1;
49        (normalized * max_val as f32).round().clamp(0.0, max_val as f32) as u8
50    }
51
52    fn dequantize_value(&self, q: u8, min: f32, max: f32) -> f32 {
53        let max_val = (1u32 << self.bits) - 1;
54        let normalized = q as f32 / max_val as f32;
55        min + normalized * (max - min)
56    }
57}
58
59impl VectorQuantizer for ScalarQuantizer {
60    fn compress(&self, vectors: &[Vec<f32>]) -> Vec<Vec<u8>> {
61        if self.ranges.is_empty() {
62            return vec![];
63        }
64
65        vectors.iter().map(|v| {
66            v.iter().enumerate().map(|(d, &val)| {
67                let (min, max) = self.ranges[d];
68                self.quantize_value(val, min, max)
69            }).collect()
70        }).collect()
71    }
72
73    fn decompress(&self, quantized: &[Vec<u8>]) -> Vec<Vec<f32>> {
74        if self.ranges.is_empty() {
75            return vec![];
76        }
77
78        quantized.iter().map(|q| {
79            q.iter().enumerate().map(|(d, &qval)| {
80                let (min, max) = self.ranges[d];
81                self.dequantize_value(qval, min, max)
82            }).collect()
83        }).collect()
84    }
85}