Skip to main content

selene_core/
vector.rs

1//! Native dense-vector metric kernels and exact-search helpers.
2//!
3//! The ANN index layer builds on these primitives so approximate indexes and
4//! exhaustive recall oracles share one definition of distance, tie-breaking,
5//! and vector validity.
6
7use std::cmp::Ordering;
8use std::collections::BinaryHeap;
9
10use serde::{Deserialize, Serialize};
11
12use crate::{CoreError, CoreResult, VectorValue};
13
14mod kernels;
15mod turbo_quant;
16
17use kernels::{
18    cosine_distance, cosine_distance_with_lhs_norm, cosine_distance_with_norms, dot,
19    squared_euclidean, validate_precomputed_squared_norm,
20};
21pub use turbo_quant::{
22    TURBO_QUANT_BLOCK_ROWS, TurboQuantBitWidth, TurboQuantBlockedCodes, TurboQuantCodebook,
23    TurboQuantCodebookKind, TurboQuantCodecError, TurboQuantCodecResult, TurboQuantPackedCodes,
24};
25
26/// Distance metric for native dense vectors.
27///
28/// All metrics return a score where **lower is better**. `NegativeInnerProduct`
29/// is the max-inner-product-search adapter: vectors with larger dot products
30/// produce smaller, more favorable scores.
31#[derive(
32    Clone,
33    Copy,
34    Debug,
35    Deserialize,
36    Eq,
37    Hash,
38    PartialEq,
39    rkyv::Archive,
40    rkyv::Deserialize,
41    rkyv::Serialize,
42    Serialize,
43)]
44pub enum VectorMetric {
45    /// Squared Euclidean distance (`sum((lhs_i - rhs_i)^2)`).
46    SquaredEuclidean,
47    /// Cosine distance (`1 - cosine_similarity`).
48    Cosine,
49    /// Negated dot product (`-sum(lhs_i * rhs_i)`), lower-is-better MIPS.
50    NegativeInnerProduct,
51}
52
53impl VectorMetric {
54    /// Bind this metric to one query vector for repeated candidate scoring.
55    ///
56    /// This precomputes metric-specific query state, such as cosine query norm,
57    /// so exact scans and ANN traversals do not redo invariant work for every
58    /// candidate.
59    ///
60    /// # Errors
61    ///
62    /// [`VectorMetric::Cosine`] returns [`CoreError::VectorZeroNorm`] when the
63    /// query has zero magnitude.
64    pub fn bind_query(self, query: &VectorValue) -> CoreResult<VectorMetricQuery<'_>> {
65        VectorMetricQuery::new(self, query)
66    }
67
68    /// Bind this metric to one query vector with a precomputed query squared norm.
69    ///
70    /// When `query_squared_norm` is the query's actual squared norm, this is
71    /// equivalent to [`Self::bind_query`]. It lets ANN indexes cache entry
72    /// norms for cosine centroid assignment while preserving the canonical
73    /// metric kernels and error behavior. Non-cosine metrics ignore
74    /// `query_squared_norm`.
75    ///
76    /// # Errors
77    ///
78    /// [`VectorMetric::Cosine`] returns [`CoreError::VectorZeroNorm`] when the
79    /// supplied query squared norm is not positive and finite.
80    pub fn bind_query_with_squared_norm(
81        self,
82        query: &VectorValue,
83        query_squared_norm: f64,
84    ) -> CoreResult<VectorMetricQuery<'_>> {
85        VectorMetricQuery::new_with_squared_norm(self, query, query_squared_norm)
86    }
87
88    /// Compute this metric for two vectors.
89    ///
90    /// # Errors
91    ///
92    /// Returns [`CoreError::VectorDimensionMismatch`] if dimensions differ.
93    /// [`VectorMetric::Cosine`] also returns [`CoreError::VectorZeroNorm`]
94    /// when either vector has zero magnitude.
95    pub fn distance(self, lhs: &VectorValue, rhs: &VectorValue) -> CoreResult<f64> {
96        let lhs = lhs.as_slice();
97        let rhs = rhs.as_slice();
98        check_same_dimension(lhs.len(), rhs.len())?;
99        Ok(canonical_score(match self {
100            Self::SquaredEuclidean => squared_euclidean(lhs, rhs),
101            Self::Cosine => cosine_distance(lhs, rhs)?,
102            Self::NegativeInnerProduct => -dot(lhs, rhs),
103        }))
104    }
105}
106
107/// Metric scorer bound to a single query vector.
108///
109/// Use this when ranking many candidates against one query. It preserves the
110/// same scores and error contract as [`VectorMetric::distance`], but avoids
111/// recomputing query-only metric state for every candidate.
112#[derive(Clone, Copy, Debug)]
113pub struct VectorMetricQuery<'a> {
114    metric: VectorMetric,
115    query: &'a VectorValue,
116    query_norm: Option<f64>,
117}
118
119impl<'a> VectorMetricQuery<'a> {
120    fn new(metric: VectorMetric, query: &'a VectorValue) -> CoreResult<Self> {
121        let query_norm = match metric {
122            VectorMetric::SquaredEuclidean | VectorMetric::NegativeInnerProduct => None,
123            VectorMetric::Cosine => {
124                let norm = dot(query.as_slice(), query.as_slice());
125                if norm == 0.0 {
126                    return Err(CoreError::VectorZeroNorm { side: "lhs" });
127                }
128                Some(norm)
129            }
130        };
131        Ok(Self {
132            metric,
133            query,
134            query_norm,
135        })
136    }
137
138    fn new_with_squared_norm(
139        metric: VectorMetric,
140        query: &'a VectorValue,
141        query_squared_norm: f64,
142    ) -> CoreResult<Self> {
143        let query_norm = match metric {
144            VectorMetric::SquaredEuclidean | VectorMetric::NegativeInnerProduct => None,
145            VectorMetric::Cosine => Some(validate_precomputed_squared_norm(
146                query_squared_norm,
147                "lhs",
148            )?),
149        };
150        Ok(Self {
151            metric,
152            query,
153            query_norm,
154        })
155    }
156
157    /// Return the metric this scorer uses.
158    #[must_use]
159    pub const fn metric(&self) -> VectorMetric {
160        self.metric
161    }
162
163    /// Return the bound query vector.
164    #[must_use]
165    pub const fn query(&self) -> &'a VectorValue {
166        self.query
167    }
168
169    /// Compute this bound query's lower-is-better distance to `candidate`.
170    ///
171    /// # Errors
172    ///
173    /// Returns [`CoreError::VectorDimensionMismatch`] if dimensions differ.
174    /// [`VectorMetric::Cosine`] also returns [`CoreError::VectorZeroNorm`] when
175    /// `candidate` has zero magnitude.
176    pub fn distance(&self, candidate: &VectorValue) -> CoreResult<f64> {
177        let query = self.query.as_slice();
178        let candidate = candidate.as_slice();
179        check_same_dimension(query.len(), candidate.len())?;
180        Ok(canonical_score(match self.metric {
181            VectorMetric::SquaredEuclidean => squared_euclidean(query, candidate),
182            VectorMetric::Cosine => cosine_distance_with_lhs_norm(
183                query,
184                candidate,
185                self.query_norm
186                    .expect("cosine query scorer stores query norm"),
187            )?,
188            VectorMetric::NegativeInnerProduct => -dot(query, candidate),
189        }))
190    }
191
192    /// Compute distance using a precomputed candidate squared norm.
193    ///
194    /// When `candidate_squared_norm` is the candidate's actual squared norm,
195    /// this is equivalent to [`Self::distance`]. It lets ANN indexes cache
196    /// centroid norms for cosine scoring while still using the canonical metric
197    /// kernels and error behavior. Non-cosine metrics ignore
198    /// `candidate_squared_norm`.
199    ///
200    /// # Errors
201    ///
202    /// Returns [`CoreError::VectorDimensionMismatch`] if dimensions differ.
203    /// [`VectorMetric::Cosine`] returns [`CoreError::VectorZeroNorm`] when the
204    /// supplied candidate squared norm is not positive and finite.
205    pub fn distance_with_candidate_squared_norm(
206        &self,
207        candidate: &VectorValue,
208        candidate_squared_norm: f64,
209    ) -> CoreResult<f64> {
210        let query = self.query.as_slice();
211        let candidate = candidate.as_slice();
212        check_same_dimension(query.len(), candidate.len())?;
213        Ok(canonical_score(match self.metric {
214            VectorMetric::SquaredEuclidean => squared_euclidean(query, candidate),
215            VectorMetric::Cosine => cosine_distance_with_norms(
216                query,
217                candidate,
218                self.query_norm
219                    .expect("cosine query scorer stores query norm"),
220                candidate_squared_norm,
221            )?,
222            VectorMetric::NegativeInnerProduct => -dot(query, candidate),
223        }))
224    }
225}
226
227/// A single exact vector-search result.
228#[derive(Clone, Debug, PartialEq)]
229pub struct VectorSearchHit<K> {
230    /// Caller-owned candidate key, such as a node id or row ordinal.
231    pub key: K,
232    /// Lower-is-better score under the requested [`VectorMetric`].
233    pub distance: f64,
234}
235
236/// Bounded deterministic lower-is-better vector hit accumulator.
237///
238/// This is the streaming form of [`exact_vector_top_k`]. It keeps only the
239/// current best `k` hits in memory, so graph-layer exact scans can avoid
240/// materializing every candidate before ranking.
241#[derive(Debug)]
242pub struct VectorTopK<K> {
243    k: usize,
244    heap: BinaryHeap<HeapEntry<K>>,
245}
246
247impl<K: Ord> VectorTopK<K> {
248    /// Construct an empty accumulator that will retain at most `k` hits.
249    #[must_use]
250    pub fn new(k: usize) -> Self {
251        Self {
252            k,
253            heap: BinaryHeap::with_capacity(k),
254        }
255    }
256
257    /// Return the configured result cap.
258    #[must_use]
259    pub const fn k(&self) -> usize {
260        self.k
261    }
262
263    /// Return the number of retained hits.
264    #[must_use]
265    pub fn len(&self) -> usize {
266        self.heap.len()
267    }
268
269    /// Return true when no hits are retained.
270    #[must_use]
271    pub fn is_empty(&self) -> bool {
272        self.heap.is_empty()
273    }
274
275    /// Push one candidate distance into the accumulator.
276    ///
277    /// `distance` must be a finite lower-is-better score produced by
278    /// [`VectorMetric::distance`] or an equivalent metric kernel. Ties are
279    /// deterministic: lower distance wins, then lower `key` wins.
280    pub fn push_distance(&mut self, key: K, distance: f64) {
281        debug_assert!(distance.is_finite(), "VectorTopK distances must be finite");
282        if self.k == 0 {
283            return;
284        }
285        let entry = HeapEntry { distance, key };
286        if self.heap.len() < self.k {
287            self.heap.push(entry);
288            return;
289        }
290        let Some(mut worst) = self.heap.peek_mut() else {
291            return;
292        };
293        if entry.cmp(&*worst).is_lt() {
294            *worst = entry;
295        }
296    }
297
298    /// Return retained hits sorted best-first.
299    #[must_use]
300    pub fn into_hits(self) -> Vec<VectorSearchHit<K>> {
301        let mut hits: Vec<_> = self
302            .heap
303            .into_iter()
304            .map(|entry| VectorSearchHit {
305                key: entry.key,
306                distance: entry.distance,
307            })
308            .collect();
309        hits.sort_by(compare_hit);
310        hits
311    }
312}
313
314/// Return the exact top-`k` nearest vector candidates.
315///
316/// This is intentionally a small exhaustive oracle, not an ANN index. Future
317/// HNSW/IVF/PQ implementations should use it for recall validation and for
318/// small result sets where index build cost cannot amortize.
319///
320/// Ties are deterministic: lower distance wins, then lower `key` wins.
321///
322/// # Errors
323///
324/// Returns a vector metric error if any candidate cannot be compared to
325/// `query` under `metric`.
326pub fn exact_vector_top_k<'a, K, I>(
327    metric: VectorMetric,
328    query: &VectorValue,
329    candidates: I,
330    k: usize,
331) -> CoreResult<Vec<VectorSearchHit<K>>>
332where
333    K: Ord,
334    I: IntoIterator<Item = (K, &'a VectorValue)>,
335{
336    if k == 0 {
337        return Ok(Vec::new());
338    }
339
340    let mut top_k = VectorTopK::new(k);
341    let scorer = metric.bind_query(query)?;
342    for (key, vector) in candidates {
343        let distance = scorer.distance(vector)?;
344        top_k.push_distance(key, distance);
345    }
346
347    Ok(top_k.into_hits())
348}
349
350/// Return `sum(component * component)` for a validated vector.
351///
352/// This is the shared squared-norm helper for ANN indexes that cache cosine
353/// query or candidate norms. It intentionally uses the same chunked dot-product
354/// kernel as the vector metric scorer.
355#[must_use]
356pub fn vector_squared_norm(vector: &VectorValue) -> f64 {
357    dot(vector.as_slice(), vector.as_slice())
358}
359
360#[derive(Debug)]
361struct HeapEntry<K> {
362    distance: f64,
363    key: K,
364}
365
366impl<K: Eq> Eq for HeapEntry<K> {}
367
368impl<K: Eq> PartialEq for HeapEntry<K> {
369    fn eq(&self, rhs: &Self) -> bool {
370        self.distance.to_bits() == rhs.distance.to_bits() && self.key == rhs.key
371    }
372}
373
374impl<K: Ord> Ord for HeapEntry<K> {
375    fn cmp(&self, rhs: &Self) -> Ordering {
376        self.distance
377            .total_cmp(&rhs.distance)
378            .then_with(|| self.key.cmp(&rhs.key))
379    }
380}
381
382impl<K: Ord> PartialOrd for HeapEntry<K> {
383    fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
384        Some(self.cmp(rhs))
385    }
386}
387
388fn compare_hit<K: Ord>(lhs: &VectorSearchHit<K>, rhs: &VectorSearchHit<K>) -> Ordering {
389    lhs.distance
390        .total_cmp(&rhs.distance)
391        .then_with(|| lhs.key.cmp(&rhs.key))
392}
393
394fn check_same_dimension(lhs: usize, rhs: usize) -> CoreResult<()> {
395    if lhs == rhs {
396        Ok(())
397    } else {
398        Err(CoreError::VectorDimensionMismatch { lhs, rhs })
399    }
400}
401
402fn canonical_score(score: f64) -> f64 {
403    if score == 0.0 { 0.0 } else { score }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    fn vector(components: &[f32]) -> VectorValue {
411        VectorValue::new(components.to_vec()).expect("test vector is valid")
412    }
413
414    #[test]
415    fn squared_euclidean_uses_f64_accumulation() {
416        let lhs = vector(&[1.0, 2.0, 3.0]);
417        let rhs = vector(&[1.0, 4.0, -1.0]);
418        let distance = VectorMetric::SquaredEuclidean
419            .distance(&lhs, &rhs)
420            .expect("dimensions match");
421        assert_eq!(distance, 20.0);
422    }
423
424    #[test]
425    fn negative_inner_product_is_lower_for_larger_dot_product() {
426        let query = vector(&[1.0, 2.0]);
427        let low_dot = vector(&[1.0, 0.0]);
428        let high_dot = vector(&[2.0, 2.0]);
429
430        let low_score = VectorMetric::NegativeInnerProduct
431            .distance(&query, &low_dot)
432            .expect("dimensions match");
433        let high_score = VectorMetric::NegativeInnerProduct
434            .distance(&query, &high_dot)
435            .expect("dimensions match");
436
437        assert!(high_score < low_score);
438        assert_eq!(low_score, -1.0);
439        assert_eq!(high_score, -6.0);
440    }
441
442    #[test]
443    fn metric_distance_canonicalizes_signed_zero_scores() {
444        let lhs = vector(&[0.0, -0.0]);
445        let rhs = vector(&[1.0, -1.0]);
446
447        let distance = VectorMetric::NegativeInnerProduct
448            .distance(&lhs, &rhs)
449            .expect("dimensions match");
450
451        assert_eq!(distance.to_bits(), 0.0_f64.to_bits());
452    }
453
454    #[test]
455    fn cosine_distance_handles_identical_and_opposite_vectors() {
456        let lhs = vector(&[1.0, 0.0]);
457        let same = vector(&[2.0, 0.0]);
458        let opposite = vector(&[-1.0, 0.0]);
459
460        assert_eq!(VectorMetric::Cosine.distance(&lhs, &same).unwrap(), 0.0);
461        assert_eq!(VectorMetric::Cosine.distance(&lhs, &opposite).unwrap(), 2.0);
462    }
463
464    #[test]
465    fn bound_query_scores_match_one_off_distance() {
466        let query = vector(&[1.0, 2.0, 3.0]);
467        let candidate = vector(&[4.0, 5.0, 6.0]);
468
469        for metric in [
470            VectorMetric::SquaredEuclidean,
471            VectorMetric::Cosine,
472            VectorMetric::NegativeInnerProduct,
473        ] {
474            let scorer = metric.bind_query(&query).unwrap();
475            assert_eq!(scorer.metric(), metric);
476            assert_eq!(scorer.query(), &query);
477            assert_eq!(
478                scorer.distance(&candidate).unwrap(),
479                metric.distance(&query, &candidate).unwrap()
480            );
481        }
482    }
483
484    #[test]
485    fn bound_query_accepts_precomputed_candidate_norm() {
486        let query = vector(&[1.0, 2.0, 3.0]);
487        let candidate = vector(&[4.0, 5.0, 6.0]);
488        let candidate_norm = dot(candidate.as_slice(), candidate.as_slice());
489
490        let scorer = VectorMetric::Cosine.bind_query(&query).unwrap();
491
492        assert_eq!(
493            scorer
494                .distance_with_candidate_squared_norm(&candidate, candidate_norm)
495                .unwrap(),
496            scorer.distance(&candidate).unwrap()
497        );
498    }
499
500    #[test]
501    fn bind_query_accepts_precomputed_query_norm() {
502        let query = vector(&[1.0, 2.0, 3.0]);
503        let candidate = vector(&[4.0, 5.0, 6.0]);
504        let query_norm = dot(query.as_slice(), query.as_slice());
505
506        let scorer = VectorMetric::Cosine
507            .bind_query_with_squared_norm(&query, query_norm)
508            .unwrap();
509
510        assert_eq!(
511            scorer.distance(&candidate).unwrap(),
512            VectorMetric::Cosine
513                .bind_query(&query)
514                .unwrap()
515                .distance(&candidate)
516                .unwrap()
517        );
518    }
519
520    #[test]
521    fn vector_squared_norm_matches_component_sum() {
522        let vector = vector(&[1.0, -2.0, 3.5]);
523
524        assert_eq!(vector_squared_norm(&vector), 17.25);
525    }
526
527    #[test]
528    fn bound_cosine_query_preserves_zero_norm_error_sides() {
529        let zero = vector(&[0.0, 0.0]);
530        let rhs = vector(&[1.0, 0.0]);
531
532        let error = VectorMetric::Cosine.bind_query(&zero).unwrap_err();
533        assert!(matches!(error, CoreError::VectorZeroNorm { side: "lhs" }));
534        let error = VectorMetric::Cosine
535            .bind_query_with_squared_norm(&rhs, 0.0)
536            .unwrap_err();
537        assert!(matches!(error, CoreError::VectorZeroNorm { side: "lhs" }));
538        let error = VectorMetric::Cosine
539            .bind_query_with_squared_norm(&rhs, f64::NAN)
540            .unwrap_err();
541        assert!(matches!(error, CoreError::VectorZeroNorm { side: "lhs" }));
542
543        let scorer = VectorMetric::Cosine.bind_query(&rhs).unwrap();
544        let error = scorer.distance(&zero).unwrap_err();
545        assert!(matches!(error, CoreError::VectorZeroNorm { side: "rhs" }));
546
547        let error = scorer
548            .distance_with_candidate_squared_norm(&rhs, 0.0)
549            .unwrap_err();
550        assert!(matches!(error, CoreError::VectorZeroNorm { side: "rhs" }));
551        let error = scorer
552            .distance_with_candidate_squared_norm(&rhs, -1.0)
553            .unwrap_err();
554        assert!(matches!(error, CoreError::VectorZeroNorm { side: "rhs" }));
555    }
556
557    #[test]
558    fn cosine_rejects_zero_norm_vectors() {
559        let zero = vector(&[0.0, 0.0]);
560        let rhs = vector(&[1.0, 0.0]);
561
562        let error = VectorMetric::Cosine.distance(&zero, &rhs).unwrap_err();
563        assert!(matches!(error, CoreError::VectorZeroNorm { side: "lhs" }));
564
565        let error = VectorMetric::Cosine.distance(&rhs, &zero).unwrap_err();
566        assert!(matches!(error, CoreError::VectorZeroNorm { side: "rhs" }));
567    }
568
569    #[test]
570    fn distance_rejects_dimension_mismatch() {
571        let lhs = vector(&[1.0, 2.0]);
572        let rhs = vector(&[1.0, 2.0, 3.0]);
573
574        let error = VectorMetric::SquaredEuclidean
575            .distance(&lhs, &rhs)
576            .unwrap_err();
577        assert!(matches!(
578            error,
579            CoreError::VectorDimensionMismatch { lhs: 2, rhs: 3 }
580        ));
581    }
582
583    #[test]
584    fn exact_top_k_returns_empty_for_zero_k() {
585        let query = vector(&[0.0]);
586        let candidate = vector(&[1.0]);
587        let candidates = [(7_u64, &candidate)];
588
589        let hits = exact_vector_top_k(VectorMetric::Cosine, &query, candidates, 0)
590            .expect("zero k does not inspect candidates");
591
592        assert!(hits.is_empty());
593    }
594
595    #[test]
596    fn vector_top_k_streams_and_orders_hits() {
597        let mut top_k = VectorTopK::new(2);
598        top_k.push_distance(3_u64, 0.25);
599        top_k.push_distance(1, 0.25);
600        top_k.push_distance(2, 0.5);
601        top_k.push_distance(4, 0.1);
602
603        assert_eq!(top_k.k(), 2);
604        assert_eq!(top_k.len(), 2);
605        assert_eq!(
606            top_k.into_hits(),
607            vec![
608                VectorSearchHit {
609                    key: 4,
610                    distance: 0.1
611                },
612                VectorSearchHit {
613                    key: 1,
614                    distance: 0.25
615                }
616            ]
617        );
618    }
619
620    #[test]
621    fn vector_top_k_zero_k_retains_nothing() {
622        let mut top_k = VectorTopK::new(0);
623        top_k.push_distance(1_u64, 0.0);
624
625        assert!(top_k.is_empty());
626        assert!(top_k.into_hits().is_empty());
627    }
628
629    #[test]
630    fn exact_top_k_is_distance_then_key_ordered() {
631        let query = vector(&[0.0]);
632        let one = vector(&[1.0]);
633        let two = vector(&[2.0]);
634        let candidates = [(3_u64, &two), (2, &one), (1, &one)];
635
636        let hits = exact_vector_top_k(VectorMetric::SquaredEuclidean, &query, candidates, 2)
637            .expect("all dimensions match");
638
639        assert_eq!(
640            hits,
641            vec![
642                VectorSearchHit {
643                    key: 1,
644                    distance: 1.0
645                },
646                VectorSearchHit {
647                    key: 2,
648                    distance: 1.0
649                }
650            ]
651        );
652    }
653
654    #[test]
655    fn exact_top_k_surfaces_candidate_metric_errors() {
656        let query = vector(&[0.0]);
657        let candidate = vector(&[1.0, 2.0]);
658        let candidates = [(1_u64, &candidate)];
659
660        let error =
661            exact_vector_top_k(VectorMetric::SquaredEuclidean, &query, candidates, 10).unwrap_err();
662
663        assert!(matches!(
664            error,
665            CoreError::VectorDimensionMismatch { lhs: 1, rhs: 2 }
666        ));
667    }
668}