Skip to main content

ruvector_diskann/
pq.rs

1//! Product Quantization for compressed distance computation
2//!
3//! Splits D-dimensional vectors into M subspaces of D/M dimensions each,
4//! then quantizes each subspace independently using k-means (k=256 centroids).
5
6use crate::distance::l2_squared;
7use crate::error::{DiskAnnError, Result};
8use rand::prelude::*;
9use bincode::{Decode, Encode};
10use serde::{Deserialize, Serialize};
11
12/// Product Quantizer with M subspaces, 256 centroids each (1 byte per subspace)
13#[derive(Clone, Serialize, Deserialize, Encode, Decode)]
14pub struct ProductQuantizer {
15    /// Number of subspaces
16    pub m: usize,
17    /// Dimensions per subspace
18    pub dsub: usize,
19    /// Total dimensions
20    pub dim: usize,
21    /// Centroids: [m][256][dsub]
22    pub centroids: Vec<Vec<Vec<f32>>>,
23    /// Whether the PQ has been trained
24    pub trained: bool,
25}
26
27impl ProductQuantizer {
28    /// Create a new PQ with M subspaces for D-dimensional vectors
29    pub fn new(dim: usize, m: usize) -> Result<Self> {
30        if dim % m != 0 {
31            return Err(DiskAnnError::InvalidConfig(format!(
32                "dim ({dim}) must be divisible by m ({m})"
33            )));
34        }
35        let dsub = dim / m;
36        Ok(Self {
37            m,
38            dsub,
39            dim,
40            centroids: Vec::new(),
41            trained: false,
42        })
43    }
44
45    /// Train PQ centroids using k-means on training vectors
46    pub fn train(&mut self, vectors: &[Vec<f32>], iterations: usize) -> Result<()> {
47        if vectors.is_empty() {
48            return Err(DiskAnnError::Empty);
49        }
50        if vectors[0].len() != self.dim {
51            return Err(DiskAnnError::DimensionMismatch {
52                expected: self.dim,
53                actual: vectors[0].len(),
54            });
55        }
56
57        let k = 256usize; // 1 byte per code
58        let n = vectors.len();
59        let mut rng = rand::thread_rng();
60
61        self.centroids = Vec::with_capacity(self.m);
62
63        for sub in 0..self.m {
64            let offset = sub * self.dsub;
65
66            // Extract subvectors for this subspace
67            let subvectors: Vec<&[f32]> = vectors
68                .iter()
69                .map(|v| &v[offset..offset + self.dsub])
70                .collect();
71
72            // Initialize centroids with k-means++ seeding
73            let mut centers = Vec::with_capacity(k);
74            centers.push(subvectors[rng.gen_range(0..n)].to_vec());
75
76            for _ in 1..k.min(n) {
77                // Compute min distance from each point to nearest center
78                let dists: Vec<f32> = subvectors
79                    .iter()
80                    .map(|sv| {
81                        centers
82                            .iter()
83                            .map(|c| l2_squared(sv, c))
84                            .fold(f32::MAX, f32::min)
85                    })
86                    .collect();
87
88                let total: f32 = dists.iter().sum();
89                if total < 1e-10 {
90                    // All points are identical, fill remaining with the same
91                    while centers.len() < k {
92                        centers.push(centers[0].clone());
93                    }
94                    break;
95                }
96
97                // Weighted random selection
98                let mut r = rng.gen::<f32>() * total;
99                for (i, &d) in dists.iter().enumerate() {
100                    r -= d;
101                    if r <= 0.0 {
102                        centers.push(subvectors[i].to_vec());
103                        break;
104                    }
105                }
106                if centers.len() < k && r > 0.0 {
107                    centers.push(subvectors[rng.gen_range(0..n)].to_vec());
108                }
109            }
110
111            // Pad if fewer training points than k
112            while centers.len() < k {
113                centers.push(centers[rng.gen_range(0..centers.len())].clone());
114            }
115
116            // Lloyd's iterations
117            let mut assignments = vec![0u8; n];
118            for _ in 0..iterations {
119                // Assign
120                for (i, sv) in subvectors.iter().enumerate() {
121                    let mut best_dist = f32::MAX;
122                    let mut best_c = 0u8;
123                    for (c, center) in centers.iter().enumerate() {
124                        let d = l2_squared(sv, center);
125                        if d < best_dist {
126                            best_dist = d;
127                            best_c = c as u8;
128                        }
129                    }
130                    assignments[i] = best_c;
131                }
132
133                // Update centroids
134                let mut counts = vec![0usize; k];
135                let mut sums = vec![vec![0.0f32; self.dsub]; k];
136
137                for (i, &a) in assignments.iter().enumerate() {
138                    let ci = a as usize;
139                    counts[ci] += 1;
140                    for d in 0..self.dsub {
141                        sums[ci][d] += subvectors[i][d];
142                    }
143                }
144
145                for c in 0..k {
146                    if counts[c] > 0 {
147                        for d in 0..self.dsub {
148                            centers[c][d] = sums[c][d] / counts[c] as f32;
149                        }
150                    }
151                }
152            }
153
154            self.centroids.push(centers);
155        }
156
157        self.trained = true;
158        Ok(())
159    }
160
161    /// Encode a vector into PQ codes (M bytes)
162    pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
163        if !self.trained {
164            return Err(DiskAnnError::PqNotTrained);
165        }
166        if vector.len() != self.dim {
167            return Err(DiskAnnError::DimensionMismatch {
168                expected: self.dim,
169                actual: vector.len(),
170            });
171        }
172
173        let mut codes = Vec::with_capacity(self.m);
174        for sub in 0..self.m {
175            let offset = sub * self.dsub;
176            let subvec = &vector[offset..offset + self.dsub];
177
178            let mut best_dist = f32::MAX;
179            let mut best_c = 0u8;
180            for (c, center) in self.centroids[sub].iter().enumerate() {
181                let d = l2_squared(subvec, center);
182                if d < best_dist {
183                    best_dist = d;
184                    best_c = c as u8;
185                }
186            }
187            codes.push(best_c);
188        }
189        Ok(codes)
190    }
191
192    /// Build flat asymmetric distance table for a query vector
193    /// Returns flat table[subspace * 256 + centroid_id] = distance
194    pub fn build_distance_table(&self, query: &[f32]) -> Result<Vec<f32>> {
195        if !self.trained {
196            return Err(DiskAnnError::PqNotTrained);
197        }
198        if query.len() != self.dim {
199            return Err(DiskAnnError::DimensionMismatch {
200                expected: self.dim,
201                actual: query.len(),
202            });
203        }
204
205        let k = 256;
206        let mut table = vec![0.0f32; self.m * k];
207        for sub in 0..self.m {
208            let offset = sub * self.dsub;
209            let subquery = &query[offset..offset + self.dsub];
210
211            for (c, center) in self.centroids[sub].iter().enumerate() {
212                table[sub * k + c] = l2_squared(subquery, center);
213            }
214        }
215        Ok(table)
216    }
217
218    /// Compute approximate distance using flat precomputed table
219    #[inline]
220    pub fn distance_with_table(&self, codes: &[u8], table: &[f32]) -> f32 {
221        crate::distance::pq_asymmetric_distance(codes, table, 256)
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_pq_train_encode() {
231        let mut pq = ProductQuantizer::new(8, 2).unwrap();
232        let vectors: Vec<Vec<f32>> = (0..100)
233            .map(|i| (0..8).map(|d| (i * 7 + d) as f32 / 100.0).collect())
234            .collect();
235        pq.train(&vectors, 5).unwrap();
236
237        let codes = pq.encode(&vectors[0]).unwrap();
238        assert_eq!(codes.len(), 2); // M=2 subspaces
239
240        let table = pq.build_distance_table(&vectors[0]).unwrap();
241        let dist = pq.distance_with_table(&codes, &table);
242        // Self-distance through PQ should be very small
243        assert!(dist < 0.1, "self-distance was {dist}");
244    }
245}