1use std::cmp::Ordering;
8use std::collections::BinaryHeap;
9
10use serde::{Deserialize, Serialize};
11
12use crate::{CoreError, CoreResult, VectorValue};
13
14mod kernels;
15mod turbo_quant;
16
17use kernels::{
18 cosine_distance, cosine_distance_with_lhs_norm, cosine_distance_with_norms, dot,
19 squared_euclidean, validate_precomputed_squared_norm,
20};
21pub use turbo_quant::{
22 TURBO_QUANT_BLOCK_ROWS, TurboQuantBitWidth, TurboQuantBlockedCodes, TurboQuantCodebook,
23 TurboQuantCodebookKind, TurboQuantCodecError, TurboQuantCodecResult, TurboQuantPackedCodes,
24};
25
26#[derive(
32 Clone,
33 Copy,
34 Debug,
35 Deserialize,
36 Eq,
37 Hash,
38 PartialEq,
39 rkyv::Archive,
40 rkyv::Deserialize,
41 rkyv::Serialize,
42 Serialize,
43)]
44pub enum VectorMetric {
45 SquaredEuclidean,
47 Cosine,
49 NegativeInnerProduct,
51}
52
53impl VectorMetric {
54 pub fn bind_query(self, query: &VectorValue) -> CoreResult<VectorMetricQuery<'_>> {
65 VectorMetricQuery::new(self, query)
66 }
67
68 pub fn bind_query_with_squared_norm(
81 self,
82 query: &VectorValue,
83 query_squared_norm: f64,
84 ) -> CoreResult<VectorMetricQuery<'_>> {
85 VectorMetricQuery::new_with_squared_norm(self, query, query_squared_norm)
86 }
87
88 pub fn distance(self, lhs: &VectorValue, rhs: &VectorValue) -> CoreResult<f64> {
96 let lhs = lhs.as_slice();
97 let rhs = rhs.as_slice();
98 check_same_dimension(lhs.len(), rhs.len())?;
99 Ok(canonical_score(match self {
100 Self::SquaredEuclidean => squared_euclidean(lhs, rhs),
101 Self::Cosine => cosine_distance(lhs, rhs)?,
102 Self::NegativeInnerProduct => -dot(lhs, rhs),
103 }))
104 }
105}
106
107#[derive(Clone, Copy, Debug)]
113pub struct VectorMetricQuery<'a> {
114 metric: VectorMetric,
115 query: &'a VectorValue,
116 query_norm: Option<f64>,
117}
118
119impl<'a> VectorMetricQuery<'a> {
120 fn new(metric: VectorMetric, query: &'a VectorValue) -> CoreResult<Self> {
121 let query_norm = match metric {
122 VectorMetric::SquaredEuclidean | VectorMetric::NegativeInnerProduct => None,
123 VectorMetric::Cosine => {
124 let norm = dot(query.as_slice(), query.as_slice());
125 if norm == 0.0 {
126 return Err(CoreError::VectorZeroNorm { side: "lhs" });
127 }
128 Some(norm)
129 }
130 };
131 Ok(Self {
132 metric,
133 query,
134 query_norm,
135 })
136 }
137
138 fn new_with_squared_norm(
139 metric: VectorMetric,
140 query: &'a VectorValue,
141 query_squared_norm: f64,
142 ) -> CoreResult<Self> {
143 let query_norm = match metric {
144 VectorMetric::SquaredEuclidean | VectorMetric::NegativeInnerProduct => None,
145 VectorMetric::Cosine => Some(validate_precomputed_squared_norm(
146 query_squared_norm,
147 "lhs",
148 )?),
149 };
150 Ok(Self {
151 metric,
152 query,
153 query_norm,
154 })
155 }
156
157 #[must_use]
159 pub const fn metric(&self) -> VectorMetric {
160 self.metric
161 }
162
163 #[must_use]
165 pub const fn query(&self) -> &'a VectorValue {
166 self.query
167 }
168
169 pub fn distance(&self, candidate: &VectorValue) -> CoreResult<f64> {
177 let query = self.query.as_slice();
178 let candidate = candidate.as_slice();
179 check_same_dimension(query.len(), candidate.len())?;
180 Ok(canonical_score(match self.metric {
181 VectorMetric::SquaredEuclidean => squared_euclidean(query, candidate),
182 VectorMetric::Cosine => cosine_distance_with_lhs_norm(
183 query,
184 candidate,
185 self.query_norm
186 .expect("cosine query scorer stores query norm"),
187 )?,
188 VectorMetric::NegativeInnerProduct => -dot(query, candidate),
189 }))
190 }
191
192 pub fn distance_with_candidate_squared_norm(
206 &self,
207 candidate: &VectorValue,
208 candidate_squared_norm: f64,
209 ) -> CoreResult<f64> {
210 let query = self.query.as_slice();
211 let candidate = candidate.as_slice();
212 check_same_dimension(query.len(), candidate.len())?;
213 Ok(canonical_score(match self.metric {
214 VectorMetric::SquaredEuclidean => squared_euclidean(query, candidate),
215 VectorMetric::Cosine => cosine_distance_with_norms(
216 query,
217 candidate,
218 self.query_norm
219 .expect("cosine query scorer stores query norm"),
220 candidate_squared_norm,
221 )?,
222 VectorMetric::NegativeInnerProduct => -dot(query, candidate),
223 }))
224 }
225}
226
227#[derive(Clone, Debug, PartialEq)]
229pub struct VectorSearchHit<K> {
230 pub key: K,
232 pub distance: f64,
234}
235
236#[derive(Debug)]
242pub struct VectorTopK<K> {
243 k: usize,
244 heap: BinaryHeap<HeapEntry<K>>,
245}
246
247impl<K: Ord> VectorTopK<K> {
248 #[must_use]
250 pub fn new(k: usize) -> Self {
251 Self {
252 k,
253 heap: BinaryHeap::with_capacity(k),
254 }
255 }
256
257 #[must_use]
259 pub const fn k(&self) -> usize {
260 self.k
261 }
262
263 #[must_use]
265 pub fn len(&self) -> usize {
266 self.heap.len()
267 }
268
269 #[must_use]
271 pub fn is_empty(&self) -> bool {
272 self.heap.is_empty()
273 }
274
275 pub fn push_distance(&mut self, key: K, distance: f64) {
281 debug_assert!(distance.is_finite(), "VectorTopK distances must be finite");
282 if self.k == 0 {
283 return;
284 }
285 let entry = HeapEntry { distance, key };
286 if self.heap.len() < self.k {
287 self.heap.push(entry);
288 return;
289 }
290 let Some(mut worst) = self.heap.peek_mut() else {
291 return;
292 };
293 if entry.cmp(&*worst).is_lt() {
294 *worst = entry;
295 }
296 }
297
298 #[must_use]
300 pub fn into_hits(self) -> Vec<VectorSearchHit<K>> {
301 let mut hits: Vec<_> = self
302 .heap
303 .into_iter()
304 .map(|entry| VectorSearchHit {
305 key: entry.key,
306 distance: entry.distance,
307 })
308 .collect();
309 hits.sort_by(compare_hit);
310 hits
311 }
312}
313
314pub fn exact_vector_top_k<'a, K, I>(
327 metric: VectorMetric,
328 query: &VectorValue,
329 candidates: I,
330 k: usize,
331) -> CoreResult<Vec<VectorSearchHit<K>>>
332where
333 K: Ord,
334 I: IntoIterator<Item = (K, &'a VectorValue)>,
335{
336 if k == 0 {
337 return Ok(Vec::new());
338 }
339
340 let mut top_k = VectorTopK::new(k);
341 let scorer = metric.bind_query(query)?;
342 for (key, vector) in candidates {
343 let distance = scorer.distance(vector)?;
344 top_k.push_distance(key, distance);
345 }
346
347 Ok(top_k.into_hits())
348}
349
350#[must_use]
356pub fn vector_squared_norm(vector: &VectorValue) -> f64 {
357 dot(vector.as_slice(), vector.as_slice())
358}
359
360#[derive(Debug)]
361struct HeapEntry<K> {
362 distance: f64,
363 key: K,
364}
365
366impl<K: Eq> Eq for HeapEntry<K> {}
367
368impl<K: Eq> PartialEq for HeapEntry<K> {
369 fn eq(&self, rhs: &Self) -> bool {
370 self.distance.to_bits() == rhs.distance.to_bits() && self.key == rhs.key
371 }
372}
373
374impl<K: Ord> Ord for HeapEntry<K> {
375 fn cmp(&self, rhs: &Self) -> Ordering {
376 self.distance
377 .total_cmp(&rhs.distance)
378 .then_with(|| self.key.cmp(&rhs.key))
379 }
380}
381
382impl<K: Ord> PartialOrd for HeapEntry<K> {
383 fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
384 Some(self.cmp(rhs))
385 }
386}
387
388fn compare_hit<K: Ord>(lhs: &VectorSearchHit<K>, rhs: &VectorSearchHit<K>) -> Ordering {
389 lhs.distance
390 .total_cmp(&rhs.distance)
391 .then_with(|| lhs.key.cmp(&rhs.key))
392}
393
394fn check_same_dimension(lhs: usize, rhs: usize) -> CoreResult<()> {
395 if lhs == rhs {
396 Ok(())
397 } else {
398 Err(CoreError::VectorDimensionMismatch { lhs, rhs })
399 }
400}
401
402fn canonical_score(score: f64) -> f64 {
403 if score == 0.0 { 0.0 } else { score }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 fn vector(components: &[f32]) -> VectorValue {
411 VectorValue::new(components.to_vec()).expect("test vector is valid")
412 }
413
414 #[test]
415 fn squared_euclidean_uses_f64_accumulation() {
416 let lhs = vector(&[1.0, 2.0, 3.0]);
417 let rhs = vector(&[1.0, 4.0, -1.0]);
418 let distance = VectorMetric::SquaredEuclidean
419 .distance(&lhs, &rhs)
420 .expect("dimensions match");
421 assert_eq!(distance, 20.0);
422 }
423
424 #[test]
425 fn negative_inner_product_is_lower_for_larger_dot_product() {
426 let query = vector(&[1.0, 2.0]);
427 let low_dot = vector(&[1.0, 0.0]);
428 let high_dot = vector(&[2.0, 2.0]);
429
430 let low_score = VectorMetric::NegativeInnerProduct
431 .distance(&query, &low_dot)
432 .expect("dimensions match");
433 let high_score = VectorMetric::NegativeInnerProduct
434 .distance(&query, &high_dot)
435 .expect("dimensions match");
436
437 assert!(high_score < low_score);
438 assert_eq!(low_score, -1.0);
439 assert_eq!(high_score, -6.0);
440 }
441
442 #[test]
443 fn metric_distance_canonicalizes_signed_zero_scores() {
444 let lhs = vector(&[0.0, -0.0]);
445 let rhs = vector(&[1.0, -1.0]);
446
447 let distance = VectorMetric::NegativeInnerProduct
448 .distance(&lhs, &rhs)
449 .expect("dimensions match");
450
451 assert_eq!(distance.to_bits(), 0.0_f64.to_bits());
452 }
453
454 #[test]
455 fn cosine_distance_handles_identical_and_opposite_vectors() {
456 let lhs = vector(&[1.0, 0.0]);
457 let same = vector(&[2.0, 0.0]);
458 let opposite = vector(&[-1.0, 0.0]);
459
460 assert_eq!(VectorMetric::Cosine.distance(&lhs, &same).unwrap(), 0.0);
461 assert_eq!(VectorMetric::Cosine.distance(&lhs, &opposite).unwrap(), 2.0);
462 }
463
464 #[test]
465 fn bound_query_scores_match_one_off_distance() {
466 let query = vector(&[1.0, 2.0, 3.0]);
467 let candidate = vector(&[4.0, 5.0, 6.0]);
468
469 for metric in [
470 VectorMetric::SquaredEuclidean,
471 VectorMetric::Cosine,
472 VectorMetric::NegativeInnerProduct,
473 ] {
474 let scorer = metric.bind_query(&query).unwrap();
475 assert_eq!(scorer.metric(), metric);
476 assert_eq!(scorer.query(), &query);
477 assert_eq!(
478 scorer.distance(&candidate).unwrap(),
479 metric.distance(&query, &candidate).unwrap()
480 );
481 }
482 }
483
484 #[test]
485 fn bound_query_accepts_precomputed_candidate_norm() {
486 let query = vector(&[1.0, 2.0, 3.0]);
487 let candidate = vector(&[4.0, 5.0, 6.0]);
488 let candidate_norm = dot(candidate.as_slice(), candidate.as_slice());
489
490 let scorer = VectorMetric::Cosine.bind_query(&query).unwrap();
491
492 assert_eq!(
493 scorer
494 .distance_with_candidate_squared_norm(&candidate, candidate_norm)
495 .unwrap(),
496 scorer.distance(&candidate).unwrap()
497 );
498 }
499
500 #[test]
501 fn bind_query_accepts_precomputed_query_norm() {
502 let query = vector(&[1.0, 2.0, 3.0]);
503 let candidate = vector(&[4.0, 5.0, 6.0]);
504 let query_norm = dot(query.as_slice(), query.as_slice());
505
506 let scorer = VectorMetric::Cosine
507 .bind_query_with_squared_norm(&query, query_norm)
508 .unwrap();
509
510 assert_eq!(
511 scorer.distance(&candidate).unwrap(),
512 VectorMetric::Cosine
513 .bind_query(&query)
514 .unwrap()
515 .distance(&candidate)
516 .unwrap()
517 );
518 }
519
520 #[test]
521 fn vector_squared_norm_matches_component_sum() {
522 let vector = vector(&[1.0, -2.0, 3.5]);
523
524 assert_eq!(vector_squared_norm(&vector), 17.25);
525 }
526
527 #[test]
528 fn bound_cosine_query_preserves_zero_norm_error_sides() {
529 let zero = vector(&[0.0, 0.0]);
530 let rhs = vector(&[1.0, 0.0]);
531
532 let error = VectorMetric::Cosine.bind_query(&zero).unwrap_err();
533 assert!(matches!(error, CoreError::VectorZeroNorm { side: "lhs" }));
534 let error = VectorMetric::Cosine
535 .bind_query_with_squared_norm(&rhs, 0.0)
536 .unwrap_err();
537 assert!(matches!(error, CoreError::VectorZeroNorm { side: "lhs" }));
538 let error = VectorMetric::Cosine
539 .bind_query_with_squared_norm(&rhs, f64::NAN)
540 .unwrap_err();
541 assert!(matches!(error, CoreError::VectorZeroNorm { side: "lhs" }));
542
543 let scorer = VectorMetric::Cosine.bind_query(&rhs).unwrap();
544 let error = scorer.distance(&zero).unwrap_err();
545 assert!(matches!(error, CoreError::VectorZeroNorm { side: "rhs" }));
546
547 let error = scorer
548 .distance_with_candidate_squared_norm(&rhs, 0.0)
549 .unwrap_err();
550 assert!(matches!(error, CoreError::VectorZeroNorm { side: "rhs" }));
551 let error = scorer
552 .distance_with_candidate_squared_norm(&rhs, -1.0)
553 .unwrap_err();
554 assert!(matches!(error, CoreError::VectorZeroNorm { side: "rhs" }));
555 }
556
557 #[test]
558 fn cosine_rejects_zero_norm_vectors() {
559 let zero = vector(&[0.0, 0.0]);
560 let rhs = vector(&[1.0, 0.0]);
561
562 let error = VectorMetric::Cosine.distance(&zero, &rhs).unwrap_err();
563 assert!(matches!(error, CoreError::VectorZeroNorm { side: "lhs" }));
564
565 let error = VectorMetric::Cosine.distance(&rhs, &zero).unwrap_err();
566 assert!(matches!(error, CoreError::VectorZeroNorm { side: "rhs" }));
567 }
568
569 #[test]
570 fn distance_rejects_dimension_mismatch() {
571 let lhs = vector(&[1.0, 2.0]);
572 let rhs = vector(&[1.0, 2.0, 3.0]);
573
574 let error = VectorMetric::SquaredEuclidean
575 .distance(&lhs, &rhs)
576 .unwrap_err();
577 assert!(matches!(
578 error,
579 CoreError::VectorDimensionMismatch { lhs: 2, rhs: 3 }
580 ));
581 }
582
583 #[test]
584 fn exact_top_k_returns_empty_for_zero_k() {
585 let query = vector(&[0.0]);
586 let candidate = vector(&[1.0]);
587 let candidates = [(7_u64, &candidate)];
588
589 let hits = exact_vector_top_k(VectorMetric::Cosine, &query, candidates, 0)
590 .expect("zero k does not inspect candidates");
591
592 assert!(hits.is_empty());
593 }
594
595 #[test]
596 fn vector_top_k_streams_and_orders_hits() {
597 let mut top_k = VectorTopK::new(2);
598 top_k.push_distance(3_u64, 0.25);
599 top_k.push_distance(1, 0.25);
600 top_k.push_distance(2, 0.5);
601 top_k.push_distance(4, 0.1);
602
603 assert_eq!(top_k.k(), 2);
604 assert_eq!(top_k.len(), 2);
605 assert_eq!(
606 top_k.into_hits(),
607 vec![
608 VectorSearchHit {
609 key: 4,
610 distance: 0.1
611 },
612 VectorSearchHit {
613 key: 1,
614 distance: 0.25
615 }
616 ]
617 );
618 }
619
620 #[test]
621 fn vector_top_k_zero_k_retains_nothing() {
622 let mut top_k = VectorTopK::new(0);
623 top_k.push_distance(1_u64, 0.0);
624
625 assert!(top_k.is_empty());
626 assert!(top_k.into_hits().is_empty());
627 }
628
629 #[test]
630 fn exact_top_k_is_distance_then_key_ordered() {
631 let query = vector(&[0.0]);
632 let one = vector(&[1.0]);
633 let two = vector(&[2.0]);
634 let candidates = [(3_u64, &two), (2, &one), (1, &one)];
635
636 let hits = exact_vector_top_k(VectorMetric::SquaredEuclidean, &query, candidates, 2)
637 .expect("all dimensions match");
638
639 assert_eq!(
640 hits,
641 vec![
642 VectorSearchHit {
643 key: 1,
644 distance: 1.0
645 },
646 VectorSearchHit {
647 key: 2,
648 distance: 1.0
649 }
650 ]
651 );
652 }
653
654 #[test]
655 fn exact_top_k_surfaces_candidate_metric_errors() {
656 let query = vector(&[0.0]);
657 let candidate = vector(&[1.0, 2.0]);
658 let candidates = [(1_u64, &candidate)];
659
660 let error =
661 exact_vector_top_k(VectorMetric::SquaredEuclidean, &query, candidates, 10).unwrap_err();
662
663 assert!(matches!(
664 error,
665 CoreError::VectorDimensionMismatch { lhs: 1, rhs: 2 }
666 ));
667 }
668}