Skip to main content

ruvector_core/advanced_features/
opq.rs

1//! Optimized Product Quantization (OPQ) with learned rotation matrix.
2//!
3//! OPQ improves upon standard PQ by learning an orthogonal rotation matrix R
4//! that decorrelates vector dimensions before quantization. This reduces
5//! quantization error by 10-30% and yields significant recall improvements,
6//! especially when vector dimensions have unequal variance.
7//!
8//! Training alternates between PQ codebook learning and rotation update via
9//! the Procrustes solution (SVD). ADC precomputes per-subspace distance tables
10//! so each database lookup costs O(num_subspaces) instead of O(d).
11
12use crate::error::{Result, RuvectorError};
13use crate::types::DistanceMetric;
14use serde::{Deserialize, Serialize};
15
16/// Configuration for Optimized Product Quantization.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct OPQConfig {
19    /// Number of subspaces to split the (rotated) vector into.
20    pub num_subspaces: usize,
21    /// Codebook size per subspace (max 256 for u8 codes).
22    pub codebook_size: usize,
23    /// Number of k-means iterations for codebook training.
24    pub num_iterations: usize,
25    /// Number of outer OPQ iterations (rotation + PQ alternation).
26    pub num_opq_iterations: usize,
27    /// Distance metric used for codebook training and search.
28    pub metric: DistanceMetric,
29}
30
31impl Default for OPQConfig {
32    fn default() -> Self {
33        Self {
34            num_subspaces: 8, codebook_size: 256, num_iterations: 20,
35            num_opq_iterations: 10, metric: DistanceMetric::Euclidean,
36        }
37    }
38}
39
40impl OPQConfig {
41    /// Validate configuration parameters.
42    pub fn validate(&self) -> Result<()> {
43        if self.codebook_size > 256 {
44            return Err(RuvectorError::InvalidParameter(format!(
45                "Codebook size {} exceeds u8 max 256", self.codebook_size)));
46        }
47        if self.num_subspaces == 0 {
48            return Err(RuvectorError::InvalidParameter("num_subspaces must be > 0".into()));
49        }
50        if self.num_opq_iterations == 0 {
51            return Err(RuvectorError::InvalidParameter("num_opq_iterations must be > 0".into()));
52        }
53        Ok(())
54    }
55}
56
57// -- Dense matrix (row-major, internal only) ----------------------------------
58
59#[derive(Debug, Clone)]
60struct Mat { rows: usize, cols: usize, data: Vec<f32> }
61
62impl Mat {
63    fn zeros(r: usize, c: usize) -> Self { Self { rows: r, cols: c, data: vec![0.0; r * c] } }
64    fn identity(n: usize) -> Self {
65        let mut m = Self::zeros(n, n);
66        for i in 0..n { m.data[i * n + i] = 1.0; }
67        m
68    }
69    #[inline] fn get(&self, r: usize, c: usize) -> f32 { self.data[r * self.cols + c] }
70    #[inline] fn set(&mut self, r: usize, c: usize, v: f32) { self.data[r * self.cols + c] = v; }
71
72    fn transpose(&self) -> Self {
73        let mut t = Self::zeros(self.cols, self.rows);
74        for r in 0..self.rows { for c in 0..self.cols { t.set(c, r, self.get(r, c)); } }
75        t
76    }
77    fn mul(&self, b: &Mat) -> Mat {
78        assert_eq!(self.cols, b.rows);
79        let mut out = Mat::zeros(self.rows, b.cols);
80        for i in 0..self.rows {
81            for k in 0..self.cols {
82                let a = self.get(i, k);
83                for j in 0..b.cols { let c = out.get(i, j); out.set(i, j, c + a * b.get(k, j)); }
84            }
85        }
86        out
87    }
88    fn from_rows(vecs: &[Vec<f32>]) -> Self {
89        let (rows, cols) = (vecs.len(), vecs[0].len());
90        let mut data = Vec::with_capacity(rows * cols);
91        for v in vecs { data.extend_from_slice(v); }
92        Self { rows, cols, data }
93    }
94    fn row(&self, i: usize) -> Vec<f32> { self.data[i * self.cols..(i + 1) * self.cols].to_vec() }
95}
96
97// -- SVD via power iteration + deflation --------------------------------------
98
99/// Rank-1 SVD: returns (u, sigma, v) for the largest singular triplet.
100fn svd_rank1(a: &Mat, max_iters: usize) -> (Vec<f32>, f32, Vec<f32>) {
101    let ata = a.transpose().mul(a);
102    let n = ata.cols;
103    let mut v = vec![1.0 / (n as f32).sqrt(); n];
104    for _ in 0..max_iters {
105        let mut nv = vec![0.0; n];
106        for i in 0..n { for j in 0..n { nv[i] += ata.get(i, j) * v[j]; } }
107        let norm: f32 = nv.iter().map(|x| x * x).sum::<f32>().sqrt();
108        if norm < 1e-12 { break; }
109        for x in nv.iter_mut() { *x /= norm; }
110        v = nv;
111    }
112    let mut av = vec![0.0; a.rows];
113    for i in 0..a.rows { for j in 0..a.cols { av[i] += a.get(i, j) * v[j]; } }
114    let sigma: f32 = av.iter().map(|x| x * x).sum::<f32>().sqrt();
115    let u = if sigma > 1e-12 { av.iter().map(|x| x / sigma).collect() } else { vec![0.0; a.rows] };
116    (u, sigma, v)
117}
118
119/// Full SVD by repeated rank-1 extraction + deflation.
120fn svd_full(a: &Mat, iters: usize) -> (Mat, Vec<f32>, Mat) {
121    let n = a.rows;
122    let mut res = a.clone();
123    let (mut uc, mut sv, mut vc) = (Vec::new(), Vec::new(), Vec::new());
124    for _ in 0..n {
125        let (u, s, v) = svd_rank1(&res, iters);
126        if s > 1e-10 {
127            for i in 0..res.rows { for j in 0..res.cols {
128                let c = res.get(i, j); res.set(i, j, c - s * u[i] * v[j]);
129            }}
130        }
131        uc.push(u); sv.push(s); vc.push(v);
132    }
133    let (mut um, mut vm) = (Mat::zeros(n, n), Mat::zeros(n, n));
134    for j in 0..n { for i in 0..n { um.set(i, j, uc[j][i]); vm.set(i, j, vc[j][i]); } }
135    (um, sv, vm)
136}
137
138/// Procrustes: find orthogonal R minimising ||Y - X @ R||_F.
139fn procrustes(x: &Mat, y: &Mat) -> Mat {
140    let m = x.transpose().mul(y);
141    let (u, _s, v) = svd_full(&m, 100);
142    v.mul(&u.transpose())
143}
144
145// -- Rotation matrix ----------------------------------------------------------
146
147/// Orthogonal rotation matrix R (d x d) that decorrelates dimensions before PQ.
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct RotationMatrix { pub dim: usize, pub data: Vec<f32> }
150
151impl RotationMatrix {
152    /// Identity rotation (no-op).
153    pub fn identity(dim: usize) -> Self {
154        let mut data = vec![0.0; dim * dim];
155        for i in 0..dim { data[i * dim + i] = 1.0; }
156        Self { dim, data }
157    }
158    /// Rotate vector: y = x @ R.
159    pub fn rotate(&self, v: &[f32]) -> Vec<f32> {
160        let d = self.dim;
161        (0..d).map(|j| (0..d).map(|i| v[i] * self.data[i * d + j]).sum()).collect()
162    }
163    /// Inverse rotate: x = y @ R^T.
164    pub fn inverse_rotate(&self, v: &[f32]) -> Vec<f32> {
165        let d = self.dim;
166        (0..d).map(|j| (0..d).map(|i| v[i] * self.data[j * d + i]).sum()).collect()
167    }
168    fn from_mat(m: &Mat) -> Self { Self { dim: m.rows, data: m.data.clone() } }
169}
170
171// -- OPQ Index ----------------------------------------------------------------
172
173/// OPQ index: learns rotation R + PQ codebooks, supports ADC search.
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct OPQIndex {
176    pub config: OPQConfig,
177    pub rotation: RotationMatrix,
178    /// Codebooks: `[subspace][centroid][subspace_dim]`.
179    pub codebooks: Vec<Vec<Vec<f32>>>,
180    pub dimensions: usize,
181}
182
183impl OPQIndex {
184    /// Train OPQ via alternating rotation update and PQ codebook learning.
185    pub fn train(vectors: &[Vec<f32>], config: OPQConfig) -> Result<Self> {
186        config.validate()?;
187        if vectors.is_empty() {
188            return Err(RuvectorError::InvalidParameter("Training set cannot be empty".into()));
189        }
190        let d = vectors[0].len();
191        if d % config.num_subspaces != 0 {
192            return Err(RuvectorError::InvalidParameter(format!(
193                "Dimensions {} not divisible by num_subspaces {}", d, config.num_subspaces)));
194        }
195        for v in vectors { if v.len() != d {
196            return Err(RuvectorError::DimensionMismatch { expected: d, actual: v.len() });
197        }}
198        let x_mat = Mat::from_rows(vectors);
199        let mut r = Mat::identity(d);
200        let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::new();
201        let sub_dim = d / config.num_subspaces;
202        for _ in 0..config.num_opq_iterations {
203            let x_rot = x_mat.mul(&r);
204            let rotated: Vec<Vec<f32>> = (0..vectors.len()).map(|i| x_rot.row(i)).collect();
205            codebooks = train_pq_codebooks(&rotated, config.num_subspaces,
206                config.codebook_size, config.num_iterations, config.metric)?;
207            let mut x_hat = Mat::zeros(vectors.len(), d);
208            for (i, rv) in rotated.iter().enumerate() {
209                let codes = encode_vec(rv, &codebooks, sub_dim, config.metric)?;
210                let recon = decode_vec(&codes, &codebooks);
211                for (j, &val) in recon.iter().enumerate() { x_hat.set(i, j, val); }
212            }
213            r = procrustes(&x_mat, &x_hat);
214        }
215        Ok(Self { config, rotation: RotationMatrix::from_mat(&r), codebooks, dimensions: d })
216    }
217
218    /// Encode a vector: rotate then PQ-quantize.
219    pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
220        self.check_dim(vector.len())?;
221        let rotated = self.rotation.rotate(vector);
222        encode_vec(&rotated, &self.codebooks,
223            self.dimensions / self.config.num_subspaces, self.config.metric)
224    }
225
226    /// Decode PQ codes back to approximate vector (with inverse rotation).
227    pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
228        if codes.len() != self.config.num_subspaces {
229            return Err(RuvectorError::InvalidParameter(format!(
230                "Expected {} codes, got {}", self.config.num_subspaces, codes.len())));
231        }
232        Ok(self.rotation.inverse_rotate(&decode_vec(codes, &self.codebooks)))
233    }
234
235    /// ADC search: precompute distance tables then sum lookups per database vector.
236    pub fn search_adc(&self, query: &[f32], codes_db: &[Vec<u8>], top_k: usize,
237    ) -> Result<Vec<(usize, f32)>> {
238        self.check_dim(query.len())?;
239        let rq = self.rotation.rotate(query);
240        let sub_dim = self.dimensions / self.config.num_subspaces;
241        let tables: Vec<Vec<f32>> = (0..self.config.num_subspaces).map(|s| {
242            let q_sub = &rq[s * sub_dim..(s + 1) * sub_dim];
243            self.codebooks[s].iter().map(|c| dist(q_sub, c, self.config.metric)).collect()
244        }).collect();
245        let mut dists: Vec<(usize, f32)> = codes_db.iter().enumerate().map(|(idx, codes)| {
246            let d: f32 = codes.iter().enumerate().map(|(s, &c)| tables[s][c as usize]).sum();
247            (idx, d)
248        }).collect();
249        dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
250        dists.truncate(top_k);
251        Ok(dists)
252    }
253
254    /// Mean squared quantization error over a set of vectors.
255    pub fn quantization_error(&self, vectors: &[Vec<f32>]) -> Result<f32> {
256        if vectors.is_empty() { return Ok(0.0); }
257        let mut total = 0.0f64;
258        for v in vectors {
259            let recon = self.decode(&self.encode(v)?)?;
260            total += v.iter().zip(&recon).map(|(a, b)| ((a - b) as f64).powi(2)).sum::<f64>();
261        }
262        Ok((total / vectors.len() as f64) as f32)
263    }
264
265    fn check_dim(&self, len: usize) -> Result<()> {
266        if len != self.dimensions {
267            Err(RuvectorError::DimensionMismatch { expected: self.dimensions, actual: len })
268        } else { Ok(()) }
269    }
270}
271
272// -- PQ helpers ---------------------------------------------------------------
273
274fn dist(a: &[f32], b: &[f32], m: DistanceMetric) -> f32 {
275    match m {
276        DistanceMetric::Euclidean =>
277            a.iter().zip(b).map(|(x, y)| { let d = x - y; d * d }).sum::<f32>().sqrt(),
278        DistanceMetric::Cosine => {
279            let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
280            let na = a.iter().map(|x| x * x).sum::<f32>().sqrt();
281            let nb = b.iter().map(|x| x * x).sum::<f32>().sqrt();
282            if na == 0.0 || nb == 0.0 { 1.0 } else { 1.0 - dot / (na * nb) }
283        }
284        DistanceMetric::DotProduct => -a.iter().zip(b).map(|(x, y)| x * y).sum::<f32>(),
285        DistanceMetric::Manhattan => a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum(),
286    }
287}
288
289fn train_pq_codebooks(vecs: &[Vec<f32>], nsub: usize, k: usize, iters: usize,
290    metric: DistanceMetric) -> Result<Vec<Vec<Vec<f32>>>> {
291    let sub_dim = vecs[0].len() / nsub;
292    (0..nsub).map(|s| {
293        let sv: Vec<Vec<f32>> = vecs.iter().map(|v| v[s*sub_dim..(s+1)*sub_dim].to_vec()).collect();
294        kmeans(&sv, k.min(sv.len()), iters, metric)
295    }).collect()
296}
297
298fn encode_vec(v: &[f32], cbs: &[Vec<Vec<f32>>], sub_dim: usize, m: DistanceMetric,
299) -> Result<Vec<u8>> {
300    cbs.iter().enumerate().map(|(s, cb)| {
301        let sub = &v[s * sub_dim..(s + 1) * sub_dim];
302        cb.iter().enumerate()
303            .min_by(|(_, a), (_, b)| dist(sub, a, m).partial_cmp(&dist(sub, b, m)).unwrap())
304            .map(|(i, _)| i as u8)
305            .ok_or_else(|| RuvectorError::Internal("Empty codebook".into()))
306    }).collect()
307}
308
309fn decode_vec(codes: &[u8], cbs: &[Vec<Vec<f32>>]) -> Vec<f32> {
310    codes.iter().enumerate().flat_map(|(s, &c)| cbs[s][c as usize].iter().copied()).collect()
311}
312
313fn kmeans(vecs: &[Vec<f32>], k: usize, iters: usize, metric: DistanceMetric,
314) -> Result<Vec<Vec<f32>>> {
315    use rand::seq::SliceRandom;
316    if vecs.is_empty() || k == 0 {
317        return Err(RuvectorError::InvalidParameter("Cannot cluster empty set or k=0".into()));
318    }
319    let dim = vecs[0].len();
320    let mut rng = rand::thread_rng();
321    let mut cents: Vec<Vec<f32>> = vecs.choose_multiple(&mut rng, k).cloned().collect();
322    for _ in 0..iters {
323        let (mut sums, mut counts) = (vec![vec![0.0f32; dim]; k], vec![0usize; k]);
324        for v in vecs {
325            let b = cents.iter().enumerate()
326                .min_by(|(_, a), (_, b)| dist(v, a, metric).partial_cmp(&dist(v, b, metric)).unwrap())
327                .map(|(i, _)| i).unwrap_or(0);
328            counts[b] += 1;
329            for (j, &val) in v.iter().enumerate() { sums[b][j] += val; }
330        }
331        for (i, c) in cents.iter_mut().enumerate() {
332            if counts[i] > 0 { for j in 0..dim { c[j] = sums[i][j] / counts[i] as f32; } }
333        }
334    }
335    Ok(cents)
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    fn make_data(n: usize, d: usize) -> Vec<Vec<f32>> {
343        let mut seed: u64 = 42;
344        (0..n).map(|_| (0..d).map(|_| {
345            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
346            ((seed >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
347        }).collect()).collect()
348    }
349    fn cfg() -> OPQConfig {
350        OPQConfig { num_subspaces: 2, codebook_size: 4, num_iterations: 5,
351            num_opq_iterations: 3, metric: DistanceMetric::Euclidean }
352    }
353
354    #[test]
355    fn test_rotation_orthogonality() {
356        let r = RotationMatrix::identity(4);
357        let v = vec![1.0, 2.0, 3.0, 4.0];
358        let back = r.inverse_rotate(&r.rotate(&v));
359        for i in 0..4 { assert!((v[i] - back[i]).abs() < 1e-6); }
360    }
361    #[test]
362    fn test_rotation_preserves_norm() {
363        let data = make_data(30, 4);
364        let idx = OPQIndex::train(&data, cfg()).unwrap();
365        let v = vec![1.0, 2.0, 3.0, 4.0];
366        let n1: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
367        let n2: f32 = idx.rotation.rotate(&v).iter().map(|x| x * x).sum::<f32>().sqrt();
368        assert!((n1 - n2).abs() < 0.1, "norms: {} vs {}", n1, n2);
369    }
370    #[test]
371    fn test_pq_encoding_roundtrip() {
372        let data = make_data(30, 4);
373        let idx = OPQIndex::train(&data, cfg()).unwrap();
374        let codes = idx.encode(&data[0]).unwrap();
375        assert_eq!(codes.len(), 2);
376        assert_eq!(idx.decode(&codes).unwrap().len(), 4);
377    }
378    #[test]
379    fn test_opq_training_convergence() {
380        // Verify that OPQ training produces finite, non-negative quantization
381        // error and that trained index can encode/decode without degradation.
382        // Note: with small data and few centroids, more OPQ iterations do not
383        // guarantee monotone error decrease due to stochastic k-means.
384        let data = make_data(100, 4);
385        let idx = OPQIndex::train(&data, cfg()).unwrap();
386        let err = idx.quantization_error(&data).unwrap();
387        assert!(err.is_finite() && err >= 0.0, "error must be finite non-negative: {}", err);
388        // Verify round-trip through encode/decode does not explode.
389        for v in &data {
390            let codes = idx.encode(v).unwrap();
391            let decoded = idx.decode(&codes).unwrap();
392            assert_eq!(decoded.len(), v.len());
393            for x in &decoded { assert!(x.is_finite()); }
394        }
395    }
396    #[test]
397    fn test_adc_correctness() {
398        let data = make_data(30, 4);
399        let idx = OPQIndex::train(&data, cfg()).unwrap();
400        let db: Vec<Vec<u8>> = data.iter().map(|v| idx.encode(v).unwrap()).collect();
401        let res = idx.search_adc(&[0.5, -0.5, 0.5, -0.5], &db, 3).unwrap();
402        assert_eq!(res.len(), 3);
403        for w in res.windows(2) { assert!(w[0].1 <= w[1].1 + 1e-6); }
404    }
405    #[test]
406    fn test_quantization_error_reduction() {
407        let data = make_data(50, 4);
408        let err = OPQIndex::train(&data, cfg()).unwrap().quantization_error(&data).unwrap();
409        assert!(err >= 0.0 && err.is_finite() && err < 10.0, "err={}", err);
410    }
411    #[test]
412    fn test_svd_correctness() {
413        let a = Mat { rows: 2, cols: 2, data: vec![3.0, 0.0, 0.0, 2.0] };
414        let (u, s, v) = svd_full(&a, 200);
415        for i in 0..2 { for j in 0..2 {
416            let r: f32 = (0..2).map(|k| u.get(i, k) * s[k] * v.get(j, k)).sum();
417            assert!((a.get(i, j) - r).abs() < 0.1, "SVD fail ({},{}): {} vs {}", i, j, a.get(i, j), r);
418        }}
419    }
420    #[test]
421    fn test_identity_rotation_baseline() {
422        let data = make_data(30, 4);
423        let idx = OPQIndex::train(&data, OPQConfig { num_opq_iterations: 1, ..cfg() }).unwrap();
424        let recon = idx.decode(&idx.encode(&data[0]).unwrap()).unwrap();
425        assert_eq!(recon.len(), data[0].len());
426    }
427    #[test]
428    fn test_search_accuracy() {
429        let data = make_data(40, 4);
430        let idx = OPQIndex::train(&data, cfg()).unwrap();
431        let db: Vec<Vec<u8>> = data.iter().map(|v| idx.encode(v).unwrap()).collect();
432        let ids: Vec<usize> = idx.search_adc(&data[0], &db, 5).unwrap().iter().map(|r| r.0).collect();
433        assert!(ids.contains(&0), "vector 0 should be in its own top-5");
434    }
435    #[test]
436    fn test_config_validation() {
437        assert!(OPQConfig { codebook_size: 300, ..cfg() }.validate().is_err());
438        assert!(OPQConfig { num_subspaces: 0, ..cfg() }.validate().is_err());
439        assert!(OPQConfig { num_opq_iterations: 0, ..cfg() }.validate().is_err());
440    }
441    #[test]
442    fn test_dimension_mismatch_errors() {
443        let idx = OPQIndex::train(&make_data(30, 4), cfg()).unwrap();
444        assert!(idx.encode(&[1.0, 2.0]).is_err());
445        assert!(idx.search_adc(&[1.0], &[], 1).is_err());
446    }
447}