Skip to main content

xz_embed/quantize/
product.rs

1use crate::quantize::VectorQuantizer;
2
3/// 乘积量化(Product Quantization)
4#[derive(Debug)]
5pub struct ProductQuantizer {
6    /// 子向量数
7    num_sub_vectors: usize,
8    /// 每个子向量的位数
9    bits_per_sub_vector: usize,
10    /// 训练好的码本: [sub_vec_idx][code_idx][dim]
11    codebooks: Vec<Vec<Vec<f32>>>,
12}
13
14impl ProductQuantizer {
15    /// 从样本训练 PQ
16    pub fn train(
17        samples: &[Vec<f32>],
18        num_sub_vectors: usize,
19        bits: usize,
20    ) -> Result<Self, String> {
21        if samples.is_empty() || num_sub_vectors == 0 {
22            return Err("samples 和 num_sub_vectors 必须非空".into());
23        }
24
25        let dim = samples[0].len();
26        let sub_dim = dim / num_sub_vectors;
27        if sub_dim == 0 || dim % num_sub_vectors != 0 {
28            return Err(format!("维度 {dim} 不能被 {num_sub_vectors} 整除"));
29        }
30
31        let num_clusters = 1usize << bits;
32        let mut codebooks = Vec::with_capacity(num_sub_vectors);
33
34        for s in 0..num_sub_vectors {
35            // 提取子向量
36            let start = s * sub_dim;
37            let end = start + sub_dim;
38            let sub_samples: Vec<Vec<f32>> = samples
39                .iter()
40                .map(|v| v[start..end].to_vec())
41                .collect();
42
43            // 简单 K-means 聚类(随机初始中心 + 迭代)
44            let centroids = Self::kmeans_simple(&sub_samples, num_clusters, 10);
45            codebooks.push(centroids);
46        }
47
48        Ok(Self {
49            num_sub_vectors,
50            bits_per_sub_vector: bits,
51            codebooks,
52        })
53    }
54
55    fn kmeans_simple(samples: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
56        if samples.len() <= k {
57            return samples.to_vec();
58        }
59
60        let dim = samples[0].len();
61        // 随机选 k 个样本作为初始中心
62        let mut centroids: Vec<Vec<f32>> = samples.iter().take(k).cloned().collect();
63        let mut assignments = vec![0usize; samples.len()];
64
65        for _ in 0..iterations {
66            // 分配最近中心
67            for (i, sample) in samples.iter().enumerate() {
68                let mut min_dist = f32::MAX;
69                let mut best = 0;
70                for (j, centroid) in centroids.iter().enumerate() {
71                    let dist: f32 = sample
72                        .iter()
73                        .zip(centroid)
74                        .map(|(a, b)| (a - b).powi(2))
75                        .sum();
76                    if dist < min_dist {
77                        min_dist = dist;
78                        best = j;
79                    }
80                }
81                assignments[i] = best;
82            }
83
84            // 更新中心
85            for j in 0..k {
86                let members: Vec<&Vec<f32>> = samples
87                    .iter()
88                    .enumerate()
89                    .filter(|(i, _)| assignments[*i] == j)
90                    .map(|(_, v)| v)
91                    .collect();
92
93                if !members.is_empty() {
94                    let mut new_centroid = vec![0.0f32; dim];
95                    for member in &members {
96                        for d in 0..dim {
97                            new_centroid[d] += member[d];
98                        }
99                    }
100                    for d in 0..dim {
101                        new_centroid[d] /= members.len() as f32;
102                    }
103                    centroids[j] = new_centroid;
104                }
105            }
106        }
107
108        centroids
109    }
110
111    /// 最近码字搜索
112    fn nearest_codebook(&self, sub_vector: &[f32], codebook: &[Vec<f32>]) -> u8 {
113        let mut min_dist = f32::MAX;
114        let mut best = 0u8;
115        for (i, centroid) in codebook.iter().enumerate() {
116            let dist: f32 = sub_vector
117                .iter()
118                .zip(centroid)
119                .map(|(a, b)| (a - b).powi(2))
120                .sum();
121            if dist < min_dist {
122                min_dist = dist;
123                best = i as u8;
124            }
125        }
126        best
127    }
128}
129
130impl VectorQuantizer for ProductQuantizer {
131    fn compress(&self, vectors: &[Vec<f32>]) -> Vec<Vec<u8>> {
132        let dim = vectors.first().map(|v| v.len()).unwrap_or(0);
133        let sub_dim = dim / self.num_sub_vectors;
134        if sub_dim == 0 { return vec![]; }
135
136        vectors.iter().map(|v| {
137            let mut code = vec![0u8; self.num_sub_vectors];
138            for s in 0..self.num_sub_vectors {
139                let start = s * sub_dim;
140                let end = start + sub_dim;
141                code[s] = self.nearest_codebook(&v[start..end], &self.codebooks[s]);
142            }
143            code
144        }).collect()
145    }
146
147    fn decompress(&self, quantized: &[Vec<u8>]) -> Vec<Vec<f32>> {
148        let sub_dim = self.codebooks.first().map(|c| c.first().map(|v| v.len()).unwrap_or(0)).unwrap_or(0);
149        let dim = self.num_sub_vectors * sub_dim;
150
151        quantized.iter().map(|code| {
152            let mut v = vec![0.0f32; dim];
153            for (s, &c) in code.iter().enumerate().take(self.num_sub_vectors) {
154                let start = s * sub_dim;
155                let end = start + sub_dim;
156                if let Some(centroid) = self.codebooks[s].get(c as usize) {
157                    v[start..end].copy_from_slice(centroid);
158                }
159            }
160            v
161        }).collect()
162    }
163}