1#[derive(Debug, Clone, Copy)]
49pub struct CpuFeatures {
50 pub avx512f: bool,
51 pub avx512bw: bool,
52 pub avx512vl: bool,
53 pub avx512vbmi: bool,
54 pub avx2: bool,
55 pub sse41: bool,
56 pub neon: bool,
57 pub sve: bool,
58}
59
60impl CpuFeatures {
61 pub fn detect() -> Self {
63 #[cfg(target_arch = "x86_64")]
64 {
65 Self {
66 avx512f: is_x86_feature_detected!("avx512f"),
67 avx512bw: is_x86_feature_detected!("avx512bw"),
68 avx512vl: is_x86_feature_detected!("avx512vl"),
69 avx512vbmi: is_x86_feature_detected!("avx512vbmi"),
70 avx2: is_x86_feature_detected!("avx2"),
71 sse41: is_x86_feature_detected!("sse4.1"),
72 neon: false,
73 sve: false,
74 }
75 }
76 #[cfg(target_arch = "aarch64")]
77 {
78 Self {
79 avx512f: false,
80 avx512bw: false,
81 avx512vl: false,
82 avx512vbmi: false,
83 avx2: false,
84 sse41: false,
85 neon: true, sve: false, }
88 }
89 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
90 {
91 Self {
92 avx512f: false,
93 avx512bw: false,
94 avx512vl: false,
95 avx512vbmi: false,
96 avx2: false,
97 sse41: false,
98 neon: false,
99 sve: false,
100 }
101 }
102 }
103
104 pub fn best_simd_level(&self) -> SimdLevel {
106 if self.avx512f && self.avx512bw {
107 SimdLevel::Avx512
108 } else if self.avx2 {
109 SimdLevel::Avx2
110 } else if self.sse41 {
111 SimdLevel::Sse41
112 } else if self.neon {
113 SimdLevel::Neon
114 } else {
115 SimdLevel::Scalar
116 }
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum SimdLevel {
123 Avx512,
124 Avx2,
125 Sse41,
126 Neon,
127 Scalar,
128}
129
130impl SimdLevel {
131 pub fn width_bytes(&self) -> usize {
133 match self {
134 SimdLevel::Avx512 => 64,
135 SimdLevel::Avx2 => 32,
136 SimdLevel::Sse41 => 16,
137 SimdLevel::Neon => 16,
138 SimdLevel::Scalar => 1,
139 }
140 }
141
142 pub fn f32_elements(&self) -> usize {
144 self.width_bytes() / 4
145 }
146
147 pub fn i8_elements(&self) -> usize {
149 self.width_bytes()
150 }
151}
152
153pub trait DistanceKernel: Send + Sync {
159 fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32;
161
162 fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32;
164
165 fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32;
167
168 fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]);
170
171 fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]);
173
174 fn simd_level(&self) -> SimdLevel;
176}
177
178pub struct ScalarKernel;
184
185impl DistanceKernel for ScalarKernel {
186 fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32 {
187 debug_assert_eq!(a.len(), b.len());
188 a.iter()
189 .zip(b.iter())
190 .map(|(x, y)| {
191 let diff = x - y;
192 diff * diff
193 })
194 .sum()
195 }
196
197 fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32 {
198 debug_assert_eq!(a.len(), b.len());
199 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
200 }
201
202 fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32 {
203 debug_assert_eq!(a.len(), b.len());
204 a.iter()
205 .zip(b.iter())
206 .map(|(&x, &y)| x as i32 * y as i32)
207 .sum()
208 }
209
210 fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
211 let n = vectors.len() / dim;
212 debug_assert!(out.len() >= n);
213
214 for i in 0..n {
215 let vec = &vectors[i * dim..(i + 1) * dim];
216 out[i] = self.l2_squared_f32(query, vec);
217 }
218 }
219
220 fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
221 let n = vectors.len() / dim;
222 debug_assert!(out.len() >= n);
223
224 for i in 0..n {
225 let vec = &vectors[i * dim..(i + 1) * dim];
226 out[i] = self.dot_f32(query, vec);
227 }
228 }
229
230 fn simd_level(&self) -> SimdLevel {
231 SimdLevel::Scalar
232 }
233}
234
235#[cfg(target_arch = "x86_64")]
240pub struct Avx2Kernel;
241
242#[cfg(target_arch = "x86_64")]
243impl DistanceKernel for Avx2Kernel {
244 fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32 {
245 debug_assert_eq!(a.len(), b.len());
246
247 #[target_feature(enable = "avx2")]
248 unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
249 use std::arch::x86_64::*;
250 unsafe {
251 let n = a.len();
252 let chunks = n / 8;
253 let mut sum = _mm256_setzero_ps();
254
255 for i in 0..chunks {
256 let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
257 let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
258 let diff = _mm256_sub_ps(va, vb);
259 sum = _mm256_fmadd_ps(diff, diff, sum);
260 }
261
262 let sum128 =
264 _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
265 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
266 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
267 let mut result = _mm_cvtss_f32(sum32);
268
269 for i in (chunks * 8)..n {
271 let diff = a[i] - b[i];
272 result += diff * diff;
273 }
274
275 result
276 }
277 }
278
279 if is_x86_feature_detected!("avx2") {
280 unsafe { inner(a, b) }
281 } else {
282 ScalarKernel.l2_squared_f32(a, b)
283 }
284 }
285
286 fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32 {
287 debug_assert_eq!(a.len(), b.len());
288
289 #[target_feature(enable = "avx2")]
290 unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
291 use std::arch::x86_64::*;
292 unsafe {
293 let n = a.len();
294 let chunks = n / 8;
295 let mut sum = _mm256_setzero_ps();
296
297 for i in 0..chunks {
298 let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
299 let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
300 sum = _mm256_fmadd_ps(va, vb, sum);
301 }
302
303 let sum128 =
305 _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
306 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
307 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
308 let mut result = _mm_cvtss_f32(sum32);
309
310 for i in (chunks * 8)..n {
312 result += a[i] * b[i];
313 }
314
315 result
316 }
317 }
318
319 if is_x86_feature_detected!("avx2") {
320 unsafe { inner(a, b) }
321 } else {
322 ScalarKernel.dot_f32(a, b)
323 }
324 }
325
326 fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32 {
327 debug_assert_eq!(a.len(), b.len());
328
329 #[target_feature(enable = "avx2")]
330 unsafe fn inner(a: &[i8], b: &[i8]) -> i32 {
331 use std::arch::x86_64::*;
332 unsafe {
333 let n = a.len();
334 let chunks = n / 32;
335 let mut sum = _mm256_setzero_si256();
336
337 for i in 0..chunks {
338 let va = _mm256_loadu_si256(a.as_ptr().add(i * 32) as *const __m256i);
339 let vb = _mm256_loadu_si256(b.as_ptr().add(i * 32) as *const __m256i);
340
341 let a_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 0));
343 let b_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 0));
344 let a_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
345 let b_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
346
347 let prod_lo = _mm256_madd_epi16(a_lo, b_lo);
348 let prod_hi = _mm256_madd_epi16(a_hi, b_hi);
349
350 sum = _mm256_add_epi32(sum, prod_lo);
351 sum = _mm256_add_epi32(sum, prod_hi);
352 }
353
354 let sum128 = _mm_add_epi32(
356 _mm256_extracti128_si256(sum, 0),
357 _mm256_extracti128_si256(sum, 1),
358 );
359 let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
360 let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
361 let mut result = _mm_cvtsi128_si32(sum32);
362
363 for i in (chunks * 32)..n {
365 result += a[i] as i32 * b[i] as i32;
366 }
367
368 result
369 }
370 }
371
372 if is_x86_feature_detected!("avx2") {
373 unsafe { inner(a, b) }
374 } else {
375 ScalarKernel.dot_i8(a, b)
376 }
377 }
378
379 fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
380 let n = vectors.len() / dim;
381 for i in 0..n {
382 let vec = &vectors[i * dim..(i + 1) * dim];
383 out[i] = self.l2_squared_f32(query, vec);
384 }
385 }
386
387 fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
388 let n = vectors.len() / dim;
389 for i in 0..n {
390 let vec = &vectors[i * dim..(i + 1) * dim];
391 out[i] = self.dot_f32(query, vec);
392 }
393 }
394
395 fn simd_level(&self) -> SimdLevel {
396 SimdLevel::Avx2
397 }
398}
399
400#[cfg(target_arch = "aarch64")]
405pub struct NeonKernel;
406
407#[cfg(target_arch = "aarch64")]
408impl DistanceKernel for NeonKernel {
409 fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32 {
410 debug_assert_eq!(a.len(), b.len());
411
412 unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
413 use std::arch::aarch64::*;
414 unsafe {
415 let n = a.len();
416 let chunks = n / 4;
417 let mut sum = vdupq_n_f32(0.0);
418
419 for i in 0..chunks {
420 let va = vld1q_f32(a.as_ptr().add(i * 4));
421 let vb = vld1q_f32(b.as_ptr().add(i * 4));
422 let diff = vsubq_f32(va, vb);
423 sum = vfmaq_f32(sum, diff, diff);
424 }
425
426 let mut result = vaddvq_f32(sum);
428
429 for i in (chunks * 4)..n {
431 let diff = a[i] - b[i];
432 result += diff * diff;
433 }
434
435 result
436 }
437 }
438
439 unsafe { inner(a, b) }
440 }
441
442 fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32 {
443 debug_assert_eq!(a.len(), b.len());
444
445 unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
446 use std::arch::aarch64::*;
447 unsafe {
448 let n = a.len();
449 let chunks = n / 4;
450 let mut sum = vdupq_n_f32(0.0);
451
452 for i in 0..chunks {
453 let va = vld1q_f32(a.as_ptr().add(i * 4));
454 let vb = vld1q_f32(b.as_ptr().add(i * 4));
455 sum = vfmaq_f32(sum, va, vb);
456 }
457
458 let mut result = vaddvq_f32(sum);
459
460 for i in (chunks * 4)..n {
461 result += a[i] * b[i];
462 }
463
464 result
465 }
466 }
467
468 unsafe { inner(a, b) }
469 }
470
471 fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32 {
472 debug_assert_eq!(a.len(), b.len());
473
474 unsafe fn inner(a: &[i8], b: &[i8]) -> i32 {
475 use std::arch::aarch64::*;
476 unsafe {
477 let n = a.len();
478 let chunks = n / 16;
479 let mut sum = vdupq_n_s32(0);
480
481 for i in 0..chunks {
482 let va = vld1q_s8(a.as_ptr().add(i * 16));
483 let vb = vld1q_s8(b.as_ptr().add(i * 16));
484
485 let a_lo = vmovl_s8(vget_low_s8(va));
487 let b_lo = vmovl_s8(vget_low_s8(vb));
488 let a_hi = vmovl_s8(vget_high_s8(va));
489 let b_hi = vmovl_s8(vget_high_s8(vb));
490
491 let prod_lo = vmull_s16(vget_low_s16(a_lo), vget_low_s16(b_lo));
492 let prod_hi = vmull_s16(vget_high_s16(a_lo), vget_high_s16(b_lo));
493
494 sum = vaddq_s32(sum, prod_lo);
495 sum = vaddq_s32(sum, prod_hi);
496
497 let prod_lo2 = vmull_s16(vget_low_s16(a_hi), vget_low_s16(b_hi));
498 let prod_hi2 = vmull_s16(vget_high_s16(a_hi), vget_high_s16(b_hi));
499
500 sum = vaddq_s32(sum, prod_lo2);
501 sum = vaddq_s32(sum, prod_hi2);
502 }
503
504 let mut result = vaddvq_s32(sum);
505
506 for i in (chunks * 16)..n {
507 result += a[i] as i32 * b[i] as i32;
508 }
509
510 result
511 }
512 }
513
514 unsafe { inner(a, b) }
515 }
516
517 fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
518 let n = vectors.len() / dim;
519 for i in 0..n {
520 let vec = &vectors[i * dim..(i + 1) * dim];
521 out[i] = self.l2_squared_f32(query, vec);
522 }
523 }
524
525 fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
526 let n = vectors.len() / dim;
527 for i in 0..n {
528 let vec = &vectors[i * dim..(i + 1) * dim];
529 out[i] = self.dot_f32(query, vec);
530 }
531 }
532
533 fn simd_level(&self) -> SimdLevel {
534 SimdLevel::Neon
535 }
536}
537
538pub struct KernelDispatcher {
544 features: CpuFeatures,
545}
546
547impl KernelDispatcher {
548 pub fn new() -> Self {
550 Self {
551 features: CpuFeatures::detect(),
552 }
553 }
554
555 pub fn best_kernel(&self) -> Box<dyn DistanceKernel> {
557 #[cfg(target_arch = "x86_64")]
558 {
559 if self.features.avx2 {
560 return Box::new(Avx2Kernel);
561 }
562 }
563
564 #[cfg(target_arch = "aarch64")]
565 {
566 if self.features.neon {
567 return Box::new(NeonKernel);
568 }
569 }
570
571 Box::new(ScalarKernel)
572 }
573
574 pub fn kernel_for_level(&self, level: SimdLevel) -> Box<dyn DistanceKernel> {
576 match level {
577 #[cfg(target_arch = "x86_64")]
578 SimdLevel::Avx2 if self.features.avx2 => Box::new(Avx2Kernel),
579
580 #[cfg(target_arch = "aarch64")]
581 SimdLevel::Neon if self.features.neon => Box::new(NeonKernel),
582
583 _ => Box::new(ScalarKernel),
584 }
585 }
586
587 pub fn features(&self) -> CpuFeatures {
589 self.features
590 }
591
592 pub fn description(&self) -> String {
594 format!(
595 "SIMD: {:?}, Features: avx2={}, neon={}",
596 self.features.best_simd_level(),
597 self.features.avx2,
598 self.features.neon,
599 )
600 }
601}
602
603impl Default for KernelDispatcher {
604 fn default() -> Self {
605 Self::new()
606 }
607}
608
609pub struct ScanOps {
615 kernel: Box<dyn DistanceKernel>,
616}
617
618impl ScanOps {
619 pub fn new() -> Self {
621 Self {
622 kernel: KernelDispatcher::new().best_kernel(),
623 }
624 }
625
626 pub fn with_kernel(kernel: Box<dyn DistanceKernel>) -> Self {
628 Self { kernel }
629 }
630
631 pub fn top_k_l2(
633 &self,
634 query: &[f32],
635 vectors: &[f32],
636 dim: usize,
637 k: usize,
638 ) -> Vec<(u32, f32)> {
639 let n = vectors.len() / dim;
640 let mut distances = vec![0.0f32; n];
641
642 self.kernel
643 .l2_squared_batch_f32(query, vectors, dim, &mut distances);
644
645 let mut indices: Vec<usize> = (0..n).collect();
650 indices.sort_by(|&a, &b| distances[a].total_cmp(&distances[b]));
651
652 indices
653 .into_iter()
654 .take(k)
655 .map(|i| (i as u32, distances[i].sqrt()))
656 .collect()
657 }
658
659 pub fn top_k_dot(
661 &self,
662 query: &[f32],
663 vectors: &[f32],
664 dim: usize,
665 k: usize,
666 ) -> Vec<(u32, f32)> {
667 let n = vectors.len() / dim;
668 let mut scores = vec![0.0f32; n];
669
670 self.kernel.dot_batch_f32(query, vectors, dim, &mut scores);
671
672 let mut indices: Vec<usize> = (0..n).collect();
675 indices.sort_by(|&a, &b| scores[b].total_cmp(&scores[a]));
676
677 indices
678 .into_iter()
679 .take(k)
680 .map(|i| (i as u32, scores[i]))
681 .collect()
682 }
683
684 pub fn simd_level(&self) -> SimdLevel {
686 self.kernel.simd_level()
687 }
688}
689
690impl Default for ScanOps {
691 fn default() -> Self {
692 Self::new()
693 }
694}
695
696#[cfg(test)]
697mod tests {
698 use super::*;
699
700 #[test]
701 fn test_scalar_l2() {
702 let kernel = ScalarKernel;
703 let a = vec![1.0, 2.0, 3.0, 4.0];
704 let b = vec![1.0, 2.0, 3.0, 5.0];
705
706 let dist = kernel.l2_squared_f32(&a, &b);
707 assert!((dist - 1.0).abs() < 1e-6);
708 }
709
710 #[test]
711 fn test_scalar_dot() {
712 let kernel = ScalarKernel;
713 let a = vec![1.0, 2.0, 3.0, 4.0];
714 let b = vec![1.0, 2.0, 3.0, 4.0];
715
716 let dot = kernel.dot_f32(&a, &b);
717 assert!((dot - 30.0).abs() < 1e-6);
718 }
719
720 #[test]
721 fn test_scalar_dot_i8() {
722 let kernel = ScalarKernel;
723 let a: Vec<i8> = vec![1, 2, 3, 4];
724 let b: Vec<i8> = vec![1, 2, 3, 4];
725
726 let dot = kernel.dot_i8(&a, &b);
727 assert_eq!(dot, 30);
728 }
729
730 #[test]
731 fn test_dispatcher() {
732 let dispatcher = KernelDispatcher::new();
733 let kernel = dispatcher.best_kernel();
734
735 let a = vec![1.0f32; 128];
736 let b = vec![2.0f32; 128];
737
738 let l2 = kernel.l2_squared_f32(&a, &b);
739 assert!((l2 - 128.0).abs() < 1e-4);
740
741 let dot = kernel.dot_f32(&a, &b);
742 assert!((dot - 256.0).abs() < 1e-4);
743 }
744
745 #[test]
746 fn test_scan_ops() {
747 let ops = ScanOps::new();
748
749 let query = vec![1.0, 0.0, 0.0, 0.0];
750 let vectors = vec![
751 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, ];
755
756 let top2 = ops.top_k_l2(&query, &vectors, 4, 2);
757
758 assert_eq!(top2.len(), 2);
759 assert_eq!(top2[0].0, 0); }
761
762 #[test]
763 fn test_cpu_features() {
764 let features = CpuFeatures::detect();
765 let level = features.best_simd_level();
766
767 println!("Detected SIMD level: {:?}", level);
769 assert!(level.width_bytes() > 0);
770 }
771}