vector_index/metric.rs
1//! The [`Metric`] trait and built-in implementations.
2//!
3//! HNSW does not require a true mathematical metric — only a consistent
4//! ordering on pairwise distances. In practice we still call the trait
5//! `Metric` because every interesting impl (L², cosine-as-distance, SW₁)
6//! is at least a pseudometric, and the name is what users expect.
7//!
8//! # Implementing your own metric
9//!
10//! ```
11//! use vector_index::Metric;
12//!
13//! struct Hamming;
14//!
15//! impl Metric for Hamming {
16//! type Point = Vec<u8>;
17//!
18//! fn distance(&self, a: &Self::Point, b: &Self::Point) -> f32 {
19//! a.iter().zip(b).filter(|(x, y)| x != y).count() as f32
20//! }
21//!
22//! fn dim(&self, point: &Self::Point) -> usize {
23//! point.len()
24//! }
25//! }
26//! ```
27//!
28//! Implementations should be deterministic and total: `distance(a, b)` must
29//! return the same finite, non-negative `f32` for the same inputs. NaN or
30//! negative outputs will produce undefined ordering behavior in HNSW.
31
32/// A distance function over points of an associated type.
33///
34/// Metrics are passed by reference into search and insert paths, so they
35/// should be cheap to dereference. Stateful metrics (e.g. SW₁ with cached
36/// random projections) typically hold their state behind an `Arc` or
37/// inside the metric struct itself.
38///
39/// # Required properties
40///
41/// - **Determinism**: `distance(a, b)` returns the same value for the same
42/// inputs across calls.
43/// - **Non-negativity**: `distance(a, b) >= 0.0` for all valid inputs.
44/// - **Finiteness**: the returned value is never NaN or infinite.
45///
46/// HNSW assumes but does not verify these properties. Violations cause
47/// silent index corruption, not panics.
48///
49/// Symmetry (`d(a,b) == d(b,a)`) and the triangle inequality are *not*
50/// required by HNSW, though most useful metrics satisfy them.
51pub trait Metric: Send + Sync {
52 /// The point type this metric measures distances between.
53 type Point: Send + Sync;
54
55 /// Compute the distance between two points.
56 fn distance(&self, a: &Self::Point, b: &Self::Point) -> f32;
57
58 /// Return the number of dimensions / components of `point`.
59 ///
60 /// Used by [`HnswIndex`](crate::HnswIndex) to enforce dimension
61 /// consistency: the first inserted point sets the expected dimension,
62 /// and subsequent inserts with a mismatched dimension are rejected.
63 fn dim(&self, point: &Self::Point) -> usize;
64}
65
66// -----------------------------------------------------------------------------
67// Built-in metrics
68// -----------------------------------------------------------------------------
69
70/// Squared Euclidean distance over `Vec<f32>` / `[f32]` points.
71///
72/// Returns Σᵢ (aᵢ − bᵢ)² — *not* the square root. For HNSW this is
73/// equivalent (monotonic in the true L² distance) and avoids a sqrt
74/// per comparison.
75///
76/// Panics in debug builds if input lengths differ; in release builds,
77/// trailing elements of the longer slice are ignored.
78#[derive(Debug, Clone, Copy, Default)]
79pub struct L2;
80
81impl Metric for L2 {
82 type Point = Vec<f32>;
83
84 fn distance(&self, a: &Self::Point, b: &Self::Point) -> f32 {
85 a.iter()
86 .zip(b.iter())
87 .map(|(x, y)| {
88 let d = x - y;
89 d * d
90 })
91 .sum()
92 }
93
94 fn dim(&self, point: &Self::Point) -> usize {
95 point.len()
96 }
97}
98
99/// Cosine distance, defined as `1 − cosine_similarity(a, b)`.
100///
101/// Range is `[0, 2]`. Returns `1.0` if either vector has zero norm
102/// (chosen over NaN for HNSW compatibility).
103#[derive(Debug, Clone, Copy, Default)]
104pub struct Cosine;
105
106impl Metric for Cosine {
107 type Point = Vec<f32>;
108
109 fn distance(&self, a: &Self::Point, b: &Self::Point) -> f32 {
110 let mut dot = 0.0_f32;
111 let mut na = 0.0_f32;
112 let mut nb = 0.0_f32;
113 for (&x, &y) in a.iter().zip(b.iter()) {
114 dot += x * y;
115 na += x * x;
116 nb += y * y;
117 }
118 if na == 0.0 || nb == 0.0 {
119 return 1.0;
120 }
121 1.0 - dot / (na.sqrt() * nb.sqrt())
122 }
123
124 fn dim(&self, point: &Self::Point) -> usize {
125 point.len()
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use approx::assert_relative_eq;
133
134 #[test]
135 fn l2_known_values() {
136 let m = L2;
137 assert_relative_eq!(m.distance(&vec![0.0; 4], &vec![0.0; 4]), 0.0);
138 assert_relative_eq!(
139 m.distance(&vec![1.0, 0.0, 0.0], &vec![0.0, 1.0, 0.0]),
140 2.0 // squared distance: 1² + 1² + 0²
141 );
142 }
143
144 #[test]
145 fn cosine_known_values() {
146 let m = Cosine;
147 // identical direction → distance 0
148 assert_relative_eq!(
149 m.distance(&vec![1.0, 0.0], &vec![2.0, 0.0]),
150 0.0,
151 epsilon = 1e-6
152 );
153 // orthogonal → distance 1
154 assert_relative_eq!(
155 m.distance(&vec![1.0, 0.0], &vec![0.0, 1.0]),
156 1.0,
157 epsilon = 1e-6
158 );
159 // opposite direction → distance 2
160 assert_relative_eq!(
161 m.distance(&vec![1.0, 0.0], &vec![-1.0, 0.0]),
162 2.0,
163 epsilon = 1e-6
164 );
165 }
166
167 #[test]
168 fn cosine_zero_vector_does_not_nan() {
169 let m = Cosine;
170 let d = m.distance(&vec![0.0, 0.0], &vec![1.0, 1.0]);
171 assert!(
172 d.is_finite(),
173 "cosine of zero-vector should not produce NaN"
174 );
175 }
176}