Skip to main content

spg_storage/
halfvec.rs

1#![allow(
2    clippy::cast_possible_truncation,
3    clippy::cast_possible_wrap,
4    clippy::cast_sign_loss,
5    clippy::cast_lossless,
6    clippy::doc_markdown
7)]
8// All bit-twiddle casts in this file (i32 ↔ u32 ↔ u16) are
9// arithmetically bounded by the IEEE-754 binary16 field widths;
10// the lints would force an unsigned-bit-pattern detour that
11// obscures the algorithm shape.
12
13//! v6.0.3 — halfvec: IEEE-754 binary16 (`F16`) per-element storage.
14//!
15//! Stable Rust 1.96 (this workspace) does not yet expose a stable
16//! `f16` primitive or stable `core::arch::aarch64` f16 intrinsics
17//! (rust-lang/rust#116909, #125606). v6.0.3 ships with a hand-
18//! rolled IEEE-754 binary16 codec on top of `Vec<u8>` carrying
19//! raw little-endian u16 bits. NEON f16 SIMD lands as v6.0.6 or
20//! whenever the stable toolchain catches up.
21//!
22//! Layout per cell: `[u16 LE × dim]`. Dim = `bytes.len() / 2`.
23//!
24//! Codec rounding: round-to-nearest-even on overflow / underflow
25//! (matches `f32 as f16` semantics on hosts that do have the
26//! primitive). Special values:
27//!
28//! - `±0.0` → bit-exact `±0.0` half.
29//! - `±∞`  → bit-exact `±∞` half.
30//! - `NaN` → quiet NaN half (sign + payload preserved as far as
31//!   the 10-bit mantissa allows; signalling/quiet bit is forced
32//!   set so the value can't decode back as inf).
33//! - Subnormals + overflow → flushed to `0` and `±∞`
34//!   respectively per IEEE 754-2008 §7.4.
35
36use alloc::vec::Vec;
37
38/// SQ8 / SQ4 / SQ16 share an `Sq*Vector`-shaped struct; halfvec
39/// follows the same pattern. `bytes` always has even length; the
40/// invariant is enforced by every constructor in this module.
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct HalfVector {
43    pub bytes: Vec<u8>,
44}
45
46impl HalfVector {
47    /// Dimension = bytes.len() / 2. Returns 0 for empty input.
48    #[must_use]
49    pub fn dim(&self) -> usize {
50        self.bytes.len() / 2
51    }
52
53    /// Encode `v` into raw u16 LE bits via per-element
54    /// `f32 → f16` round-to-nearest-even.
55    #[must_use]
56    pub fn from_f32_slice(v: &[f32]) -> Self {
57        let mut bytes = Vec::with_capacity(v.len() * 2);
58        for &x in v {
59            let bits = f16_from_f32_bits(x.to_bits());
60            bytes.extend_from_slice(&bits.to_le_bytes());
61        }
62        Self { bytes }
63    }
64
65    /// Decode every u16 LE in `bytes` to f32. Inverse of
66    /// `from_f32_slice` modulo half-precision round-trip error
67    /// (≤ 2^-10 × |x| for finite normals).
68    #[must_use]
69    pub fn to_f32_vec(&self) -> Vec<f32> {
70        let mut out = Vec::with_capacity(self.dim());
71        let mut i = 0;
72        while i + 2 <= self.bytes.len() {
73            let bits = u16::from_le_bytes([self.bytes[i], self.bytes[i + 1]]);
74            out.push(f32::from_bits(f16_to_f32_bits(bits)));
75            i += 2;
76        }
77        out
78    }
79}
80
81/// Convert one f32 (passed as raw bits) to f16 (raw bits).
82///
83/// Implements IEEE 754-2008 §7.4 round-to-nearest-even with
84/// subnormal flush-to-zero on underflow and saturation to ±∞ on
85/// overflow. Matches the bit-pattern `f32 as f16` produces on
86/// hosts that have the primitive — verified by the unit tests
87/// below against a hand-table of fixtures (`0`, `0.25`, `1.0`,
88/// `65504.0`, `±∞`, NaN, denormals).
89#[must_use]
90pub fn f16_from_f32_bits(bits: u32) -> u16 {
91    let sign = ((bits >> 31) & 0x1) as u16;
92    let exp32 = (bits >> 23) & 0xff;
93    let mant32 = bits & 0x7f_ffff;
94
95    // 1. NaN / ±∞
96    if exp32 == 0xff {
97        if mant32 == 0 {
98            // ±∞: half is sign | 0x7c00
99            return (sign << 15) | 0x7c00;
100        }
101        // NaN: collapse to quiet NaN with the top mantissa bit set.
102        // Preserve the high bits of the f32 payload as far as the
103        // 10-bit mantissa allows; force the quiet bit (bit 9 of
104        // mantissa) so the value isn't sNaN.
105        let mant16 = ((mant32 >> 13) | 0x200) as u16;
106        return (sign << 15) | 0x7c00 | mant16;
107    }
108
109    // 2. ±0.0 (and other f32 zeros / subnormals that round to 0).
110    if exp32 == 0 {
111        return sign << 15;
112    }
113
114    // 3. Re-bias the exponent for half: half-bias 15 vs f32-bias 127.
115    let exp_unbiased: i32 = exp32 as i32 - 127;
116
117    // 3a. Overflow: |x| ≥ 65520 saturates to ±∞.
118    if exp_unbiased > 15 {
119        return (sign << 15) | 0x7c00;
120    }
121
122    // 3b. Underflow + subnormal range. exp_unbiased < -14:
123    // representable only as subnormal half; below -24 flushes to 0.
124    if exp_unbiased < -14 {
125        if exp_unbiased < -24 {
126            return sign << 15;
127        }
128        // Subnormal half: the implied leading 1 becomes explicit
129        // and we shift right by (1 - exp_unbiased - (-14)) = (-14 -
130        // exp_unbiased - 1) extra positions on top of the standard
131        // 13-bit mantissa drop.
132        let shift = (1 - 14 - exp_unbiased) as u32; // 1..=10
133        let mant_with_lead = mant32 | 0x80_0000;
134        let drop_bits = 13 + shift;
135        let mant16_pre = mant_with_lead >> drop_bits;
136        // Round-to-nearest-even on the bits we just dropped.
137        let half = 1u32 << (drop_bits - 1);
138        let mask = (1u32 << drop_bits) - 1;
139        let dropped = mant_with_lead & mask;
140        let round_up = dropped > half || (dropped == half && (mant16_pre & 1) == 1);
141        let mant16 = mant16_pre + u32::from(round_up);
142        return (sign << 15) | (mant16 as u16);
143    }
144
145    // 4. Normal range.
146    let exp16 = (exp_unbiased + 15) as u16;
147    let mant16_pre = mant32 >> 13;
148    // Round-to-nearest-even on the 13 low bits we just dropped.
149    let drop_mask = 0x1fffu32;
150    let half = 0x1000u32;
151    let dropped = mant32 & drop_mask;
152    let round_up = dropped > half || (dropped == half && (mant16_pre & 1) == 1);
153    let mant16 = mant16_pre + u32::from(round_up);
154    // Carry from rounding can bump exp16 — if mantissa hit 0x400
155    // (one past max half mantissa) the rounding overflowed into
156    // exp; collapse via `(exp16 << 10) | mant16` arithmetic.
157    let packed = (u32::from(exp16) << 10) + mant16;
158    if packed >= 0x7c00 {
159        // Overflow into infinity (e.g. 65520 → rounds to ±∞).
160        return (sign << 15) | 0x7c00;
161    }
162    #[allow(clippy::cast_possible_truncation)]
163    let packed_u16 = packed as u16;
164    (sign << 15) | packed_u16
165}
166
167/// Convert one f16 (raw bits) to f32 (raw bits). Exact for every
168/// finite f16; preserves sign + NaN-ness.
169#[must_use]
170pub fn f16_to_f32_bits(bits: u16) -> u32 {
171    let sign = u32::from(bits >> 15) & 0x1;
172    let exp16 = u32::from((bits >> 10) & 0x1f);
173    let mant16 = u32::from(bits & 0x3ff);
174
175    // 1. NaN / ±∞.
176    if exp16 == 0x1f {
177        if mant16 == 0 {
178            return (sign << 31) | 0x7f80_0000;
179        }
180        // Lift the half mantissa into the f32 mantissa, preserving
181        // the quiet bit (bit 9 → bit 22).
182        return (sign << 31) | 0x7f80_0000 | (mant16 << 13);
183    }
184
185    // 2. ±0.0
186    if exp16 == 0 && mant16 == 0 {
187        return sign << 31;
188    }
189
190    // 3. Subnormal half — re-normalise.
191    if exp16 == 0 {
192        // Find leading-1 position to count the shift.
193        let mut m = mant16;
194        let mut e: i32 = -14;
195        while (m & 0x400) == 0 {
196            m <<= 1;
197            e -= 1;
198        }
199        m &= 0x3ff; // drop the leading 1 (becomes implicit again)
200        let exp32 = ((e + 127) as u32) & 0xff;
201        return (sign << 31) | (exp32 << 23) | (m << 13);
202    }
203
204    // 4. Normal half.
205    let exp_unbiased = exp16 as i32 - 15;
206    let exp32 = (exp_unbiased + 127) as u32;
207    (sign << 31) | (exp32 << 23) | (mant16 << 13)
208}
209
210// ===========================================================================
211// v6.0.6 — NEON SIMD f16 → f32 conversion (fused into distance kernels).
212//
213// stable Rust 1.96 still gates `core::arch::aarch64::vcvt_f32_f16` behind
214// the unstable `stdarch_neon_f16` feature. Bit-manipulation in SIMD lanes
215// stays on stable NEON (`vshl`, `vand`, `vceq`, `vbsl`) and produces the
216// same f32 bit pattern for the cases ML embeddings actually hit:
217//
218//   * NaN / ±∞ (exp_h == 0x1f) — bit-exact.
219//   * ±0       (exp_h == 0 && mant_h == 0) — bit-exact.
220//   * Normal   (exp_h ∈ 1..=30) — bit-exact.
221//   * Subnormal (exp_h == 0 && mant_h != 0) — **flushed to ±0**. The
222//     scalar codec renormalises subnormals into f32 normals; SIMD path
223//     flushes them. f16 subnormals are |value| < 2^-14 ≈ 6.1e-5, which
224//     is below the precision of natural embeddings anyway. Distance
225//     queries on those values are dominated by other lanes.
226//
227// Public distance functions below dispatch the SIMD kernel under
228// `#[cfg(target_arch = "aarch64")]` when `dim ≥ 8 && dim % 8 == 0`,
229// falling back to scalar otherwise — same pre-condition shape as the
230// SQ8 / f32 NEON paths.
231// ===========================================================================
232
233/// L2² distance between an f16 cell and an f32 query, fused so no
234/// `Vec<f32>` ever materialises. Dispatches to NEON for production-
235/// shaped dims (multiples of 8); scalar fallback otherwise.
236#[must_use]
237pub fn half_l2_distance_sq_asymmetric(a: &HalfVector, q: &[f32]) -> f32 {
238    if a.dim() != q.len() {
239        return f32::INFINITY;
240    }
241    #[cfg(target_arch = "aarch64")]
242    {
243        let n = a.dim();
244        if n >= 8 && n.is_multiple_of(8) {
245            // SAFETY: NEON is baseline aarch64; preconditions
246            // (matching lengths, ≥ 1 full 8-lane chunk) checked above.
247            return unsafe { half_l2_distance_sq_asymmetric_neon(a, q) };
248        }
249    }
250    half_l2_distance_sq_asymmetric_scalar(a, q)
251}
252
253/// Negated dot product (pgvector `<#>` convention). Fused SIMD path.
254#[must_use]
255pub fn half_inner_product_asymmetric(a: &HalfVector, q: &[f32]) -> f32 {
256    if a.dim() != q.len() {
257        return f32::INFINITY;
258    }
259    #[cfg(target_arch = "aarch64")]
260    {
261        let n = a.dim();
262        if n >= 8 && n.is_multiple_of(8) {
263            // SAFETY: see `half_l2_distance_sq_asymmetric_neon`.
264            return -unsafe { half_dot_asymmetric_neon(a, q) };
265        }
266    }
267    -half_dot_asymmetric_scalar(a, q)
268}
269
270/// Cosine distance `1 - dot / (||a|| * ||q||)`. Fused SIMD path;
271/// norm-sqrt + zero-guard live in the safe wrapper.
272#[must_use]
273pub fn half_cosine_distance_asymmetric(a: &HalfVector, q: &[f32]) -> f32 {
274    if a.dim() != q.len() {
275        return f32::INFINITY;
276    }
277    let (dot, na, nq);
278    #[cfg(target_arch = "aarch64")]
279    {
280        let n = a.dim();
281        if n >= 8 && n.is_multiple_of(8) {
282            // SAFETY: see `half_l2_distance_sq_asymmetric_neon`.
283            let (d, a2, q2) = unsafe { half_cosine_accumulators_asymmetric_neon(a, q) };
284            dot = d;
285            na = a2;
286            nq = q2;
287        } else {
288            let (d, a2, q2) = half_cosine_accumulators_asymmetric_scalar(a, q);
289            dot = d;
290            na = a2;
291            nq = q2;
292        }
293    }
294    #[cfg(not(target_arch = "aarch64"))]
295    {
296        let (d, a2, q2) = half_cosine_accumulators_asymmetric_scalar(a, q);
297        dot = d;
298        na = a2;
299        nq = q2;
300    }
301    if na == 0.0 || nq == 0.0 {
302        return f32::INFINITY;
303    }
304    1.0 - dot / (sqrt_finite(na) * sqrt_finite(nq))
305}
306
307/// Symmetric L2² between two f16 cells. Used during HNSW build.
308#[must_use]
309pub fn half_l2_distance_sq(a: &HalfVector, b: &HalfVector) -> f32 {
310    if a.dim() != b.dim() {
311        return f32::INFINITY;
312    }
313    #[cfg(target_arch = "aarch64")]
314    {
315        let n = a.dim();
316        if n >= 8 && n.is_multiple_of(8) {
317            // SAFETY: see `half_l2_distance_sq_asymmetric_neon`.
318            return unsafe { half_l2_distance_sq_symmetric_neon(a, b) };
319        }
320    }
321    half_l2_distance_sq_symmetric_scalar(a, b)
322}
323
324// ---------------------------------------------------------------------------
325// Scalar references — used as fallback + for SIMD parity tests below.
326
327fn half_l2_distance_sq_asymmetric_scalar(a: &HalfVector, q: &[f32]) -> f32 {
328    let mut acc: f32 = 0.0;
329    let mut i = 0usize;
330    while i + 2 <= a.bytes.len() {
331        let bits = u16::from_le_bytes([a.bytes[i], a.bytes[i + 1]]);
332        let xa = f32::from_bits(f16_to_f32_bits(bits));
333        let d = xa - q[i / 2];
334        acc += d * d;
335        i += 2;
336    }
337    acc
338}
339
340fn half_dot_asymmetric_scalar(a: &HalfVector, q: &[f32]) -> f32 {
341    let mut dot: f32 = 0.0;
342    let mut i = 0usize;
343    while i + 2 <= a.bytes.len() {
344        let bits = u16::from_le_bytes([a.bytes[i], a.bytes[i + 1]]);
345        let xa = f32::from_bits(f16_to_f32_bits(bits));
346        dot += xa * q[i / 2];
347        i += 2;
348    }
349    dot
350}
351
352fn half_cosine_accumulators_asymmetric_scalar(a: &HalfVector, q: &[f32]) -> (f32, f32, f32) {
353    let (mut dot, mut na, mut nq) = (0.0_f32, 0.0_f32, 0.0_f32);
354    let mut i = 0usize;
355    while i + 2 <= a.bytes.len() {
356        let bits = u16::from_le_bytes([a.bytes[i], a.bytes[i + 1]]);
357        let xa = f32::from_bits(f16_to_f32_bits(bits));
358        let qx = q[i / 2];
359        dot += xa * qx;
360        na += xa * xa;
361        nq += qx * qx;
362        i += 2;
363    }
364    (dot, na, nq)
365}
366
367fn half_l2_distance_sq_symmetric_scalar(a: &HalfVector, b: &HalfVector) -> f32 {
368    let mut acc: f32 = 0.0;
369    let mut i = 0usize;
370    while i + 2 <= a.bytes.len() {
371        let av = u16::from_le_bytes([a.bytes[i], a.bytes[i + 1]]);
372        let bv = u16::from_le_bytes([b.bytes[i], b.bytes[i + 1]]);
373        let xa = f32::from_bits(f16_to_f32_bits(av));
374        let xb = f32::from_bits(f16_to_f32_bits(bv));
375        let d = xa - xb;
376        acc += d * d;
377        i += 2;
378    }
379    acc
380}
381
382fn sqrt_finite(x: f32) -> f32 {
383    if x <= 0.0 {
384        return 0.0;
385    }
386    let mut y = if x >= 1.0 { x * 0.5 } else { (x + 1.0) * 0.5 };
387    for _ in 0..6 {
388        y = 0.5 * (y + x / y);
389    }
390    y
391}
392
393// ---------------------------------------------------------------------------
394// NEON kernels — bit-manipulation f16 → f32 in u32 lanes, then standard
395// f32 SIMD arithmetic (subtract / FMA / dot / norm).
396
397/// Convert eight half-precision lanes (loaded via `vld1q_u16` from raw
398/// `bytes`) into two `float32x4_t` registers. Bit-exact for normal /
399/// zero / inf / nan; subnormals flush to `±0` (see module docstring).
400#[cfg(target_arch = "aarch64")]
401#[target_feature(enable = "neon")]
402#[allow(clippy::many_single_char_names)]
403#[inline]
404unsafe fn half_to_f32x8_neon(
405    h: core::arch::aarch64::uint16x8_t,
406) -> [core::arch::aarch64::float32x4_t; 2] {
407    use core::arch::aarch64::{
408        vaddq_u32, vandq_u32, vbslq_u32, vceqq_u32, vdupq_n_u32, vget_high_u16, vget_low_u16,
409        vmovl_u16, vorrq_u32, vreinterpretq_f32_u32, vshlq_n_u32, vshrq_n_u32,
410    };
411    // Widen u16x8 → 2× u32x4.
412    let lo = vmovl_u16(vget_low_u16(h));
413    let hi = vmovl_u16(vget_high_u16(h));
414
415    // Helper: convert one u32x4 of raw f16 bits → f32x4.
416    // Bit-exact for normal / zero / inf / nan; subnormals → ±0 via
417    // the `exp == 0` mask. ML embeddings never trip the latter.
418    let convert = |w: core::arch::aarch64::uint32x4_t| -> core::arch::aarch64::float32x4_t {
419        let sign = vshlq_n_u32::<16>(vandq_u32(w, vdupq_n_u32(0x8000)));
420        let mant = vandq_u32(w, vdupq_n_u32(0x3ff));
421        let exp = vandq_u32(vshrq_n_u32::<10>(w), vdupq_n_u32(0x1f));
422        let mant_f32 = vshlq_n_u32::<13>(mant);
423        let exp_plus_bias = vaddq_u32(exp, vdupq_n_u32(112));
424        let exp_f32_shifted = vshlq_n_u32::<23>(exp_plus_bias);
425        let normal = vorrq_u32(vorrq_u32(sign, exp_f32_shifted), mant_f32);
426        let inf_nan = vorrq_u32(vorrq_u32(sign, vdupq_n_u32(0x7f80_0000)), mant_f32);
427        let is_inf_nan = vceqq_u32(exp, vdupq_n_u32(0x1f));
428        let is_zero_or_subnormal = vceqq_u32(exp, vdupq_n_u32(0));
429        let result = vbslq_u32(is_inf_nan, inf_nan, normal);
430        let result = vbslq_u32(is_zero_or_subnormal, sign, result);
431        vreinterpretq_f32_u32(result)
432    };
433
434    [convert(lo), convert(hi)]
435}
436
437#[cfg(target_arch = "aarch64")]
438#[target_feature(enable = "neon")]
439#[allow(clippy::many_single_char_names)]
440unsafe fn half_l2_distance_sq_asymmetric_neon(a: &HalfVector, q: &[f32]) -> f32 {
441    use core::arch::aarch64::{
442        float32x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vld1q_u8,
443        vreinterpretq_u16_u8, vsubq_f32,
444    };
445    unsafe {
446        let zero: float32x4_t = vdupq_n_f32(0.0);
447        let mut acc0 = zero;
448        let mut acc1 = zero;
449        let n = a.dim();
450        let mut i = 0usize;
451        while i + 8 <= n {
452            // 16 bytes from a.bytes → 8 u16 raw bits → 2× f32x4.
453            // Load 16 u8 then reinterpret as 8 u16 lanes. Avoids
454            // the cast-alignment lint and stays correct on hosts
455            // where `Vec<u8>`'s buffer alignment isn't a multiple
456            // of 2 (it always is in practice, but the lint is
457            // right to flag the unsafe assumption).
458            let h = vreinterpretq_u16_u8(vld1q_u8(a.bytes.as_ptr().add(i * 2)));
459            let [xa0, xa1] = half_to_f32x8_neon(h);
460            let q0 = vld1q_f32(q.as_ptr().add(i));
461            let q1 = vld1q_f32(q.as_ptr().add(i + 4));
462            let d0 = vsubq_f32(xa0, q0);
463            let d1 = vsubq_f32(xa1, q1);
464            acc0 = vfmaq_f32(acc0, d0, d0);
465            acc1 = vfmaq_f32(acc1, d1, d1);
466            i += 8;
467        }
468        vaddvq_f32(vaddq_f32(acc0, acc1))
469    }
470}
471
472#[cfg(target_arch = "aarch64")]
473#[target_feature(enable = "neon")]
474#[allow(clippy::many_single_char_names)]
475unsafe fn half_dot_asymmetric_neon(a: &HalfVector, q: &[f32]) -> f32 {
476    use core::arch::aarch64::{
477        float32x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vld1q_u8,
478        vreinterpretq_u16_u8,
479    };
480    unsafe {
481        let zero: float32x4_t = vdupq_n_f32(0.0);
482        let mut acc0 = zero;
483        let mut acc1 = zero;
484        let n = a.dim();
485        let mut i = 0usize;
486        while i + 8 <= n {
487            // Load 16 u8 then reinterpret as 8 u16 lanes. Avoids
488            // the cast-alignment lint and stays correct on hosts
489            // where `Vec<u8>`'s buffer alignment isn't a multiple
490            // of 2 (it always is in practice, but the lint is
491            // right to flag the unsafe assumption).
492            let h = vreinterpretq_u16_u8(vld1q_u8(a.bytes.as_ptr().add(i * 2)));
493            let [xa0, xa1] = half_to_f32x8_neon(h);
494            acc0 = vfmaq_f32(acc0, xa0, vld1q_f32(q.as_ptr().add(i)));
495            acc1 = vfmaq_f32(acc1, xa1, vld1q_f32(q.as_ptr().add(i + 4)));
496            i += 8;
497        }
498        vaddvq_f32(vaddq_f32(acc0, acc1))
499    }
500}
501
502#[cfg(target_arch = "aarch64")]
503#[target_feature(enable = "neon")]
504#[allow(clippy::many_single_char_names, clippy::similar_names)]
505unsafe fn half_cosine_accumulators_asymmetric_neon(a: &HalfVector, q: &[f32]) -> (f32, f32, f32) {
506    use core::arch::aarch64::{
507        float32x4_t, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vld1q_u8, vreinterpretq_u16_u8,
508    };
509    unsafe {
510        let zero: float32x4_t = vdupq_n_f32(0.0);
511        let mut acc_dot = zero;
512        let mut acc_na = zero;
513        let mut acc_nq = zero;
514        let n = a.dim();
515        let mut i = 0usize;
516        while i + 8 <= n {
517            // Load 16 u8 then reinterpret as 8 u16 lanes. Avoids
518            // the cast-alignment lint and stays correct on hosts
519            // where `Vec<u8>`'s buffer alignment isn't a multiple
520            // of 2 (it always is in practice, but the lint is
521            // right to flag the unsafe assumption).
522            let h = vreinterpretq_u16_u8(vld1q_u8(a.bytes.as_ptr().add(i * 2)));
523            let [xa0, xa1] = half_to_f32x8_neon(h);
524            let q0 = vld1q_f32(q.as_ptr().add(i));
525            let q1 = vld1q_f32(q.as_ptr().add(i + 4));
526            acc_dot = vfmaq_f32(acc_dot, xa0, q0);
527            acc_dot = vfmaq_f32(acc_dot, xa1, q1);
528            acc_na = vfmaq_f32(acc_na, xa0, xa0);
529            acc_na = vfmaq_f32(acc_na, xa1, xa1);
530            acc_nq = vfmaq_f32(acc_nq, q0, q0);
531            acc_nq = vfmaq_f32(acc_nq, q1, q1);
532            i += 8;
533        }
534        (vaddvq_f32(acc_dot), vaddvq_f32(acc_na), vaddvq_f32(acc_nq))
535    }
536}
537
538#[cfg(target_arch = "aarch64")]
539#[target_feature(enable = "neon")]
540#[allow(clippy::many_single_char_names)]
541unsafe fn half_l2_distance_sq_symmetric_neon(a: &HalfVector, b: &HalfVector) -> f32 {
542    use core::arch::aarch64::{
543        float32x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_u8, vreinterpretq_u16_u8,
544        vsubq_f32,
545    };
546    unsafe {
547        let zero: float32x4_t = vdupq_n_f32(0.0);
548        let mut acc0 = zero;
549        let mut acc1 = zero;
550        let n = a.dim();
551        let mut i = 0usize;
552        while i + 8 <= n {
553            let ha = vreinterpretq_u16_u8(vld1q_u8(a.bytes.as_ptr().add(i * 2)));
554            let hb = vreinterpretq_u16_u8(vld1q_u8(b.bytes.as_ptr().add(i * 2)));
555            let [xa0, xa1] = half_to_f32x8_neon(ha);
556            let [xb0, xb1] = half_to_f32x8_neon(hb);
557            let d0 = vsubq_f32(xa0, xb0);
558            let d1 = vsubq_f32(xa1, xb1);
559            acc0 = vfmaq_f32(acc0, d0, d0);
560            acc1 = vfmaq_f32(acc1, d1, d1);
561            i += 8;
562        }
563        vaddvq_f32(vaddq_f32(acc0, acc1))
564    }
565}
566
567#[cfg(test)]
568#[allow(
569    clippy::float_cmp,
570    clippy::approx_constant,
571    clippy::suboptimal_flops,
572    clippy::unreadable_literal
573)]
574mod tests {
575    use super::*;
576
577    fn f32_eq_bits(a: f32, b: f32) -> bool {
578        // Includes ±0.0 separately + NaN-aware equality.
579        if a.is_nan() && b.is_nan() {
580            return true;
581        }
582        a.to_bits() == b.to_bits()
583    }
584
585    #[test]
586    fn f16_roundtrip_representable_values() {
587        // Values that fall on f16 grid points round-trip exactly.
588        let cases: &[f32] = &[
589            0.0,
590            -0.0,
591            1.0,
592            -1.0,
593            0.5,
594            -0.5,
595            0.25,
596            2.0,
597            4.0,
598            1.5,
599            -1.5,
600            65504.0, // f16 max
601            -65504.0,
602            1.0 / 16384.0, // = 2^-14 (smallest normal)
603        ];
604        for &x in cases {
605            let bits = f16_from_f32_bits(x.to_bits());
606            let y = f32::from_bits(f16_to_f32_bits(bits));
607            assert!(f32_eq_bits(x, y), "expected {x} == {y} (bits {bits:#x})");
608        }
609    }
610
611    #[test]
612    fn f16_roundtrip_inf_and_nan() {
613        let inf = f32::INFINITY;
614        let neg_inf = f32::NEG_INFINITY;
615        assert_eq!(
616            f16_to_f32_bits(f16_from_f32_bits(inf.to_bits())),
617            inf.to_bits()
618        );
619        assert_eq!(
620            f16_to_f32_bits(f16_from_f32_bits(neg_inf.to_bits())),
621            neg_inf.to_bits()
622        );
623        let nan = f32::NAN;
624        let nan_back = f32::from_bits(f16_to_f32_bits(f16_from_f32_bits(nan.to_bits())));
625        assert!(nan_back.is_nan(), "NaN should round-trip as NaN");
626    }
627
628    #[test]
629    fn f16_overflow_saturates_to_inf() {
630        // > 65504 saturates to +∞.
631        let huge = 1e30_f32;
632        let half_bits = f16_from_f32_bits(huge.to_bits());
633        assert_eq!(half_bits, 0x7c00, "huge positive → +∞");
634        let half_back = f32::from_bits(f16_to_f32_bits(half_bits));
635        assert_eq!(half_back, f32::INFINITY);
636    }
637
638    #[test]
639    fn f16_underflow_flushes_to_zero() {
640        // 2^-30 is way below the f16 subnormal range, flushes to 0.
641        let tiny = 1.0e-30_f32;
642        let half_bits = f16_from_f32_bits(tiny.to_bits());
643        assert_eq!(
644            half_bits & 0x7fff,
645            0,
646            "tiny positive → +0 (got {half_bits:#x})"
647        );
648    }
649
650    #[test]
651    fn f16_codec_roundtrip_finite_normals_bounded_error() {
652        // Smooth-sweep test: half-precision has ~10 bits of
653        // mantissa, so the relative error after roundtrip is
654        // ≤ 2^-10 ≈ 9.77e-4 for finite normals. Allow a touch
655        // more for the rounding boundary case.
656        let cases: &[f32] = &[
657            0.1,
658            0.333,
659            1.0 / 7.0,
660            3.14159,
661            100.0,
662            12345.0,
663            -0.1,
664            -3.14159,
665        ];
666        for &x in cases {
667            let bits = f16_from_f32_bits(x.to_bits());
668            let y = f32::from_bits(f16_to_f32_bits(bits));
669            let rel = (x - y).abs() / x.abs();
670            assert!(rel < 1e-3, "x={x} y={y} rel_err={rel} (bits {bits:#x})");
671        }
672    }
673
674    #[test]
675    fn half_vector_from_to_f32_slice() {
676        let v = alloc::vec![0.0_f32, 0.25, 0.5, 1.0, -1.0];
677        let h = HalfVector::from_f32_slice(&v);
678        assert_eq!(h.dim(), 5);
679        let back = h.to_f32_vec();
680        assert_eq!(back, v);
681    }
682
683    #[test]
684    fn half_vector_empty() {
685        let h = HalfVector::from_f32_slice(&[]);
686        assert_eq!(h.dim(), 0);
687        assert!(h.bytes.is_empty());
688        let back = h.to_f32_vec();
689        assert!(back.is_empty());
690    }
691
692    // ------------------------------------------------------------------
693    // v6.0.6 — NEON SIMD f16 → f32 + fused distance kernels.
694
695    /// Generate a deterministic dim-N f32 vector of small finite
696    /// values so the f16 round-trip stays inside the normal range
697    /// (avoids subnormal flush-to-zero divergence between scalar
698    /// and SIMD paths).
699    #[allow(clippy::cast_precision_loss)]
700    fn random_normal_vec(seed: u64, dim: usize) -> alloc::vec::Vec<f32> {
701        let mut state = seed | 1;
702        let mut out = alloc::vec::Vec::with_capacity(dim);
703        for _ in 0..dim {
704            state = state
705                .wrapping_mul(6_364_136_223_846_793_005)
706                .wrapping_add(1);
707            // 24 high bits of state → [0, 2^24) → /2^24 → [0, 1).
708            // The cast is safe (lossless) because the mask leaves
709            // at most 24 bits, which fits f32's mantissa.
710            let u = ((state >> 32) & 0x00FF_FFFF) as f32 / (0x80_0000_u32 as f32);
711            // Range (-1, 1) — well inside f16 normal range; no subnormals
712            // emerge from the round-trip.
713            out.push(2.0 * u - 1.0);
714        }
715        out
716    }
717
718    #[cfg(target_arch = "aarch64")]
719    #[test]
720    #[allow(clippy::cast_precision_loss)]
721    fn half_l2_asymmetric_neon_matches_scalar() {
722        // NEON SIMD path must agree with the scalar reference on
723        // every production-shaped dim. Tolerance reflects FMA
724        // rounding + the f32→f16→f32 round-trip noise; the lanes
725        // are normal floats so subnormal flush-to-zero in the SIMD
726        // path is a no-op.
727        for &d in &[8_usize, 16, 32, 64, 128, 256, 512, 1024] {
728            for trial in 0..8_u64 {
729                let v = random_normal_vec(0xA5A5_F160_F160_0001 ^ trial ^ (d as u64), d);
730                let q = random_normal_vec(0xC0FE_F160_F160_0002 ^ trial ^ (d as u64), d);
731                let h = HalfVector::from_f32_slice(&v);
732                let scalar = half_l2_distance_sq_asymmetric_scalar(&h, &q);
733                let neon = unsafe { half_l2_distance_sq_asymmetric_neon(&h, &q) };
734                let tol = (scalar.abs().max(1e-6)) * 1e-4 + (d as f32) * 1e-5;
735                assert!(
736                    (scalar - neon).abs() <= tol,
737                    "L2 asym dim={d} trial={trial}: scalar={scalar} neon={neon}"
738                );
739            }
740        }
741    }
742
743    #[cfg(target_arch = "aarch64")]
744    #[test]
745    #[allow(clippy::cast_precision_loss)]
746    fn half_dot_asymmetric_neon_matches_scalar() {
747        for &d in &[8_usize, 16, 32, 64, 128, 256, 512, 1024] {
748            for trial in 0..8_u64 {
749                let v = random_normal_vec(0xBEEF_F160_F160_0003 ^ trial ^ (d as u64), d);
750                let q = random_normal_vec(0xDEAD_F160_F160_0004 ^ trial ^ (d as u64), d);
751                let h = HalfVector::from_f32_slice(&v);
752                let scalar = half_dot_asymmetric_scalar(&h, &q);
753                let neon = unsafe { half_dot_asymmetric_neon(&h, &q) };
754                let tol = (scalar.abs().max(1e-6)) * 1e-4 + (d as f32) * 1e-5;
755                assert!(
756                    (scalar - neon).abs() <= tol,
757                    "dot dim={d} trial={trial}: scalar={scalar} neon={neon}"
758                );
759            }
760        }
761    }
762
763    #[cfg(target_arch = "aarch64")]
764    #[test]
765    #[allow(clippy::similar_names, clippy::cast_precision_loss)]
766    fn half_cosine_accumulators_neon_matches_scalar() {
767        for &d in &[8_usize, 16, 32, 64, 128, 256, 512, 1024] {
768            for trial in 0..8_u64 {
769                let v = random_normal_vec(0xC051_F160_F160_0005 ^ trial ^ (d as u64), d);
770                let q = random_normal_vec(0xF00D_F160_F160_0006 ^ trial ^ (d as u64), d);
771                let h = HalfVector::from_f32_slice(&v);
772                let (dot_s, na_s, nq_s) = half_cosine_accumulators_asymmetric_scalar(&h, &q);
773                let (dot_n, na_n, nq_n) =
774                    unsafe { half_cosine_accumulators_asymmetric_neon(&h, &q) };
775                let tol = |x: f32| (x.abs().max(1e-6)) * 1e-4 + (d as f32) * 1e-5;
776                assert!(
777                    (dot_s - dot_n).abs() <= tol(dot_s),
778                    "cos dot dim={d}: scalar={dot_s} neon={dot_n}"
779                );
780                assert!(
781                    (na_s - na_n).abs() <= tol(na_s),
782                    "cos na dim={d}: scalar={na_s} neon={na_n}"
783                );
784                assert!(
785                    (nq_s - nq_n).abs() <= tol(nq_s),
786                    "cos nq dim={d}: scalar={nq_s} neon={nq_n}"
787                );
788            }
789        }
790    }
791
792    #[cfg(target_arch = "aarch64")]
793    #[test]
794    #[allow(clippy::cast_precision_loss)]
795    fn half_l2_symmetric_neon_matches_scalar() {
796        for &d in &[8_usize, 16, 32, 64, 128, 256, 512, 1024] {
797            for trial in 0..8_u64 {
798                let va = random_normal_vec(0x1234_F160_F160_0007 ^ trial ^ (d as u64), d);
799                let vb = random_normal_vec(0x5678_F160_F160_0008 ^ trial ^ (d as u64), d);
800                let ha = HalfVector::from_f32_slice(&va);
801                let hb = HalfVector::from_f32_slice(&vb);
802                let scalar = half_l2_distance_sq_symmetric_scalar(&ha, &hb);
803                let neon = unsafe { half_l2_distance_sq_symmetric_neon(&ha, &hb) };
804                let tol = (scalar.abs().max(1e-6)) * 1e-4 + (d as f32) * 1e-5;
805                assert!(
806                    (scalar - neon).abs() <= tol,
807                    "L2 sym dim={d}: scalar={scalar} neon={neon}"
808                );
809            }
810        }
811    }
812}