Skip to main content

vector_index/
metric.rs

1//! The [`Metric`] trait and built-in implementations.
2//!
3//! HNSW does not require a true mathematical metric — only a consistent
4//! ordering on pairwise distances. In practice we still call the trait
5//! `Metric` because every interesting impl (L², cosine-as-distance, SW₁)
6//! is at least a pseudometric, and the name is what users expect.
7//!
8//! # Implementing your own metric
9//!
10//! ```
11//! use vector_index::Metric;
12//!
13//! struct Hamming;
14//!
15//! impl Metric for Hamming {
16//!     type Point = Vec<u8>;
17//!
18//!     fn distance(&self, a: &Self::Point, b: &Self::Point) -> f32 {
19//!         a.iter().zip(b).filter(|(x, y)| x != y).count() as f32
20//!     }
21//!
22//!     fn dim(&self, point: &Self::Point) -> usize {
23//!         point.len()
24//!     }
25//! }
26//! ```
27//!
28//! Implementations should be deterministic and total: `distance(a, b)` must
29//! return the same finite, non-negative `f32` for the same inputs. NaN or
30//! negative outputs will produce undefined ordering behavior in HNSW.
31
32/// A distance function over points of an associated type.
33///
34/// Metrics are passed by reference into search and insert paths, so they
35/// should be cheap to dereference. Stateful metrics (e.g. SW₁ with cached
36/// random projections) typically hold their state behind an `Arc` or
37/// inside the metric struct itself.
38///
39/// # Required properties
40///
41/// - **Determinism**: `distance(a, b)` returns the same value for the same
42///   inputs across calls.
43/// - **Non-negativity**: `distance(a, b) >= 0.0` for all valid inputs.
44/// - **Finiteness**: the returned value is never NaN or infinite.
45///
46/// HNSW assumes but does not verify these properties. Violations cause
47/// silent index corruption, not panics.
48///
49/// Symmetry (`d(a,b) == d(b,a)`) and the triangle inequality are *not*
50/// required by HNSW, though most useful metrics satisfy them.
51pub trait Metric: Send + Sync {
52    /// The point type this metric measures distances between.
53    type Point: Send + Sync;
54
55    /// Compute the distance between two points.
56    fn distance(&self, a: &Self::Point, b: &Self::Point) -> f32;
57
58    /// Return the number of dimensions / components of `point`.
59    ///
60    /// Used by [`HnswIndex`](crate::HnswIndex) to enforce dimension
61    /// consistency: the first inserted point sets the expected dimension,
62    /// and subsequent inserts with a mismatched dimension are rejected.
63    fn dim(&self, point: &Self::Point) -> usize;
64}
65
66// -----------------------------------------------------------------------------
67// Built-in metrics
68// -----------------------------------------------------------------------------
69
70/// Squared Euclidean distance over `Vec<f32>` / `[f32]` points.
71///
72/// Returns Σᵢ (aᵢ − bᵢ)² — *not* the square root. For HNSW this is
73/// equivalent (monotonic in the true L² distance) and avoids a sqrt
74/// per comparison.
75///
76/// Panics in debug builds if input lengths differ; in release builds,
77/// trailing elements of the longer slice are ignored.
78#[derive(Debug, Clone, Copy, Default)]
79pub struct L2;
80
81impl Metric for L2 {
82    type Point = Vec<f32>;
83
84    fn distance(&self, a: &Self::Point, b: &Self::Point) -> f32 {
85        a.iter()
86            .zip(b.iter())
87            .map(|(x, y)| {
88                let d = x - y;
89                d * d
90            })
91            .sum()
92    }
93
94    fn dim(&self, point: &Self::Point) -> usize {
95        point.len()
96    }
97}
98
99/// Cosine distance, defined as `1 − cosine_similarity(a, b)`.
100///
101/// Range is `[0, 2]`. Returns `1.0` if either vector has zero norm
102/// (chosen over NaN for HNSW compatibility).
103#[derive(Debug, Clone, Copy, Default)]
104pub struct Cosine;
105
106impl Metric for Cosine {
107    type Point = Vec<f32>;
108
109    fn distance(&self, a: &Self::Point, b: &Self::Point) -> f32 {
110        let mut dot = 0.0_f32;
111        let mut na = 0.0_f32;
112        let mut nb = 0.0_f32;
113        for (&x, &y) in a.iter().zip(b.iter()) {
114            dot += x * y;
115            na += x * x;
116            nb += y * y;
117        }
118        if na == 0.0 || nb == 0.0 {
119            return 1.0;
120        }
121        1.0 - dot / (na.sqrt() * nb.sqrt())
122    }
123
124    fn dim(&self, point: &Self::Point) -> usize {
125        point.len()
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use approx::assert_relative_eq;
133
134    #[test]
135    fn l2_known_values() {
136        let m = L2;
137        assert_relative_eq!(m.distance(&vec![0.0; 4], &vec![0.0; 4]), 0.0);
138        assert_relative_eq!(
139            m.distance(&vec![1.0, 0.0, 0.0], &vec![0.0, 1.0, 0.0]),
140            2.0 // squared distance: 1² + 1² + 0²
141        );
142    }
143
144    #[test]
145    fn cosine_known_values() {
146        let m = Cosine;
147        // identical direction → distance 0
148        assert_relative_eq!(
149            m.distance(&vec![1.0, 0.0], &vec![2.0, 0.0]),
150            0.0,
151            epsilon = 1e-6
152        );
153        // orthogonal → distance 1
154        assert_relative_eq!(
155            m.distance(&vec![1.0, 0.0], &vec![0.0, 1.0]),
156            1.0,
157            epsilon = 1e-6
158        );
159        // opposite direction → distance 2
160        assert_relative_eq!(
161            m.distance(&vec![1.0, 0.0], &vec![-1.0, 0.0]),
162            2.0,
163            epsilon = 1e-6
164        );
165    }
166
167    #[test]
168    fn cosine_zero_vector_does_not_nan() {
169        let m = Cosine;
170        let d = m.distance(&vec![0.0, 0.0], &vec![1.0, 1.0]);
171        assert!(
172            d.is_finite(),
173            "cosine of zero-vector should not produce NaN"
174        );
175    }
176}