xz_embed/quantize/
product.rs1use crate::quantize::VectorQuantizer;
2
3#[derive(Debug)]
5pub struct ProductQuantizer {
6 num_sub_vectors: usize,
8 bits_per_sub_vector: usize,
10 codebooks: Vec<Vec<Vec<f32>>>,
12}
13
14impl ProductQuantizer {
15 pub fn train(
17 samples: &[Vec<f32>],
18 num_sub_vectors: usize,
19 bits: usize,
20 ) -> Result<Self, String> {
21 if samples.is_empty() || num_sub_vectors == 0 {
22 return Err("samples 和 num_sub_vectors 必须非空".into());
23 }
24
25 let dim = samples[0].len();
26 let sub_dim = dim / num_sub_vectors;
27 if sub_dim == 0 || dim % num_sub_vectors != 0 {
28 return Err(format!("维度 {dim} 不能被 {num_sub_vectors} 整除"));
29 }
30
31 let num_clusters = 1usize << bits;
32 let mut codebooks = Vec::with_capacity(num_sub_vectors);
33
34 for s in 0..num_sub_vectors {
35 let start = s * sub_dim;
37 let end = start + sub_dim;
38 let sub_samples: Vec<Vec<f32>> = samples
39 .iter()
40 .map(|v| v[start..end].to_vec())
41 .collect();
42
43 let centroids = Self::kmeans_simple(&sub_samples, num_clusters, 10);
45 codebooks.push(centroids);
46 }
47
48 Ok(Self {
49 num_sub_vectors,
50 bits_per_sub_vector: bits,
51 codebooks,
52 })
53 }
54
55 fn kmeans_simple(samples: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
56 if samples.len() <= k {
57 return samples.to_vec();
58 }
59
60 let dim = samples[0].len();
61 let mut centroids: Vec<Vec<f32>> = samples.iter().take(k).cloned().collect();
63 let mut assignments = vec![0usize; samples.len()];
64
65 for _ in 0..iterations {
66 for (i, sample) in samples.iter().enumerate() {
68 let mut min_dist = f32::MAX;
69 let mut best = 0;
70 for (j, centroid) in centroids.iter().enumerate() {
71 let dist: f32 = sample
72 .iter()
73 .zip(centroid)
74 .map(|(a, b)| (a - b).powi(2))
75 .sum();
76 if dist < min_dist {
77 min_dist = dist;
78 best = j;
79 }
80 }
81 assignments[i] = best;
82 }
83
84 for j in 0..k {
86 let members: Vec<&Vec<f32>> = samples
87 .iter()
88 .enumerate()
89 .filter(|(i, _)| assignments[*i] == j)
90 .map(|(_, v)| v)
91 .collect();
92
93 if !members.is_empty() {
94 let mut new_centroid = vec![0.0f32; dim];
95 for member in &members {
96 for d in 0..dim {
97 new_centroid[d] += member[d];
98 }
99 }
100 for d in 0..dim {
101 new_centroid[d] /= members.len() as f32;
102 }
103 centroids[j] = new_centroid;
104 }
105 }
106 }
107
108 centroids
109 }
110
111 fn nearest_codebook(&self, sub_vector: &[f32], codebook: &[Vec<f32>]) -> u8 {
113 let mut min_dist = f32::MAX;
114 let mut best = 0u8;
115 for (i, centroid) in codebook.iter().enumerate() {
116 let dist: f32 = sub_vector
117 .iter()
118 .zip(centroid)
119 .map(|(a, b)| (a - b).powi(2))
120 .sum();
121 if dist < min_dist {
122 min_dist = dist;
123 best = i as u8;
124 }
125 }
126 best
127 }
128}
129
130impl VectorQuantizer for ProductQuantizer {
131 fn compress(&self, vectors: &[Vec<f32>]) -> Vec<Vec<u8>> {
132 let dim = vectors.first().map(|v| v.len()).unwrap_or(0);
133 let sub_dim = dim / self.num_sub_vectors;
134 if sub_dim == 0 { return vec![]; }
135
136 vectors.iter().map(|v| {
137 let mut code = vec![0u8; self.num_sub_vectors];
138 for s in 0..self.num_sub_vectors {
139 let start = s * sub_dim;
140 let end = start + sub_dim;
141 code[s] = self.nearest_codebook(&v[start..end], &self.codebooks[s]);
142 }
143 code
144 }).collect()
145 }
146
147 fn decompress(&self, quantized: &[Vec<u8>]) -> Vec<Vec<f32>> {
148 let sub_dim = self.codebooks.first().map(|c| c.first().map(|v| v.len()).unwrap_or(0)).unwrap_or(0);
149 let dim = self.num_sub_vectors * sub_dim;
150
151 quantized.iter().map(|code| {
152 let mut v = vec![0.0f32; dim];
153 for (s, &c) in code.iter().enumerate().take(self.num_sub_vectors) {
154 let start = s * sub_dim;
155 let end = start + sub_dim;
156 if let Some(centroid) = self.codebooks[s].get(c as usize) {
157 v[start..end].copy_from_slice(centroid);
158 }
159 }
160 v
161 }).collect()
162 }
163}