1use std::sync::OnceLock;
15
16type DistanceFn = fn(&[f32], &[f32]) -> f32;
18
19type BinaryDistanceFn = fn(&[f32], &[f32]) -> u32;
21
22static DOT_PRODUCT_FN: OnceLock<DistanceFn> = OnceLock::new();
28
29static EUCLIDEAN_FN: OnceLock<DistanceFn> = OnceLock::new();
31
32static COSINE_FN: OnceLock<DistanceFn> = OnceLock::new();
34
35static COSINE_NORMALIZED_FN: OnceLock<DistanceFn> = OnceLock::new();
37
38static HAMMING_FN: OnceLock<BinaryDistanceFn> = OnceLock::new();
40
41fn select_dot_product() -> DistanceFn {
47 #[cfg(target_arch = "x86_64")]
48 {
49 if is_x86_feature_detected!("avx512f") {
50 return dot_product_avx512;
51 }
52 if is_x86_feature_detected!("avx2") {
53 return dot_product_avx2;
54 }
55 }
56 dot_product_scalar
57}
58
59fn select_euclidean() -> DistanceFn {
61 #[cfg(target_arch = "x86_64")]
62 {
63 if is_x86_feature_detected!("avx512f") {
64 return euclidean_avx512;
65 }
66 if is_x86_feature_detected!("avx2") {
67 return euclidean_avx2;
68 }
69 }
70 euclidean_scalar
71}
72
73fn select_cosine() -> DistanceFn {
75 #[cfg(target_arch = "x86_64")]
76 {
77 if is_x86_feature_detected!("avx512f") {
78 return cosine_avx512;
79 }
80 if is_x86_feature_detected!("avx2") {
81 return cosine_avx2;
82 }
83 }
84 cosine_scalar
85}
86
87fn select_cosine_normalized() -> DistanceFn {
89 #[cfg(target_arch = "x86_64")]
90 {
91 if is_x86_feature_detected!("avx512f") {
92 return cosine_normalized_avx512;
93 }
94 if is_x86_feature_detected!("avx2") {
95 return cosine_normalized_avx2;
96 }
97 }
98 cosine_normalized_scalar
99}
100
101fn select_hamming() -> BinaryDistanceFn {
103 #[cfg(target_arch = "x86_64")]
104 {
105 if is_x86_feature_detected!("avx512vpopcntdq") {
106 return hamming_avx512_popcnt;
107 }
108 if is_x86_feature_detected!("popcnt") {
109 return hamming_popcnt;
110 }
111 }
112 hamming_scalar
113}
114
115#[inline]
127#[must_use]
128pub fn dot_product_dispatched(a: &[f32], b: &[f32]) -> f32 {
129 let f = DOT_PRODUCT_FN.get_or_init(select_dot_product);
130 f(a, b)
131}
132
133#[inline]
135#[must_use]
136pub fn euclidean_dispatched(a: &[f32], b: &[f32]) -> f32 {
137 let f = EUCLIDEAN_FN.get_or_init(select_euclidean);
138 f(a, b)
139}
140
141#[inline]
143#[must_use]
144pub fn cosine_dispatched(a: &[f32], b: &[f32]) -> f32 {
145 let f = COSINE_FN.get_or_init(select_cosine);
146 f(a, b)
147}
148
149#[inline]
151#[must_use]
152pub fn cosine_normalized_dispatched(a: &[f32], b: &[f32]) -> f32 {
153 let f = COSINE_NORMALIZED_FN.get_or_init(select_cosine_normalized);
154 f(a, b)
155}
156
157#[inline]
159#[must_use]
160pub fn hamming_dispatched(a: &[f32], b: &[f32]) -> u32 {
161 let f = HAMMING_FN.get_or_init(select_hamming);
162 f(a, b)
163}
164
165#[must_use]
167pub fn simd_features_info() -> SimdFeatures {
168 SimdFeatures::detect()
169}
170
171#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173#[allow(clippy::struct_excessive_bools)]
174pub struct SimdFeatures {
175 pub avx512f: bool,
177 pub avx512_popcnt: bool,
179 pub avx2: bool,
181 pub popcnt: bool,
183}
184
185impl SimdFeatures {
186 #[must_use]
188 pub fn detect() -> Self {
189 #[cfg(target_arch = "x86_64")]
190 {
191 Self {
192 avx512f: is_x86_feature_detected!("avx512f"),
193 avx512_popcnt: is_x86_feature_detected!("avx512vpopcntdq"),
194 avx2: is_x86_feature_detected!("avx2"),
195 popcnt: is_x86_feature_detected!("popcnt"),
196 }
197 }
198
199 #[cfg(not(target_arch = "x86_64"))]
200 {
201 Self {
202 avx512f: false,
203 avx512_popcnt: false,
204 avx2: false,
205 popcnt: false,
206 }
207 }
208 }
209
210 #[must_use]
212 pub const fn best_instruction_set(&self) -> &'static str {
213 if self.avx512f {
214 "AVX-512"
215 } else if self.avx2 {
216 "AVX2"
217 } else {
218 "Scalar"
219 }
220 }
221}
222
223fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
230 assert_eq!(a.len(), b.len(), "Vector length mismatch");
231 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
232}
233
234#[cfg(target_arch = "x86_64")]
235fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
236 crate::simd_explicit::dot_product_simd(a, b)
237}
238
239#[cfg(target_arch = "x86_64")]
240fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 {
241 crate::simd_avx512::dot_product_auto(a, b)
242}
243
244fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
247 assert_eq!(a.len(), b.len(), "Vector length mismatch");
248 a.iter()
249 .zip(b.iter())
250 .map(|(x, y)| {
251 let d = x - y;
252 d * d
253 })
254 .sum::<f32>()
255 .sqrt()
256}
257
258#[cfg(target_arch = "x86_64")]
259fn euclidean_avx2(a: &[f32], b: &[f32]) -> f32 {
260 crate::simd_explicit::euclidean_distance_simd(a, b)
261}
262
263#[cfg(target_arch = "x86_64")]
264fn euclidean_avx512(a: &[f32], b: &[f32]) -> f32 {
265 crate::simd_avx512::euclidean_auto(a, b)
266}
267
268fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
271 assert_eq!(a.len(), b.len(), "Vector length mismatch");
272 let mut dot = 0.0f32;
273 let mut norm_a = 0.0f32;
274 let mut norm_b = 0.0f32;
275
276 for (x, y) in a.iter().zip(b.iter()) {
277 dot += x * y;
278 norm_a += x * x;
279 norm_b += y * y;
280 }
281
282 let denom = (norm_a * norm_b).sqrt();
283 if denom > 0.0 {
284 dot / denom
285 } else {
286 0.0
287 }
288}
289
290#[cfg(target_arch = "x86_64")]
291fn cosine_avx2(a: &[f32], b: &[f32]) -> f32 {
292 crate::simd_explicit::cosine_similarity_simd(a, b)
293}
294
295#[cfg(target_arch = "x86_64")]
296fn cosine_avx512(a: &[f32], b: &[f32]) -> f32 {
297 crate::simd_avx512::cosine_similarity_auto(a, b)
298}
299
300fn cosine_normalized_scalar(a: &[f32], b: &[f32]) -> f32 {
303 dot_product_scalar(a, b)
305}
306
307#[cfg(target_arch = "x86_64")]
308fn cosine_normalized_avx2(a: &[f32], b: &[f32]) -> f32 {
309 crate::simd_explicit::dot_product_simd(a, b)
310}
311
312#[cfg(target_arch = "x86_64")]
313fn cosine_normalized_avx512(a: &[f32], b: &[f32]) -> f32 {
314 crate::simd_avx512::dot_product_auto(a, b)
315}
316
317fn hamming_scalar(a: &[f32], b: &[f32]) -> u32 {
320 assert_eq!(a.len(), b.len(), "Vector length mismatch");
321 #[allow(clippy::cast_possible_truncation)]
322 let count = a
323 .iter()
324 .zip(b.iter())
325 .filter(|(&x, &y)| (x > 0.5) != (y > 0.5))
326 .count() as u32;
327 count
328}
329
330#[cfg(target_arch = "x86_64")]
331#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
332fn hamming_popcnt(a: &[f32], b: &[f32]) -> u32 {
333 crate::simd_explicit::hamming_distance_simd(a, b) as u32
335}
336
337#[cfg(target_arch = "x86_64")]
338fn hamming_avx512_popcnt(a: &[f32], b: &[f32]) -> u32 {
339 hamming_popcnt(a, b)
342}
343
344pub const CACHE_LINE_SIZE: usize = 64;
350
351pub const PREFETCH_DISTANCE_768D: usize = 768 * std::mem::size_of::<f32>() / CACHE_LINE_SIZE;
354
355pub const PREFETCH_DISTANCE_384D: usize = 384 * std::mem::size_of::<f32>() / CACHE_LINE_SIZE;
357
358pub const PREFETCH_DISTANCE_1536D: usize = 1536 * std::mem::size_of::<f32>() / CACHE_LINE_SIZE;
360
361#[inline]
363#[must_use]
364pub const fn prefetch_distance(dimension: usize) -> usize {
365 (dimension * std::mem::size_of::<f32>()) / CACHE_LINE_SIZE
366}
367
368#[cfg(test)]
373#[allow(
374 clippy::cast_precision_loss,
375 clippy::uninlined_format_args,
376 clippy::float_cmp
377)]
378mod tests {
379 use super::*;
380
381 #[test]
386 fn test_dot_product_dispatched_correctness() {
387 let a = vec![1.0f32, 2.0, 3.0, 4.0];
389 let b = vec![5.0f32, 6.0, 7.0, 8.0];
390
391 let result = dot_product_dispatched(&a, &b);
393
394 assert!((result - 70.0).abs() < 1e-5);
396 }
397
398 #[test]
399 fn test_euclidean_dispatched_correctness() {
400 let a = vec![0.0f32, 0.0, 0.0];
402 let b = vec![3.0f32, 4.0, 0.0];
403
404 let result = euclidean_dispatched(&a, &b);
406
407 assert!((result - 5.0).abs() < 1e-5);
409 }
410
411 #[test]
412 fn test_cosine_dispatched_correctness() {
413 let a = vec![1.0f32, 2.0, 3.0];
415 let b = vec![1.0f32, 2.0, 3.0];
416
417 let result = cosine_dispatched(&a, &b);
419
420 assert!((result - 1.0).abs() < 1e-5);
422 }
423
424 #[test]
425 fn test_cosine_dispatched_orthogonal() {
426 let a = vec![1.0f32, 0.0, 0.0];
428 let b = vec![0.0f32, 1.0, 0.0];
429
430 let result = cosine_dispatched(&a, &b);
432
433 assert!(result.abs() < 1e-5);
435 }
436
437 #[test]
438 fn test_cosine_normalized_dispatched() {
439 let a = vec![1.0f32, 0.0];
441 let b = vec![0.707f32, 0.707]; let result = cosine_normalized_dispatched(&a, &b);
445
446 assert!((result - 0.707).abs() < 0.01);
448 }
449
450 #[test]
451 fn test_hamming_dispatched_correctness() {
452 let a = vec![1.0f32, 0.0, 1.0, 0.0]; let b = vec![1.0f32, 1.0, 0.0, 0.0]; let result = hamming_dispatched(&a, &b);
458
459 assert_eq!(result, 2);
461 }
462
463 #[test]
468 fn test_dot_product_dispatched_768d() {
469 let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.001).collect();
471 let b: Vec<f32> = (0..768).map(|i| ((768 - i) as f32) * 0.001).collect();
472
473 let result = dot_product_dispatched(&a, &b);
475
476 assert!(result.is_finite());
478 assert!(result > 0.0);
479 }
480
481 #[test]
482 fn test_euclidean_dispatched_768d() {
483 let a: Vec<f32> = vec![0.0; 768];
485 let b: Vec<f32> = vec![1.0; 768];
486
487 let result = euclidean_dispatched(&a, &b);
489
490 assert!((result - 768.0_f32.sqrt()).abs() < 0.01);
492 }
493
494 #[test]
495 fn test_cosine_dispatched_768d() {
496 let a: Vec<f32> = (0..768).map(|i| (i as f32).sin()).collect();
498 let b = a.clone();
499
500 let result = cosine_dispatched(&a, &b);
502
503 assert!((result - 1.0).abs() < 1e-4);
505 }
506
507 #[test]
512 fn test_simd_features_detect() {
513 let features = SimdFeatures::detect();
515
516 let _name = features.best_instruction_set();
518 println!("SIMD features: {:?}", features);
519 println!("Best instruction set: {}", features.best_instruction_set());
520 }
521
522 #[test]
523 fn test_simd_features_info() {
524 let features = simd_features_info();
526
527 assert!(!features.best_instruction_set().is_empty());
529 }
530
531 #[test]
536 fn test_prefetch_distance_768d() {
537 assert_eq!(PREFETCH_DISTANCE_768D, 48);
539 }
540
541 #[test]
542 fn test_prefetch_distance_384d() {
543 assert_eq!(PREFETCH_DISTANCE_384D, 24);
545 }
546
547 #[test]
548 fn test_prefetch_distance_1536d() {
549 assert_eq!(PREFETCH_DISTANCE_1536D, 96);
551 }
552
553 #[test]
554 fn test_prefetch_distance_function() {
555 assert_eq!(prefetch_distance(768), 48);
556 assert_eq!(prefetch_distance(384), 24);
557 assert_eq!(prefetch_distance(128), 8);
558 }
559
560 #[test]
565 fn test_dispatch_initialized_once() {
566 let a = vec![1.0f32; 100];
568 let b = vec![2.0f32; 100];
569
570 let r1 = dot_product_dispatched(&a, &b);
572
573 let r2 = dot_product_dispatched(&a, &b);
575
576 assert_eq!(r1, r2);
578 }
579
580 #[test]
581 fn test_dispatch_thread_safe() {
582 use std::sync::Arc;
583 use std::thread;
584
585 let a = Arc::new(vec![1.0f32; 768]);
587 let b = Arc::new(vec![2.0f32; 768]);
588
589 let handles: Vec<_> = (0..4)
591 .map(|_| {
592 let a = Arc::clone(&a);
593 let b = Arc::clone(&b);
594 thread::spawn(move || {
595 for _ in 0..100 {
596 let _ = dot_product_dispatched(&a, &b);
597 let _ = cosine_dispatched(&a, &b);
598 let _ = euclidean_dispatched(&a, &b);
599 }
600 })
601 })
602 .collect();
603
604 for h in handles {
606 h.join().expect("Thread should not panic");
607 }
608 }
609
610 #[test]
615 #[should_panic(expected = "dimensions must match")]
616 fn test_dot_product_dispatched_length_mismatch() {
617 let a = vec![1.0f32, 2.0];
618 let b = vec![1.0f32, 2.0, 3.0];
619 let _ = dot_product_dispatched(&a, &b);
620 }
621
622 #[test]
623 fn test_empty_vectors() {
624 let a: Vec<f32> = vec![];
625 let b: Vec<f32> = vec![];
626
627 assert_eq!(dot_product_dispatched(&a, &b), 0.0);
629 }
630
631 #[test]
632 fn test_single_element() {
633 let a = vec![3.0f32];
634 let b = vec![4.0f32];
635
636 assert_eq!(dot_product_dispatched(&a, &b), 12.0);
637 }
638}