Skip to main content

rustyhdf5_accel/
lib.rs

1//! SIMD-accelerated operations for rustyhdf5.
2//!
3//! This crate provides runtime-dispatched SIMD acceleration for common
4//! vector operations used in HDF5 processing: dot products, cosine similarity,
5//! L2 distance, f16 conversion, and checksums.
6//!
7//! All public functions automatically select the best available SIMD backend
8//! at runtime. Every operation has a portable scalar fallback.
9
10pub mod scalar;
11
12#[cfg(target_arch = "aarch64")]
13pub mod neon;
14
15#[cfg(target_arch = "x86_64")]
16pub mod avx2;
17
18#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
19pub mod avx512;
20
21pub mod checksum;
22pub mod convert;
23
24// ---------------------------------------------------------------------------
25// Cache-line size detection (TVL — Tensor Virtualization Layout)
26// ---------------------------------------------------------------------------
27
28/// Cache line size in bytes for the target architecture.
29///
30/// ARM64 (Apple M-series, Cortex-A76+) uses 128-byte cache lines.
31/// x86_64 uses 64-byte cache lines. Other architectures default to 64.
32#[cfg(target_arch = "aarch64")]
33pub const CACHE_LINE_SIZE: usize = 128;
34
35#[cfg(target_arch = "x86_64")]
36pub const CACHE_LINE_SIZE: usize = 64;
37
38#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
39pub const CACHE_LINE_SIZE: usize = 64;
40
41/// Round `size` up to the next multiple of [`CACHE_LINE_SIZE`].
42#[inline]
43pub fn align_to_cache_line(size: usize) -> usize {
44    (size + CACHE_LINE_SIZE - 1) & !(CACHE_LINE_SIZE - 1)
45}
46
47/// Available SIMD backends.
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum Backend {
50    /// ARM NEON (always available on aarch64)
51    Neon,
52    /// x86_64 AVX2 + FMA
53    Avx2,
54    /// x86_64 AVX-512F
55    Avx512,
56    /// x86_64 SSE4.1
57    Sse4,
58    /// WebAssembly SIMD128
59    WasmSimd128,
60    /// Portable scalar fallback
61    Scalar,
62}
63
64/// Detect the best available SIMD backend at runtime.
65pub fn detect_backend() -> Backend {
66    #[cfg(target_arch = "aarch64")]
67    {
68        return Backend::Neon; // Always available on aarch64
69    }
70
71    #[cfg(target_arch = "x86_64")]
72    {
73        #[cfg(feature = "avx512")]
74        {
75            if is_x86_feature_detected!("avx512f") {
76                return Backend::Avx512;
77            }
78        }
79        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
80            return Backend::Avx2;
81        }
82        if is_x86_feature_detected!("sse4.1") {
83            return Backend::Sse4;
84        }
85    }
86
87    #[cfg(target_arch = "wasm32")]
88    {
89        return Backend::WasmSimd128;
90    }
91
92    #[allow(unreachable_code)]
93    Backend::Scalar
94}
95
96// ---------------------------------------------------------------------------
97// Public API — auto-dispatched
98// ---------------------------------------------------------------------------
99
100/// Compute the dot product of two f32 slices.
101pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
102    match detect_backend() {
103        #[cfg(target_arch = "aarch64")]
104        Backend::Neon => unsafe { neon::dot_product(a, b) },
105
106        #[cfg(all(target_arch = "x86_64", feature = "avx512"))]
107        Backend::Avx512 => unsafe { avx512::dot_product(a, b) },
108
109        #[cfg(target_arch = "x86_64")]
110        Backend::Avx2 => unsafe { avx2::dot_product(a, b) },
111
112        _ => scalar::dot_product(a, b),
113    }
114}
115
116/// Compute the L2 norm (magnitude) of a vector.
117pub fn vector_norm(v: &[f32]) -> f32 {
118    dot_product(v, v).sqrt()
119}
120
121/// Compute cosine similarity between two vectors (fused single-pass).
122pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
123    match detect_backend() {
124        #[cfg(target_arch = "aarch64")]
125        Backend::Neon => unsafe { neon::cosine_similarity(a, b) },
126
127        #[cfg(all(target_arch = "x86_64", feature = "avx512"))]
128        Backend::Avx512 => unsafe { avx512::cosine_similarity(a, b) },
129
130        #[cfg(target_arch = "x86_64")]
131        Backend::Avx2 => unsafe { avx2::cosine_similarity(a, b) },
132
133        _ => scalar::cosine_similarity(a, b),
134    }
135}
136
137/// Compute cosine similarity between a query and multiple vectors.
138///
139/// Results are stored as `(index, similarity)` pairs.
140pub fn batch_cosine(query: &[f32], vectors: &[&[f32]], results: &mut [(usize, f32)]) {
141    assert!(results.len() >= vectors.len());
142    for (i, v) in vectors.iter().enumerate() {
143        results[i] = (i, cosine_similarity(query, v));
144    }
145}
146
147/// Compute cosine similarity with pre-normalized query vector.
148///
149/// `query_normed` must already be unit-length. `norms` contains the L2 norms
150/// of each vector in `vectors`.
151pub fn batch_cosine_prenorm(
152    query_normed: &[f32],
153    vectors: &[&[f32]],
154    norms: &[f32],
155    results: &mut [(usize, f32)],
156) {
157    assert!(results.len() >= vectors.len());
158    assert!(norms.len() >= vectors.len());
159    for (i, v) in vectors.iter().enumerate() {
160        let dot = dot_product(query_normed, v);
161        let sim = if norms[i] == 0.0 { 0.0 } else { dot / norms[i] };
162        results[i] = (i, sim);
163    }
164}
165
166/// Compute L2 (Euclidean) distance between two vectors.
167pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
168    match detect_backend() {
169        #[cfg(target_arch = "aarch64")]
170        Backend::Neon => unsafe { neon::l2_distance(a, b) },
171
172        #[cfg(all(target_arch = "x86_64", feature = "avx512"))]
173        Backend::Avx512 => unsafe { avx512::l2_distance(a, b) },
174
175        #[cfg(target_arch = "x86_64")]
176        Backend::Avx2 => unsafe { avx2::l2_distance(a, b) },
177
178        _ => scalar::l2_distance(a, b),
179    }
180}
181
182/// Compute L2 norms for a batch of vectors.
183pub fn batch_norms(vectors: &[&[f32]], norms: &mut [f32]) {
184    assert!(norms.len() >= vectors.len());
185    for (i, v) in vectors.iter().enumerate() {
186        norms[i] = vector_norm(v);
187    }
188}
189
190/// Convert a batch of f16 values (as raw u16 bits) to f32.
191pub fn f16_to_f32_batch(input: &[u16], output: &mut [f32]) {
192    convert::f16_to_f32_batch(input, output);
193}
194
195/// Compute Fletcher-32 checksum.
196pub fn checksum_fletcher32(data: &[u8]) -> u32 {
197    checksum::checksum_fletcher32(data)
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    const EPSILON: f32 = 1e-5;
205
206    fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
207        (a - b).abs() < eps
208    }
209
210    // -----------------------------------------------------------------------
211    // Backend detection
212    // -----------------------------------------------------------------------
213
214    #[test]
215    fn test_detect_backend_returns_valid() {
216        let backend = detect_backend();
217        match backend {
218            Backend::Neon
219            | Backend::Avx2
220            | Backend::Avx512
221            | Backend::Sse4
222            | Backend::WasmSimd128
223            | Backend::Scalar => {}
224        }
225    }
226
227    #[test]
228    fn test_detect_backend_consistent() {
229        let b1 = detect_backend();
230        let b2 = detect_backend();
231        assert_eq!(b1, b2);
232    }
233
234    // -----------------------------------------------------------------------
235    // Dot product
236    // -----------------------------------------------------------------------
237
238    #[test]
239    fn test_dot_product_known_values() {
240        let a = [1.0, 2.0, 3.0, 4.0];
241        let b = [5.0, 6.0, 7.0, 8.0];
242        // 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70
243        let result = dot_product(&a, &b);
244        assert!(approx_eq(result, 70.0, EPSILON), "got {result}");
245    }
246
247    #[test]
248    fn test_dot_product_zero_vectors() {
249        let a = [0.0f32; 16];
250        let b = [1.0f32; 16];
251        assert!(approx_eq(dot_product(&a, &b), 0.0, EPSILON));
252    }
253
254    #[test]
255    fn test_dot_product_unit_vectors() {
256        let mut a = [0.0f32; 3];
257        let mut b = [0.0f32; 3];
258        a[0] = 1.0;
259        b[0] = 1.0;
260        assert!(approx_eq(dot_product(&a, &b), 1.0, EPSILON));
261    }
262
263    #[test]
264    fn test_dot_product_large_random() {
265        let n = 1024;
266        let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.01).collect();
267        let b: Vec<f32> = (0..n).map(|i| ((n - i) as f32) * 0.01).collect();
268        let scalar_result = scalar::dot_product(&a, &b);
269        let simd_result = dot_product(&a, &b);
270        assert!(
271            approx_eq(scalar_result, simd_result, 0.1),
272            "scalar={scalar_result} simd={simd_result}"
273        );
274    }
275
276    #[test]
277    fn test_dot_product_negative_values() {
278        let a = [-1.0, -2.0, -3.0];
279        let b = [1.0, 2.0, 3.0];
280        assert!(approx_eq(dot_product(&a, &b), -14.0, EPSILON));
281    }
282
283    #[test]
284    fn test_dot_product_single_element() {
285        assert!(approx_eq(dot_product(&[3.0], &[4.0]), 12.0, EPSILON));
286    }
287
288    #[test]
289    fn test_dot_product_empty() {
290        assert!(approx_eq(dot_product(&[], &[]), 0.0, EPSILON));
291    }
292
293    #[test]
294    fn test_dot_product_scalar_vs_dispatch() {
295        let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
296        let b: Vec<f32> = (0..384).map(|i| (i as f32).cos()).collect();
297        let s = scalar::dot_product(&a, &b);
298        let d = dot_product(&a, &b);
299        assert!(
300            approx_eq(s, d, 0.01),
301            "scalar={s} dispatched={d}"
302        );
303    }
304
305    // -----------------------------------------------------------------------
306    // Vector norm
307    // -----------------------------------------------------------------------
308
309    #[test]
310    fn test_vector_norm_unit() {
311        let v = [1.0, 0.0, 0.0];
312        assert!(approx_eq(vector_norm(&v), 1.0, EPSILON));
313    }
314
315    #[test]
316    fn test_vector_norm_345() {
317        let v = [3.0, 4.0];
318        assert!(approx_eq(vector_norm(&v), 5.0, EPSILON));
319    }
320
321    #[test]
322    fn test_vector_norm_zero() {
323        let v = [0.0f32; 10];
324        assert!(approx_eq(vector_norm(&v), 0.0, EPSILON));
325    }
326
327    // -----------------------------------------------------------------------
328    // Cosine similarity
329    // -----------------------------------------------------------------------
330
331    #[test]
332    fn test_cosine_identical_is_one() {
333        let v = [1.0, 2.0, 3.0, 4.0, 5.0];
334        assert!(approx_eq(cosine_similarity(&v, &v), 1.0, EPSILON));
335    }
336
337    #[test]
338    fn test_cosine_opposite_is_neg_one() {
339        let a = [1.0, 2.0, 3.0];
340        let b = [-1.0, -2.0, -3.0];
341        assert!(approx_eq(cosine_similarity(&a, &b), -1.0, EPSILON));
342    }
343
344    #[test]
345    fn test_cosine_orthogonal_is_zero() {
346        let a = [1.0, 0.0, 0.0, 0.0];
347        let b = [0.0, 1.0, 0.0, 0.0];
348        assert!(approx_eq(cosine_similarity(&a, &b), 0.0, EPSILON));
349    }
350
351    #[test]
352    fn test_cosine_zero_vector() {
353        let a = [0.0f32; 4];
354        let b = [1.0, 2.0, 3.0, 4.0];
355        assert!(approx_eq(cosine_similarity(&a, &b), 0.0, EPSILON));
356    }
357
358    #[test]
359    fn test_cosine_scalar_vs_dispatch() {
360        let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
361        let b: Vec<f32> = (0..384).map(|i| (i as f32 * 0.7).cos()).collect();
362        let s = scalar::cosine_similarity(&a, &b);
363        let d = cosine_similarity(&a, &b);
364        assert!(
365            approx_eq(s, d, 1e-4),
366            "scalar={s} dispatched={d}"
367        );
368    }
369
370    // -----------------------------------------------------------------------
371    // Batch cosine
372    // -----------------------------------------------------------------------
373
374    #[test]
375    fn test_batch_cosine_ranking_order() {
376        let query = [1.0, 0.0, 0.0];
377        let v0: Vec<f32> = vec![0.0, 1.0, 0.0]; // orthogonal = 0
378        let v1: Vec<f32> = vec![1.0, 0.0, 0.0]; // identical = 1
379        let v2: Vec<f32> = vec![0.5, 0.5, 0.0]; // in between
380        let vectors: Vec<&[f32]> = vec![&v0, &v1, &v2];
381        let mut results = vec![(0usize, 0.0f32); 3];
382        batch_cosine(&query, &vectors, &mut results);
383
384        // v1 should have highest similarity
385        assert!(results[1].1 > results[2].1);
386        assert!(results[2].1 > results[0].1);
387    }
388
389    #[test]
390    fn test_batch_cosine_scalar_vs_dispatch() {
391        let query: Vec<f32> = (0..32).map(|i| (i as f32).sin()).collect();
392        let v0: Vec<f32> = (0..32).map(|i| (i as f32).cos()).collect();
393        let v1: Vec<f32> = (0..32).map(|i| (i as f32 * 2.0).sin()).collect();
394        let vectors: Vec<&[f32]> = vec![&v0, &v1];
395
396        let mut scalar_results = vec![(0usize, 0.0f32); 2];
397        scalar::batch_cosine(&query, &vectors, &mut scalar_results);
398
399        let mut simd_results = vec![(0usize, 0.0f32); 2];
400        batch_cosine(&query, &vectors, &mut simd_results);
401
402        for i in 0..2 {
403            assert!(
404                approx_eq(scalar_results[i].1, simd_results[i].1, 1e-4),
405                "mismatch at {i}: scalar={} simd={}",
406                scalar_results[i].1,
407                simd_results[i].1
408            );
409        }
410    }
411
412    // -----------------------------------------------------------------------
413    // Batch cosine prenorm
414    // -----------------------------------------------------------------------
415
416    #[test]
417    fn test_batch_cosine_prenorm() {
418        let query = [1.0, 0.0, 0.0]; // already unit-length
419        let v0: Vec<f32> = vec![3.0, 4.0, 0.0];
420        let v1: Vec<f32> = vec![0.0, 0.0, 5.0];
421        let vectors: Vec<&[f32]> = vec![&v0, &v1];
422        let norms = [5.0, 5.0];
423        let mut results = vec![(0usize, 0.0f32); 2];
424        batch_cosine_prenorm(&query, &vectors, &norms, &mut results);
425        // dot(query, v0) = 3.0, sim = 3.0/5.0 = 0.6
426        assert!(approx_eq(results[0].1, 0.6, EPSILON));
427        // dot(query, v1) = 0.0, sim = 0.0
428        assert!(approx_eq(results[1].1, 0.0, EPSILON));
429    }
430
431    // -----------------------------------------------------------------------
432    // L2 distance
433    // -----------------------------------------------------------------------
434
435    #[test]
436    fn test_l2_distance_same_is_zero() {
437        let v = [1.0, 2.0, 3.0, 4.0];
438        assert!(approx_eq(l2_distance(&v, &v), 0.0, EPSILON));
439    }
440
441    #[test]
442    fn test_l2_distance_known_triangle() {
443        let a = [0.0, 0.0];
444        let b = [3.0, 4.0];
445        assert!(approx_eq(l2_distance(&a, &b), 5.0, EPSILON));
446    }
447
448    #[test]
449    fn test_l2_distance_unit_axes() {
450        let a = [1.0, 0.0, 0.0];
451        let b = [0.0, 1.0, 0.0];
452        assert!(approx_eq(l2_distance(&a, &b), 2.0f32.sqrt(), EPSILON));
453    }
454
455    #[test]
456    fn test_l2_distance_scalar_vs_dispatch() {
457        let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
458        let b: Vec<f32> = (0..384).map(|i| (i as f32).cos()).collect();
459        let s = scalar::l2_distance(&a, &b);
460        let d = l2_distance(&a, &b);
461        assert!(
462            approx_eq(s, d, 0.01),
463            "scalar={s} dispatched={d}"
464        );
465    }
466
467    // -----------------------------------------------------------------------
468    // Batch norms
469    // -----------------------------------------------------------------------
470
471    #[test]
472    fn test_batch_norms() {
473        let v0: Vec<f32> = vec![3.0, 4.0];
474        let v1: Vec<f32> = vec![0.0, 0.0];
475        let v2: Vec<f32> = vec![1.0, 0.0, 0.0];
476        let vectors: Vec<&[f32]> = vec![&v0, &v1, &v2];
477        let mut norms = vec![0.0f32; 3];
478        batch_norms(&vectors, &mut norms);
479        assert!(approx_eq(norms[0], 5.0, EPSILON));
480        assert!(approx_eq(norms[1], 0.0, EPSILON));
481        assert!(approx_eq(norms[2], 1.0, EPSILON));
482    }
483
484    // -----------------------------------------------------------------------
485    // f16 conversion
486    // -----------------------------------------------------------------------
487
488    #[test]
489    fn test_f16_to_f32_known_values() {
490        // f16 representation of 1.0 = 0x3C00
491        let input = [0x3C00u16, 0x4000, 0x0000]; // 1.0, 2.0, 0.0
492        let mut output = [0.0f32; 3];
493        f16_to_f32_batch(&input, &mut output);
494        assert!(approx_eq(output[0], 1.0, EPSILON), "got {}", output[0]);
495        assert!(approx_eq(output[1], 2.0, EPSILON), "got {}", output[1]);
496        assert!(approx_eq(output[2], 0.0, EPSILON), "got {}", output[2]);
497    }
498
499    #[test]
500    fn test_f16_to_f32_negative() {
501        // f16 -1.0 = 0xBC00
502        let input = [0xBC00u16];
503        let mut output = [0.0f32; 1];
504        f16_to_f32_batch(&input, &mut output);
505        assert!(approx_eq(output[0], -1.0, EPSILON), "got {}", output[0]);
506    }
507
508    #[test]
509    fn test_f16_to_f32_batch_larger() {
510        // Test with a larger batch to exercise SIMD paths
511        let input: Vec<u16> = (0..32).map(|_| 0x3C00u16).collect(); // all 1.0
512        let mut output = vec![0.0f32; 32];
513        f16_to_f32_batch(&input, &mut output);
514        for (i, &v) in output.iter().enumerate() {
515            assert!(approx_eq(v, 1.0, EPSILON), "mismatch at {i}: {v}");
516        }
517    }
518
519    #[test]
520    fn test_f16_to_f32_round_trip_accuracy() {
521        // Test several known f16 bit patterns
522        let cases: Vec<(u16, f32)> = vec![
523            (0x3C00, 1.0),
524            (0x4000, 2.0),
525            (0x3800, 0.5),
526            (0x4200, 3.0),
527            (0x4400, 4.0),
528            (0x0000, 0.0),
529            (0x8000, -0.0),
530        ];
531        let input: Vec<u16> = cases.iter().map(|(bits, _)| *bits).collect();
532        let mut output = vec![0.0f32; cases.len()];
533        f16_to_f32_batch(&input, &mut output);
534        for (i, (_, expected)) in cases.iter().enumerate() {
535            assert!(
536                approx_eq(output[i], *expected, EPSILON),
537                "f16 0x{:04X}: expected {expected}, got {}",
538                input[i],
539                output[i]
540            );
541        }
542    }
543
544    // -----------------------------------------------------------------------
545    // Fletcher-32 checksum
546    // -----------------------------------------------------------------------
547
548    #[test]
549    fn test_fletcher32_empty() {
550        let result = checksum_fletcher32(&[]);
551        // Both sums remain 0xFFFF
552        assert_eq!(result, 0xFFFF_FFFF);
553    }
554
555    #[test]
556    fn test_fletcher32_known() {
557        let data = [0x00u8, 0x01, 0x00, 0x02];
558        let result = checksum_fletcher32(&data);
559        let scalar = scalar::checksum_fletcher32(&data);
560        assert_eq!(result, scalar);
561    }
562
563    #[test]
564    fn test_fletcher32_scalar_vs_dispatch() {
565        let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
566        let s = scalar::checksum_fletcher32(&data);
567        let d = checksum_fletcher32(&data);
568        assert_eq!(s, d);
569    }
570
571    // -----------------------------------------------------------------------
572    // Performance sanity check
573    // -----------------------------------------------------------------------
574
575    #[test]
576    fn test_dot_product_384_dim_perf() {
577        use std::time::Instant;
578        let a: Vec<f32> = (0..384).map(|i| (i as f32).sin()).collect();
579        let b: Vec<f32> = (0..384).map(|i| (i as f32).cos()).collect();
580
581        // Warm up
582        for _ in 0..100 {
583            let _ = dot_product(&a, &b);
584        }
585
586        let start = Instant::now();
587        let iterations = 10_000;
588        let mut sum = 0.0f32;
589        for _ in 0..iterations {
590            sum += dot_product(&a, &b);
591        }
592        let elapsed = start.elapsed();
593        let per_call = elapsed / iterations;
594        // Prevent optimization
595        assert!(sum.abs() >= 0.0);
596
597        // In release mode, 384-dim dot product should be < 1µs.
598        // In debug mode, allow up to 20µs (no optimizations).
599        let limit_ns = if cfg!(debug_assertions) { 20_000 } else { 1_000 };
600        assert!(
601            per_call.as_nanos() < limit_ns,
602            "dot product too slow: {per_call:?} per call (limit {limit_ns}ns)"
603        );
604    }
605
606    // -----------------------------------------------------------------------
607    // Cache-line alignment (TVL)
608    // -----------------------------------------------------------------------
609
610    #[test]
611    fn test_cache_line_size_is_power_of_two() {
612        assert!(CACHE_LINE_SIZE.is_power_of_two());
613    }
614
615    #[test]
616    fn test_cache_line_size_platform() {
617        #[cfg(target_arch = "aarch64")]
618        assert_eq!(CACHE_LINE_SIZE, 128);
619        #[cfg(target_arch = "x86_64")]
620        assert_eq!(CACHE_LINE_SIZE, 64);
621    }
622
623    #[test]
624    fn test_align_to_cache_line() {
625        assert_eq!(align_to_cache_line(0), 0);
626        assert_eq!(align_to_cache_line(1), CACHE_LINE_SIZE);
627        assert_eq!(align_to_cache_line(CACHE_LINE_SIZE), CACHE_LINE_SIZE);
628        assert_eq!(align_to_cache_line(CACHE_LINE_SIZE + 1), CACHE_LINE_SIZE * 2);
629        assert_eq!(align_to_cache_line(CACHE_LINE_SIZE * 3), CACHE_LINE_SIZE * 3);
630    }
631
632    #[test]
633    fn test_align_to_cache_line_64_and_128() {
634        // Both 64 and 128 alignment scenarios
635        let val = align_to_cache_line(100);
636        assert_eq!(val % CACHE_LINE_SIZE, 0);
637        assert!(val >= 100);
638        assert!(val < 100 + CACHE_LINE_SIZE);
639    }
640
641    // -----------------------------------------------------------------------
642    // Edge cases / additional coverage
643    // -----------------------------------------------------------------------
644
645    #[test]
646    fn test_dot_product_non_aligned_length() {
647        // Test with lengths that don't align to SIMD widths (not multiple of 4, 8, 16)
648        for len in [1, 3, 5, 7, 9, 13, 17, 31, 33] {
649            let a: Vec<f32> = (0..len).map(|i| i as f32).collect();
650            let b: Vec<f32> = (0..len).map(|i| (i as f32) * 0.5).collect();
651            let s = scalar::dot_product(&a, &b);
652            let d = dot_product(&a, &b);
653            assert!(
654                approx_eq(s, d, 0.01),
655                "len={len}: scalar={s} dispatched={d}"
656            );
657        }
658    }
659
660    #[test]
661    fn test_cosine_non_aligned_length() {
662        for len in [1, 3, 5, 7, 9, 13, 17, 31, 33] {
663            let a: Vec<f32> = (0..len).map(|i| i as f32 + 1.0).collect();
664            let b: Vec<f32> = (0..len).map(|i| (i as f32 + 1.0) * 2.0).collect();
665            let s = scalar::cosine_similarity(&a, &b);
666            let d = cosine_similarity(&a, &b);
667            assert!(
668                approx_eq(s, d, 1e-4),
669                "len={len}: scalar={s} dispatched={d}"
670            );
671        }
672    }
673
674    #[test]
675    fn test_l2_distance_non_aligned_length() {
676        for len in [1, 3, 5, 7, 9, 13, 17, 31, 33] {
677            let a: Vec<f32> = (0..len).map(|i| i as f32).collect();
678            let b: Vec<f32> = (0..len).map(|i| (i as f32) + 1.0).collect();
679            let s = scalar::l2_distance(&a, &b);
680            let d = l2_distance(&a, &b);
681            assert!(
682                approx_eq(s, d, 0.01),
683                "len={len}: scalar={s} dispatched={d}"
684            );
685        }
686    }
687}