Skip to main content

simsimd/
lib.rs

1//! # SpatialSimilarity - Hardware-Accelerated Similarity Metrics and Distance Functions
2//!
3//! * Targets ARM NEON, SVE, x86 AVX2, AVX-512 (VNNI, FP16) hardware backends.
4//! * Handles `f64` double- and `f32` single-precision, integral, and binary vectors.
5//! * Exposes half-precision (`f16`) and brain floating point (`bf16`) types.
6//! * Zero-dependency header-only C 99 library with bindings for Rust and other languages.
7//!
8//! ## Implemented distance functions include:
9//!
10//! * Euclidean (L2), inner product, and cosine (angular) spatial distances.
11//! * Hamming (~ Manhattan) and Jaccard (~ Tanimoto) binary distances.
12//! * Kullback-Leibler and Jensen-Shannon divergences for probability distributions.
13//!
14//! ## Example
15//!
16//! ```rust
17//! use simsimd::SpatialSimilarity;
18//!
19//! let a = &[1, 2, 3];
20//! let b = &[4, 5, 6];
21//!
22//! // Compute cosine distance
23//! let cos_dist = i8::cos(a, b);
24//!
25//! // Compute dot product distance
26//! let dot_product = i8::dot(a, b);
27//!
28//! // Compute squared Euclidean distance
29//! let l2sq_dist = i8::l2sq(a, b);
30//!
31//! // Optimize performance by flushing denormals
32//! simsimd::capabilities::flush_denormals();
33//! ```
34//!
35//! ## Mixed Precision Support
36//!
37//! ```rust
38//! use simsimd::{SpatialSimilarity, f16, bf16};
39//!
40//! // Work with half-precision floats
41//! let half_a: Vec<f16> = vec![1.0, 2.0, 3.0].iter().map(|&x| f16::from_f32(x)).collect();
42//! let half_b: Vec<f16> = vec![4.0, 5.0, 6.0].iter().map(|&x| f16::from_f32(x)).collect();
43//! let half_cos_dist = f16::cos(&half_a, &half_b);
44//!
45//! // Work with brain floats
46//! let brain_a: Vec<bf16> = vec![1.0, 2.0, 3.0].iter().map(|&x| bf16::from_f32(x)).collect();
47//! let brain_b: Vec<bf16> = vec![4.0, 5.0, 6.0].iter().map(|&x| bf16::from_f32(x)).collect();
48//! let brain_cos_dist = bf16::cos(&brain_a, &brain_b);
49//!
50//! // Direct bit manipulation
51//! let half = f16::from_f32(3.14);
52//! let bits = half.0; // Access raw u16 representation
53//! let reconstructed = f16(bits);
54//! ```
55//!
56//! ## Traits
57//!
58//! The `SpatialSimilarity` trait covers following methods:
59//!
60//! - `cosine(a: &[Self], b: &[Self]) -> Option<Distance>`: Computes cosine distance (1 - similarity) between two slices.
61//! - `dot(a: &[Self], b: &[Self]) -> Option<Distance>`: Computes dot product distance between two slices.
62//! - `sqeuclidean(a: &[Self], b: &[Self]) -> Option<Distance>`: Computes squared Euclidean distance between two slices.
63//!
64//! The `BinarySimilarity` trait covers following methods:
65//!
66//! - `hamming(a: &[Self], b: &[Self]) -> Option<Distance>`: Computes Hamming distance between two slices.
67//! - `jaccard(a: &[Self], b: &[Self]) -> Option<Distance>`: Computes Jaccard distance between two slices.
68//!
69//! The `ProbabilitySimilarity` trait covers following methods:
70//!
71//! - `jensenshannon(a: &[Self], b: &[Self]) -> Option<Distance>`: Computes Jensen-Shannon divergence between two slices.
72//! - `kullbackleibler(a: &[Self], b: &[Self]) -> Option<Distance>`: Computes Kullback-Leibler divergence between two slices.
73//!
74#![allow(non_camel_case_types)]
75#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
76
77pub type Distance = f64;
78pub type ComplexProduct = (f64, f64);
79
80/// Size type used in C FFI to match `simsimd_size_t` which is always `uint64_t`.
81/// This is aliased to `u64` instead of `usize` to maintain ABI compatibility across
82/// all platforms, including 32-bit architectures where `usize` is 32-bit but the
83/// C library expects 64-bit size parameters.
84///
85/// TODO: In v7, change the C library to use `size_t` and this to `usize`.
86type u64size = u64;
87
88/// Compatibility function for pre 1.85 Rust versions lacking `f32::abs`.
89#[inline(always)]
90fn f32_abs_compat(x: f32) -> f32 {
91    f32::from_bits(x.to_bits() & 0x7FFF_FFFF)
92}
93
94#[link(name = "simsimd")]
95extern "C" {
96
97    fn simsimd_dot_i8(a: *const i8, b: *const i8, c: u64size, d: *mut Distance);
98    fn simsimd_dot_f16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
99    fn simsimd_dot_bf16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
100    fn simsimd_dot_f32(a: *const f32, b: *const f32, c: u64size, d: *mut Distance);
101    fn simsimd_dot_f64(a: *const f64, b: *const f64, c: u64size, d: *mut Distance);
102
103    fn simsimd_dot_f16c(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
104    fn simsimd_dot_bf16c(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
105    fn simsimd_dot_f32c(a: *const f32, b: *const f32, c: u64size, d: *mut Distance);
106    fn simsimd_dot_f64c(a: *const f64, b: *const f64, c: u64size, d: *mut Distance);
107
108    fn simsimd_vdot_f16c(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
109    fn simsimd_vdot_bf16c(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
110    fn simsimd_vdot_f32c(a: *const f32, b: *const f32, c: u64size, d: *mut Distance);
111    fn simsimd_vdot_f64c(a: *const f64, b: *const f64, c: u64size, d: *mut Distance);
112
113    fn simsimd_cos_i8(a: *const i8, b: *const i8, c: u64size, d: *mut Distance);
114    fn simsimd_cos_f16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
115    fn simsimd_cos_bf16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
116    fn simsimd_cos_f32(a: *const f32, b: *const f32, c: u64size, d: *mut Distance);
117    fn simsimd_cos_f64(a: *const f64, b: *const f64, c: u64size, d: *mut Distance);
118
119    fn simsimd_l2sq_i8(a: *const i8, b: *const i8, c: u64size, d: *mut Distance);
120    fn simsimd_l2sq_f16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
121    fn simsimd_l2sq_bf16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
122    fn simsimd_l2sq_f32(a: *const f32, b: *const f32, c: u64size, d: *mut Distance);
123    fn simsimd_l2sq_f64(a: *const f64, b: *const f64, c: u64size, d: *mut Distance);
124
125    fn simsimd_l2_i8(a: *const i8, b: *const i8, c: u64size, d: *mut Distance);
126    fn simsimd_l2_f16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
127    fn simsimd_l2_bf16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
128    fn simsimd_l2_f32(a: *const f32, b: *const f32, c: u64size, d: *mut Distance);
129    fn simsimd_l2_f64(a: *const f64, b: *const f64, c: u64size, d: *mut Distance);
130
131    fn simsimd_hamming_b8(a: *const u8, b: *const u8, c: u64size, d: *mut Distance);
132    fn simsimd_jaccard_b8(a: *const u8, b: *const u8, c: u64size, d: *mut Distance);
133
134    fn simsimd_js_f16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
135    fn simsimd_js_bf16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
136    fn simsimd_js_f32(a: *const f32, b: *const f32, c: u64size, d: *mut Distance);
137    fn simsimd_js_f64(a: *const f64, b: *const f64, c: u64size, d: *mut Distance);
138
139    fn simsimd_kl_f16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
140    fn simsimd_kl_bf16(a: *const u16, b: *const u16, c: u64size, d: *mut Distance);
141    fn simsimd_kl_f32(a: *const f32, b: *const f32, c: u64size, d: *mut Distance);
142    fn simsimd_kl_f64(a: *const f64, b: *const f64, c: u64size, d: *mut Distance);
143
144    fn simsimd_intersect_u16(
145        a: *const u16,
146        b: *const u16,
147        a_length: u64size,
148        b_length: u64size,
149        d: *mut Distance,
150    );
151    fn simsimd_intersect_u32(
152        a: *const u32,
153        b: *const u32,
154        a_length: u64size,
155        b_length: u64size,
156        d: *mut Distance,
157    );
158
159    fn simsimd_uses_neon() -> i32;
160    fn simsimd_uses_neon_f16() -> i32;
161    fn simsimd_uses_neon_bf16() -> i32;
162    fn simsimd_uses_neon_i8() -> i32;
163    fn simsimd_uses_sve() -> i32;
164    fn simsimd_uses_sve_f16() -> i32;
165    fn simsimd_uses_sve_bf16() -> i32;
166    fn simsimd_uses_sve_i8() -> i32;
167    fn simsimd_uses_haswell() -> i32;
168    fn simsimd_uses_skylake() -> i32;
169    fn simsimd_uses_ice() -> i32;
170    fn simsimd_uses_genoa() -> i32;
171    fn simsimd_uses_sapphire() -> i32;
172    fn simsimd_uses_turin() -> i32;
173    fn simsimd_uses_sierra() -> i32;
174
175    fn simsimd_flush_denormals() -> i32;
176    fn simsimd_uses_dynamic_dispatch() -> i32;
177
178    fn simsimd_f32_to_f16(f32_value: f32, result_ptr: *mut u16);
179    fn simsimd_f16_to_f32(f16_ptr: *const u16) -> f32;
180    fn simsimd_f32_to_bf16(f32_value: f32, result_ptr: *mut u16);
181    fn simsimd_bf16_to_f32(bf16_ptr: *const u16) -> f32;
182}
183
184/// A half-precision (16-bit) floating point number.
185///
186/// This type represents IEEE 754 half-precision binary floating-point format.
187/// It provides conversion methods to and from f32, and the underlying u16
188/// representation is publicly accessible for direct bit manipulation.
189///
190/// # Examples
191///
192/// ```
193/// use simsimd::f16;
194///
195/// // Create from f32
196/// let half = f16::from_f32(3.14);
197///
198/// // Convert back to f32
199/// let float = half.to_f32();
200///
201/// // Direct access to bits
202/// let bits = half.0;
203/// ```
204#[repr(transparent)]
205#[derive(Debug, Clone, Copy, PartialEq, Eq)]
206pub struct f16(pub u16);
207
208impl f16 {
209    /// Positive zero.
210    pub const ZERO: Self = f16(0);
211
212    /// Positive one.
213    pub const ONE: Self = f16(0x3C00);
214
215    /// Negative one.
216    pub const NEG_ONE: Self = f16(0xBC00);
217
218    /// Converts an f32 to f16 representation.
219    ///
220    /// # Examples
221    ///
222    /// ```
223    /// use simsimd::f16;
224    /// let half = f16::from_f32(3.14159);
225    /// ```
226    #[inline(always)]
227    pub fn from_f32(value: f32) -> Self {
228        let mut result: u16 = 0;
229        unsafe { simsimd_f32_to_f16(value, &mut result) };
230        f16(result)
231    }
232
233    /// Converts the f16 to an f32.
234    ///
235    /// # Examples
236    ///
237    /// ```
238    /// use simsimd::f16;
239    /// let half = f16::from_f32(3.14159);
240    /// let float = half.to_f32();
241    /// ```
242    #[inline(always)]
243    pub fn to_f32(self) -> f32 {
244        unsafe { simsimd_f16_to_f32(&self.0) }
245    }
246
247    /// Returns true if this value is NaN.
248    #[inline(always)]
249    pub fn is_nan(self) -> bool {
250        self.to_f32().is_nan()
251    }
252
253    /// Returns true if this value is positive or negative infinity.
254    #[inline(always)]
255    pub fn is_infinite(self) -> bool {
256        self.to_f32().is_infinite()
257    }
258
259    /// Returns true if this number is neither infinite nor NaN.
260    #[inline(always)]
261    pub fn is_finite(self) -> bool {
262        self.to_f32().is_finite()
263    }
264
265    /// Returns the absolute value of self.
266    #[inline(always)]
267    pub fn abs(self) -> Self {
268        Self::from_f32(f32_abs_compat(self.to_f32()))
269    }
270
271    /// Returns the largest integer less than or equal to a number.
272    ///
273    /// This method is only available when the `std` feature is enabled.
274    #[cfg(feature = "std")]
275    #[inline(always)]
276    pub fn floor(self) -> Self {
277        Self::from_f32(self.to_f32().floor())
278    }
279
280    /// Returns the smallest integer greater than or equal to a number.
281    ///
282    /// This method is only available when the `std` feature is enabled.
283    #[cfg(feature = "std")]
284    #[inline(always)]
285    pub fn ceil(self) -> Self {
286        Self::from_f32(self.to_f32().ceil())
287    }
288
289    /// Returns the nearest integer to a number. Round half-way cases away from 0.0.
290    ///
291    /// This method is only available when the `std` feature is enabled.
292    #[cfg(feature = "std")]
293    #[inline(always)]
294    pub fn round(self) -> Self {
295        Self::from_f32(self.to_f32().round())
296    }
297}
298
299#[cfg(feature = "std")]
300impl core::fmt::Display for f16 {
301    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
302        write!(f, "{}", self.to_f32())
303    }
304}
305
306impl core::ops::Add for f16 {
307    type Output = Self;
308
309    #[inline(always)]
310    fn add(self, rhs: Self) -> Self::Output {
311        Self::from_f32(self.to_f32() + rhs.to_f32())
312    }
313}
314
315impl core::ops::Sub for f16 {
316    type Output = Self;
317
318    #[inline(always)]
319    fn sub(self, rhs: Self) -> Self::Output {
320        Self::from_f32(self.to_f32() - rhs.to_f32())
321    }
322}
323
324impl core::ops::Mul for f16 {
325    type Output = Self;
326
327    #[inline(always)]
328    fn mul(self, rhs: Self) -> Self::Output {
329        Self::from_f32(self.to_f32() * rhs.to_f32())
330    }
331}
332
333impl core::ops::Div for f16 {
334    type Output = Self;
335
336    #[inline(always)]
337    fn div(self, rhs: Self) -> Self::Output {
338        Self::from_f32(self.to_f32() / rhs.to_f32())
339    }
340}
341
342impl core::ops::Neg for f16 {
343    type Output = Self;
344
345    #[inline(always)]
346    fn neg(self) -> Self::Output {
347        Self::from_f32(-self.to_f32())
348    }
349}
350
351impl core::cmp::PartialOrd for f16 {
352    #[inline(always)]
353    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
354        self.to_f32().partial_cmp(&other.to_f32())
355    }
356}
357
358/// A brain floating point (bfloat16) number.
359///
360/// This type represents Google's bfloat16 format, which truncates IEEE 754
361/// single-precision to 16 bits by keeping the exponent bits but reducing
362/// the mantissa. This provides a wider range than f16 but lower precision.
363///
364/// # Examples
365///
366/// ```
367/// use simsimd::bf16;
368///
369/// // Create from f32
370/// let brain_half = bf16::from_f32(3.14);
371///
372/// // Convert back to f32
373/// let float = brain_half.to_f32();
374///
375/// // Direct access to bits
376/// let bits = brain_half.0;
377/// ```
378#[repr(transparent)]
379#[derive(Debug, Clone, Copy, PartialEq, Eq)]
380pub struct bf16(pub u16);
381
382impl bf16 {
383    /// Positive zero.
384    pub const ZERO: Self = bf16(0);
385
386    /// Positive one.
387    pub const ONE: Self = bf16(0x3F80);
388
389    /// Negative one.
390    pub const NEG_ONE: Self = bf16(0xBF80);
391
392    /// Converts an f32 to bf16 representation.
393    ///
394    /// # Examples
395    ///
396    /// ```
397    /// use simsimd::bf16;
398    /// let brain_half = bf16::from_f32(3.14159);
399    /// ```
400    #[inline(always)]
401    pub fn from_f32(value: f32) -> Self {
402        let mut result: u16 = 0;
403        unsafe { simsimd_f32_to_bf16(value, &mut result) };
404        bf16(result)
405    }
406
407    /// Converts the bf16 to an f32.
408    ///
409    /// # Examples
410    ///
411    /// ```
412    /// use simsimd::bf16;
413    /// let brain_half = bf16::from_f32(3.14159);
414    /// let float = brain_half.to_f32();
415    /// ```
416    #[inline(always)]
417    pub fn to_f32(self) -> f32 {
418        unsafe { simsimd_bf16_to_f32(&self.0) }
419    }
420
421    /// Returns true if this value is NaN.
422    #[inline(always)]
423    pub fn is_nan(self) -> bool {
424        self.to_f32().is_nan()
425    }
426
427    /// Returns true if this value is positive or negative infinity.
428    #[inline(always)]
429    pub fn is_infinite(self) -> bool {
430        self.to_f32().is_infinite()
431    }
432
433    /// Returns true if this number is neither infinite nor NaN.
434    #[inline(always)]
435    pub fn is_finite(self) -> bool {
436        self.to_f32().is_finite()
437    }
438
439    /// Returns the absolute value of self.
440    #[inline(always)]
441    pub fn abs(self) -> Self {
442        Self::from_f32(f32_abs_compat(self.to_f32()))
443    }
444
445    /// Returns the largest integer less than or equal to a number.
446    ///
447    /// This method is only available when the `std` feature is enabled.
448    #[cfg(feature = "std")]
449    #[inline(always)]
450    pub fn floor(self) -> Self {
451        Self::from_f32(self.to_f32().floor())
452    }
453
454    /// Returns the smallest integer greater than or equal to a number.
455    ///
456    /// This method is only available when the `std` feature is enabled.
457    #[cfg(feature = "std")]
458    #[inline(always)]
459    pub fn ceil(self) -> Self {
460        Self::from_f32(self.to_f32().ceil())
461    }
462
463    /// Returns the nearest integer to a number. Round half-way cases away from 0.0.
464    ///
465    /// This method is only available when the `std` feature is enabled.
466    #[cfg(feature = "std")]
467    #[inline(always)]
468    pub fn round(self) -> Self {
469        Self::from_f32(self.to_f32().round())
470    }
471}
472
473#[cfg(feature = "std")]
474impl core::fmt::Display for bf16 {
475    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
476        write!(f, "{}", self.to_f32())
477    }
478}
479
480impl core::ops::Add for bf16 {
481    type Output = Self;
482
483    #[inline(always)]
484    fn add(self, rhs: Self) -> Self::Output {
485        Self::from_f32(self.to_f32() + rhs.to_f32())
486    }
487}
488
489impl core::ops::Sub for bf16 {
490    type Output = Self;
491
492    #[inline(always)]
493    fn sub(self, rhs: Self) -> Self::Output {
494        Self::from_f32(self.to_f32() - rhs.to_f32())
495    }
496}
497
498impl core::ops::Mul for bf16 {
499    type Output = Self;
500
501    #[inline(always)]
502    fn mul(self, rhs: Self) -> Self::Output {
503        Self::from_f32(self.to_f32() * rhs.to_f32())
504    }
505}
506
507impl core::ops::Div for bf16 {
508    type Output = Self;
509
510    #[inline(always)]
511    fn div(self, rhs: Self) -> Self::Output {
512        Self::from_f32(self.to_f32() / rhs.to_f32())
513    }
514}
515
516impl core::ops::Neg for bf16 {
517    type Output = Self;
518
519    #[inline(always)]
520    fn neg(self) -> Self::Output {
521        Self::from_f32(-self.to_f32())
522    }
523}
524
525impl core::cmp::PartialOrd for bf16 {
526    #[inline(always)]
527    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
528        self.to_f32().partial_cmp(&other.to_f32())
529    }
530}
531
532/// The `capabilities` module provides functions for detecting the hardware features
533/// available on the current system.
534pub mod capabilities {
535
536    pub fn uses_neon() -> bool {
537        unsafe { crate::simsimd_uses_neon() != 0 }
538    }
539
540    pub fn uses_neon_f16() -> bool {
541        unsafe { crate::simsimd_uses_neon_f16() != 0 }
542    }
543
544    pub fn uses_neon_bf16() -> bool {
545        unsafe { crate::simsimd_uses_neon_bf16() != 0 }
546    }
547
548    pub fn uses_neon_i8() -> bool {
549        unsafe { crate::simsimd_uses_neon_i8() != 0 }
550    }
551
552    pub fn uses_sve() -> bool {
553        unsafe { crate::simsimd_uses_sve() != 0 }
554    }
555
556    pub fn uses_sve_f16() -> bool {
557        unsafe { crate::simsimd_uses_sve_f16() != 0 }
558    }
559
560    pub fn uses_sve_bf16() -> bool {
561        unsafe { crate::simsimd_uses_sve_bf16() != 0 }
562    }
563
564    pub fn uses_sve_i8() -> bool {
565        unsafe { crate::simsimd_uses_sve_i8() != 0 }
566    }
567
568    pub fn uses_haswell() -> bool {
569        unsafe { crate::simsimd_uses_haswell() != 0 }
570    }
571
572    pub fn uses_skylake() -> bool {
573        unsafe { crate::simsimd_uses_skylake() != 0 }
574    }
575
576    pub fn uses_ice() -> bool {
577        unsafe { crate::simsimd_uses_ice() != 0 }
578    }
579
580    pub fn uses_genoa() -> bool {
581        unsafe { crate::simsimd_uses_genoa() != 0 }
582    }
583
584    pub fn uses_sapphire() -> bool {
585        unsafe { crate::simsimd_uses_sapphire() != 0 }
586    }
587
588    pub fn uses_turin() -> bool {
589        unsafe { crate::simsimd_uses_turin() != 0 }
590    }
591
592    pub fn uses_sierra() -> bool {
593        unsafe { crate::simsimd_uses_sierra() != 0 }
594    }
595
596    /// Flushes denormalized numbers to zero on the current CPU architecture.
597    ///
598    /// This function should be called on each thread before any SIMD operations
599    /// to avoid performance penalties. When facing denormalized values,
600    /// Fused-Multiply-Add (FMA) operations can be up to 30x slower.
601    ///
602    /// # Returns
603    ///
604    /// Returns `true` if the operation was successful, `false` otherwise.
605    pub fn flush_denormals() -> bool {
606        unsafe { crate::simsimd_flush_denormals() != 0 }
607    }
608
609    /// Checks if the library is using dynamic dispatch for function selection.
610    ///
611    /// # Returns
612    ///
613    /// Returns `true` when the C backend is compiled with dynamic dispatch
614    /// (default for this crate via `build.rs`), otherwise `false`.
615    pub fn uses_dynamic_dispatch() -> bool {
616        unsafe { crate::simsimd_uses_dynamic_dispatch() != 0 }
617    }
618}
619
620/// `SpatialSimilarity` provides a set of trait methods for computing similarity
621/// or distance between spatial data vectors in SIMD (Single Instruction, Multiple Data) context.
622/// These methods can be used to calculate metrics like cosine distance, dot product,
623/// and squared Euclidean distance between two slices of data.
624///
625/// Each method takes two slices of data (a and b) and returns an Option<Distance>.
626/// The result is `None` if the slices are not of the same length, as these operations
627/// require one-to-one correspondence between the elements of the slices.
628/// Otherwise, it returns the computed similarity or distance as `Some(f64)`.
629/// Convenience methods like `cosine`/`sqeuclidean` delegate to the core methods
630/// `cos`/`l2sq` implemented by this trait.
631pub trait SpatialSimilarity
632where
633    Self: Sized,
634{
635    /// Computes the cosine distance between two slices.
636    /// The cosine distance is 1 minus the cosine similarity between two non-zero vectors
637    /// of an dot product space that measures the cosine of the angle between them.
638    fn cos(a: &[Self], b: &[Self]) -> Option<Distance>;
639
640    /// Computes the inner product (also known as dot product) between two slices.
641    /// The dot product is the sum of the products of the corresponding entries
642    /// of the two sequences of numbers.
643    fn dot(a: &[Self], b: &[Self]) -> Option<Distance>;
644
645    /// Computes the squared Euclidean distance between two slices.
646    /// The squared Euclidean distance is the sum of the squared differences
647    /// between corresponding elements of the two slices.
648    fn l2sq(a: &[Self], b: &[Self]) -> Option<Distance>;
649
650    /// Computes the Euclidean distance between two slices.
651    /// The Euclidean distance is the square root of
652    //  sum of the squared differences between corresponding
653    /// elements of the two slices.
654    fn l2(a: &[Self], b: &[Self]) -> Option<Distance>;
655
656    /// Computes the squared Euclidean distance between two slices.
657    /// The squared Euclidean distance is the sum of the squared differences
658    /// between corresponding elements of the two slices.
659    fn sqeuclidean(a: &[Self], b: &[Self]) -> Option<Distance> {
660        SpatialSimilarity::l2sq(a, b)
661    }
662
663    /// Computes the Euclidean distance between two slices.
664    /// The Euclidean distance is the square root of the
665    /// sum of the squared differences between corresponding
666    /// elements of the two slices.
667    fn euclidean(a: &[Self], b: &[Self]) -> Option<Distance> {
668        SpatialSimilarity::l2(a, b)
669    }
670
671    /// Computes the squared Euclidean distance between two slices.
672    /// The squared Euclidean distance is the sum of the squared differences
673    /// between corresponding elements of the two slices.
674    fn inner(a: &[Self], b: &[Self]) -> Option<Distance> {
675        SpatialSimilarity::dot(a, b)
676    }
677
678    /// Computes the cosine distance between two slices.
679    /// The cosine distance is 1 minus the cosine similarity between two non-zero vectors
680    /// of an dot product space that measures the cosine of the angle between them.
681    fn cosine(a: &[Self], b: &[Self]) -> Option<Distance> {
682        SpatialSimilarity::cos(a, b)
683    }
684}
685
686/// `BinarySimilarity` provides trait methods for computing similarity metrics
687/// that are commonly used with binary data vectors, such as Hamming distance
688/// and Jaccard index.
689///
690/// The methods accept two slices of binary data and return an Option<Distance>
691/// indicating the computed similarity or distance, with `None` returned if the
692/// slices differ in length.
693pub trait BinarySimilarity
694where
695    Self: Sized,
696{
697    /// Computes the Hamming distance between two binary data slices.
698    /// The Hamming distance between two strings of equal length is the number of
699    /// bits at which the corresponding values are different.
700    fn hamming(a: &[Self], b: &[Self]) -> Option<Distance>;
701
702    /// Computes the Jaccard index between two bitsets represented by binary data slices.
703    /// The Jaccard index, also known as the Jaccard similarity coefficient, is a statistic
704    /// used for gauging the similarity and diversity of sample sets.
705    fn jaccard(a: &[Self], b: &[Self]) -> Option<Distance>;
706}
707
708/// `ProbabilitySimilarity` provides trait methods for computing similarity or divergence
709/// measures between probability distributions, such as the Jensen-Shannon divergence
710/// and the Kullback-Leibler divergence.
711///
712/// These methods are particularly useful in contexts such as information theory and
713/// machine learning, where one often needs to measure how one probability distribution
714/// differs from a second, reference probability distribution.
715pub trait ProbabilitySimilarity
716where
717    Self: Sized,
718{
719    /// Computes the Jensen-Shannon divergence between two probability distributions.
720    /// The Jensen-Shannon divergence is a method of measuring the similarity between
721    /// two probability distributions. It is based on the Kullback-Leibler divergence,
722    /// but is symmetric and always has a finite value.
723    fn jensenshannon(a: &[Self], b: &[Self]) -> Option<Distance>;
724
725    /// Computes the Kullback-Leibler divergence between two probability distributions.
726    /// The Kullback-Leibler divergence is a measure of how one probability distribution
727    /// diverges from a second, expected probability distribution.
728    fn kullbackleibler(a: &[Self], b: &[Self]) -> Option<Distance>;
729}
730
731/// `ComplexProducts` provides trait methods for computing products between
732/// complex number vectors. This includes standard and Hermitian dot products.
733pub trait ComplexProducts
734where
735    Self: Sized,
736{
737    /// Computes the dot product between two complex number vectors.
738    fn dot(a: &[Self], b: &[Self]) -> Option<ComplexProduct>;
739
740    /// Computes the Hermitian dot product (conjugate dot product) between two complex number vectors.
741    fn vdot(a: &[Self], b: &[Self]) -> Option<ComplexProduct>;
742}
743
744/// `Sparse` provides trait methods for sparse vectors.
745pub trait Sparse
746where
747    Self: Sized,
748{
749    /// Computes the number of common elements between two sparse vectors.
750    /// both vectors must be sorted in ascending order.
751    fn intersect(a: &[Self], b: &[Self]) -> Option<Distance>;
752}
753
754impl BinarySimilarity for u8 {
755    fn hamming(a: &[Self], b: &[Self]) -> Option<Distance> {
756        if a.len() != b.len() {
757            return None;
758        }
759        let mut distance_value: Distance = 0.0;
760        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
761        unsafe { simsimd_hamming_b8(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
762        Some(distance_value)
763    }
764
765    fn jaccard(a: &[Self], b: &[Self]) -> Option<Distance> {
766        if a.len() != b.len() {
767            return None;
768        }
769        let mut distance_value: Distance = 0.0;
770        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
771        unsafe { simsimd_jaccard_b8(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
772        Some(distance_value)
773    }
774}
775
776impl SpatialSimilarity for i8 {
777    fn cos(a: &[Self], b: &[Self]) -> Option<Distance> {
778        if a.len() != b.len() {
779            return None;
780        }
781        let mut distance_value: Distance = 0.0;
782        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
783        unsafe { simsimd_cos_i8(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
784        Some(distance_value)
785    }
786
787    fn dot(a: &[Self], b: &[Self]) -> Option<Distance> {
788        if a.len() != b.len() {
789            return None;
790        }
791        let mut distance_value: Distance = 0.0;
792        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
793        unsafe { simsimd_dot_i8(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
794        Some(distance_value)
795    }
796
797    fn l2sq(a: &[Self], b: &[Self]) -> Option<Distance> {
798        if a.len() != b.len() {
799            return None;
800        }
801        let mut distance_value: Distance = 0.0;
802        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
803        unsafe { simsimd_l2sq_i8(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
804        Some(distance_value)
805    }
806
807    fn l2(a: &[Self], b: &[Self]) -> Option<Distance> {
808        if a.len() != b.len() {
809            return None;
810        }
811        let mut distance_value: Distance = 0.0;
812        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
813        unsafe { simsimd_l2_i8(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
814        Some(distance_value)
815    }
816}
817
818impl Sparse for u16 {
819    fn intersect(a: &[Self], b: &[Self]) -> Option<Distance> {
820        let mut distance_value: Distance = 0.0;
821        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
822        unsafe {
823            simsimd_intersect_u16(
824                a.as_ptr(),
825                b.as_ptr(),
826                a.len() as u64size,
827                b.len() as u64size,
828                distance_ptr,
829            )
830        };
831        Some(distance_value)
832    }
833}
834
835impl Sparse for u32 {
836    fn intersect(a: &[Self], b: &[Self]) -> Option<Distance> {
837        let mut distance_value: Distance = 0.0;
838        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
839        unsafe {
840            simsimd_intersect_u32(
841                a.as_ptr(),
842                b.as_ptr(),
843                a.len() as u64size,
844                b.len() as u64size,
845                distance_ptr,
846            )
847        };
848        Some(distance_value)
849    }
850}
851
852impl SpatialSimilarity for f16 {
853    fn cos(a: &[Self], b: &[Self]) -> Option<Distance> {
854        if a.len() != b.len() {
855            return None;
856        }
857
858        // Explicitly cast `*const f16` to `*const u16`
859        let a_ptr = a.as_ptr() as *const u16;
860        let b_ptr = b.as_ptr() as *const u16;
861        let mut distance_value: Distance = 0.0;
862        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
863        unsafe { simsimd_cos_f16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
864        Some(distance_value)
865    }
866
867    fn dot(a: &[Self], b: &[Self]) -> Option<Distance> {
868        if a.len() != b.len() {
869            return None;
870        }
871
872        // Explicitly cast `*const f16` to `*const u16`
873        let a_ptr = a.as_ptr() as *const u16;
874        let b_ptr = b.as_ptr() as *const u16;
875        let mut distance_value: Distance = 0.0;
876        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
877        unsafe { simsimd_dot_f16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
878        Some(distance_value)
879    }
880
881    fn l2sq(a: &[Self], b: &[Self]) -> Option<Distance> {
882        if a.len() != b.len() {
883            return None;
884        }
885
886        // Explicitly cast `*const f16` to `*const u16`
887        let a_ptr = a.as_ptr() as *const u16;
888        let b_ptr = b.as_ptr() as *const u16;
889        let mut distance_value: Distance = 0.0;
890        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
891        unsafe { simsimd_l2sq_f16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
892        Some(distance_value)
893    }
894
895    fn l2(a: &[Self], b: &[Self]) -> Option<Distance> {
896        if a.len() != b.len() {
897            return None;
898        }
899        // Explicitly cast `*const f16` to `*const u16`
900        let a_ptr = a.as_ptr() as *const u16;
901        let b_ptr = b.as_ptr() as *const u16;
902        let mut distance_value: Distance = 0.0;
903        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
904        unsafe { simsimd_l2_f16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
905        Some(distance_value)
906    }
907}
908
909impl SpatialSimilarity for bf16 {
910    fn cos(a: &[Self], b: &[Self]) -> Option<Distance> {
911        if a.len() != b.len() {
912            return None;
913        }
914
915        // Explicitly cast `*const bf16` to `*const u16`
916        let a_ptr = a.as_ptr() as *const u16;
917        let b_ptr = b.as_ptr() as *const u16;
918        let mut distance_value: Distance = 0.0;
919        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
920        unsafe { simsimd_cos_bf16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
921        Some(distance_value)
922    }
923
924    fn dot(a: &[Self], b: &[Self]) -> Option<Distance> {
925        if a.len() != b.len() {
926            return None;
927        }
928
929        // Explicitly cast `*const bf16` to `*const u16`
930        let a_ptr = a.as_ptr() as *const u16;
931        let b_ptr = b.as_ptr() as *const u16;
932        let mut distance_value: Distance = 0.0;
933        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
934        unsafe { simsimd_dot_bf16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
935        Some(distance_value)
936    }
937
938    fn l2sq(a: &[Self], b: &[Self]) -> Option<Distance> {
939        if a.len() != b.len() {
940            return None;
941        }
942
943        // Explicitly cast `*const bf16` to `*const u16`
944        let a_ptr = a.as_ptr() as *const u16;
945        let b_ptr = b.as_ptr() as *const u16;
946        let mut distance_value: Distance = 0.0;
947        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
948        unsafe { simsimd_l2sq_bf16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
949        Some(distance_value)
950    }
951
952    fn l2(a: &[Self], b: &[Self]) -> Option<Distance> {
953        if a.len() != b.len() {
954            return None;
955        }
956        // Explicitly cast `*const bf16` to `*const u16`
957        let a_ptr = a.as_ptr() as *const u16;
958        let b_ptr = b.as_ptr() as *const u16;
959        let mut distance_value: Distance = 0.0;
960        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
961        unsafe { simsimd_l2_bf16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
962        Some(distance_value)
963    }
964}
965
966impl SpatialSimilarity for f32 {
967    fn cos(a: &[Self], b: &[Self]) -> Option<Distance> {
968        if a.len() != b.len() {
969            return None;
970        }
971        let mut distance_value: Distance = 0.0;
972        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
973        unsafe { simsimd_cos_f32(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
974        Some(distance_value)
975    }
976
977    fn dot(a: &[Self], b: &[Self]) -> Option<Distance> {
978        if a.len() != b.len() {
979            return None;
980        }
981        let mut distance_value: Distance = 0.0;
982        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
983        unsafe { simsimd_dot_f32(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
984        Some(distance_value)
985    }
986
987    fn l2sq(a: &[Self], b: &[Self]) -> Option<Distance> {
988        if a.len() != b.len() {
989            return None;
990        }
991        let mut distance_value: Distance = 0.0;
992        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
993        unsafe { simsimd_l2sq_f32(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
994        Some(distance_value)
995    }
996
997    fn l2(a: &[Self], b: &[Self]) -> Option<Distance> {
998        if a.len() != b.len() {
999            return None;
1000        }
1001        let mut distance_value: Distance = 0.0;
1002        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1003        unsafe { simsimd_l2_f32(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
1004        Some(distance_value)
1005    }
1006}
1007
1008impl SpatialSimilarity for f64 {
1009    fn cos(a: &[Self], b: &[Self]) -> Option<Distance> {
1010        if a.len() != b.len() {
1011            return None;
1012        }
1013        let mut distance_value: Distance = 0.0;
1014        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1015        unsafe { simsimd_cos_f64(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
1016        Some(distance_value)
1017    }
1018
1019    fn dot(a: &[Self], b: &[Self]) -> Option<Distance> {
1020        if a.len() != b.len() {
1021            return None;
1022        }
1023        let mut distance_value: Distance = 0.0;
1024        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1025        unsafe { simsimd_dot_f64(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
1026        Some(distance_value)
1027    }
1028
1029    fn l2sq(a: &[Self], b: &[Self]) -> Option<Distance> {
1030        if a.len() != b.len() {
1031            return None;
1032        }
1033        let mut distance_value: Distance = 0.0;
1034        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1035        unsafe { simsimd_l2sq_f64(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
1036        Some(distance_value)
1037    }
1038
1039    fn l2(a: &[Self], b: &[Self]) -> Option<Distance> {
1040        if a.len() != b.len() {
1041            return None;
1042        }
1043        let mut distance_value: Distance = 0.0;
1044        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1045        unsafe { simsimd_l2_f64(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
1046        Some(distance_value)
1047    }
1048}
1049
1050impl ProbabilitySimilarity for f16 {
1051    fn jensenshannon(a: &[Self], b: &[Self]) -> Option<Distance> {
1052        if a.len() != b.len() {
1053            return None;
1054        }
1055
1056        // Explicitly cast `*const f16` to `*const u16`
1057        let a_ptr = a.as_ptr() as *const u16;
1058        let b_ptr = b.as_ptr() as *const u16;
1059        let mut distance_value: Distance = 0.0;
1060        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1061        unsafe { simsimd_js_f16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
1062        Some(distance_value)
1063    }
1064
1065    fn kullbackleibler(a: &[Self], b: &[Self]) -> Option<Distance> {
1066        if a.len() != b.len() {
1067            return None;
1068        }
1069
1070        // Explicitly cast `*const f16` to `*const u16`
1071        let a_ptr = a.as_ptr() as *const u16;
1072        let b_ptr = b.as_ptr() as *const u16;
1073        let mut distance_value: Distance = 0.0;
1074        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1075        unsafe { simsimd_kl_f16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
1076        Some(distance_value)
1077    }
1078}
1079
1080impl ProbabilitySimilarity for bf16 {
1081    fn jensenshannon(a: &[Self], b: &[Self]) -> Option<Distance> {
1082        if a.len() != b.len() {
1083            return None;
1084        }
1085
1086        // Explicitly cast `*const bf16` to `*const u16`
1087        let a_ptr = a.as_ptr() as *const u16;
1088        let b_ptr = b.as_ptr() as *const u16;
1089        let mut distance_value: Distance = 0.0;
1090        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1091        unsafe { simsimd_js_bf16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
1092        Some(distance_value)
1093    }
1094
1095    fn kullbackleibler(a: &[Self], b: &[Self]) -> Option<Distance> {
1096        if a.len() != b.len() {
1097            return None;
1098        }
1099
1100        // Explicitly cast `*const bf16` to `*const u16`
1101        let a_ptr = a.as_ptr() as *const u16;
1102        let b_ptr = b.as_ptr() as *const u16;
1103        let mut distance_value: Distance = 0.0;
1104        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1105        unsafe { simsimd_kl_bf16(a_ptr, b_ptr, a.len() as u64size, distance_ptr) };
1106        Some(distance_value)
1107    }
1108}
1109
1110impl ProbabilitySimilarity for f32 {
1111    fn jensenshannon(a: &[Self], b: &[Self]) -> Option<Distance> {
1112        if a.len() != b.len() {
1113            return None;
1114        }
1115        let mut distance_value: Distance = 0.0;
1116        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1117        unsafe { simsimd_js_f32(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
1118        Some(distance_value)
1119    }
1120
1121    fn kullbackleibler(a: &[Self], b: &[Self]) -> Option<Distance> {
1122        if a.len() != b.len() {
1123            return None;
1124        }
1125        let mut distance_value: Distance = 0.0;
1126        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1127        unsafe { simsimd_kl_f32(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
1128        Some(distance_value)
1129    }
1130}
1131
1132impl ProbabilitySimilarity for f64 {
1133    fn jensenshannon(a: &[Self], b: &[Self]) -> Option<Distance> {
1134        if a.len() != b.len() {
1135            return None;
1136        }
1137        let mut distance_value: Distance = 0.0;
1138        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1139        unsafe { simsimd_js_f64(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
1140        Some(distance_value)
1141    }
1142
1143    fn kullbackleibler(a: &[Self], b: &[Self]) -> Option<Distance> {
1144        if a.len() != b.len() {
1145            return None;
1146        }
1147        let mut distance_value: Distance = 0.0;
1148        let distance_ptr: *mut Distance = &mut distance_value as *mut Distance;
1149        unsafe { simsimd_kl_f64(a.as_ptr(), b.as_ptr(), a.len() as u64size, distance_ptr) };
1150        Some(distance_value)
1151    }
1152}
1153
1154impl ComplexProducts for f16 {
1155    fn dot(a: &[Self], b: &[Self]) -> Option<ComplexProduct> {
1156        if a.len() != b.len() || a.len() % 2 != 0 {
1157            return None;
1158        }
1159        // Prepare the output array where the real and imaginary parts will be stored
1160        let mut product: [Distance; 2] = [0.0, 0.0];
1161        let product_ptr: *mut Distance = &mut product[0] as *mut _;
1162        // Explicitly cast `*const f16` to `*const u16`
1163        let a_ptr = a.as_ptr() as *const u16;
1164        let b_ptr = b.as_ptr() as *const u16;
1165        // The C function expects the number of complex pairs, not the total number of f16 elements
1166        unsafe { simsimd_dot_f16c(a_ptr, b_ptr, a.len() as u64size / 2, product_ptr) };
1167        Some((product[0], product[1]))
1168    }
1169
1170    fn vdot(a: &[Self], b: &[Self]) -> Option<ComplexProduct> {
1171        if a.len() != b.len() || a.len() % 2 != 0 {
1172            return None;
1173        }
1174        let mut product: [Distance; 2] = [0.0, 0.0];
1175        let product_ptr: *mut Distance = &mut product[0] as *mut _;
1176        let a_ptr = a.as_ptr() as *const u16;
1177        let b_ptr = b.as_ptr() as *const u16;
1178        // The C function expects the number of complex pairs, not the total number of f16 elements
1179        unsafe { simsimd_vdot_f16c(a_ptr, b_ptr, a.len() as u64size / 2, product_ptr) };
1180        Some((product[0], product[1]))
1181    }
1182}
1183
1184impl ComplexProducts for bf16 {
1185    fn dot(a: &[Self], b: &[Self]) -> Option<ComplexProduct> {
1186        if a.len() != b.len() || a.len() % 2 != 0 {
1187            return None;
1188        }
1189        // Prepare the output array where the real and imaginary parts will be stored
1190        let mut product: [Distance; 2] = [0.0, 0.0];
1191        let product_ptr: *mut Distance = &mut product[0] as *mut _;
1192        // Explicitly cast `*const bf16` to `*const u16`
1193        let a_ptr = a.as_ptr() as *const u16;
1194        let b_ptr = b.as_ptr() as *const u16;
1195        // The C function expects the number of complex pairs, not the total number of bf16 elements
1196        unsafe { simsimd_dot_bf16c(a_ptr, b_ptr, a.len() as u64size / 2, product_ptr) };
1197        Some((product[0], product[1]))
1198    }
1199
1200    fn vdot(a: &[Self], b: &[Self]) -> Option<ComplexProduct> {
1201        if a.len() != b.len() || a.len() % 2 != 0 {
1202            return None;
1203        }
1204        // Prepare the output array where the real and imaginary parts will be stored
1205        let mut product: [Distance; 2] = [0.0, 0.0];
1206        let product_ptr: *mut Distance = &mut product[0] as *mut _;
1207        // Explicitly cast `*const bf16` to `*const u16`
1208        let a_ptr = a.as_ptr() as *const u16;
1209        let b_ptr = b.as_ptr() as *const u16;
1210        // The C function expects the number of complex pairs, not the total number of bf16 elements
1211        unsafe { simsimd_vdot_bf16c(a_ptr, b_ptr, a.len() as u64size / 2, product_ptr) };
1212        Some((product[0], product[1]))
1213    }
1214}
1215
1216impl ComplexProducts for f32 {
1217    fn dot(a: &[Self], b: &[Self]) -> Option<ComplexProduct> {
1218        if a.len() != b.len() || a.len() % 2 != 0 {
1219            return None;
1220        }
1221        let mut product: [Distance; 2] = [0.0, 0.0];
1222        let product_ptr: *mut Distance = &mut product[0] as *mut _;
1223        // The C function expects the number of complex pairs, not the total number of floats
1224        unsafe { simsimd_dot_f32c(a.as_ptr(), b.as_ptr(), a.len() as u64size / 2, product_ptr) };
1225        Some((product[0], product[1]))
1226    }
1227
1228    fn vdot(a: &[Self], b: &[Self]) -> Option<ComplexProduct> {
1229        if a.len() != b.len() || a.len() % 2 != 0 {
1230            return None;
1231        }
1232        let mut product: [Distance; 2] = [0.0, 0.0];
1233        let product_ptr: *mut Distance = &mut product[0] as *mut _;
1234        // The C function expects the number of complex pairs, not the total number of floats
1235        unsafe { simsimd_vdot_f32c(a.as_ptr(), b.as_ptr(), a.len() as u64size / 2, product_ptr) };
1236        Some((product[0], product[1]))
1237    }
1238}
1239
1240impl ComplexProducts for f64 {
1241    fn dot(a: &[Self], b: &[Self]) -> Option<ComplexProduct> {
1242        if a.len() != b.len() || a.len() % 2 != 0 {
1243            return None;
1244        }
1245        let mut product: [Distance; 2] = [0.0, 0.0];
1246        let product_ptr: *mut Distance = &mut product[0] as *mut _;
1247        // The C function expects the number of complex pairs, not the total number of floats
1248        unsafe { simsimd_dot_f64c(a.as_ptr(), b.as_ptr(), a.len() as u64size / 2, product_ptr) };
1249        Some((product[0], product[1]))
1250    }
1251
1252    fn vdot(a: &[Self], b: &[Self]) -> Option<ComplexProduct> {
1253        if a.len() != b.len() || a.len() % 2 != 0 {
1254            return None;
1255        }
1256        let mut product: [Distance; 2] = [0.0, 0.0];
1257        let product_ptr: *mut Distance = &mut product[0] as *mut _;
1258        // The C function expects the number of complex pairs, not the total number of floats
1259        unsafe { simsimd_vdot_f64c(a.as_ptr(), b.as_ptr(), a.len() as u64size / 2, product_ptr) };
1260        Some((product[0], product[1]))
1261    }
1262}
1263
1264#[cfg(test)]
1265mod tests {
1266    use super::*;
1267    use half::bf16 as HalfBF16;
1268    use half::f16 as HalfF16;
1269
1270    #[test]
1271    fn hardware_features_detection() {
1272        let uses_arm = capabilities::uses_neon() || capabilities::uses_sve();
1273        let uses_x86 = capabilities::uses_haswell()
1274            || capabilities::uses_skylake()
1275            || capabilities::uses_ice()
1276            || capabilities::uses_genoa()
1277            || capabilities::uses_sapphire()
1278            || capabilities::uses_turin();
1279
1280        // The CPU can't simultaneously support ARM and x86 SIMD extensions
1281        if uses_arm {
1282            assert!(!uses_x86);
1283        }
1284        if uses_x86 {
1285            assert!(!uses_arm);
1286        }
1287
1288        println!("- uses_neon: {}", capabilities::uses_neon());
1289        println!("- uses_neon_f16: {}", capabilities::uses_neon_f16());
1290        println!("- uses_neon_bf16: {}", capabilities::uses_neon_bf16());
1291        println!("- uses_neon_i8: {}", capabilities::uses_neon_i8());
1292        println!("- uses_sve: {}", capabilities::uses_sve());
1293        println!("- uses_sve_f16: {}", capabilities::uses_sve_f16());
1294        println!("- uses_sve_bf16: {}", capabilities::uses_sve_bf16());
1295        println!("- uses_sve_i8: {}", capabilities::uses_sve_i8());
1296        println!("- uses_haswell: {}", capabilities::uses_haswell());
1297        println!("- uses_skylake: {}", capabilities::uses_skylake());
1298        println!("- uses_ice: {}", capabilities::uses_ice());
1299        println!("- uses_genoa: {}", capabilities::uses_genoa());
1300        println!("- uses_sapphire: {}", capabilities::uses_sapphire());
1301        println!("- uses_turin: {}", capabilities::uses_turin());
1302        println!("- uses_sierra: {}", capabilities::uses_sierra());
1303    }
1304
1305    //
1306    fn assert_almost_equal(left: Distance, right: Distance, tolerance: Distance) {
1307        let lower = right - tolerance;
1308        let upper = right + tolerance;
1309
1310        assert!(left >= lower && left <= upper);
1311    }
1312
1313    #[test]
1314    fn cos_i8() {
1315        let a = &[3, 97, 127];
1316        let b = &[3, 97, 127];
1317
1318        if let Some(result) = SpatialSimilarity::cosine(a, b) {
1319            assert_almost_equal(0.00012027938, result, 0.01);
1320        }
1321    }
1322
1323    #[test]
1324    fn cos_f32() {
1325        let a = &[1.0, 2.0, 3.0];
1326        let b = &[4.0, 5.0, 6.0];
1327
1328        if let Some(result) = SpatialSimilarity::cosine(a, b) {
1329            assert_almost_equal(0.025, result, 0.01);
1330        }
1331    }
1332
1333    #[test]
1334    fn dot_i8() {
1335        let a = &[1, 2, 3];
1336        let b = &[4, 5, 6];
1337
1338        if let Some(result) = SpatialSimilarity::dot(a, b) {
1339            assert_almost_equal(32.0, result, 0.01);
1340        }
1341    }
1342
1343    #[test]
1344    fn dot_f32() {
1345        let a = &[1.0, 2.0, 3.0];
1346        let b = &[4.0, 5.0, 6.0];
1347
1348        if let Some(result) = SpatialSimilarity::dot(a, b) {
1349            assert_almost_equal(32.0, result, 0.01);
1350        }
1351    }
1352
1353    #[test]
1354    fn dot_f32_complex() {
1355        // Let's consider these as complex numbers where every pair is (real, imaginary)
1356        let a: &[f32; 4] = &[1.0, 2.0, 3.0, 4.0]; // Represents two complex numbers: 1+2i, 3+4i
1357        let b: &[f32; 4] = &[5.0, 6.0, 7.0, 8.0]; // Represents two complex numbers: 5+6i, 7+8i
1358
1359        if let Some((real, imag)) = ComplexProducts::dot(a, b) {
1360            assert_almost_equal(-18.0, real, 0.01);
1361            assert_almost_equal(68.0, imag, 0.01);
1362        }
1363    }
1364
1365    #[test]
1366    fn vdot_f32_complex() {
1367        // Here we're assuming a similar setup to the previous test, but for the Hermitian (conjugate) dot product
1368        let a: &[f32; 4] = &[1.0, 2.0, 3.0, 4.0]; // Represents two complex numbers: 1+2i, 3+4i
1369        let b: &[f32; 4] = &[5.0, 6.0, 7.0, 8.0]; // Represents two complex numbers: 5+6i, 7+8i
1370
1371        if let Some((real, imag)) = ComplexProducts::vdot(a, b) {
1372            assert_almost_equal(70.0, real, 0.01);
1373            assert_almost_equal(-8.0, imag, 0.01);
1374        }
1375    }
1376
1377    #[test]
1378    fn l2sq_i8() {
1379        let a = &[1, 2, 3];
1380        let b = &[4, 5, 6];
1381
1382        if let Some(result) = SpatialSimilarity::sqeuclidean(a, b) {
1383            assert_almost_equal(27.0, result, 0.01);
1384        }
1385    }
1386
1387    #[test]
1388    fn l2sq_f32() {
1389        let a = &[1.0, 2.0, 3.0];
1390        let b = &[4.0, 5.0, 6.0];
1391
1392        if let Some(result) = SpatialSimilarity::sqeuclidean(a, b) {
1393            assert_almost_equal(27.0, result, 0.01);
1394        }
1395    }
1396
1397    #[test]
1398    fn l2_f32() {
1399        let a: &[f32; 3] = &[1.0, 2.0, 3.0];
1400        let b: &[f32; 3] = &[4.0, 5.0, 6.0];
1401        if let Some(result) = SpatialSimilarity::euclidean(a, b) {
1402            assert_almost_equal(5.2, result, 0.01);
1403        }
1404    }
1405
1406    #[test]
1407    fn l2_f64() {
1408        let a: &[f64; 3] = &[1.0, 2.0, 3.0];
1409        let b: &[f64; 3] = &[4.0, 5.0, 6.0];
1410        if let Some(result) = SpatialSimilarity::euclidean(a, b) {
1411            assert_almost_equal(5.2, result, 0.01);
1412        }
1413    }
1414
1415    #[test]
1416    fn l2_f16() {
1417        let a_half: Vec<HalfF16> = vec![1.0, 2.0, 3.0]
1418            .iter()
1419            .map(|&x| HalfF16::from_f32(x))
1420            .collect();
1421        let b_half: Vec<HalfF16> = vec![4.0, 5.0, 6.0]
1422            .iter()
1423            .map(|&x| HalfF16::from_f32(x))
1424            .collect();
1425
1426        let a_simsimd: &[f16] =
1427            unsafe { std::slice::from_raw_parts(a_half.as_ptr() as *const f16, a_half.len()) };
1428        let b_simsimd: &[f16] =
1429            unsafe { std::slice::from_raw_parts(b_half.as_ptr() as *const f16, b_half.len()) };
1430
1431        if let Some(result) = SpatialSimilarity::euclidean(&a_simsimd, &b_simsimd) {
1432            assert_almost_equal(5.2, result, 0.01);
1433        }
1434    }
1435
1436    #[test]
1437    fn l2_i8() {
1438        let a = &[1, 2, 3];
1439        let b = &[4, 5, 6];
1440
1441        if let Some(result) = SpatialSimilarity::euclidean(a, b) {
1442            assert_almost_equal(5.2, result, 0.01);
1443        }
1444    }
1445    // Adding new tests for bit-level distances
1446    #[test]
1447    fn hamming_u8() {
1448        let a = &[0b01010101, 0b11110000, 0b10101010];
1449        let b = &[0b01010101, 0b11110000, 0b10101010];
1450
1451        if let Some(result) = BinarySimilarity::hamming(a, b) {
1452            assert_almost_equal(0.0, result, 0.01);
1453        }
1454    }
1455
1456    #[test]
1457    fn jaccard_u8() {
1458        // For binary data, treat each byte as a set of bits
1459        let a = &[0b11110000, 0b00001111, 0b10101010];
1460        let b = &[0b11110000, 0b00001111, 0b01010101];
1461
1462        if let Some(result) = BinarySimilarity::jaccard(a, b) {
1463            assert_almost_equal(0.5, result, 0.01);
1464        }
1465    }
1466
1467    // Adding new tests for probability similarities
1468    #[test]
1469    fn js_f32() {
1470        let a: &[f32; 3] = &[0.1, 0.9, 0.0];
1471        let b: &[f32; 3] = &[0.2, 0.8, 0.0];
1472
1473        if let Some(result) = ProbabilitySimilarity::jensenshannon(a, b) {
1474            assert_almost_equal(0.099, result, 0.01);
1475        }
1476    }
1477
1478    #[test]
1479    fn kl_f32() {
1480        let a: &[f32; 3] = &[0.1, 0.9, 0.0];
1481        let b: &[f32; 3] = &[0.2, 0.8, 0.0];
1482
1483        if let Some(result) = ProbabilitySimilarity::kullbackleibler(a, b) {
1484            assert_almost_equal(0.036, result, 0.01);
1485        }
1486    }
1487
1488    #[test]
1489    fn cos_f16_same() {
1490        // Assuming these u16 values represent f16 bit patterns, and they are identical
1491        let a_u16: &[u16] = &[15360, 16384, 17408]; // Corresponding to some f16 values
1492        let b_u16: &[u16] = &[15360, 16384, 17408]; // Same as above for simplicity
1493
1494        // Reinterpret cast from &[u16] to &[f16]
1495        let a_f16: &[f16] =
1496            unsafe { std::slice::from_raw_parts(a_u16.as_ptr() as *const f16, a_u16.len()) };
1497        let b_f16: &[f16] =
1498            unsafe { std::slice::from_raw_parts(b_u16.as_ptr() as *const f16, b_u16.len()) };
1499
1500        if let Some(result) = SpatialSimilarity::cosine(a_f16, b_f16) {
1501            assert_almost_equal(0.0, result, 0.01);
1502        }
1503    }
1504
1505    #[test]
1506    fn cos_bf16_same() {
1507        // Assuming these u16 values represent bf16 bit patterns, and they are identical
1508        let a_u16: &[u16] = &[15360, 16384, 17408]; // Corresponding to some bf16 values
1509        let b_u16: &[u16] = &[15360, 16384, 17408]; // Same as above for simplicity
1510
1511        // Reinterpret cast from &[u16] to &[bf16]
1512        let a_bf16: &[bf16] =
1513            unsafe { std::slice::from_raw_parts(a_u16.as_ptr() as *const bf16, a_u16.len()) };
1514        let b_bf16: &[bf16] =
1515            unsafe { std::slice::from_raw_parts(b_u16.as_ptr() as *const bf16, b_u16.len()) };
1516
1517        if let Some(result) = SpatialSimilarity::cosine(a_bf16, b_bf16) {
1518            assert_almost_equal(0.0, result, 0.01);
1519        }
1520    }
1521
1522    #[test]
1523    fn cos_f16_interop() {
1524        let a_half: Vec<HalfF16> = vec![1.0, 2.0, 3.0]
1525            .iter()
1526            .map(|&x| HalfF16::from_f32(x))
1527            .collect();
1528        let b_half: Vec<HalfF16> = vec![4.0, 5.0, 6.0]
1529            .iter()
1530            .map(|&x| HalfF16::from_f32(x))
1531            .collect();
1532
1533        // SAFETY: This is safe as long as the memory representations are guaranteed to be identical,
1534        // which they are due to both being #[repr(transparent)] wrappers around u16.
1535        let a_simsimd: &[f16] =
1536            unsafe { std::slice::from_raw_parts(a_half.as_ptr() as *const f16, a_half.len()) };
1537        let b_simsimd: &[f16] =
1538            unsafe { std::slice::from_raw_parts(b_half.as_ptr() as *const f16, b_half.len()) };
1539
1540        // Use the reinterpret-casted slices with your SpatialSimilarity implementation
1541        if let Some(result) = SpatialSimilarity::cosine(a_simsimd, b_simsimd) {
1542            assert_almost_equal(0.025, result, 0.01);
1543        }
1544    }
1545
1546    #[test]
1547    fn cos_bf16_interop() {
1548        let a_half: Vec<HalfBF16> = vec![1.0, 2.0, 3.0]
1549            .iter()
1550            .map(|&x| HalfBF16::from_f32(x))
1551            .collect();
1552        let b_half: Vec<HalfBF16> = vec![4.0, 5.0, 6.0]
1553            .iter()
1554            .map(|&x| HalfBF16::from_f32(x))
1555            .collect();
1556
1557        // SAFETY: This is safe as long as the memory representations are guaranteed to be identical,
1558        // which they are due to both being #[repr(transparent)] wrappers around u16.
1559        let a_simsimd: &[bf16] =
1560            unsafe { std::slice::from_raw_parts(a_half.as_ptr() as *const bf16, a_half.len()) };
1561        let b_simsimd: &[bf16] =
1562            unsafe { std::slice::from_raw_parts(b_half.as_ptr() as *const bf16, b_half.len()) };
1563
1564        // Use the reinterpret-casted slices with your SpatialSimilarity implementation
1565        if let Some(result) = SpatialSimilarity::cosine(a_simsimd, b_simsimd) {
1566            assert_almost_equal(0.025, result, 0.01);
1567        }
1568    }
1569
1570    #[test]
1571    fn intersect_u16() {
1572        {
1573            let a_u16: &[u16] = &[153, 16384, 17408];
1574            let b_u16: &[u16] = &[7408, 15360, 16384];
1575
1576            if let Some(result) = Sparse::intersect(a_u16, b_u16) {
1577                assert_almost_equal(1.0, result, 0.0001);
1578            }
1579        }
1580
1581        {
1582            let a_u16: &[u16] = &[8, 153, 11638];
1583            let b_u16: &[u16] = &[7408, 15360, 16384];
1584
1585            if let Some(result) = Sparse::intersect(a_u16, b_u16) {
1586                assert_almost_equal(0.0, result, 0.0001);
1587            }
1588        }
1589    }
1590
1591    #[test]
1592    fn intersect_u32() {
1593        {
1594            let a_u32: &[u32] = &[11, 153];
1595            let b_u32: &[u32] = &[11, 153, 7408, 16384];
1596
1597            if let Some(result) = Sparse::intersect(a_u32, b_u32) {
1598                assert_almost_equal(2.0, result, 0.0001);
1599            }
1600        }
1601
1602        {
1603            let a_u32: &[u32] = &[153, 7408, 11638];
1604            let b_u32: &[u32] = &[153, 7408, 11638];
1605
1606            if let Some(result) = Sparse::intersect(a_u32, b_u32) {
1607                assert_almost_equal(3.0, result, 0.0001);
1608            }
1609        }
1610    }
1611
1612    /// Reference implementation of set intersection using Rust's standard library
1613    fn reference_intersect<T: Ord>(a: &[T], b: &[T]) -> usize {
1614        let mut a_iter = a.iter();
1615        let mut b_iter = b.iter();
1616        let mut a_current = a_iter.next();
1617        let mut b_current = b_iter.next();
1618        let mut count = 0;
1619
1620        while let (Some(a_val), Some(b_val)) = (a_current, b_current) {
1621            match a_val.cmp(b_val) {
1622                core::cmp::Ordering::Less => a_current = a_iter.next(),
1623                core::cmp::Ordering::Greater => b_current = b_iter.next(),
1624                core::cmp::Ordering::Equal => {
1625                    count += 1;
1626                    a_current = a_iter.next();
1627                    b_current = b_iter.next();
1628                }
1629            }
1630        }
1631        count
1632    }
1633
1634    /// Generate test arrays with various sizes and patterns for intersection testing
1635    /// Includes empty, small, medium, large arrays with different overlap characteristics
1636    fn generate_intersection_test_arrays<T>() -> Vec<Vec<T>>
1637    where
1638        T: core::convert::TryFrom<u32> + Copy,
1639        <T as core::convert::TryFrom<u32>>::Error: core::fmt::Debug,
1640    {
1641        vec![
1642            // Empty array
1643            vec![],
1644            // Single element
1645            vec![T::try_from(42).unwrap()],
1646            // Very small arrays (< 16 elements) - tests serial fallback
1647            vec![
1648                T::try_from(1).unwrap(),
1649                T::try_from(5).unwrap(),
1650                T::try_from(10).unwrap(),
1651            ],
1652            vec![
1653                T::try_from(2).unwrap(),
1654                T::try_from(4).unwrap(),
1655                T::try_from(6).unwrap(),
1656                T::try_from(8).unwrap(),
1657                T::try_from(10).unwrap(),
1658                T::try_from(12).unwrap(),
1659                T::try_from(14).unwrap(),
1660            ],
1661            // Small arrays (< 32 elements) - boundary case for Turin
1662            (0..14).map(|x| T::try_from(x * 10).unwrap()).collect(),
1663            (5..20).map(|x| T::try_from(x * 10).unwrap()).collect(),
1664            // Medium arrays (32-64 elements) - tests one or two SIMD iterations
1665            (0..40).map(|x| T::try_from(x * 2).unwrap()).collect(),
1666            (10..50).map(|x| T::try_from(x * 2).unwrap()).collect(), // 50% overlap with previous
1667            (0..45).map(|x| T::try_from(x * 3).unwrap()).collect(),  // Different stride
1668            // Large arrays (> 64 elements) - tests main SIMD loop
1669            (0..100).map(|x| T::try_from(x * 2).unwrap()).collect(),
1670            (50..150).map(|x| T::try_from(x * 2).unwrap()).collect(), // 50% overlap
1671            (0..100).map(|x| T::try_from(x * 5).unwrap()).collect(),  // Sparse overlap
1672            (0..150)
1673                .filter(|x| x % 7 == 0)
1674                .map(|x| T::try_from(x).unwrap())
1675                .collect(),
1676            // Very large arrays (> 256 elements) - stress test
1677            (0..500).map(|x| T::try_from(x * 3).unwrap()).collect(),
1678            (100..600).map(|x| T::try_from(x * 3).unwrap()).collect(), // Large overlap
1679            (0..600).map(|x| T::try_from(x * 7).unwrap()).collect(),   // Minimal overlap
1680            // Edge cases: no overlap at all
1681            (0..50).map(|x| T::try_from(x * 2).unwrap()).collect(),
1682            (1000..1050).map(|x| T::try_from(x * 2).unwrap()).collect(), // Completely disjoint
1683            // Dense arrays at boundaries
1684            (0..16).map(|x| T::try_from(x).unwrap()).collect(), // Exactly 16 elements
1685            (0..32).map(|x| T::try_from(x).unwrap()).collect(), // Exactly 32 elements
1686            (0..64).map(|x| T::try_from(x).unwrap()).collect(), // Exactly 64 elements
1687        ]
1688    }
1689
1690    #[test]
1691    fn intersect_u32_comprehensive() {
1692        let test_arrays: Vec<Vec<u32>> = generate_intersection_test_arrays();
1693
1694        for (i, array_a) in test_arrays.iter().enumerate() {
1695            for (j, array_b) in test_arrays.iter().enumerate() {
1696                let expected = reference_intersect(array_a, array_b);
1697                let result =
1698                    Sparse::intersect(array_a.as_slice(), array_b.as_slice()).unwrap() as usize;
1699
1700                assert_eq!(
1701                    expected,
1702                    result,
1703                    "Intersection mismatch for arrays[{}] (len={}) and arrays[{}] (len={})",
1704                    i,
1705                    array_a.len(),
1706                    j,
1707                    array_b.len()
1708                );
1709            }
1710        }
1711    }
1712
1713    #[test]
1714    fn intersect_u16_comprehensive() {
1715        let test_arrays: Vec<Vec<u16>> = generate_intersection_test_arrays();
1716
1717        for (i, array_a) in test_arrays.iter().enumerate() {
1718            for (j, array_b) in test_arrays.iter().enumerate() {
1719                let expected = reference_intersect(array_a, array_b);
1720                let result =
1721                    Sparse::intersect(array_a.as_slice(), array_b.as_slice()).unwrap() as usize;
1722
1723                assert_eq!(
1724                    expected,
1725                    result,
1726                    "Intersection mismatch for arrays[{}] (len={}) and arrays[{}] (len={})",
1727                    i,
1728                    array_a.len(),
1729                    j,
1730                    array_b.len()
1731                );
1732            }
1733        }
1734    }
1735
1736    #[test]
1737    fn intersect_edge_cases() {
1738        // Test empty arrays
1739        let empty: &[u32] = &[];
1740        let non_empty: &[u32] = &[1, 2, 3];
1741        assert_eq!(Sparse::intersect(empty, empty), Some(0.0));
1742        assert_eq!(Sparse::intersect(empty, non_empty), Some(0.0));
1743        assert_eq!(Sparse::intersect(non_empty, empty), Some(0.0));
1744
1745        // Test single element matches
1746        assert_eq!(Sparse::intersect(&[42u32], &[42u32]), Some(1.0));
1747        assert_eq!(Sparse::intersect(&[42u32], &[43u32]), Some(0.0));
1748
1749        // Test no overlap
1750        let a: &[u32] = &[1, 2, 3, 4, 5];
1751        let b: &[u32] = &[10, 20, 30, 40, 50];
1752        assert_eq!(Sparse::intersect(a, b), Some(0.0));
1753
1754        // Test complete overlap
1755        let c: &[u32] = &[10, 20, 30, 40, 50];
1756        assert_eq!(Sparse::intersect(c, c), Some(5.0));
1757
1758        // Test one element at boundary (exactly at 16, 32, 64 element boundaries)
1759        let boundary_16: Vec<u32> = (0..16).collect();
1760        let boundary_32: Vec<u32> = (0..32).collect();
1761        let boundary_64: Vec<u32> = (0..64).collect();
1762
1763        assert_eq!(Sparse::intersect(&boundary_16, &boundary_16), Some(16.0));
1764        assert_eq!(Sparse::intersect(&boundary_32, &boundary_32), Some(32.0));
1765        assert_eq!(Sparse::intersect(&boundary_64, &boundary_64), Some(64.0));
1766
1767        // Test partial overlap at boundaries
1768        let first_half: Vec<u32> = (0..32).collect();
1769        let second_half: Vec<u32> = (16..48).collect();
1770        assert_eq!(Sparse::intersect(&first_half, &second_half), Some(16.0));
1771    }
1772
1773    #[test]
1774    fn f16_arithmetic() {
1775        let a = f16::from_f32(3.5);
1776        let b = f16::from_f32(2.0);
1777
1778        // Test basic arithmetic
1779        assert!((a + b).to_f32() - 5.5 < 0.01);
1780        assert!((a - b).to_f32() - 1.5 < 0.01);
1781        assert!((a * b).to_f32() - 7.0 < 0.01);
1782        assert!((a / b).to_f32() - 1.75 < 0.01);
1783        assert!((-a).to_f32() + 3.5 < 0.01);
1784
1785        // Test constants
1786        assert!(f16::ZERO.to_f32() == 0.0);
1787        assert!((f16::ONE.to_f32() - 1.0).abs() < 0.01);
1788        assert!((f16::NEG_ONE.to_f32() + 1.0).abs() < 0.01);
1789
1790        // Test comparisons
1791        assert!(a > b);
1792        assert!(!(a < b));
1793        assert!(a == a);
1794
1795        // Test utility methods
1796        assert!((-a).abs().to_f32() - 3.5 < 0.01);
1797        assert!(a.is_finite());
1798        assert!(!a.is_nan());
1799        assert!(!a.is_infinite());
1800    }
1801
1802    #[test]
1803    fn bf16_arithmetic() {
1804        let a = bf16::from_f32(3.5);
1805        let b = bf16::from_f32(2.0);
1806
1807        // Test basic arithmetic
1808        assert!((a + b).to_f32() - 5.5 < 0.1);
1809        assert!((a - b).to_f32() - 1.5 < 0.1);
1810        assert!((a * b).to_f32() - 7.0 < 0.1);
1811        assert!((a / b).to_f32() - 1.75 < 0.1);
1812        assert!((-a).to_f32() + 3.5 < 0.1);
1813
1814        // Test constants
1815        assert!(bf16::ZERO.to_f32() == 0.0);
1816        assert!((bf16::ONE.to_f32() - 1.0).abs() < 0.01);
1817        assert!((bf16::NEG_ONE.to_f32() + 1.0).abs() < 0.01);
1818
1819        // Test comparisons
1820        assert!(a > b);
1821        assert!(!(a < b));
1822        assert!(a == a);
1823
1824        // Test utility methods
1825        assert!((-a).abs().to_f32() - 3.5 < 0.1);
1826        assert!(a.is_finite());
1827        assert!(!a.is_nan());
1828        assert!(!a.is_infinite());
1829    }
1830
1831    #[test]
1832    fn bf16_dot() {
1833        let brain_a: Vec<bf16> = vec![1.0, 2.0, 3.0, 1.0, 2.0]
1834            .iter()
1835            .map(|&x| bf16::from_f32(x))
1836            .collect();
1837        let brain_b: Vec<bf16> = vec![4.0, 5.0, 6.0, 4.0, 5.0]
1838            .iter()
1839            .map(|&x| bf16::from_f32(x))
1840            .collect();
1841        if let Some(result) = <bf16 as SpatialSimilarity>::dot(&brain_a, &brain_b) {
1842            assert_eq!(46.0, result);
1843        }
1844    }
1845}