Skip to main content

ruvector_core/advanced_features/
opq.rs

1//! Optimized Product Quantization (OPQ) with learned rotation matrix.
2//!
3//! OPQ improves upon standard PQ by learning an orthogonal rotation matrix R
4//! that decorrelates vector dimensions before quantization. This reduces
5//! quantization error by 10-30% and yields significant recall improvements,
6//! especially when vector dimensions have unequal variance.
7//!
8//! Training alternates between PQ codebook learning and rotation update via
9//! the Procrustes solution (SVD). ADC precomputes per-subspace distance tables
10//! so each database lookup costs O(num_subspaces) instead of O(d).
11
12use crate::error::{Result, RuvectorError};
13use crate::types::DistanceMetric;
14use serde::{Deserialize, Serialize};
15
16/// Configuration for Optimized Product Quantization.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct OPQConfig {
19    /// Number of subspaces to split the (rotated) vector into.
20    pub num_subspaces: usize,
21    /// Codebook size per subspace (max 256 for u8 codes).
22    pub codebook_size: usize,
23    /// Number of k-means iterations for codebook training.
24    pub num_iterations: usize,
25    /// Number of outer OPQ iterations (rotation + PQ alternation).
26    pub num_opq_iterations: usize,
27    /// Distance metric used for codebook training and search.
28    pub metric: DistanceMetric,
29}
30
31impl Default for OPQConfig {
32    fn default() -> Self {
33        Self {
34            num_subspaces: 8,
35            codebook_size: 256,
36            num_iterations: 20,
37            num_opq_iterations: 10,
38            metric: DistanceMetric::Euclidean,
39        }
40    }
41}
42
43impl OPQConfig {
44    /// Validate configuration parameters.
45    pub fn validate(&self) -> Result<()> {
46        if self.codebook_size > 256 {
47            return Err(RuvectorError::InvalidParameter(format!(
48                "Codebook size {} exceeds u8 max 256",
49                self.codebook_size
50            )));
51        }
52        if self.num_subspaces == 0 {
53            return Err(RuvectorError::InvalidParameter(
54                "num_subspaces must be > 0".into(),
55            ));
56        }
57        if self.num_opq_iterations == 0 {
58            return Err(RuvectorError::InvalidParameter(
59                "num_opq_iterations must be > 0".into(),
60            ));
61        }
62        Ok(())
63    }
64}
65
66// -- Dense matrix (row-major, internal only) ----------------------------------
67
68#[derive(Debug, Clone)]
69struct Mat {
70    rows: usize,
71    cols: usize,
72    data: Vec<f32>,
73}
74
75impl Mat {
76    fn zeros(r: usize, c: usize) -> Self {
77        Self {
78            rows: r,
79            cols: c,
80            data: vec![0.0; r * c],
81        }
82    }
83    fn identity(n: usize) -> Self {
84        let mut m = Self::zeros(n, n);
85        for i in 0..n {
86            m.data[i * n + i] = 1.0;
87        }
88        m
89    }
90    #[inline]
91    fn get(&self, r: usize, c: usize) -> f32 {
92        self.data[r * self.cols + c]
93    }
94    #[inline]
95    fn set(&mut self, r: usize, c: usize, v: f32) {
96        self.data[r * self.cols + c] = v;
97    }
98
99    fn transpose(&self) -> Self {
100        let mut t = Self::zeros(self.cols, self.rows);
101        for r in 0..self.rows {
102            for c in 0..self.cols {
103                t.set(c, r, self.get(r, c));
104            }
105        }
106        t
107    }
108    fn mul(&self, b: &Mat) -> Mat {
109        assert_eq!(self.cols, b.rows);
110        let mut out = Mat::zeros(self.rows, b.cols);
111        for i in 0..self.rows {
112            for k in 0..self.cols {
113                let a = self.get(i, k);
114                for j in 0..b.cols {
115                    let c = out.get(i, j);
116                    out.set(i, j, c + a * b.get(k, j));
117                }
118            }
119        }
120        out
121    }
122    fn from_rows(vecs: &[Vec<f32>]) -> Self {
123        let (rows, cols) = (vecs.len(), vecs[0].len());
124        let mut data = Vec::with_capacity(rows * cols);
125        for v in vecs {
126            data.extend_from_slice(v);
127        }
128        Self { rows, cols, data }
129    }
130    fn row(&self, i: usize) -> Vec<f32> {
131        self.data[i * self.cols..(i + 1) * self.cols].to_vec()
132    }
133}
134
135// -- SVD via power iteration + deflation --------------------------------------
136
137/// Rank-1 SVD: returns (u, sigma, v) for the largest singular triplet.
138fn svd_rank1(a: &Mat, max_iters: usize) -> (Vec<f32>, f32, Vec<f32>) {
139    let ata = a.transpose().mul(a);
140    let n = ata.cols;
141    let mut v = vec![1.0 / (n as f32).sqrt(); n];
142    for _ in 0..max_iters {
143        let mut nv = vec![0.0; n];
144        for i in 0..n {
145            for j in 0..n {
146                nv[i] += ata.get(i, j) * v[j];
147            }
148        }
149        let norm: f32 = nv.iter().map(|x| x * x).sum::<f32>().sqrt();
150        if norm < 1e-12 {
151            break;
152        }
153        for x in nv.iter_mut() {
154            *x /= norm;
155        }
156        v = nv;
157    }
158    let mut av = vec![0.0; a.rows];
159    for i in 0..a.rows {
160        for j in 0..a.cols {
161            av[i] += a.get(i, j) * v[j];
162        }
163    }
164    let sigma: f32 = av.iter().map(|x| x * x).sum::<f32>().sqrt();
165    let u = if sigma > 1e-12 {
166        av.iter().map(|x| x / sigma).collect()
167    } else {
168        vec![0.0; a.rows]
169    };
170    (u, sigma, v)
171}
172
173/// Full SVD by repeated rank-1 extraction + deflation.
174fn svd_full(a: &Mat, iters: usize) -> (Mat, Vec<f32>, Mat) {
175    let n = a.rows;
176    let mut res = a.clone();
177    let (mut uc, mut sv, mut vc) = (Vec::new(), Vec::new(), Vec::new());
178    for _ in 0..n {
179        let (u, s, v) = svd_rank1(&res, iters);
180        if s > 1e-10 {
181            for i in 0..res.rows {
182                for j in 0..res.cols {
183                    let c = res.get(i, j);
184                    res.set(i, j, c - s * u[i] * v[j]);
185                }
186            }
187        }
188        uc.push(u);
189        sv.push(s);
190        vc.push(v);
191    }
192    let (mut um, mut vm) = (Mat::zeros(n, n), Mat::zeros(n, n));
193    for j in 0..n {
194        for i in 0..n {
195            um.set(i, j, uc[j][i]);
196            vm.set(i, j, vc[j][i]);
197        }
198    }
199    (um, sv, vm)
200}
201
202/// Procrustes: find orthogonal R minimising ||Y - X @ R||_F.
203fn procrustes(x: &Mat, y: &Mat) -> Mat {
204    let m = x.transpose().mul(y);
205    let (u, _s, v) = svd_full(&m, 100);
206    v.mul(&u.transpose())
207}
208
209// -- Rotation matrix ----------------------------------------------------------
210
211/// Orthogonal rotation matrix R (d x d) that decorrelates dimensions before PQ.
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct RotationMatrix {
214    pub dim: usize,
215    pub data: Vec<f32>,
216}
217
218impl RotationMatrix {
219    /// Identity rotation (no-op).
220    pub fn identity(dim: usize) -> Self {
221        let mut data = vec![0.0; dim * dim];
222        for i in 0..dim {
223            data[i * dim + i] = 1.0;
224        }
225        Self { dim, data }
226    }
227    /// Rotate vector: y = x @ R.
228    pub fn rotate(&self, v: &[f32]) -> Vec<f32> {
229        let d = self.dim;
230        (0..d)
231            .map(|j| (0..d).map(|i| v[i] * self.data[i * d + j]).sum())
232            .collect()
233    }
234    /// Inverse rotate: x = y @ R^T.
235    pub fn inverse_rotate(&self, v: &[f32]) -> Vec<f32> {
236        let d = self.dim;
237        (0..d)
238            .map(|j| (0..d).map(|i| v[i] * self.data[j * d + i]).sum())
239            .collect()
240    }
241    fn from_mat(m: &Mat) -> Self {
242        Self {
243            dim: m.rows,
244            data: m.data.clone(),
245        }
246    }
247}
248
249// -- OPQ Index ----------------------------------------------------------------
250
251/// OPQ index: learns rotation R + PQ codebooks, supports ADC search.
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct OPQIndex {
254    pub config: OPQConfig,
255    pub rotation: RotationMatrix,
256    /// Codebooks: `[subspace][centroid][subspace_dim]`.
257    pub codebooks: Vec<Vec<Vec<f32>>>,
258    pub dimensions: usize,
259}
260
261impl OPQIndex {
262    /// Train OPQ via alternating rotation update and PQ codebook learning.
263    pub fn train(vectors: &[Vec<f32>], config: OPQConfig) -> Result<Self> {
264        config.validate()?;
265        if vectors.is_empty() {
266            return Err(RuvectorError::InvalidParameter(
267                "Training set cannot be empty".into(),
268            ));
269        }
270        let d = vectors[0].len();
271        if d % config.num_subspaces != 0 {
272            return Err(RuvectorError::InvalidParameter(format!(
273                "Dimensions {} not divisible by num_subspaces {}",
274                d, config.num_subspaces
275            )));
276        }
277        for v in vectors {
278            if v.len() != d {
279                return Err(RuvectorError::DimensionMismatch {
280                    expected: d,
281                    actual: v.len(),
282                });
283            }
284        }
285        let x_mat = Mat::from_rows(vectors);
286        let mut r = Mat::identity(d);
287        let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::new();
288        let sub_dim = d / config.num_subspaces;
289        for _ in 0..config.num_opq_iterations {
290            let x_rot = x_mat.mul(&r);
291            let rotated: Vec<Vec<f32>> = (0..vectors.len()).map(|i| x_rot.row(i)).collect();
292            codebooks = train_pq_codebooks(
293                &rotated,
294                config.num_subspaces,
295                config.codebook_size,
296                config.num_iterations,
297                config.metric,
298            )?;
299            let mut x_hat = Mat::zeros(vectors.len(), d);
300            for (i, rv) in rotated.iter().enumerate() {
301                let codes = encode_vec(rv, &codebooks, sub_dim, config.metric)?;
302                let recon = decode_vec(&codes, &codebooks);
303                for (j, &val) in recon.iter().enumerate() {
304                    x_hat.set(i, j, val);
305                }
306            }
307            r = procrustes(&x_mat, &x_hat);
308        }
309        Ok(Self {
310            config,
311            rotation: RotationMatrix::from_mat(&r),
312            codebooks,
313            dimensions: d,
314        })
315    }
316
317    /// Encode a vector: rotate then PQ-quantize.
318    pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
319        self.check_dim(vector.len())?;
320        let rotated = self.rotation.rotate(vector);
321        encode_vec(
322            &rotated,
323            &self.codebooks,
324            self.dimensions / self.config.num_subspaces,
325            self.config.metric,
326        )
327    }
328
329    /// Decode PQ codes back to approximate vector (with inverse rotation).
330    pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
331        if codes.len() != self.config.num_subspaces {
332            return Err(RuvectorError::InvalidParameter(format!(
333                "Expected {} codes, got {}",
334                self.config.num_subspaces,
335                codes.len()
336            )));
337        }
338        Ok(self
339            .rotation
340            .inverse_rotate(&decode_vec(codes, &self.codebooks)))
341    }
342
343    /// ADC search: precompute distance tables then sum lookups per database vector.
344    pub fn search_adc(
345        &self,
346        query: &[f32],
347        codes_db: &[Vec<u8>],
348        top_k: usize,
349    ) -> Result<Vec<(usize, f32)>> {
350        self.check_dim(query.len())?;
351        let rq = self.rotation.rotate(query);
352        let sub_dim = self.dimensions / self.config.num_subspaces;
353        let tables: Vec<Vec<f32>> = (0..self.config.num_subspaces)
354            .map(|s| {
355                let q_sub = &rq[s * sub_dim..(s + 1) * sub_dim];
356                self.codebooks[s]
357                    .iter()
358                    .map(|c| dist(q_sub, c, self.config.metric))
359                    .collect()
360            })
361            .collect();
362        let mut dists: Vec<(usize, f32)> = codes_db
363            .iter()
364            .enumerate()
365            .map(|(idx, codes)| {
366                let d: f32 = codes
367                    .iter()
368                    .enumerate()
369                    .map(|(s, &c)| tables[s][c as usize])
370                    .sum();
371                (idx, d)
372            })
373            .collect();
374        dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
375        dists.truncate(top_k);
376        Ok(dists)
377    }
378
379    /// Mean squared quantization error over a set of vectors.
380    pub fn quantization_error(&self, vectors: &[Vec<f32>]) -> Result<f32> {
381        if vectors.is_empty() {
382            return Ok(0.0);
383        }
384        let mut total = 0.0f64;
385        for v in vectors {
386            let recon = self.decode(&self.encode(v)?)?;
387            total += v
388                .iter()
389                .zip(&recon)
390                .map(|(a, b)| ((a - b) as f64).powi(2))
391                .sum::<f64>();
392        }
393        Ok((total / vectors.len() as f64) as f32)
394    }
395
396    fn check_dim(&self, len: usize) -> Result<()> {
397        if len != self.dimensions {
398            Err(RuvectorError::DimensionMismatch {
399                expected: self.dimensions,
400                actual: len,
401            })
402        } else {
403            Ok(())
404        }
405    }
406}
407
408// -- PQ helpers ---------------------------------------------------------------
409
410fn dist(a: &[f32], b: &[f32], m: DistanceMetric) -> f32 {
411    match m {
412        DistanceMetric::Euclidean => a
413            .iter()
414            .zip(b)
415            .map(|(x, y)| {
416                let d = x - y;
417                d * d
418            })
419            .sum::<f32>()
420            .sqrt(),
421        DistanceMetric::Cosine => {
422            let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
423            let na = a.iter().map(|x| x * x).sum::<f32>().sqrt();
424            let nb = b.iter().map(|x| x * x).sum::<f32>().sqrt();
425            if na == 0.0 || nb == 0.0 {
426                1.0
427            } else {
428                1.0 - dot / (na * nb)
429            }
430        }
431        DistanceMetric::DotProduct => -a.iter().zip(b).map(|(x, y)| x * y).sum::<f32>(),
432        DistanceMetric::Manhattan => a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum(),
433    }
434}
435
436fn train_pq_codebooks(
437    vecs: &[Vec<f32>],
438    nsub: usize,
439    k: usize,
440    iters: usize,
441    metric: DistanceMetric,
442) -> Result<Vec<Vec<Vec<f32>>>> {
443    let sub_dim = vecs[0].len() / nsub;
444    (0..nsub)
445        .map(|s| {
446            let sv: Vec<Vec<f32>> = vecs
447                .iter()
448                .map(|v| v[s * sub_dim..(s + 1) * sub_dim].to_vec())
449                .collect();
450            kmeans(&sv, k.min(sv.len()), iters, metric)
451        })
452        .collect()
453}
454
455fn encode_vec(
456    v: &[f32],
457    cbs: &[Vec<Vec<f32>>],
458    sub_dim: usize,
459    m: DistanceMetric,
460) -> Result<Vec<u8>> {
461    cbs.iter()
462        .enumerate()
463        .map(|(s, cb)| {
464            let sub = &v[s * sub_dim..(s + 1) * sub_dim];
465            cb.iter()
466                .enumerate()
467                .min_by(|(_, a), (_, b)| dist(sub, a, m).partial_cmp(&dist(sub, b, m)).unwrap())
468                .map(|(i, _)| i as u8)
469                .ok_or_else(|| RuvectorError::Internal("Empty codebook".into()))
470        })
471        .collect()
472}
473
474fn decode_vec(codes: &[u8], cbs: &[Vec<Vec<f32>>]) -> Vec<f32> {
475    codes
476        .iter()
477        .enumerate()
478        .flat_map(|(s, &c)| cbs[s][c as usize].iter().copied())
479        .collect()
480}
481
482fn kmeans(
483    vecs: &[Vec<f32>],
484    k: usize,
485    iters: usize,
486    metric: DistanceMetric,
487) -> Result<Vec<Vec<f32>>> {
488    use rand::seq::SliceRandom;
489    if vecs.is_empty() || k == 0 {
490        return Err(RuvectorError::InvalidParameter(
491            "Cannot cluster empty set or k=0".into(),
492        ));
493    }
494    let dim = vecs[0].len();
495    let mut rng = rand::thread_rng();
496    let mut cents: Vec<Vec<f32>> = vecs.choose_multiple(&mut rng, k).cloned().collect();
497    for _ in 0..iters {
498        let (mut sums, mut counts) = (vec![vec![0.0f32; dim]; k], vec![0usize; k]);
499        for v in vecs {
500            let b = cents
501                .iter()
502                .enumerate()
503                .min_by(|(_, a), (_, b)| {
504                    dist(v, a, metric).partial_cmp(&dist(v, b, metric)).unwrap()
505                })
506                .map(|(i, _)| i)
507                .unwrap_or(0);
508            counts[b] += 1;
509            for (j, &val) in v.iter().enumerate() {
510                sums[b][j] += val;
511            }
512        }
513        for (i, c) in cents.iter_mut().enumerate() {
514            if counts[i] > 0 {
515                for j in 0..dim {
516                    c[j] = sums[i][j] / counts[i] as f32;
517                }
518            }
519        }
520    }
521    Ok(cents)
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    fn make_data(n: usize, d: usize) -> Vec<Vec<f32>> {
529        let mut seed: u64 = 42;
530        (0..n)
531            .map(|_| {
532                (0..d)
533                    .map(|_| {
534                        seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
535                        ((seed >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
536                    })
537                    .collect()
538            })
539            .collect()
540    }
541    fn cfg() -> OPQConfig {
542        OPQConfig {
543            num_subspaces: 2,
544            codebook_size: 4,
545            num_iterations: 5,
546            num_opq_iterations: 3,
547            metric: DistanceMetric::Euclidean,
548        }
549    }
550
551    #[test]
552    fn test_rotation_orthogonality() {
553        let r = RotationMatrix::identity(4);
554        let v = vec![1.0, 2.0, 3.0, 4.0];
555        let back = r.inverse_rotate(&r.rotate(&v));
556        for i in 0..4 {
557            assert!((v[i] - back[i]).abs() < 1e-6);
558        }
559    }
560    #[test]
561    fn test_rotation_preserves_norm() {
562        let data = make_data(30, 4);
563        let idx = OPQIndex::train(&data, cfg()).unwrap();
564        let v = vec![1.0, 2.0, 3.0, 4.0];
565        let n1: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
566        let n2: f32 = idx
567            .rotation
568            .rotate(&v)
569            .iter()
570            .map(|x| x * x)
571            .sum::<f32>()
572            .sqrt();
573        assert!((n1 - n2).abs() < 0.1, "norms: {} vs {}", n1, n2);
574    }
575    #[test]
576    fn test_pq_encoding_roundtrip() {
577        let data = make_data(30, 4);
578        let idx = OPQIndex::train(&data, cfg()).unwrap();
579        let codes = idx.encode(&data[0]).unwrap();
580        assert_eq!(codes.len(), 2);
581        assert_eq!(idx.decode(&codes).unwrap().len(), 4);
582    }
583    #[test]
584    fn test_opq_training_convergence() {
585        // Verify that OPQ training produces finite, non-negative quantization
586        // error and that trained index can encode/decode without degradation.
587        // Note: with small data and few centroids, more OPQ iterations do not
588        // guarantee monotone error decrease due to stochastic k-means.
589        let data = make_data(100, 4);
590        let idx = OPQIndex::train(&data, cfg()).unwrap();
591        let err = idx.quantization_error(&data).unwrap();
592        assert!(
593            err.is_finite() && err >= 0.0,
594            "error must be finite non-negative: {}",
595            err
596        );
597        // Verify round-trip through encode/decode does not explode.
598        for v in &data {
599            let codes = idx.encode(v).unwrap();
600            let decoded = idx.decode(&codes).unwrap();
601            assert_eq!(decoded.len(), v.len());
602            for x in &decoded {
603                assert!(x.is_finite());
604            }
605        }
606    }
607    #[test]
608    fn test_adc_correctness() {
609        let data = make_data(30, 4);
610        let idx = OPQIndex::train(&data, cfg()).unwrap();
611        let db: Vec<Vec<u8>> = data.iter().map(|v| idx.encode(v).unwrap()).collect();
612        let res = idx.search_adc(&[0.5, -0.5, 0.5, -0.5], &db, 3).unwrap();
613        assert_eq!(res.len(), 3);
614        for w in res.windows(2) {
615            assert!(w[0].1 <= w[1].1 + 1e-6);
616        }
617    }
618    #[test]
619    fn test_quantization_error_reduction() {
620        let data = make_data(50, 4);
621        let err = OPQIndex::train(&data, cfg())
622            .unwrap()
623            .quantization_error(&data)
624            .unwrap();
625        assert!(err >= 0.0 && err.is_finite() && err < 10.0, "err={}", err);
626    }
627    #[test]
628    fn test_svd_correctness() {
629        let a = Mat {
630            rows: 2,
631            cols: 2,
632            data: vec![3.0, 0.0, 0.0, 2.0],
633        };
634        let (u, s, v) = svd_full(&a, 200);
635        for i in 0..2 {
636            for j in 0..2 {
637                let r: f32 = (0..2).map(|k| u.get(i, k) * s[k] * v.get(j, k)).sum();
638                assert!(
639                    (a.get(i, j) - r).abs() < 0.1,
640                    "SVD fail ({},{}): {} vs {}",
641                    i,
642                    j,
643                    a.get(i, j),
644                    r
645                );
646            }
647        }
648    }
649    #[test]
650    fn test_identity_rotation_baseline() {
651        let data = make_data(30, 4);
652        let idx = OPQIndex::train(
653            &data,
654            OPQConfig {
655                num_opq_iterations: 1,
656                ..cfg()
657            },
658        )
659        .unwrap();
660        let recon = idx.decode(&idx.encode(&data[0]).unwrap()).unwrap();
661        assert_eq!(recon.len(), data[0].len());
662    }
663    #[test]
664    fn test_search_accuracy() {
665        let data = make_data(40, 4);
666        let idx = OPQIndex::train(&data, cfg()).unwrap();
667        let db: Vec<Vec<u8>> = data.iter().map(|v| idx.encode(v).unwrap()).collect();
668        let ids: Vec<usize> = idx
669            .search_adc(&data[0], &db, 5)
670            .unwrap()
671            .iter()
672            .map(|r| r.0)
673            .collect();
674        assert!(ids.contains(&0), "vector 0 should be in its own top-5");
675    }
676    #[test]
677    fn test_config_validation() {
678        assert!(OPQConfig {
679            codebook_size: 300,
680            ..cfg()
681        }
682        .validate()
683        .is_err());
684        assert!(OPQConfig {
685            num_subspaces: 0,
686            ..cfg()
687        }
688        .validate()
689        .is_err());
690        assert!(OPQConfig {
691            num_opq_iterations: 0,
692            ..cfg()
693        }
694        .validate()
695        .is_err());
696    }
697    #[test]
698    fn test_dimension_mismatch_errors() {
699        let idx = OPQIndex::train(&make_data(30, 4), cfg()).unwrap();
700        assert!(idx.encode(&[1.0, 2.0]).is_err());
701        assert!(idx.search_adc(&[1.0], &[], 1).is_err());
702    }
703}