Skip to main content

sqlite_vector_rs/
distance.rs

1use std::fmt;
2
3use bytemuck::cast_slice;
4use half::f16;
5
6use crate::types::VectorType;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum DistanceMetric {
10    L2,
11    Cosine,
12    InnerProduct,
13}
14
15impl DistanceMetric {
16    pub fn from_name(name: &str) -> Result<Self, DistanceError> {
17        match name {
18            "l2" => Ok(Self::L2),
19            "cosine" => Ok(Self::Cosine),
20            "ip" => Ok(Self::InnerProduct),
21            other => Err(DistanceError::UnknownMetric(other.to_string())),
22        }
23    }
24
25    pub fn name(&self) -> &'static str {
26        match self {
27            Self::L2 => "l2",
28            Self::Cosine => "cosine",
29            Self::InnerProduct => "ip",
30        }
31    }
32
33    /// Convert to usearch MetricKind.
34    pub fn to_usearch(&self) -> usearch::MetricKind {
35        match self {
36            Self::L2 => usearch::MetricKind::L2sq,
37            Self::Cosine => usearch::MetricKind::Cos,
38            Self::InnerProduct => usearch::MetricKind::IP,
39        }
40    }
41}
42
43#[derive(Debug)]
44pub enum DistanceError {
45    UnknownMetric(String),
46    DimensionMismatch,
47    Usearch(String),
48}
49
50impl fmt::Display for DistanceError {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        match self {
53            Self::UnknownMetric(name) => write!(f, "unknown metric: {name}"),
54            Self::DimensionMismatch => write!(f, "vector dimensions do not match"),
55            Self::Usearch(e) => write!(f, "usearch error: {e}"),
56        }
57    }
58}
59
60impl std::error::Error for DistanceError {}
61
62/// Compute distance between two vector blobs.
63///
64/// Both blobs must be the same type and dimension. For int2/int4 types,
65/// values are cast to f32 before computation since usearch only supports
66/// f32, f64, f16, and i8 natively.
67pub fn compute_distance(
68    a: &[u8],
69    b: &[u8],
70    vtype: VectorType,
71    metric: DistanceMetric,
72    dim: usize,
73) -> Result<f64, DistanceError> {
74    let expected_size = vtype.blob_size(dim);
75    if a.len() != expected_size || b.len() != expected_size {
76        return Err(DistanceError::DimensionMismatch);
77    }
78
79    match vtype {
80        VectorType::Float4 => {
81            let va: &[f32] = cast_slice(a);
82            let vb: &[f32] = cast_slice(b);
83            Ok(scalar_distance(va, vb, metric))
84        }
85        VectorType::Float8 => {
86            let va: &[f64] = cast_slice(a);
87            let vb: &[f64] = cast_slice(b);
88            Ok(scalar_distance_f64(va, vb, metric))
89        }
90        VectorType::Float2 => {
91            let va: &[f16] = cast_slice(a);
92            let vb: &[f16] = cast_slice(b);
93            let fa: Vec<f32> = va.iter().map(|v| v.to_f32()).collect();
94            let fb: Vec<f32> = vb.iter().map(|v| v.to_f32()).collect();
95            Ok(scalar_distance(&fa, &fb, metric))
96        }
97        VectorType::Int1 => {
98            let va: &[i8] = cast_slice(a);
99            let vb: &[i8] = cast_slice(b);
100            let fa: Vec<f32> = va.iter().map(|v| *v as f32).collect();
101            let fb: Vec<f32> = vb.iter().map(|v| *v as f32).collect();
102            Ok(scalar_distance(&fa, &fb, metric))
103        }
104        VectorType::Int2 => {
105            let va: &[i16] = cast_slice(a);
106            let vb: &[i16] = cast_slice(b);
107            let fa: Vec<f32> = va.iter().map(|v| *v as f32).collect();
108            let fb: Vec<f32> = vb.iter().map(|v| *v as f32).collect();
109            Ok(scalar_distance(&fa, &fb, metric))
110        }
111        VectorType::Int4 => {
112            let va: &[i32] = cast_slice(a);
113            let vb: &[i32] = cast_slice(b);
114            let fa: Vec<f32> = va.iter().map(|v| *v as f32).collect();
115            let fb: Vec<f32> = vb.iter().map(|v| *v as f32).collect();
116            Ok(scalar_distance(&fa, &fb, metric))
117        }
118    }
119}
120
121fn scalar_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f64 {
122    match metric {
123        DistanceMetric::L2 => a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum::<f32>() as f64,
124        DistanceMetric::Cosine => {
125            let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
126            let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
127            let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
128            let denom = norm_a * norm_b;
129            if denom == 0.0 {
130                1.0
131            } else {
132                1.0 - (dot / denom) as f64
133            }
134        }
135        DistanceMetric::InnerProduct => {
136            let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
137            -(dot as f64)
138        }
139    }
140}
141
142fn scalar_distance_f64(a: &[f64], b: &[f64], metric: DistanceMetric) -> f64 {
143    match metric {
144        DistanceMetric::L2 => a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum(),
145        DistanceMetric::Cosine => {
146            let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
147            let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
148            let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
149            let denom = norm_a * norm_b;
150            if denom == 0.0 {
151                1.0
152            } else {
153                1.0 - (dot / denom)
154            }
155        }
156        DistanceMetric::InnerProduct => {
157            let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
158            -dot
159        }
160    }
161}
162
163/// Map VectorType to usearch ScalarKind for index creation.
164pub fn vtype_to_scalar_kind(vtype: VectorType) -> usearch::ScalarKind {
165    match vtype {
166        VectorType::Float2 => usearch::ScalarKind::F16,
167        VectorType::Float4 => usearch::ScalarKind::F32,
168        VectorType::Float8 => usearch::ScalarKind::F64,
169        VectorType::Int1 => usearch::ScalarKind::I8,
170        // i16/i32 not natively supported by usearch, quantize to f32
171        VectorType::Int2 | VectorType::Int4 => usearch::ScalarKind::F32,
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use bytemuck::cast_slice;
179
180    // ----------------------------------------------------------------
181    // Helpers
182    // ----------------------------------------------------------------
183
184    fn f32_blob(values: &[f32]) -> Vec<u8> {
185        cast_slice(values).to_vec()
186    }
187
188    fn f64_blob(values: &[f64]) -> Vec<u8> {
189        cast_slice(values).to_vec()
190    }
191
192    fn i32_blob(values: &[i32]) -> Vec<u8> {
193        cast_slice(values).to_vec()
194    }
195
196    fn i8_blob(values: &[i8]) -> Vec<u8> {
197        cast_slice(values).to_vec()
198    }
199
200    fn f16_blob(values: &[half::f16]) -> Vec<u8> {
201        cast_slice(values).to_vec()
202    }
203
204    /// Assert two f64 values are within `eps` of each other.
205    fn assert_approx(actual: f64, expected: f64, eps: f64) {
206        assert!(
207            (actual - expected).abs() < eps,
208            "expected {expected} ± {eps}, got {actual}"
209        );
210    }
211
212    // ----------------------------------------------------------------
213    // DistanceMetric::from_name
214    // ----------------------------------------------------------------
215
216    #[test]
217    fn from_name_valid_l2() {
218        assert_eq!(DistanceMetric::from_name("l2").unwrap(), DistanceMetric::L2);
219    }
220
221    #[test]
222    fn from_name_valid_cosine() {
223        assert_eq!(
224            DistanceMetric::from_name("cosine").unwrap(),
225            DistanceMetric::Cosine
226        );
227    }
228
229    #[test]
230    fn from_name_valid_ip() {
231        assert_eq!(
232            DistanceMetric::from_name("ip").unwrap(),
233            DistanceMetric::InnerProduct
234        );
235    }
236
237    #[test]
238    fn from_name_unknown_returns_error() {
239        let err = DistanceMetric::from_name("manhattan").unwrap_err();
240        assert!(
241            matches!(err, DistanceError::UnknownMetric(ref s) if s == "manhattan"),
242            "unexpected error variant: {err}"
243        );
244    }
245
246    #[test]
247    fn from_name_empty_string_returns_error() {
248        assert!(DistanceMetric::from_name("").is_err());
249    }
250
251    #[test]
252    fn from_name_case_sensitive() {
253        // The matcher is lowercase-only; uppercase variants must be rejected.
254        assert!(DistanceMetric::from_name("L2").is_err());
255        assert!(DistanceMetric::from_name("Cosine").is_err());
256        assert!(DistanceMetric::from_name("IP").is_err());
257    }
258
259    // ----------------------------------------------------------------
260    // DistanceMetric::name — round-trip with from_name
261    // ----------------------------------------------------------------
262
263    #[test]
264    fn name_round_trips_with_from_name() {
265        let variants = [
266            DistanceMetric::L2,
267            DistanceMetric::Cosine,
268            DistanceMetric::InnerProduct,
269        ];
270        for metric in variants {
271            assert_eq!(
272                DistanceMetric::from_name(metric.name()).unwrap(),
273                metric,
274                "round-trip failed for {:?}",
275                metric
276            );
277        }
278    }
279
280    // ----------------------------------------------------------------
281    // DistanceMetric::to_usearch — MetricKind mapping
282    // ----------------------------------------------------------------
283
284    #[test]
285    fn to_usearch_l2_maps_to_l2sq() {
286        assert_eq!(DistanceMetric::L2.to_usearch(), usearch::MetricKind::L2sq);
287    }
288
289    #[test]
290    fn to_usearch_cosine_maps_to_cos() {
291        assert_eq!(
292            DistanceMetric::Cosine.to_usearch(),
293            usearch::MetricKind::Cos
294        );
295    }
296
297    #[test]
298    fn to_usearch_ip_maps_to_ip() {
299        assert_eq!(
300            DistanceMetric::InnerProduct.to_usearch(),
301            usearch::MetricKind::IP
302        );
303    }
304
305    // ----------------------------------------------------------------
306    // vtype_to_scalar_kind — all six types
307    // ----------------------------------------------------------------
308
309    #[test]
310    fn vtype_to_scalar_kind_float2_is_f16() {
311        assert_eq!(
312            vtype_to_scalar_kind(VectorType::Float2),
313            usearch::ScalarKind::F16
314        );
315    }
316
317    #[test]
318    fn vtype_to_scalar_kind_float4_is_f32() {
319        assert_eq!(
320            vtype_to_scalar_kind(VectorType::Float4),
321            usearch::ScalarKind::F32
322        );
323    }
324
325    #[test]
326    fn vtype_to_scalar_kind_float8_is_f64() {
327        assert_eq!(
328            vtype_to_scalar_kind(VectorType::Float8),
329            usearch::ScalarKind::F64
330        );
331    }
332
333    #[test]
334    fn vtype_to_scalar_kind_int1_is_i8() {
335        assert_eq!(
336            vtype_to_scalar_kind(VectorType::Int1),
337            usearch::ScalarKind::I8
338        );
339    }
340
341    #[test]
342    fn vtype_to_scalar_kind_int2_quantizes_to_f32() {
343        // i16 is not natively supported by usearch; it is quantized to f32.
344        assert_eq!(
345            vtype_to_scalar_kind(VectorType::Int2),
346            usearch::ScalarKind::F32
347        );
348    }
349
350    #[test]
351    fn vtype_to_scalar_kind_int4_quantizes_to_f32() {
352        // i32 is not natively supported by usearch; it is quantized to f32.
353        assert_eq!(
354            vtype_to_scalar_kind(VectorType::Int4),
355            usearch::ScalarKind::F32
356        );
357    }
358
359    // ----------------------------------------------------------------
360    // compute_distance — dimension mismatch guard
361    // ----------------------------------------------------------------
362
363    #[test]
364    fn compute_distance_dimension_mismatch_returns_error() {
365        let a = f32_blob(&[1.0, 0.0, 0.0]);
366        let b = f32_blob(&[1.0, 0.0]); // wrong length
367        let err = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::L2, 3).unwrap_err();
368        assert!(
369            matches!(err, DistanceError::DimensionMismatch),
370            "expected DimensionMismatch, got {err}"
371        );
372    }
373
374    // ----------------------------------------------------------------
375    // compute_distance — Float4 / L2
376    // ----------------------------------------------------------------
377
378    #[test]
379    fn float4_l2_identical_vectors_is_zero() {
380        let v = f32_blob(&[1.0, 2.0, 3.0]);
381        let d = compute_distance(&v, &v, VectorType::Float4, DistanceMetric::L2, 3).unwrap();
382        assert_approx(d, 0.0, 1e-10);
383    }
384
385    #[test]
386    fn float4_l2_orthogonal_unit_vectors_is_two() {
387        // [1,0,0] vs [0,1,0]: squared L2 = (1-0)² + (0-1)² + (0-0)² = 2.0
388        let a = f32_blob(&[1.0, 0.0, 0.0]);
389        let b = f32_blob(&[0.0, 1.0, 0.0]);
390        let d = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::L2, 3).unwrap();
391        assert_approx(d, 2.0, 1e-6);
392    }
393
394    #[test]
395    fn float4_l2_known_distance() {
396        // [3, 4] vs [0, 0]: squared L2 = 9 + 16 = 25
397        let a = f32_blob(&[3.0, 4.0]);
398        let b = f32_blob(&[0.0, 0.0]);
399        let d = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::L2, 2).unwrap();
400        assert_approx(d, 25.0, 1e-5);
401    }
402
403    // ----------------------------------------------------------------
404    // compute_distance — Float4 / Cosine
405    // ----------------------------------------------------------------
406
407    #[test]
408    fn float4_cosine_identical_vectors_is_zero() {
409        let v = f32_blob(&[1.0, 2.0, 3.0]);
410        let d = compute_distance(&v, &v, VectorType::Float4, DistanceMetric::Cosine, 3).unwrap();
411        assert_approx(d, 0.0, 1e-6);
412    }
413
414    #[test]
415    fn float4_cosine_orthogonal_vectors_is_one() {
416        // [1,0] and [0,1] have zero dot product → cosine distance = 1.
417        let a = f32_blob(&[1.0, 0.0]);
418        let b = f32_blob(&[0.0, 1.0]);
419        let d = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::Cosine, 2).unwrap();
420        assert_approx(d, 1.0, 1e-6);
421    }
422
423    #[test]
424    fn float4_cosine_antiparallel_vectors_is_two() {
425        // [1,0] and [-1,0] are antiparallel → cosine distance = 2.
426        let a = f32_blob(&[1.0, 0.0]);
427        let b = f32_blob(&[-1.0, 0.0]);
428        let d = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::Cosine, 2).unwrap();
429        assert_approx(d, 2.0, 1e-6);
430    }
431
432    #[test]
433    fn float4_cosine_zero_vector_returns_one() {
434        // When the denominator is zero the implementation returns 1.0.
435        let a = f32_blob(&[0.0, 0.0, 0.0]);
436        let b = f32_blob(&[0.0, 0.0, 0.0]);
437        let d = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::Cosine, 3).unwrap();
438        assert_approx(d, 1.0, 1e-10);
439    }
440
441    // ----------------------------------------------------------------
442    // compute_distance — Float4 / InnerProduct
443    // ----------------------------------------------------------------
444
445    #[test]
446    fn float4_ip_unit_vectors_dot_product() {
447        // [1,0,0]·[0,0,1] = 0  →  -(0) = 0.0
448        let a = f32_blob(&[1.0, 0.0, 0.0]);
449        let b = f32_blob(&[0.0, 0.0, 1.0]);
450        let d =
451            compute_distance(&a, &b, VectorType::Float4, DistanceMetric::InnerProduct, 3).unwrap();
452        assert_approx(d, 0.0, 1e-6);
453    }
454
455    #[test]
456    fn float4_ip_parallel_unit_vectors() {
457        // [1,0]·[1,0] = 1  →  -(1) = -1.0  (higher similarity = more negative)
458        let a = f32_blob(&[1.0, 0.0]);
459        let b = f32_blob(&[1.0, 0.0]);
460        let d =
461            compute_distance(&a, &b, VectorType::Float4, DistanceMetric::InnerProduct, 2).unwrap();
462        assert_approx(d, -1.0, 1e-6);
463    }
464
465    #[test]
466    fn float4_ip_known_value() {
467        // [1,2]·[3,4] = 11  →  -(11) = -11.0
468        let a = f32_blob(&[1.0, 2.0]);
469        let b = f32_blob(&[3.0, 4.0]);
470        let d =
471            compute_distance(&a, &b, VectorType::Float4, DistanceMetric::InnerProduct, 2).unwrap();
472        assert_approx(d, -11.0, 1e-5);
473    }
474
475    // ----------------------------------------------------------------
476    // compute_distance — Float8
477    // ----------------------------------------------------------------
478
479    #[test]
480    fn float8_l2_identical_vectors_is_zero() {
481        let v = f64_blob(&[1.0, 2.0, 3.0]);
482        let d = compute_distance(&v, &v, VectorType::Float8, DistanceMetric::L2, 3).unwrap();
483        assert_approx(d, 0.0, 1e-15);
484    }
485
486    #[test]
487    fn float8_l2_known_distance() {
488        // [1, 1] vs [4, 5]: (1-4)² + (1-5)² = 9 + 16 = 25
489        let a = f64_blob(&[1.0, 1.0]);
490        let b = f64_blob(&[4.0, 5.0]);
491        let d = compute_distance(&a, &b, VectorType::Float8, DistanceMetric::L2, 2).unwrap();
492        assert_approx(d, 25.0, 1e-12);
493    }
494
495    #[test]
496    fn float8_cosine_orthogonal_is_one() {
497        let a = f64_blob(&[1.0, 0.0]);
498        let b = f64_blob(&[0.0, 1.0]);
499        let d = compute_distance(&a, &b, VectorType::Float8, DistanceMetric::Cosine, 2).unwrap();
500        assert_approx(d, 1.0, 1e-14);
501    }
502
503    #[test]
504    fn float8_cosine_zero_vector_returns_one() {
505        let a = f64_blob(&[0.0, 0.0]);
506        let b = f64_blob(&[0.0, 0.0]);
507        let d = compute_distance(&a, &b, VectorType::Float8, DistanceMetric::Cosine, 2).unwrap();
508        assert_approx(d, 1.0, 1e-15);
509    }
510
511    #[test]
512    fn float8_ip_known_value() {
513        // [2,3]·[4,5] = 23  →  -(23) = -23.0
514        let a = f64_blob(&[2.0, 3.0]);
515        let b = f64_blob(&[4.0, 5.0]);
516        let d =
517            compute_distance(&a, &b, VectorType::Float8, DistanceMetric::InnerProduct, 2).unwrap();
518        assert_approx(d, -23.0, 1e-12);
519    }
520
521    // ----------------------------------------------------------------
522    // compute_distance — Int4
523    // ----------------------------------------------------------------
524
525    #[test]
526    fn int4_l2_identical_vectors_is_zero() {
527        let v = i32_blob(&[10, -5, 3]);
528        let d = compute_distance(&v, &v, VectorType::Int4, DistanceMetric::L2, 3).unwrap();
529        assert_approx(d, 0.0, 1e-10);
530    }
531
532    #[test]
533    fn int4_l2_known_distance() {
534        // [0, 0] vs [3, 4]: (3-0)² + (4-0)² = 9 + 16 = 25, cast to f32 first
535        let a = i32_blob(&[0, 0]);
536        let b = i32_blob(&[3, 4]);
537        let d = compute_distance(&a, &b, VectorType::Int4, DistanceMetric::L2, 2).unwrap();
538        assert_approx(d, 25.0, 1e-5);
539    }
540
541    #[test]
542    fn int4_cosine_orthogonal_is_one() {
543        let a = i32_blob(&[1, 0]);
544        let b = i32_blob(&[0, 1]);
545        let d = compute_distance(&a, &b, VectorType::Int4, DistanceMetric::Cosine, 2).unwrap();
546        assert_approx(d, 1.0, 1e-6);
547    }
548
549    #[test]
550    fn int4_ip_known_value() {
551        // [1, 2]·[3, 4] = 11  →  -(11) = -11.0  (int4 is cast to f32)
552        let a = i32_blob(&[1, 2]);
553        let b = i32_blob(&[3, 4]);
554        let d =
555            compute_distance(&a, &b, VectorType::Int4, DistanceMetric::InnerProduct, 2).unwrap();
556        assert_approx(d, -11.0, 1e-5);
557    }
558
559    // ----------------------------------------------------------------
560    // compute_distance — Int1 (i8)
561    // ----------------------------------------------------------------
562
563    #[test]
564    fn int1_l2_known_distance() {
565        // [3, 4] vs [0, 0] cast to f32: 9 + 16 = 25
566        let a = i8_blob(&[3, 4]);
567        let b = i8_blob(&[0, 0]);
568        let d = compute_distance(&a, &b, VectorType::Int1, DistanceMetric::L2, 2).unwrap();
569        assert_approx(d, 25.0, 1e-5);
570    }
571
572    // ----------------------------------------------------------------
573    // compute_distance — Float2 (f16)
574    // ----------------------------------------------------------------
575
576    #[test]
577    fn float2_cosine_orthogonal_is_one() {
578        let a = f16_blob(&[half::f16::from_f32(1.0), half::f16::from_f32(0.0)]);
579        let b = f16_blob(&[half::f16::from_f32(0.0), half::f16::from_f32(1.0)]);
580        let d = compute_distance(&a, &b, VectorType::Float2, DistanceMetric::Cosine, 2).unwrap();
581        // f16 has limited precision; use a looser tolerance.
582        assert_approx(d, 1.0, 1e-3);
583    }
584
585    #[test]
586    fn float2_l2_identical_vectors_is_zero() {
587        let v = f16_blob(&[
588            half::f16::from_f32(1.0),
589            half::f16::from_f32(-2.0),
590            half::f16::from_f32(0.5),
591        ]);
592        let d = compute_distance(&v, &v, VectorType::Float2, DistanceMetric::L2, 3).unwrap();
593        assert_approx(d, 0.0, 1e-6);
594    }
595}