1use crate::error::{Result, RuvectorError};
13use crate::types::DistanceMetric;
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct OPQConfig {
19 pub num_subspaces: usize,
21 pub codebook_size: usize,
23 pub num_iterations: usize,
25 pub num_opq_iterations: usize,
27 pub metric: DistanceMetric,
29}
30
31impl Default for OPQConfig {
32 fn default() -> Self {
33 Self {
34 num_subspaces: 8, codebook_size: 256, num_iterations: 20,
35 num_opq_iterations: 10, metric: DistanceMetric::Euclidean,
36 }
37 }
38}
39
40impl OPQConfig {
41 pub fn validate(&self) -> Result<()> {
43 if self.codebook_size > 256 {
44 return Err(RuvectorError::InvalidParameter(format!(
45 "Codebook size {} exceeds u8 max 256", self.codebook_size)));
46 }
47 if self.num_subspaces == 0 {
48 return Err(RuvectorError::InvalidParameter("num_subspaces must be > 0".into()));
49 }
50 if self.num_opq_iterations == 0 {
51 return Err(RuvectorError::InvalidParameter("num_opq_iterations must be > 0".into()));
52 }
53 Ok(())
54 }
55}
56
57#[derive(Debug, Clone)]
60struct Mat { rows: usize, cols: usize, data: Vec<f32> }
61
62impl Mat {
63 fn zeros(r: usize, c: usize) -> Self { Self { rows: r, cols: c, data: vec![0.0; r * c] } }
64 fn identity(n: usize) -> Self {
65 let mut m = Self::zeros(n, n);
66 for i in 0..n { m.data[i * n + i] = 1.0; }
67 m
68 }
69 #[inline] fn get(&self, r: usize, c: usize) -> f32 { self.data[r * self.cols + c] }
70 #[inline] fn set(&mut self, r: usize, c: usize, v: f32) { self.data[r * self.cols + c] = v; }
71
72 fn transpose(&self) -> Self {
73 let mut t = Self::zeros(self.cols, self.rows);
74 for r in 0..self.rows { for c in 0..self.cols { t.set(c, r, self.get(r, c)); } }
75 t
76 }
77 fn mul(&self, b: &Mat) -> Mat {
78 assert_eq!(self.cols, b.rows);
79 let mut out = Mat::zeros(self.rows, b.cols);
80 for i in 0..self.rows {
81 for k in 0..self.cols {
82 let a = self.get(i, k);
83 for j in 0..b.cols { let c = out.get(i, j); out.set(i, j, c + a * b.get(k, j)); }
84 }
85 }
86 out
87 }
88 fn from_rows(vecs: &[Vec<f32>]) -> Self {
89 let (rows, cols) = (vecs.len(), vecs[0].len());
90 let mut data = Vec::with_capacity(rows * cols);
91 for v in vecs { data.extend_from_slice(v); }
92 Self { rows, cols, data }
93 }
94 fn row(&self, i: usize) -> Vec<f32> { self.data[i * self.cols..(i + 1) * self.cols].to_vec() }
95}
96
97fn svd_rank1(a: &Mat, max_iters: usize) -> (Vec<f32>, f32, Vec<f32>) {
101 let ata = a.transpose().mul(a);
102 let n = ata.cols;
103 let mut v = vec![1.0 / (n as f32).sqrt(); n];
104 for _ in 0..max_iters {
105 let mut nv = vec![0.0; n];
106 for i in 0..n { for j in 0..n { nv[i] += ata.get(i, j) * v[j]; } }
107 let norm: f32 = nv.iter().map(|x| x * x).sum::<f32>().sqrt();
108 if norm < 1e-12 { break; }
109 for x in nv.iter_mut() { *x /= norm; }
110 v = nv;
111 }
112 let mut av = vec![0.0; a.rows];
113 for i in 0..a.rows { for j in 0..a.cols { av[i] += a.get(i, j) * v[j]; } }
114 let sigma: f32 = av.iter().map(|x| x * x).sum::<f32>().sqrt();
115 let u = if sigma > 1e-12 { av.iter().map(|x| x / sigma).collect() } else { vec![0.0; a.rows] };
116 (u, sigma, v)
117}
118
119fn svd_full(a: &Mat, iters: usize) -> (Mat, Vec<f32>, Mat) {
121 let n = a.rows;
122 let mut res = a.clone();
123 let (mut uc, mut sv, mut vc) = (Vec::new(), Vec::new(), Vec::new());
124 for _ in 0..n {
125 let (u, s, v) = svd_rank1(&res, iters);
126 if s > 1e-10 {
127 for i in 0..res.rows { for j in 0..res.cols {
128 let c = res.get(i, j); res.set(i, j, c - s * u[i] * v[j]);
129 }}
130 }
131 uc.push(u); sv.push(s); vc.push(v);
132 }
133 let (mut um, mut vm) = (Mat::zeros(n, n), Mat::zeros(n, n));
134 for j in 0..n { for i in 0..n { um.set(i, j, uc[j][i]); vm.set(i, j, vc[j][i]); } }
135 (um, sv, vm)
136}
137
138fn procrustes(x: &Mat, y: &Mat) -> Mat {
140 let m = x.transpose().mul(y);
141 let (u, _s, v) = svd_full(&m, 100);
142 v.mul(&u.transpose())
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct RotationMatrix { pub dim: usize, pub data: Vec<f32> }
150
151impl RotationMatrix {
152 pub fn identity(dim: usize) -> Self {
154 let mut data = vec![0.0; dim * dim];
155 for i in 0..dim { data[i * dim + i] = 1.0; }
156 Self { dim, data }
157 }
158 pub fn rotate(&self, v: &[f32]) -> Vec<f32> {
160 let d = self.dim;
161 (0..d).map(|j| (0..d).map(|i| v[i] * self.data[i * d + j]).sum()).collect()
162 }
163 pub fn inverse_rotate(&self, v: &[f32]) -> Vec<f32> {
165 let d = self.dim;
166 (0..d).map(|j| (0..d).map(|i| v[i] * self.data[j * d + i]).sum()).collect()
167 }
168 fn from_mat(m: &Mat) -> Self { Self { dim: m.rows, data: m.data.clone() } }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct OPQIndex {
176 pub config: OPQConfig,
177 pub rotation: RotationMatrix,
178 pub codebooks: Vec<Vec<Vec<f32>>>,
180 pub dimensions: usize,
181}
182
183impl OPQIndex {
184 pub fn train(vectors: &[Vec<f32>], config: OPQConfig) -> Result<Self> {
186 config.validate()?;
187 if vectors.is_empty() {
188 return Err(RuvectorError::InvalidParameter("Training set cannot be empty".into()));
189 }
190 let d = vectors[0].len();
191 if d % config.num_subspaces != 0 {
192 return Err(RuvectorError::InvalidParameter(format!(
193 "Dimensions {} not divisible by num_subspaces {}", d, config.num_subspaces)));
194 }
195 for v in vectors { if v.len() != d {
196 return Err(RuvectorError::DimensionMismatch { expected: d, actual: v.len() });
197 }}
198 let x_mat = Mat::from_rows(vectors);
199 let mut r = Mat::identity(d);
200 let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::new();
201 let sub_dim = d / config.num_subspaces;
202 for _ in 0..config.num_opq_iterations {
203 let x_rot = x_mat.mul(&r);
204 let rotated: Vec<Vec<f32>> = (0..vectors.len()).map(|i| x_rot.row(i)).collect();
205 codebooks = train_pq_codebooks(&rotated, config.num_subspaces,
206 config.codebook_size, config.num_iterations, config.metric)?;
207 let mut x_hat = Mat::zeros(vectors.len(), d);
208 for (i, rv) in rotated.iter().enumerate() {
209 let codes = encode_vec(rv, &codebooks, sub_dim, config.metric)?;
210 let recon = decode_vec(&codes, &codebooks);
211 for (j, &val) in recon.iter().enumerate() { x_hat.set(i, j, val); }
212 }
213 r = procrustes(&x_mat, &x_hat);
214 }
215 Ok(Self { config, rotation: RotationMatrix::from_mat(&r), codebooks, dimensions: d })
216 }
217
218 pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
220 self.check_dim(vector.len())?;
221 let rotated = self.rotation.rotate(vector);
222 encode_vec(&rotated, &self.codebooks,
223 self.dimensions / self.config.num_subspaces, self.config.metric)
224 }
225
226 pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
228 if codes.len() != self.config.num_subspaces {
229 return Err(RuvectorError::InvalidParameter(format!(
230 "Expected {} codes, got {}", self.config.num_subspaces, codes.len())));
231 }
232 Ok(self.rotation.inverse_rotate(&decode_vec(codes, &self.codebooks)))
233 }
234
235 pub fn search_adc(&self, query: &[f32], codes_db: &[Vec<u8>], top_k: usize,
237 ) -> Result<Vec<(usize, f32)>> {
238 self.check_dim(query.len())?;
239 let rq = self.rotation.rotate(query);
240 let sub_dim = self.dimensions / self.config.num_subspaces;
241 let tables: Vec<Vec<f32>> = (0..self.config.num_subspaces).map(|s| {
242 let q_sub = &rq[s * sub_dim..(s + 1) * sub_dim];
243 self.codebooks[s].iter().map(|c| dist(q_sub, c, self.config.metric)).collect()
244 }).collect();
245 let mut dists: Vec<(usize, f32)> = codes_db.iter().enumerate().map(|(idx, codes)| {
246 let d: f32 = codes.iter().enumerate().map(|(s, &c)| tables[s][c as usize]).sum();
247 (idx, d)
248 }).collect();
249 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
250 dists.truncate(top_k);
251 Ok(dists)
252 }
253
254 pub fn quantization_error(&self, vectors: &[Vec<f32>]) -> Result<f32> {
256 if vectors.is_empty() { return Ok(0.0); }
257 let mut total = 0.0f64;
258 for v in vectors {
259 let recon = self.decode(&self.encode(v)?)?;
260 total += v.iter().zip(&recon).map(|(a, b)| ((a - b) as f64).powi(2)).sum::<f64>();
261 }
262 Ok((total / vectors.len() as f64) as f32)
263 }
264
265 fn check_dim(&self, len: usize) -> Result<()> {
266 if len != self.dimensions {
267 Err(RuvectorError::DimensionMismatch { expected: self.dimensions, actual: len })
268 } else { Ok(()) }
269 }
270}
271
272fn dist(a: &[f32], b: &[f32], m: DistanceMetric) -> f32 {
275 match m {
276 DistanceMetric::Euclidean =>
277 a.iter().zip(b).map(|(x, y)| { let d = x - y; d * d }).sum::<f32>().sqrt(),
278 DistanceMetric::Cosine => {
279 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
280 let na = a.iter().map(|x| x * x).sum::<f32>().sqrt();
281 let nb = b.iter().map(|x| x * x).sum::<f32>().sqrt();
282 if na == 0.0 || nb == 0.0 { 1.0 } else { 1.0 - dot / (na * nb) }
283 }
284 DistanceMetric::DotProduct => -a.iter().zip(b).map(|(x, y)| x * y).sum::<f32>(),
285 DistanceMetric::Manhattan => a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum(),
286 }
287}
288
289fn train_pq_codebooks(vecs: &[Vec<f32>], nsub: usize, k: usize, iters: usize,
290 metric: DistanceMetric) -> Result<Vec<Vec<Vec<f32>>>> {
291 let sub_dim = vecs[0].len() / nsub;
292 (0..nsub).map(|s| {
293 let sv: Vec<Vec<f32>> = vecs.iter().map(|v| v[s*sub_dim..(s+1)*sub_dim].to_vec()).collect();
294 kmeans(&sv, k.min(sv.len()), iters, metric)
295 }).collect()
296}
297
298fn encode_vec(v: &[f32], cbs: &[Vec<Vec<f32>>], sub_dim: usize, m: DistanceMetric,
299) -> Result<Vec<u8>> {
300 cbs.iter().enumerate().map(|(s, cb)| {
301 let sub = &v[s * sub_dim..(s + 1) * sub_dim];
302 cb.iter().enumerate()
303 .min_by(|(_, a), (_, b)| dist(sub, a, m).partial_cmp(&dist(sub, b, m)).unwrap())
304 .map(|(i, _)| i as u8)
305 .ok_or_else(|| RuvectorError::Internal("Empty codebook".into()))
306 }).collect()
307}
308
309fn decode_vec(codes: &[u8], cbs: &[Vec<Vec<f32>>]) -> Vec<f32> {
310 codes.iter().enumerate().flat_map(|(s, &c)| cbs[s][c as usize].iter().copied()).collect()
311}
312
313fn kmeans(vecs: &[Vec<f32>], k: usize, iters: usize, metric: DistanceMetric,
314) -> Result<Vec<Vec<f32>>> {
315 use rand::seq::SliceRandom;
316 if vecs.is_empty() || k == 0 {
317 return Err(RuvectorError::InvalidParameter("Cannot cluster empty set or k=0".into()));
318 }
319 let dim = vecs[0].len();
320 let mut rng = rand::thread_rng();
321 let mut cents: Vec<Vec<f32>> = vecs.choose_multiple(&mut rng, k).cloned().collect();
322 for _ in 0..iters {
323 let (mut sums, mut counts) = (vec![vec![0.0f32; dim]; k], vec![0usize; k]);
324 for v in vecs {
325 let b = cents.iter().enumerate()
326 .min_by(|(_, a), (_, b)| dist(v, a, metric).partial_cmp(&dist(v, b, metric)).unwrap())
327 .map(|(i, _)| i).unwrap_or(0);
328 counts[b] += 1;
329 for (j, &val) in v.iter().enumerate() { sums[b][j] += val; }
330 }
331 for (i, c) in cents.iter_mut().enumerate() {
332 if counts[i] > 0 { for j in 0..dim { c[j] = sums[i][j] / counts[i] as f32; } }
333 }
334 }
335 Ok(cents)
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 fn make_data(n: usize, d: usize) -> Vec<Vec<f32>> {
343 let mut seed: u64 = 42;
344 (0..n).map(|_| (0..d).map(|_| {
345 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
346 ((seed >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
347 }).collect()).collect()
348 }
349 fn cfg() -> OPQConfig {
350 OPQConfig { num_subspaces: 2, codebook_size: 4, num_iterations: 5,
351 num_opq_iterations: 3, metric: DistanceMetric::Euclidean }
352 }
353
354 #[test]
355 fn test_rotation_orthogonality() {
356 let r = RotationMatrix::identity(4);
357 let v = vec![1.0, 2.0, 3.0, 4.0];
358 let back = r.inverse_rotate(&r.rotate(&v));
359 for i in 0..4 { assert!((v[i] - back[i]).abs() < 1e-6); }
360 }
361 #[test]
362 fn test_rotation_preserves_norm() {
363 let data = make_data(30, 4);
364 let idx = OPQIndex::train(&data, cfg()).unwrap();
365 let v = vec![1.0, 2.0, 3.0, 4.0];
366 let n1: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
367 let n2: f32 = idx.rotation.rotate(&v).iter().map(|x| x * x).sum::<f32>().sqrt();
368 assert!((n1 - n2).abs() < 0.1, "norms: {} vs {}", n1, n2);
369 }
370 #[test]
371 fn test_pq_encoding_roundtrip() {
372 let data = make_data(30, 4);
373 let idx = OPQIndex::train(&data, cfg()).unwrap();
374 let codes = idx.encode(&data[0]).unwrap();
375 assert_eq!(codes.len(), 2);
376 assert_eq!(idx.decode(&codes).unwrap().len(), 4);
377 }
378 #[test]
379 fn test_opq_training_convergence() {
380 let data = make_data(100, 4);
385 let idx = OPQIndex::train(&data, cfg()).unwrap();
386 let err = idx.quantization_error(&data).unwrap();
387 assert!(err.is_finite() && err >= 0.0, "error must be finite non-negative: {}", err);
388 for v in &data {
390 let codes = idx.encode(v).unwrap();
391 let decoded = idx.decode(&codes).unwrap();
392 assert_eq!(decoded.len(), v.len());
393 for x in &decoded { assert!(x.is_finite()); }
394 }
395 }
396 #[test]
397 fn test_adc_correctness() {
398 let data = make_data(30, 4);
399 let idx = OPQIndex::train(&data, cfg()).unwrap();
400 let db: Vec<Vec<u8>> = data.iter().map(|v| idx.encode(v).unwrap()).collect();
401 let res = idx.search_adc(&[0.5, -0.5, 0.5, -0.5], &db, 3).unwrap();
402 assert_eq!(res.len(), 3);
403 for w in res.windows(2) { assert!(w[0].1 <= w[1].1 + 1e-6); }
404 }
405 #[test]
406 fn test_quantization_error_reduction() {
407 let data = make_data(50, 4);
408 let err = OPQIndex::train(&data, cfg()).unwrap().quantization_error(&data).unwrap();
409 assert!(err >= 0.0 && err.is_finite() && err < 10.0, "err={}", err);
410 }
411 #[test]
412 fn test_svd_correctness() {
413 let a = Mat { rows: 2, cols: 2, data: vec![3.0, 0.0, 0.0, 2.0] };
414 let (u, s, v) = svd_full(&a, 200);
415 for i in 0..2 { for j in 0..2 {
416 let r: f32 = (0..2).map(|k| u.get(i, k) * s[k] * v.get(j, k)).sum();
417 assert!((a.get(i, j) - r).abs() < 0.1, "SVD fail ({},{}): {} vs {}", i, j, a.get(i, j), r);
418 }}
419 }
420 #[test]
421 fn test_identity_rotation_baseline() {
422 let data = make_data(30, 4);
423 let idx = OPQIndex::train(&data, OPQConfig { num_opq_iterations: 1, ..cfg() }).unwrap();
424 let recon = idx.decode(&idx.encode(&data[0]).unwrap()).unwrap();
425 assert_eq!(recon.len(), data[0].len());
426 }
427 #[test]
428 fn test_search_accuracy() {
429 let data = make_data(40, 4);
430 let idx = OPQIndex::train(&data, cfg()).unwrap();
431 let db: Vec<Vec<u8>> = data.iter().map(|v| idx.encode(v).unwrap()).collect();
432 let ids: Vec<usize> = idx.search_adc(&data[0], &db, 5).unwrap().iter().map(|r| r.0).collect();
433 assert!(ids.contains(&0), "vector 0 should be in its own top-5");
434 }
435 #[test]
436 fn test_config_validation() {
437 assert!(OPQConfig { codebook_size: 300, ..cfg() }.validate().is_err());
438 assert!(OPQConfig { num_subspaces: 0, ..cfg() }.validate().is_err());
439 assert!(OPQConfig { num_opq_iterations: 0, ..cfg() }.validate().is_err());
440 }
441 #[test]
442 fn test_dimension_mismatch_errors() {
443 let idx = OPQIndex::train(&make_data(30, 4), cfg()).unwrap();
444 assert!(idx.encode(&[1.0, 2.0]).is_err());
445 assert!(idx.search_adc(&[1.0], &[], 1).is_err());
446 }
447}