Skip to main content

sochdb_vector/
portable_simd.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! # Portable SIMD Scan Kernels (Task 6)
19//!
20//! Provides a family of SIMD kernels that avoid gather pathologies and work
21//! across diverse hardware:
22//!
23//! 1. **AVX-512**: Gather or permute-based
24//! 2. **AVX2**: Byte LUT via shuffle + partial sums
25//! 3. **NEON**: Table lookup primitives
26//! 4. **Scalar**: Universal fallback
27//!
28//! ## Design Principles
29//!
30//! - Prefer layouts that allow structured loads (SoA)
31//! - Use int16/int32 accumulation to reduce bandwidth
32//! - Minimize unpredictable memory refs
33//! - Maximize instruction-level parallelism (ILP)
34//!
35//! ## Math/Algorithm
36//!
37//! Inner loop is Θ(N_scanned). Performance is dominated by:
38//! - Memory access patterns
39//! - Instruction throughput
40//!
41//! Kernel design minimizes cache misses and maximizes ILP.
42
43// ============================================================================
44// CPU Feature Detection
45// ============================================================================
46
47/// Detected CPU features
48#[derive(Debug, Clone, Copy)]
49pub struct CpuFeatures {
50    pub avx512f: bool,
51    pub avx512bw: bool,
52    pub avx512vl: bool,
53    pub avx512vbmi: bool,
54    pub avx2: bool,
55    pub sse41: bool,
56    pub neon: bool,
57    pub sve: bool,
58}
59
60impl CpuFeatures {
61    /// Detect CPU features at runtime
62    pub fn detect() -> Self {
63        #[cfg(target_arch = "x86_64")]
64        {
65            Self {
66                avx512f: is_x86_feature_detected!("avx512f"),
67                avx512bw: is_x86_feature_detected!("avx512bw"),
68                avx512vl: is_x86_feature_detected!("avx512vl"),
69                avx512vbmi: is_x86_feature_detected!("avx512vbmi"),
70                avx2: is_x86_feature_detected!("avx2"),
71                sse41: is_x86_feature_detected!("sse4.1"),
72                neon: false,
73                sve: false,
74            }
75        }
76        #[cfg(target_arch = "aarch64")]
77        {
78            Self {
79                avx512f: false,
80                avx512bw: false,
81                avx512vl: false,
82                avx512vbmi: false,
83                avx2: false,
84                sse41: false,
85                neon: true, // NEON is mandatory on AArch64
86                sve: false, // SVE detection would require runtime check
87            }
88        }
89        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
90        {
91            Self {
92                avx512f: false,
93                avx512bw: false,
94                avx512vl: false,
95                avx512vbmi: false,
96                avx2: false,
97                sse41: false,
98                neon: false,
99                sve: false,
100            }
101        }
102    }
103
104    /// Get best available SIMD level
105    pub fn best_simd_level(&self) -> SimdLevel {
106        if self.avx512f && self.avx512bw {
107            SimdLevel::Avx512
108        } else if self.avx2 {
109            SimdLevel::Avx2
110        } else if self.sse41 {
111            SimdLevel::Sse41
112        } else if self.neon {
113            SimdLevel::Neon
114        } else {
115            SimdLevel::Scalar
116        }
117    }
118}
119
120/// SIMD level for kernel dispatch
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum SimdLevel {
123    Avx512,
124    Avx2,
125    Sse41,
126    Neon,
127    Scalar,
128}
129
130impl SimdLevel {
131    /// Vector width in bytes
132    pub fn width_bytes(&self) -> usize {
133        match self {
134            SimdLevel::Avx512 => 64,
135            SimdLevel::Avx2 => 32,
136            SimdLevel::Sse41 => 16,
137            SimdLevel::Neon => 16,
138            SimdLevel::Scalar => 1,
139        }
140    }
141
142    /// Elements processed per iteration for f32
143    pub fn f32_elements(&self) -> usize {
144        self.width_bytes() / 4
145    }
146
147    /// Elements processed per iteration for i8
148    pub fn i8_elements(&self) -> usize {
149        self.width_bytes()
150    }
151}
152
153// ============================================================================
154// Kernel Trait
155// ============================================================================
156
157/// Trait for portable distance kernels
158pub trait DistanceKernel: Send + Sync {
159    /// Compute L2 squared distance between two f32 vectors
160    fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32;
161
162    /// Compute dot product of two f32 vectors
163    fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32;
164
165    /// Compute dot product of two i8 vectors (returns i32)
166    fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32;
167
168    /// Batch L2 squared: query vs multiple vectors
169    fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]);
170
171    /// Batch dot product: query vs multiple vectors
172    fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]);
173
174    /// SIMD level of this kernel
175    fn simd_level(&self) -> SimdLevel;
176}
177
178// ============================================================================
179// Scalar Fallback Implementation
180// ============================================================================
181
182/// Scalar fallback implementation (works everywhere)
183pub struct ScalarKernel;
184
185impl DistanceKernel for ScalarKernel {
186    fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32 {
187        debug_assert_eq!(a.len(), b.len());
188        a.iter()
189            .zip(b.iter())
190            .map(|(x, y)| {
191                let diff = x - y;
192                diff * diff
193            })
194            .sum()
195    }
196
197    fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32 {
198        debug_assert_eq!(a.len(), b.len());
199        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
200    }
201
202    fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32 {
203        debug_assert_eq!(a.len(), b.len());
204        a.iter()
205            .zip(b.iter())
206            .map(|(&x, &y)| x as i32 * y as i32)
207            .sum()
208    }
209
210    fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
211        let n = vectors.len() / dim;
212        debug_assert!(out.len() >= n);
213
214        for i in 0..n {
215            let vec = &vectors[i * dim..(i + 1) * dim];
216            out[i] = self.l2_squared_f32(query, vec);
217        }
218    }
219
220    fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
221        let n = vectors.len() / dim;
222        debug_assert!(out.len() >= n);
223
224        for i in 0..n {
225            let vec = &vectors[i * dim..(i + 1) * dim];
226            out[i] = self.dot_f32(query, vec);
227        }
228    }
229
230    fn simd_level(&self) -> SimdLevel {
231        SimdLevel::Scalar
232    }
233}
234
235// ============================================================================
236// AVX2 Implementation
237// ============================================================================
238
239#[cfg(target_arch = "x86_64")]
240pub struct Avx2Kernel;
241
242#[cfg(target_arch = "x86_64")]
243impl DistanceKernel for Avx2Kernel {
244    fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32 {
245        debug_assert_eq!(a.len(), b.len());
246
247        #[target_feature(enable = "avx2")]
248        unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
249            use std::arch::x86_64::*;
250            unsafe {
251                let n = a.len();
252                let chunks = n / 8;
253                let mut sum = _mm256_setzero_ps();
254
255                for i in 0..chunks {
256                    let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
257                    let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
258                    let diff = _mm256_sub_ps(va, vb);
259                    sum = _mm256_fmadd_ps(diff, diff, sum);
260                }
261
262                // Horizontal sum
263                let sum128 =
264                    _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
265                let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
266                let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
267                let mut result = _mm_cvtss_f32(sum32);
268
269                // Handle remainder
270                for i in (chunks * 8)..n {
271                    let diff = a[i] - b[i];
272                    result += diff * diff;
273                }
274
275                result
276            }
277        }
278
279        if is_x86_feature_detected!("avx2") {
280            unsafe { inner(a, b) }
281        } else {
282            ScalarKernel.l2_squared_f32(a, b)
283        }
284    }
285
286    fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32 {
287        debug_assert_eq!(a.len(), b.len());
288
289        #[target_feature(enable = "avx2")]
290        unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
291            use std::arch::x86_64::*;
292            unsafe {
293                let n = a.len();
294                let chunks = n / 8;
295                let mut sum = _mm256_setzero_ps();
296
297                for i in 0..chunks {
298                    let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
299                    let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
300                    sum = _mm256_fmadd_ps(va, vb, sum);
301                }
302
303                // Horizontal sum
304                let sum128 =
305                    _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
306                let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
307                let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
308                let mut result = _mm_cvtss_f32(sum32);
309
310                // Handle remainder
311                for i in (chunks * 8)..n {
312                    result += a[i] * b[i];
313                }
314
315                result
316            }
317        }
318
319        if is_x86_feature_detected!("avx2") {
320            unsafe { inner(a, b) }
321        } else {
322            ScalarKernel.dot_f32(a, b)
323        }
324    }
325
326    fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32 {
327        debug_assert_eq!(a.len(), b.len());
328
329        #[target_feature(enable = "avx2")]
330        unsafe fn inner(a: &[i8], b: &[i8]) -> i32 {
331            use std::arch::x86_64::*;
332            unsafe {
333                let n = a.len();
334                let chunks = n / 32;
335                let mut sum = _mm256_setzero_si256();
336
337                for i in 0..chunks {
338                    let va = _mm256_loadu_si256(a.as_ptr().add(i * 32) as *const __m256i);
339                    let vb = _mm256_loadu_si256(b.as_ptr().add(i * 32) as *const __m256i);
340
341                    // Unpack to i16 and multiply
342                    let a_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 0));
343                    let b_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 0));
344                    let a_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
345                    let b_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
346
347                    let prod_lo = _mm256_madd_epi16(a_lo, b_lo);
348                    let prod_hi = _mm256_madd_epi16(a_hi, b_hi);
349
350                    sum = _mm256_add_epi32(sum, prod_lo);
351                    sum = _mm256_add_epi32(sum, prod_hi);
352                }
353
354                // Horizontal sum
355                let sum128 = _mm_add_epi32(
356                    _mm256_extracti128_si256(sum, 0),
357                    _mm256_extracti128_si256(sum, 1),
358                );
359                let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
360                let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
361                let mut result = _mm_cvtsi128_si32(sum32);
362
363                // Handle remainder
364                for i in (chunks * 32)..n {
365                    result += a[i] as i32 * b[i] as i32;
366                }
367
368                result
369            }
370        }
371
372        if is_x86_feature_detected!("avx2") {
373            unsafe { inner(a, b) }
374        } else {
375            ScalarKernel.dot_i8(a, b)
376        }
377    }
378
379    fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
380        let n = vectors.len() / dim;
381        for i in 0..n {
382            let vec = &vectors[i * dim..(i + 1) * dim];
383            out[i] = self.l2_squared_f32(query, vec);
384        }
385    }
386
387    fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
388        let n = vectors.len() / dim;
389        for i in 0..n {
390            let vec = &vectors[i * dim..(i + 1) * dim];
391            out[i] = self.dot_f32(query, vec);
392        }
393    }
394
395    fn simd_level(&self) -> SimdLevel {
396        SimdLevel::Avx2
397    }
398}
399
400// ============================================================================
401// NEON Implementation
402// ============================================================================
403
404#[cfg(target_arch = "aarch64")]
405pub struct NeonKernel;
406
407#[cfg(target_arch = "aarch64")]
408impl DistanceKernel for NeonKernel {
409    fn l2_squared_f32(&self, a: &[f32], b: &[f32]) -> f32 {
410        debug_assert_eq!(a.len(), b.len());
411
412        unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
413            use std::arch::aarch64::*;
414            unsafe {
415                let n = a.len();
416                let chunks = n / 4;
417                let mut sum = vdupq_n_f32(0.0);
418
419                for i in 0..chunks {
420                    let va = vld1q_f32(a.as_ptr().add(i * 4));
421                    let vb = vld1q_f32(b.as_ptr().add(i * 4));
422                    let diff = vsubq_f32(va, vb);
423                    sum = vfmaq_f32(sum, diff, diff);
424                }
425
426                // Horizontal sum
427                let mut result = vaddvq_f32(sum);
428
429                // Handle remainder
430                for i in (chunks * 4)..n {
431                    let diff = a[i] - b[i];
432                    result += diff * diff;
433                }
434
435                result
436            }
437        }
438
439        unsafe { inner(a, b) }
440    }
441
442    fn dot_f32(&self, a: &[f32], b: &[f32]) -> f32 {
443        debug_assert_eq!(a.len(), b.len());
444
445        unsafe fn inner(a: &[f32], b: &[f32]) -> f32 {
446            use std::arch::aarch64::*;
447            unsafe {
448                let n = a.len();
449                let chunks = n / 4;
450                let mut sum = vdupq_n_f32(0.0);
451
452                for i in 0..chunks {
453                    let va = vld1q_f32(a.as_ptr().add(i * 4));
454                    let vb = vld1q_f32(b.as_ptr().add(i * 4));
455                    sum = vfmaq_f32(sum, va, vb);
456                }
457
458                let mut result = vaddvq_f32(sum);
459
460                for i in (chunks * 4)..n {
461                    result += a[i] * b[i];
462                }
463
464                result
465            }
466        }
467
468        unsafe { inner(a, b) }
469    }
470
471    fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32 {
472        debug_assert_eq!(a.len(), b.len());
473
474        unsafe fn inner(a: &[i8], b: &[i8]) -> i32 {
475            use std::arch::aarch64::*;
476            unsafe {
477                let n = a.len();
478                let chunks = n / 16;
479                let mut sum = vdupq_n_s32(0);
480
481                for i in 0..chunks {
482                    let va = vld1q_s8(a.as_ptr().add(i * 16));
483                    let vb = vld1q_s8(b.as_ptr().add(i * 16));
484
485                    // Multiply and accumulate using SDOT if available, else manual
486                    let a_lo = vmovl_s8(vget_low_s8(va));
487                    let b_lo = vmovl_s8(vget_low_s8(vb));
488                    let a_hi = vmovl_s8(vget_high_s8(va));
489                    let b_hi = vmovl_s8(vget_high_s8(vb));
490
491                    let prod_lo = vmull_s16(vget_low_s16(a_lo), vget_low_s16(b_lo));
492                    let prod_hi = vmull_s16(vget_high_s16(a_lo), vget_high_s16(b_lo));
493
494                    sum = vaddq_s32(sum, prod_lo);
495                    sum = vaddq_s32(sum, prod_hi);
496
497                    let prod_lo2 = vmull_s16(vget_low_s16(a_hi), vget_low_s16(b_hi));
498                    let prod_hi2 = vmull_s16(vget_high_s16(a_hi), vget_high_s16(b_hi));
499
500                    sum = vaddq_s32(sum, prod_lo2);
501                    sum = vaddq_s32(sum, prod_hi2);
502                }
503
504                let mut result = vaddvq_s32(sum);
505
506                for i in (chunks * 16)..n {
507                    result += a[i] as i32 * b[i] as i32;
508                }
509
510                result
511            }
512        }
513
514        unsafe { inner(a, b) }
515    }
516
517    fn l2_squared_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
518        let n = vectors.len() / dim;
519        for i in 0..n {
520            let vec = &vectors[i * dim..(i + 1) * dim];
521            out[i] = self.l2_squared_f32(query, vec);
522        }
523    }
524
525    fn dot_batch_f32(&self, query: &[f32], vectors: &[f32], dim: usize, out: &mut [f32]) {
526        let n = vectors.len() / dim;
527        for i in 0..n {
528            let vec = &vectors[i * dim..(i + 1) * dim];
529            out[i] = self.dot_f32(query, vec);
530        }
531    }
532
533    fn simd_level(&self) -> SimdLevel {
534        SimdLevel::Neon
535    }
536}
537
538// ============================================================================
539// Kernel Dispatcher
540// ============================================================================
541
542/// Global kernel dispatcher that selects best implementation
543pub struct KernelDispatcher {
544    features: CpuFeatures,
545}
546
547impl KernelDispatcher {
548    /// Create new dispatcher with runtime feature detection
549    pub fn new() -> Self {
550        Self {
551            features: CpuFeatures::detect(),
552        }
553    }
554
555    /// Get the best available kernel
556    pub fn best_kernel(&self) -> Box<dyn DistanceKernel> {
557        #[cfg(target_arch = "x86_64")]
558        {
559            if self.features.avx2 {
560                return Box::new(Avx2Kernel);
561            }
562        }
563
564        #[cfg(target_arch = "aarch64")]
565        {
566            if self.features.neon {
567                return Box::new(NeonKernel);
568            }
569        }
570
571        Box::new(ScalarKernel)
572    }
573
574    /// Get kernel for specific SIMD level
575    pub fn kernel_for_level(&self, level: SimdLevel) -> Box<dyn DistanceKernel> {
576        match level {
577            #[cfg(target_arch = "x86_64")]
578            SimdLevel::Avx2 if self.features.avx2 => Box::new(Avx2Kernel),
579
580            #[cfg(target_arch = "aarch64")]
581            SimdLevel::Neon if self.features.neon => Box::new(NeonKernel),
582
583            _ => Box::new(ScalarKernel),
584        }
585    }
586
587    /// Get detected features
588    pub fn features(&self) -> CpuFeatures {
589        self.features
590    }
591
592    /// Get description of selected kernel
593    pub fn description(&self) -> String {
594        format!(
595            "SIMD: {:?}, Features: avx2={}, neon={}",
596            self.features.best_simd_level(),
597            self.features.avx2,
598            self.features.neon,
599        )
600    }
601}
602
603impl Default for KernelDispatcher {
604    fn default() -> Self {
605        Self::new()
606    }
607}
608
609// ============================================================================
610// Scan Operations
611// ============================================================================
612
613/// High-level scan operations using best available SIMD
614pub struct ScanOps {
615    kernel: Box<dyn DistanceKernel>,
616}
617
618impl ScanOps {
619    /// Create with automatic kernel selection
620    pub fn new() -> Self {
621        Self {
622            kernel: KernelDispatcher::new().best_kernel(),
623        }
624    }
625
626    /// Create with specific kernel
627    pub fn with_kernel(kernel: Box<dyn DistanceKernel>) -> Self {
628        Self { kernel }
629    }
630
631    /// Scan vectors and return top-k by L2 distance
632    pub fn top_k_l2(
633        &self,
634        query: &[f32],
635        vectors: &[f32],
636        dim: usize,
637        k: usize,
638    ) -> Vec<(u32, f32)> {
639        let n = vectors.len() / dim;
640        let mut distances = vec![0.0f32; n];
641
642        self.kernel
643            .l2_squared_batch_f32(query, vectors, dim, &mut distances);
644
645        // Get top-k indices. Use total_cmp, not partial_cmp().unwrap(): a NaN
646        // distance (e.g. from an Inf in the input vector — Inf-Inf=NaN) makes
647        // partial_cmp return None and the unwrap panics. total_cmp is a true
648        // total order that sinks NaN deterministically.
649        let mut indices: Vec<usize> = (0..n).collect();
650        indices.sort_by(|&a, &b| distances[a].total_cmp(&distances[b]));
651
652        indices
653            .into_iter()
654            .take(k)
655            .map(|i| (i as u32, distances[i].sqrt()))
656            .collect()
657    }
658
659    /// Scan vectors and return top-k by dot product (descending)
660    pub fn top_k_dot(
661        &self,
662        query: &[f32],
663        vectors: &[f32],
664        dim: usize,
665        k: usize,
666    ) -> Vec<(u32, f32)> {
667        let n = vectors.len() / dim;
668        let mut scores = vec![0.0f32; n];
669
670        self.kernel.dot_batch_f32(query, vectors, dim, &mut scores);
671
672        // Get top-k indices (descending for dot product). total_cmp avoids the
673        // NaN panic that partial_cmp().unwrap() would hit on Inf/NaN scores.
674        let mut indices: Vec<usize> = (0..n).collect();
675        indices.sort_by(|&a, &b| scores[b].total_cmp(&scores[a]));
676
677        indices
678            .into_iter()
679            .take(k)
680            .map(|i| (i as u32, scores[i]))
681            .collect()
682    }
683
684    /// Get SIMD level being used
685    pub fn simd_level(&self) -> SimdLevel {
686        self.kernel.simd_level()
687    }
688}
689
690impl Default for ScanOps {
691    fn default() -> Self {
692        Self::new()
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use super::*;
699
700    #[test]
701    fn test_scalar_l2() {
702        let kernel = ScalarKernel;
703        let a = vec![1.0, 2.0, 3.0, 4.0];
704        let b = vec![1.0, 2.0, 3.0, 5.0];
705
706        let dist = kernel.l2_squared_f32(&a, &b);
707        assert!((dist - 1.0).abs() < 1e-6);
708    }
709
710    #[test]
711    fn test_scalar_dot() {
712        let kernel = ScalarKernel;
713        let a = vec![1.0, 2.0, 3.0, 4.0];
714        let b = vec![1.0, 2.0, 3.0, 4.0];
715
716        let dot = kernel.dot_f32(&a, &b);
717        assert!((dot - 30.0).abs() < 1e-6);
718    }
719
720    #[test]
721    fn test_scalar_dot_i8() {
722        let kernel = ScalarKernel;
723        let a: Vec<i8> = vec![1, 2, 3, 4];
724        let b: Vec<i8> = vec![1, 2, 3, 4];
725
726        let dot = kernel.dot_i8(&a, &b);
727        assert_eq!(dot, 30);
728    }
729
730    #[test]
731    fn test_dispatcher() {
732        let dispatcher = KernelDispatcher::new();
733        let kernel = dispatcher.best_kernel();
734
735        let a = vec![1.0f32; 128];
736        let b = vec![2.0f32; 128];
737
738        let l2 = kernel.l2_squared_f32(&a, &b);
739        assert!((l2 - 128.0).abs() < 1e-4);
740
741        let dot = kernel.dot_f32(&a, &b);
742        assert!((dot - 256.0).abs() < 1e-4);
743    }
744
745    #[test]
746    fn test_scan_ops() {
747        let ops = ScanOps::new();
748
749        let query = vec![1.0, 0.0, 0.0, 0.0];
750        let vectors = vec![
751            1.0, 0.0, 0.0, 0.0, // Distance 0
752            0.0, 1.0, 0.0, 0.0, // Distance sqrt(2)
753            0.5, 0.5, 0.0, 0.0, // Distance ~0.7
754        ];
755
756        let top2 = ops.top_k_l2(&query, &vectors, 4, 2);
757
758        assert_eq!(top2.len(), 2);
759        assert_eq!(top2[0].0, 0); // First vector is closest
760    }
761
762    #[test]
763    fn test_cpu_features() {
764        let features = CpuFeatures::detect();
765        let level = features.best_simd_level();
766
767        // Just verify it doesn't crash
768        println!("Detected SIMD level: {:?}", level);
769        assert!(level.width_bytes() > 0);
770    }
771}