1use crate::error::{Result, RuvectorError};
13use crate::types::DistanceMetric;
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct OPQConfig {
19 pub num_subspaces: usize,
21 pub codebook_size: usize,
23 pub num_iterations: usize,
25 pub num_opq_iterations: usize,
27 pub metric: DistanceMetric,
29}
30
31impl Default for OPQConfig {
32 fn default() -> Self {
33 Self {
34 num_subspaces: 8,
35 codebook_size: 256,
36 num_iterations: 20,
37 num_opq_iterations: 10,
38 metric: DistanceMetric::Euclidean,
39 }
40 }
41}
42
43impl OPQConfig {
44 pub fn validate(&self) -> Result<()> {
46 if self.codebook_size > 256 {
47 return Err(RuvectorError::InvalidParameter(format!(
48 "Codebook size {} exceeds u8 max 256",
49 self.codebook_size
50 )));
51 }
52 if self.num_subspaces == 0 {
53 return Err(RuvectorError::InvalidParameter(
54 "num_subspaces must be > 0".into(),
55 ));
56 }
57 if self.num_opq_iterations == 0 {
58 return Err(RuvectorError::InvalidParameter(
59 "num_opq_iterations must be > 0".into(),
60 ));
61 }
62 Ok(())
63 }
64}
65
66#[derive(Debug, Clone)]
69struct Mat {
70 rows: usize,
71 cols: usize,
72 data: Vec<f32>,
73}
74
75impl Mat {
76 fn zeros(r: usize, c: usize) -> Self {
77 Self {
78 rows: r,
79 cols: c,
80 data: vec![0.0; r * c],
81 }
82 }
83 fn identity(n: usize) -> Self {
84 let mut m = Self::zeros(n, n);
85 for i in 0..n {
86 m.data[i * n + i] = 1.0;
87 }
88 m
89 }
90 #[inline]
91 fn get(&self, r: usize, c: usize) -> f32 {
92 self.data[r * self.cols + c]
93 }
94 #[inline]
95 fn set(&mut self, r: usize, c: usize, v: f32) {
96 self.data[r * self.cols + c] = v;
97 }
98
99 fn transpose(&self) -> Self {
100 let mut t = Self::zeros(self.cols, self.rows);
101 for r in 0..self.rows {
102 for c in 0..self.cols {
103 t.set(c, r, self.get(r, c));
104 }
105 }
106 t
107 }
108 fn mul(&self, b: &Mat) -> Mat {
109 assert_eq!(self.cols, b.rows);
110 let mut out = Mat::zeros(self.rows, b.cols);
111 for i in 0..self.rows {
112 for k in 0..self.cols {
113 let a = self.get(i, k);
114 for j in 0..b.cols {
115 let c = out.get(i, j);
116 out.set(i, j, c + a * b.get(k, j));
117 }
118 }
119 }
120 out
121 }
122 fn from_rows(vecs: &[Vec<f32>]) -> Self {
123 let (rows, cols) = (vecs.len(), vecs[0].len());
124 let mut data = Vec::with_capacity(rows * cols);
125 for v in vecs {
126 data.extend_from_slice(v);
127 }
128 Self { rows, cols, data }
129 }
130 fn row(&self, i: usize) -> Vec<f32> {
131 self.data[i * self.cols..(i + 1) * self.cols].to_vec()
132 }
133}
134
135fn svd_rank1(a: &Mat, max_iters: usize) -> (Vec<f32>, f32, Vec<f32>) {
139 let ata = a.transpose().mul(a);
140 let n = ata.cols;
141 let mut v = vec![1.0 / (n as f32).sqrt(); n];
142 for _ in 0..max_iters {
143 let mut nv = vec![0.0; n];
144 for i in 0..n {
145 for j in 0..n {
146 nv[i] += ata.get(i, j) * v[j];
147 }
148 }
149 let norm: f32 = nv.iter().map(|x| x * x).sum::<f32>().sqrt();
150 if norm < 1e-12 {
151 break;
152 }
153 for x in nv.iter_mut() {
154 *x /= norm;
155 }
156 v = nv;
157 }
158 let mut av = vec![0.0; a.rows];
159 for i in 0..a.rows {
160 for j in 0..a.cols {
161 av[i] += a.get(i, j) * v[j];
162 }
163 }
164 let sigma: f32 = av.iter().map(|x| x * x).sum::<f32>().sqrt();
165 let u = if sigma > 1e-12 {
166 av.iter().map(|x| x / sigma).collect()
167 } else {
168 vec![0.0; a.rows]
169 };
170 (u, sigma, v)
171}
172
173fn svd_full(a: &Mat, iters: usize) -> (Mat, Vec<f32>, Mat) {
175 let n = a.rows;
176 let mut res = a.clone();
177 let (mut uc, mut sv, mut vc) = (Vec::new(), Vec::new(), Vec::new());
178 for _ in 0..n {
179 let (u, s, v) = svd_rank1(&res, iters);
180 if s > 1e-10 {
181 for i in 0..res.rows {
182 for j in 0..res.cols {
183 let c = res.get(i, j);
184 res.set(i, j, c - s * u[i] * v[j]);
185 }
186 }
187 }
188 uc.push(u);
189 sv.push(s);
190 vc.push(v);
191 }
192 let (mut um, mut vm) = (Mat::zeros(n, n), Mat::zeros(n, n));
193 for j in 0..n {
194 for i in 0..n {
195 um.set(i, j, uc[j][i]);
196 vm.set(i, j, vc[j][i]);
197 }
198 }
199 (um, sv, vm)
200}
201
202fn procrustes(x: &Mat, y: &Mat) -> Mat {
204 let m = x.transpose().mul(y);
205 let (u, _s, v) = svd_full(&m, 100);
206 v.mul(&u.transpose())
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct RotationMatrix {
214 pub dim: usize,
215 pub data: Vec<f32>,
216}
217
218impl RotationMatrix {
219 pub fn identity(dim: usize) -> Self {
221 let mut data = vec![0.0; dim * dim];
222 for i in 0..dim {
223 data[i * dim + i] = 1.0;
224 }
225 Self { dim, data }
226 }
227 pub fn rotate(&self, v: &[f32]) -> Vec<f32> {
229 let d = self.dim;
230 (0..d)
231 .map(|j| (0..d).map(|i| v[i] * self.data[i * d + j]).sum())
232 .collect()
233 }
234 pub fn inverse_rotate(&self, v: &[f32]) -> Vec<f32> {
236 let d = self.dim;
237 (0..d)
238 .map(|j| (0..d).map(|i| v[i] * self.data[j * d + i]).sum())
239 .collect()
240 }
241 fn from_mat(m: &Mat) -> Self {
242 Self {
243 dim: m.rows,
244 data: m.data.clone(),
245 }
246 }
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct OPQIndex {
254 pub config: OPQConfig,
255 pub rotation: RotationMatrix,
256 pub codebooks: Vec<Vec<Vec<f32>>>,
258 pub dimensions: usize,
259}
260
261impl OPQIndex {
262 pub fn train(vectors: &[Vec<f32>], config: OPQConfig) -> Result<Self> {
264 config.validate()?;
265 if vectors.is_empty() {
266 return Err(RuvectorError::InvalidParameter(
267 "Training set cannot be empty".into(),
268 ));
269 }
270 let d = vectors[0].len();
271 if d % config.num_subspaces != 0 {
272 return Err(RuvectorError::InvalidParameter(format!(
273 "Dimensions {} not divisible by num_subspaces {}",
274 d, config.num_subspaces
275 )));
276 }
277 for v in vectors {
278 if v.len() != d {
279 return Err(RuvectorError::DimensionMismatch {
280 expected: d,
281 actual: v.len(),
282 });
283 }
284 }
285 let x_mat = Mat::from_rows(vectors);
286 let mut r = Mat::identity(d);
287 let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::new();
288 let sub_dim = d / config.num_subspaces;
289 for _ in 0..config.num_opq_iterations {
290 let x_rot = x_mat.mul(&r);
291 let rotated: Vec<Vec<f32>> = (0..vectors.len()).map(|i| x_rot.row(i)).collect();
292 codebooks = train_pq_codebooks(
293 &rotated,
294 config.num_subspaces,
295 config.codebook_size,
296 config.num_iterations,
297 config.metric,
298 )?;
299 let mut x_hat = Mat::zeros(vectors.len(), d);
300 for (i, rv) in rotated.iter().enumerate() {
301 let codes = encode_vec(rv, &codebooks, sub_dim, config.metric)?;
302 let recon = decode_vec(&codes, &codebooks);
303 for (j, &val) in recon.iter().enumerate() {
304 x_hat.set(i, j, val);
305 }
306 }
307 r = procrustes(&x_mat, &x_hat);
308 }
309 Ok(Self {
310 config,
311 rotation: RotationMatrix::from_mat(&r),
312 codebooks,
313 dimensions: d,
314 })
315 }
316
317 pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
319 self.check_dim(vector.len())?;
320 let rotated = self.rotation.rotate(vector);
321 encode_vec(
322 &rotated,
323 &self.codebooks,
324 self.dimensions / self.config.num_subspaces,
325 self.config.metric,
326 )
327 }
328
329 pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
331 if codes.len() != self.config.num_subspaces {
332 return Err(RuvectorError::InvalidParameter(format!(
333 "Expected {} codes, got {}",
334 self.config.num_subspaces,
335 codes.len()
336 )));
337 }
338 Ok(self
339 .rotation
340 .inverse_rotate(&decode_vec(codes, &self.codebooks)))
341 }
342
343 pub fn search_adc(
345 &self,
346 query: &[f32],
347 codes_db: &[Vec<u8>],
348 top_k: usize,
349 ) -> Result<Vec<(usize, f32)>> {
350 self.check_dim(query.len())?;
351 let rq = self.rotation.rotate(query);
352 let sub_dim = self.dimensions / self.config.num_subspaces;
353 let tables: Vec<Vec<f32>> = (0..self.config.num_subspaces)
354 .map(|s| {
355 let q_sub = &rq[s * sub_dim..(s + 1) * sub_dim];
356 self.codebooks[s]
357 .iter()
358 .map(|c| dist(q_sub, c, self.config.metric))
359 .collect()
360 })
361 .collect();
362 let mut dists: Vec<(usize, f32)> = codes_db
363 .iter()
364 .enumerate()
365 .map(|(idx, codes)| {
366 let d: f32 = codes
367 .iter()
368 .enumerate()
369 .map(|(s, &c)| tables[s][c as usize])
370 .sum();
371 (idx, d)
372 })
373 .collect();
374 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
375 dists.truncate(top_k);
376 Ok(dists)
377 }
378
379 pub fn quantization_error(&self, vectors: &[Vec<f32>]) -> Result<f32> {
381 if vectors.is_empty() {
382 return Ok(0.0);
383 }
384 let mut total = 0.0f64;
385 for v in vectors {
386 let recon = self.decode(&self.encode(v)?)?;
387 total += v
388 .iter()
389 .zip(&recon)
390 .map(|(a, b)| ((a - b) as f64).powi(2))
391 .sum::<f64>();
392 }
393 Ok((total / vectors.len() as f64) as f32)
394 }
395
396 fn check_dim(&self, len: usize) -> Result<()> {
397 if len != self.dimensions {
398 Err(RuvectorError::DimensionMismatch {
399 expected: self.dimensions,
400 actual: len,
401 })
402 } else {
403 Ok(())
404 }
405 }
406}
407
408fn dist(a: &[f32], b: &[f32], m: DistanceMetric) -> f32 {
411 match m {
412 DistanceMetric::Euclidean => a
413 .iter()
414 .zip(b)
415 .map(|(x, y)| {
416 let d = x - y;
417 d * d
418 })
419 .sum::<f32>()
420 .sqrt(),
421 DistanceMetric::Cosine => {
422 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
423 let na = a.iter().map(|x| x * x).sum::<f32>().sqrt();
424 let nb = b.iter().map(|x| x * x).sum::<f32>().sqrt();
425 if na == 0.0 || nb == 0.0 {
426 1.0
427 } else {
428 1.0 - dot / (na * nb)
429 }
430 }
431 DistanceMetric::DotProduct => -a.iter().zip(b).map(|(x, y)| x * y).sum::<f32>(),
432 DistanceMetric::Manhattan => a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum(),
433 }
434}
435
436fn train_pq_codebooks(
437 vecs: &[Vec<f32>],
438 nsub: usize,
439 k: usize,
440 iters: usize,
441 metric: DistanceMetric,
442) -> Result<Vec<Vec<Vec<f32>>>> {
443 let sub_dim = vecs[0].len() / nsub;
444 (0..nsub)
445 .map(|s| {
446 let sv: Vec<Vec<f32>> = vecs
447 .iter()
448 .map(|v| v[s * sub_dim..(s + 1) * sub_dim].to_vec())
449 .collect();
450 kmeans(&sv, k.min(sv.len()), iters, metric)
451 })
452 .collect()
453}
454
455fn encode_vec(
456 v: &[f32],
457 cbs: &[Vec<Vec<f32>>],
458 sub_dim: usize,
459 m: DistanceMetric,
460) -> Result<Vec<u8>> {
461 cbs.iter()
462 .enumerate()
463 .map(|(s, cb)| {
464 let sub = &v[s * sub_dim..(s + 1) * sub_dim];
465 cb.iter()
466 .enumerate()
467 .min_by(|(_, a), (_, b)| dist(sub, a, m).partial_cmp(&dist(sub, b, m)).unwrap())
468 .map(|(i, _)| i as u8)
469 .ok_or_else(|| RuvectorError::Internal("Empty codebook".into()))
470 })
471 .collect()
472}
473
474fn decode_vec(codes: &[u8], cbs: &[Vec<Vec<f32>>]) -> Vec<f32> {
475 codes
476 .iter()
477 .enumerate()
478 .flat_map(|(s, &c)| cbs[s][c as usize].iter().copied())
479 .collect()
480}
481
482fn kmeans(
483 vecs: &[Vec<f32>],
484 k: usize,
485 iters: usize,
486 metric: DistanceMetric,
487) -> Result<Vec<Vec<f32>>> {
488 use rand::seq::SliceRandom;
489 if vecs.is_empty() || k == 0 {
490 return Err(RuvectorError::InvalidParameter(
491 "Cannot cluster empty set or k=0".into(),
492 ));
493 }
494 let dim = vecs[0].len();
495 let mut rng = rand::thread_rng();
496 let mut cents: Vec<Vec<f32>> = vecs.choose_multiple(&mut rng, k).cloned().collect();
497 for _ in 0..iters {
498 let (mut sums, mut counts) = (vec![vec![0.0f32; dim]; k], vec![0usize; k]);
499 for v in vecs {
500 let b = cents
501 .iter()
502 .enumerate()
503 .min_by(|(_, a), (_, b)| {
504 dist(v, a, metric).partial_cmp(&dist(v, b, metric)).unwrap()
505 })
506 .map(|(i, _)| i)
507 .unwrap_or(0);
508 counts[b] += 1;
509 for (j, &val) in v.iter().enumerate() {
510 sums[b][j] += val;
511 }
512 }
513 for (i, c) in cents.iter_mut().enumerate() {
514 if counts[i] > 0 {
515 for j in 0..dim {
516 c[j] = sums[i][j] / counts[i] as f32;
517 }
518 }
519 }
520 }
521 Ok(cents)
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 fn make_data(n: usize, d: usize) -> Vec<Vec<f32>> {
529 let mut seed: u64 = 42;
530 (0..n)
531 .map(|_| {
532 (0..d)
533 .map(|_| {
534 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
535 ((seed >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
536 })
537 .collect()
538 })
539 .collect()
540 }
541 fn cfg() -> OPQConfig {
542 OPQConfig {
543 num_subspaces: 2,
544 codebook_size: 4,
545 num_iterations: 5,
546 num_opq_iterations: 3,
547 metric: DistanceMetric::Euclidean,
548 }
549 }
550
551 #[test]
552 fn test_rotation_orthogonality() {
553 let r = RotationMatrix::identity(4);
554 let v = vec![1.0, 2.0, 3.0, 4.0];
555 let back = r.inverse_rotate(&r.rotate(&v));
556 for i in 0..4 {
557 assert!((v[i] - back[i]).abs() < 1e-6);
558 }
559 }
560 #[test]
561 fn test_rotation_preserves_norm() {
562 let data = make_data(30, 4);
563 let idx = OPQIndex::train(&data, cfg()).unwrap();
564 let v = vec![1.0, 2.0, 3.0, 4.0];
565 let n1: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
566 let n2: f32 = idx
567 .rotation
568 .rotate(&v)
569 .iter()
570 .map(|x| x * x)
571 .sum::<f32>()
572 .sqrt();
573 assert!((n1 - n2).abs() < 0.1, "norms: {} vs {}", n1, n2);
574 }
575 #[test]
576 fn test_pq_encoding_roundtrip() {
577 let data = make_data(30, 4);
578 let idx = OPQIndex::train(&data, cfg()).unwrap();
579 let codes = idx.encode(&data[0]).unwrap();
580 assert_eq!(codes.len(), 2);
581 assert_eq!(idx.decode(&codes).unwrap().len(), 4);
582 }
583 #[test]
584 fn test_opq_training_convergence() {
585 let data = make_data(100, 4);
590 let idx = OPQIndex::train(&data, cfg()).unwrap();
591 let err = idx.quantization_error(&data).unwrap();
592 assert!(
593 err.is_finite() && err >= 0.0,
594 "error must be finite non-negative: {}",
595 err
596 );
597 for v in &data {
599 let codes = idx.encode(v).unwrap();
600 let decoded = idx.decode(&codes).unwrap();
601 assert_eq!(decoded.len(), v.len());
602 for x in &decoded {
603 assert!(x.is_finite());
604 }
605 }
606 }
607 #[test]
608 fn test_adc_correctness() {
609 let data = make_data(30, 4);
610 let idx = OPQIndex::train(&data, cfg()).unwrap();
611 let db: Vec<Vec<u8>> = data.iter().map(|v| idx.encode(v).unwrap()).collect();
612 let res = idx.search_adc(&[0.5, -0.5, 0.5, -0.5], &db, 3).unwrap();
613 assert_eq!(res.len(), 3);
614 for w in res.windows(2) {
615 assert!(w[0].1 <= w[1].1 + 1e-6);
616 }
617 }
618 #[test]
619 fn test_quantization_error_reduction() {
620 let data = make_data(50, 4);
621 let err = OPQIndex::train(&data, cfg())
622 .unwrap()
623 .quantization_error(&data)
624 .unwrap();
625 assert!(err >= 0.0 && err.is_finite() && err < 10.0, "err={}", err);
626 }
627 #[test]
628 fn test_svd_correctness() {
629 let a = Mat {
630 rows: 2,
631 cols: 2,
632 data: vec![3.0, 0.0, 0.0, 2.0],
633 };
634 let (u, s, v) = svd_full(&a, 200);
635 for i in 0..2 {
636 for j in 0..2 {
637 let r: f32 = (0..2).map(|k| u.get(i, k) * s[k] * v.get(j, k)).sum();
638 assert!(
639 (a.get(i, j) - r).abs() < 0.1,
640 "SVD fail ({},{}): {} vs {}",
641 i,
642 j,
643 a.get(i, j),
644 r
645 );
646 }
647 }
648 }
649 #[test]
650 fn test_identity_rotation_baseline() {
651 let data = make_data(30, 4);
652 let idx = OPQIndex::train(
653 &data,
654 OPQConfig {
655 num_opq_iterations: 1,
656 ..cfg()
657 },
658 )
659 .unwrap();
660 let recon = idx.decode(&idx.encode(&data[0]).unwrap()).unwrap();
661 assert_eq!(recon.len(), data[0].len());
662 }
663 #[test]
664 fn test_search_accuracy() {
665 let data = make_data(40, 4);
666 let idx = OPQIndex::train(&data, cfg()).unwrap();
667 let db: Vec<Vec<u8>> = data.iter().map(|v| idx.encode(v).unwrap()).collect();
668 let ids: Vec<usize> = idx
669 .search_adc(&data[0], &db, 5)
670 .unwrap()
671 .iter()
672 .map(|r| r.0)
673 .collect();
674 assert!(ids.contains(&0), "vector 0 should be in its own top-5");
675 }
676 #[test]
677 fn test_config_validation() {
678 assert!(OPQConfig {
679 codebook_size: 300,
680 ..cfg()
681 }
682 .validate()
683 .is_err());
684 assert!(OPQConfig {
685 num_subspaces: 0,
686 ..cfg()
687 }
688 .validate()
689 .is_err());
690 assert!(OPQConfig {
691 num_opq_iterations: 0,
692 ..cfg()
693 }
694 .validate()
695 .is_err());
696 }
697 #[test]
698 fn test_dimension_mismatch_errors() {
699 let idx = OPQIndex::train(&make_data(30, 4), cfg()).unwrap();
700 assert!(idx.encode(&[1.0, 2.0]).is_err());
701 assert!(idx.search_adc(&[1.0], &[], 1).is_err());
702 }
703}