1use crate::distance::l2_squared;
7use crate::error::{DiskAnnError, Result};
8use bincode::{Decode, Encode};
9use rand::prelude::*;
10use serde::{Deserialize, Serialize};
11
12#[derive(Clone, Serialize, Deserialize, Encode, Decode)]
14pub struct ProductQuantizer {
15 pub m: usize,
17 pub dsub: usize,
19 pub dim: usize,
21 pub centroids: Vec<Vec<Vec<f32>>>,
23 pub trained: bool,
25}
26
27impl ProductQuantizer {
28 pub fn new(dim: usize, m: usize) -> Result<Self> {
30 if dim % m != 0 {
31 return Err(DiskAnnError::InvalidConfig(format!(
32 "dim ({dim}) must be divisible by m ({m})"
33 )));
34 }
35 let dsub = dim / m;
36 Ok(Self {
37 m,
38 dsub,
39 dim,
40 centroids: Vec::new(),
41 trained: false,
42 })
43 }
44
45 pub fn train(&mut self, vectors: &[Vec<f32>], iterations: usize) -> Result<()> {
47 if vectors.is_empty() {
48 return Err(DiskAnnError::Empty);
49 }
50 if vectors[0].len() != self.dim {
51 return Err(DiskAnnError::DimensionMismatch {
52 expected: self.dim,
53 actual: vectors[0].len(),
54 });
55 }
56
57 let k = 256usize; let n = vectors.len();
59 let mut rng = rand::thread_rng();
60
61 self.centroids = Vec::with_capacity(self.m);
62
63 for sub in 0..self.m {
64 let offset = sub * self.dsub;
65
66 let subvectors: Vec<&[f32]> = vectors
68 .iter()
69 .map(|v| &v[offset..offset + self.dsub])
70 .collect();
71
72 let mut centers = Vec::with_capacity(k);
74 centers.push(subvectors[rng.gen_range(0..n)].to_vec());
75
76 for _ in 1..k.min(n) {
77 let dists: Vec<f32> = subvectors
79 .iter()
80 .map(|sv| {
81 centers
82 .iter()
83 .map(|c| l2_squared(sv, c))
84 .fold(f32::MAX, f32::min)
85 })
86 .collect();
87
88 let total: f32 = dists.iter().sum();
89 if total < 1e-10 {
90 while centers.len() < k {
92 centers.push(centers[0].clone());
93 }
94 break;
95 }
96
97 let mut r = rng.gen::<f32>() * total;
99 for (i, &d) in dists.iter().enumerate() {
100 r -= d;
101 if r <= 0.0 {
102 centers.push(subvectors[i].to_vec());
103 break;
104 }
105 }
106 if centers.len() < k && r > 0.0 {
107 centers.push(subvectors[rng.gen_range(0..n)].to_vec());
108 }
109 }
110
111 while centers.len() < k {
113 centers.push(centers[rng.gen_range(0..centers.len())].clone());
114 }
115
116 let mut assignments = vec![0u8; n];
118 for _ in 0..iterations {
119 for (i, sv) in subvectors.iter().enumerate() {
121 let mut best_dist = f32::MAX;
122 let mut best_c = 0u8;
123 for (c, center) in centers.iter().enumerate() {
124 let d = l2_squared(sv, center);
125 if d < best_dist {
126 best_dist = d;
127 best_c = c as u8;
128 }
129 }
130 assignments[i] = best_c;
131 }
132
133 let mut counts = vec![0usize; k];
135 let mut sums = vec![vec![0.0f32; self.dsub]; k];
136
137 for (i, &a) in assignments.iter().enumerate() {
138 let ci = a as usize;
139 counts[ci] += 1;
140 for d in 0..self.dsub {
141 sums[ci][d] += subvectors[i][d];
142 }
143 }
144
145 for c in 0..k {
146 if counts[c] > 0 {
147 for d in 0..self.dsub {
148 centers[c][d] = sums[c][d] / counts[c] as f32;
149 }
150 }
151 }
152 }
153
154 self.centroids.push(centers);
155 }
156
157 self.trained = true;
158 Ok(())
159 }
160
161 pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
163 if !self.trained {
164 return Err(DiskAnnError::PqNotTrained);
165 }
166 if vector.len() != self.dim {
167 return Err(DiskAnnError::DimensionMismatch {
168 expected: self.dim,
169 actual: vector.len(),
170 });
171 }
172
173 let mut codes = Vec::with_capacity(self.m);
174 for sub in 0..self.m {
175 let offset = sub * self.dsub;
176 let subvec = &vector[offset..offset + self.dsub];
177
178 let mut best_dist = f32::MAX;
179 let mut best_c = 0u8;
180 for (c, center) in self.centroids[sub].iter().enumerate() {
181 let d = l2_squared(subvec, center);
182 if d < best_dist {
183 best_dist = d;
184 best_c = c as u8;
185 }
186 }
187 codes.push(best_c);
188 }
189 Ok(codes)
190 }
191
192 pub fn build_distance_table(&self, query: &[f32]) -> Result<Vec<f32>> {
195 if !self.trained {
196 return Err(DiskAnnError::PqNotTrained);
197 }
198 if query.len() != self.dim {
199 return Err(DiskAnnError::DimensionMismatch {
200 expected: self.dim,
201 actual: query.len(),
202 });
203 }
204
205 let k = 256;
206 let mut table = vec![0.0f32; self.m * k];
207 for sub in 0..self.m {
208 let offset = sub * self.dsub;
209 let subquery = &query[offset..offset + self.dsub];
210
211 for (c, center) in self.centroids[sub].iter().enumerate() {
212 table[sub * k + c] = l2_squared(subquery, center);
213 }
214 }
215 Ok(table)
216 }
217
218 #[inline]
220 pub fn distance_with_table(&self, codes: &[u8], table: &[f32]) -> f32 {
221 crate::distance::pq_asymmetric_distance(codes, table, 256)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_pq_train_encode() {
231 let mut pq = ProductQuantizer::new(8, 2).unwrap();
232 let vectors: Vec<Vec<f32>> = (0..100)
233 .map(|i| (0..8).map(|d| (i * 7 + d) as f32 / 100.0).collect())
234 .collect();
235 pq.train(&vectors, 5).unwrap();
236
237 let codes = pq.encode(&vectors[0]).unwrap();
238 assert_eq!(codes.len(), 2); let table = pq.build_distance_table(&vectors[0]).unwrap();
241 let dist = pq.distance_with_table(&codes, &table);
242 assert!(dist < 0.1, "self-distance was {dist}");
244 }
245}