1use std::fmt;
2
3use bytemuck::cast_slice;
4use half::f16;
5
6use crate::types::VectorType;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum DistanceMetric {
10 L2,
11 Cosine,
12 InnerProduct,
13}
14
15impl DistanceMetric {
16 pub fn from_name(name: &str) -> Result<Self, DistanceError> {
17 match name {
18 "l2" => Ok(Self::L2),
19 "cosine" => Ok(Self::Cosine),
20 "ip" => Ok(Self::InnerProduct),
21 other => Err(DistanceError::UnknownMetric(other.to_string())),
22 }
23 }
24
25 pub fn name(&self) -> &'static str {
26 match self {
27 Self::L2 => "l2",
28 Self::Cosine => "cosine",
29 Self::InnerProduct => "ip",
30 }
31 }
32
33 pub fn to_usearch(&self) -> usearch::MetricKind {
35 match self {
36 Self::L2 => usearch::MetricKind::L2sq,
37 Self::Cosine => usearch::MetricKind::Cos,
38 Self::InnerProduct => usearch::MetricKind::IP,
39 }
40 }
41}
42
43#[derive(Debug)]
44pub enum DistanceError {
45 UnknownMetric(String),
46 DimensionMismatch,
47 Usearch(String),
48}
49
50impl fmt::Display for DistanceError {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 match self {
53 Self::UnknownMetric(name) => write!(f, "unknown metric: {name}"),
54 Self::DimensionMismatch => write!(f, "vector dimensions do not match"),
55 Self::Usearch(e) => write!(f, "usearch error: {e}"),
56 }
57 }
58}
59
60impl std::error::Error for DistanceError {}
61
62pub fn compute_distance(
68 a: &[u8],
69 b: &[u8],
70 vtype: VectorType,
71 metric: DistanceMetric,
72 dim: usize,
73) -> Result<f64, DistanceError> {
74 let expected_size = vtype.blob_size(dim);
75 if a.len() != expected_size || b.len() != expected_size {
76 return Err(DistanceError::DimensionMismatch);
77 }
78
79 match vtype {
80 VectorType::Float4 => {
81 let va: &[f32] = cast_slice(a);
82 let vb: &[f32] = cast_slice(b);
83 Ok(scalar_distance(va, vb, metric))
84 }
85 VectorType::Float8 => {
86 let va: &[f64] = cast_slice(a);
87 let vb: &[f64] = cast_slice(b);
88 Ok(scalar_distance_f64(va, vb, metric))
89 }
90 VectorType::Float2 => {
91 let va: &[f16] = cast_slice(a);
92 let vb: &[f16] = cast_slice(b);
93 let fa: Vec<f32> = va.iter().map(|v| v.to_f32()).collect();
94 let fb: Vec<f32> = vb.iter().map(|v| v.to_f32()).collect();
95 Ok(scalar_distance(&fa, &fb, metric))
96 }
97 VectorType::Int1 => {
98 let va: &[i8] = cast_slice(a);
99 let vb: &[i8] = cast_slice(b);
100 let fa: Vec<f32> = va.iter().map(|v| *v as f32).collect();
101 let fb: Vec<f32> = vb.iter().map(|v| *v as f32).collect();
102 Ok(scalar_distance(&fa, &fb, metric))
103 }
104 VectorType::Int2 => {
105 let va: &[i16] = cast_slice(a);
106 let vb: &[i16] = cast_slice(b);
107 let fa: Vec<f32> = va.iter().map(|v| *v as f32).collect();
108 let fb: Vec<f32> = vb.iter().map(|v| *v as f32).collect();
109 Ok(scalar_distance(&fa, &fb, metric))
110 }
111 VectorType::Int4 => {
112 let va: &[i32] = cast_slice(a);
113 let vb: &[i32] = cast_slice(b);
114 let fa: Vec<f32> = va.iter().map(|v| *v as f32).collect();
115 let fb: Vec<f32> = vb.iter().map(|v| *v as f32).collect();
116 Ok(scalar_distance(&fa, &fb, metric))
117 }
118 }
119}
120
121fn scalar_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f64 {
122 match metric {
123 DistanceMetric::L2 => a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum::<f32>() as f64,
124 DistanceMetric::Cosine => {
125 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
126 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
127 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
128 let denom = norm_a * norm_b;
129 if denom == 0.0 {
130 1.0
131 } else {
132 1.0 - (dot / denom) as f64
133 }
134 }
135 DistanceMetric::InnerProduct => {
136 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
137 -(dot as f64)
138 }
139 }
140}
141
142fn scalar_distance_f64(a: &[f64], b: &[f64], metric: DistanceMetric) -> f64 {
143 match metric {
144 DistanceMetric::L2 => a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum(),
145 DistanceMetric::Cosine => {
146 let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
147 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
148 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
149 let denom = norm_a * norm_b;
150 if denom == 0.0 {
151 1.0
152 } else {
153 1.0 - (dot / denom)
154 }
155 }
156 DistanceMetric::InnerProduct => {
157 let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
158 -dot
159 }
160 }
161}
162
163pub fn vtype_to_scalar_kind(vtype: VectorType) -> usearch::ScalarKind {
165 match vtype {
166 VectorType::Float2 => usearch::ScalarKind::F16,
167 VectorType::Float4 => usearch::ScalarKind::F32,
168 VectorType::Float8 => usearch::ScalarKind::F64,
169 VectorType::Int1 => usearch::ScalarKind::I8,
170 VectorType::Int2 | VectorType::Int4 => usearch::ScalarKind::F32,
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use bytemuck::cast_slice;
179
180 fn f32_blob(values: &[f32]) -> Vec<u8> {
185 cast_slice(values).to_vec()
186 }
187
188 fn f64_blob(values: &[f64]) -> Vec<u8> {
189 cast_slice(values).to_vec()
190 }
191
192 fn i32_blob(values: &[i32]) -> Vec<u8> {
193 cast_slice(values).to_vec()
194 }
195
196 fn i8_blob(values: &[i8]) -> Vec<u8> {
197 cast_slice(values).to_vec()
198 }
199
200 fn f16_blob(values: &[half::f16]) -> Vec<u8> {
201 cast_slice(values).to_vec()
202 }
203
204 fn assert_approx(actual: f64, expected: f64, eps: f64) {
206 assert!(
207 (actual - expected).abs() < eps,
208 "expected {expected} ± {eps}, got {actual}"
209 );
210 }
211
212 #[test]
217 fn from_name_valid_l2() {
218 assert_eq!(DistanceMetric::from_name("l2").unwrap(), DistanceMetric::L2);
219 }
220
221 #[test]
222 fn from_name_valid_cosine() {
223 assert_eq!(
224 DistanceMetric::from_name("cosine").unwrap(),
225 DistanceMetric::Cosine
226 );
227 }
228
229 #[test]
230 fn from_name_valid_ip() {
231 assert_eq!(
232 DistanceMetric::from_name("ip").unwrap(),
233 DistanceMetric::InnerProduct
234 );
235 }
236
237 #[test]
238 fn from_name_unknown_returns_error() {
239 let err = DistanceMetric::from_name("manhattan").unwrap_err();
240 assert!(
241 matches!(err, DistanceError::UnknownMetric(ref s) if s == "manhattan"),
242 "unexpected error variant: {err}"
243 );
244 }
245
246 #[test]
247 fn from_name_empty_string_returns_error() {
248 assert!(DistanceMetric::from_name("").is_err());
249 }
250
251 #[test]
252 fn from_name_case_sensitive() {
253 assert!(DistanceMetric::from_name("L2").is_err());
255 assert!(DistanceMetric::from_name("Cosine").is_err());
256 assert!(DistanceMetric::from_name("IP").is_err());
257 }
258
259 #[test]
264 fn name_round_trips_with_from_name() {
265 let variants = [
266 DistanceMetric::L2,
267 DistanceMetric::Cosine,
268 DistanceMetric::InnerProduct,
269 ];
270 for metric in variants {
271 assert_eq!(
272 DistanceMetric::from_name(metric.name()).unwrap(),
273 metric,
274 "round-trip failed for {:?}",
275 metric
276 );
277 }
278 }
279
280 #[test]
285 fn to_usearch_l2_maps_to_l2sq() {
286 assert_eq!(DistanceMetric::L2.to_usearch(), usearch::MetricKind::L2sq);
287 }
288
289 #[test]
290 fn to_usearch_cosine_maps_to_cos() {
291 assert_eq!(
292 DistanceMetric::Cosine.to_usearch(),
293 usearch::MetricKind::Cos
294 );
295 }
296
297 #[test]
298 fn to_usearch_ip_maps_to_ip() {
299 assert_eq!(
300 DistanceMetric::InnerProduct.to_usearch(),
301 usearch::MetricKind::IP
302 );
303 }
304
305 #[test]
310 fn vtype_to_scalar_kind_float2_is_f16() {
311 assert_eq!(
312 vtype_to_scalar_kind(VectorType::Float2),
313 usearch::ScalarKind::F16
314 );
315 }
316
317 #[test]
318 fn vtype_to_scalar_kind_float4_is_f32() {
319 assert_eq!(
320 vtype_to_scalar_kind(VectorType::Float4),
321 usearch::ScalarKind::F32
322 );
323 }
324
325 #[test]
326 fn vtype_to_scalar_kind_float8_is_f64() {
327 assert_eq!(
328 vtype_to_scalar_kind(VectorType::Float8),
329 usearch::ScalarKind::F64
330 );
331 }
332
333 #[test]
334 fn vtype_to_scalar_kind_int1_is_i8() {
335 assert_eq!(
336 vtype_to_scalar_kind(VectorType::Int1),
337 usearch::ScalarKind::I8
338 );
339 }
340
341 #[test]
342 fn vtype_to_scalar_kind_int2_quantizes_to_f32() {
343 assert_eq!(
345 vtype_to_scalar_kind(VectorType::Int2),
346 usearch::ScalarKind::F32
347 );
348 }
349
350 #[test]
351 fn vtype_to_scalar_kind_int4_quantizes_to_f32() {
352 assert_eq!(
354 vtype_to_scalar_kind(VectorType::Int4),
355 usearch::ScalarKind::F32
356 );
357 }
358
359 #[test]
364 fn compute_distance_dimension_mismatch_returns_error() {
365 let a = f32_blob(&[1.0, 0.0, 0.0]);
366 let b = f32_blob(&[1.0, 0.0]); let err = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::L2, 3).unwrap_err();
368 assert!(
369 matches!(err, DistanceError::DimensionMismatch),
370 "expected DimensionMismatch, got {err}"
371 );
372 }
373
374 #[test]
379 fn float4_l2_identical_vectors_is_zero() {
380 let v = f32_blob(&[1.0, 2.0, 3.0]);
381 let d = compute_distance(&v, &v, VectorType::Float4, DistanceMetric::L2, 3).unwrap();
382 assert_approx(d, 0.0, 1e-10);
383 }
384
385 #[test]
386 fn float4_l2_orthogonal_unit_vectors_is_two() {
387 let a = f32_blob(&[1.0, 0.0, 0.0]);
389 let b = f32_blob(&[0.0, 1.0, 0.0]);
390 let d = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::L2, 3).unwrap();
391 assert_approx(d, 2.0, 1e-6);
392 }
393
394 #[test]
395 fn float4_l2_known_distance() {
396 let a = f32_blob(&[3.0, 4.0]);
398 let b = f32_blob(&[0.0, 0.0]);
399 let d = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::L2, 2).unwrap();
400 assert_approx(d, 25.0, 1e-5);
401 }
402
403 #[test]
408 fn float4_cosine_identical_vectors_is_zero() {
409 let v = f32_blob(&[1.0, 2.0, 3.0]);
410 let d = compute_distance(&v, &v, VectorType::Float4, DistanceMetric::Cosine, 3).unwrap();
411 assert_approx(d, 0.0, 1e-6);
412 }
413
414 #[test]
415 fn float4_cosine_orthogonal_vectors_is_one() {
416 let a = f32_blob(&[1.0, 0.0]);
418 let b = f32_blob(&[0.0, 1.0]);
419 let d = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::Cosine, 2).unwrap();
420 assert_approx(d, 1.0, 1e-6);
421 }
422
423 #[test]
424 fn float4_cosine_antiparallel_vectors_is_two() {
425 let a = f32_blob(&[1.0, 0.0]);
427 let b = f32_blob(&[-1.0, 0.0]);
428 let d = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::Cosine, 2).unwrap();
429 assert_approx(d, 2.0, 1e-6);
430 }
431
432 #[test]
433 fn float4_cosine_zero_vector_returns_one() {
434 let a = f32_blob(&[0.0, 0.0, 0.0]);
436 let b = f32_blob(&[0.0, 0.0, 0.0]);
437 let d = compute_distance(&a, &b, VectorType::Float4, DistanceMetric::Cosine, 3).unwrap();
438 assert_approx(d, 1.0, 1e-10);
439 }
440
441 #[test]
446 fn float4_ip_unit_vectors_dot_product() {
447 let a = f32_blob(&[1.0, 0.0, 0.0]);
449 let b = f32_blob(&[0.0, 0.0, 1.0]);
450 let d =
451 compute_distance(&a, &b, VectorType::Float4, DistanceMetric::InnerProduct, 3).unwrap();
452 assert_approx(d, 0.0, 1e-6);
453 }
454
455 #[test]
456 fn float4_ip_parallel_unit_vectors() {
457 let a = f32_blob(&[1.0, 0.0]);
459 let b = f32_blob(&[1.0, 0.0]);
460 let d =
461 compute_distance(&a, &b, VectorType::Float4, DistanceMetric::InnerProduct, 2).unwrap();
462 assert_approx(d, -1.0, 1e-6);
463 }
464
465 #[test]
466 fn float4_ip_known_value() {
467 let a = f32_blob(&[1.0, 2.0]);
469 let b = f32_blob(&[3.0, 4.0]);
470 let d =
471 compute_distance(&a, &b, VectorType::Float4, DistanceMetric::InnerProduct, 2).unwrap();
472 assert_approx(d, -11.0, 1e-5);
473 }
474
475 #[test]
480 fn float8_l2_identical_vectors_is_zero() {
481 let v = f64_blob(&[1.0, 2.0, 3.0]);
482 let d = compute_distance(&v, &v, VectorType::Float8, DistanceMetric::L2, 3).unwrap();
483 assert_approx(d, 0.0, 1e-15);
484 }
485
486 #[test]
487 fn float8_l2_known_distance() {
488 let a = f64_blob(&[1.0, 1.0]);
490 let b = f64_blob(&[4.0, 5.0]);
491 let d = compute_distance(&a, &b, VectorType::Float8, DistanceMetric::L2, 2).unwrap();
492 assert_approx(d, 25.0, 1e-12);
493 }
494
495 #[test]
496 fn float8_cosine_orthogonal_is_one() {
497 let a = f64_blob(&[1.0, 0.0]);
498 let b = f64_blob(&[0.0, 1.0]);
499 let d = compute_distance(&a, &b, VectorType::Float8, DistanceMetric::Cosine, 2).unwrap();
500 assert_approx(d, 1.0, 1e-14);
501 }
502
503 #[test]
504 fn float8_cosine_zero_vector_returns_one() {
505 let a = f64_blob(&[0.0, 0.0]);
506 let b = f64_blob(&[0.0, 0.0]);
507 let d = compute_distance(&a, &b, VectorType::Float8, DistanceMetric::Cosine, 2).unwrap();
508 assert_approx(d, 1.0, 1e-15);
509 }
510
511 #[test]
512 fn float8_ip_known_value() {
513 let a = f64_blob(&[2.0, 3.0]);
515 let b = f64_blob(&[4.0, 5.0]);
516 let d =
517 compute_distance(&a, &b, VectorType::Float8, DistanceMetric::InnerProduct, 2).unwrap();
518 assert_approx(d, -23.0, 1e-12);
519 }
520
521 #[test]
526 fn int4_l2_identical_vectors_is_zero() {
527 let v = i32_blob(&[10, -5, 3]);
528 let d = compute_distance(&v, &v, VectorType::Int4, DistanceMetric::L2, 3).unwrap();
529 assert_approx(d, 0.0, 1e-10);
530 }
531
532 #[test]
533 fn int4_l2_known_distance() {
534 let a = i32_blob(&[0, 0]);
536 let b = i32_blob(&[3, 4]);
537 let d = compute_distance(&a, &b, VectorType::Int4, DistanceMetric::L2, 2).unwrap();
538 assert_approx(d, 25.0, 1e-5);
539 }
540
541 #[test]
542 fn int4_cosine_orthogonal_is_one() {
543 let a = i32_blob(&[1, 0]);
544 let b = i32_blob(&[0, 1]);
545 let d = compute_distance(&a, &b, VectorType::Int4, DistanceMetric::Cosine, 2).unwrap();
546 assert_approx(d, 1.0, 1e-6);
547 }
548
549 #[test]
550 fn int4_ip_known_value() {
551 let a = i32_blob(&[1, 2]);
553 let b = i32_blob(&[3, 4]);
554 let d =
555 compute_distance(&a, &b, VectorType::Int4, DistanceMetric::InnerProduct, 2).unwrap();
556 assert_approx(d, -11.0, 1e-5);
557 }
558
559 #[test]
564 fn int1_l2_known_distance() {
565 let a = i8_blob(&[3, 4]);
567 let b = i8_blob(&[0, 0]);
568 let d = compute_distance(&a, &b, VectorType::Int1, DistanceMetric::L2, 2).unwrap();
569 assert_approx(d, 25.0, 1e-5);
570 }
571
572 #[test]
577 fn float2_cosine_orthogonal_is_one() {
578 let a = f16_blob(&[half::f16::from_f32(1.0), half::f16::from_f32(0.0)]);
579 let b = f16_blob(&[half::f16::from_f32(0.0), half::f16::from_f32(1.0)]);
580 let d = compute_distance(&a, &b, VectorType::Float2, DistanceMetric::Cosine, 2).unwrap();
581 assert_approx(d, 1.0, 1e-3);
583 }
584
585 #[test]
586 fn float2_l2_identical_vectors_is_zero() {
587 let v = f16_blob(&[
588 half::f16::from_f32(1.0),
589 half::f16::from_f32(-2.0),
590 half::f16::from_f32(0.5),
591 ]);
592 let d = compute_distance(&v, &v, VectorType::Float2, DistanceMetric::L2, 3).unwrap();
593 assert_approx(d, 0.0, 1e-6);
594 }
595}