1#![allow(
2 clippy::cast_possible_truncation,
3 clippy::cast_possible_wrap,
4 clippy::cast_sign_loss,
5 clippy::cast_lossless,
6 clippy::doc_markdown
7)]
8use alloc::vec::Vec;
37
38#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct HalfVector {
43 pub bytes: Vec<u8>,
44}
45
46impl HalfVector {
47 #[must_use]
49 pub fn dim(&self) -> usize {
50 self.bytes.len() / 2
51 }
52
53 #[must_use]
56 pub fn from_f32_slice(v: &[f32]) -> Self {
57 let mut bytes = Vec::with_capacity(v.len() * 2);
58 for &x in v {
59 let bits = f16_from_f32_bits(x.to_bits());
60 bytes.extend_from_slice(&bits.to_le_bytes());
61 }
62 Self { bytes }
63 }
64
65 #[must_use]
69 pub fn to_f32_vec(&self) -> Vec<f32> {
70 let mut out = Vec::with_capacity(self.dim());
71 let mut i = 0;
72 while i + 2 <= self.bytes.len() {
73 let bits = u16::from_le_bytes([self.bytes[i], self.bytes[i + 1]]);
74 out.push(f32::from_bits(f16_to_f32_bits(bits)));
75 i += 2;
76 }
77 out
78 }
79}
80
81#[must_use]
90pub fn f16_from_f32_bits(bits: u32) -> u16 {
91 let sign = ((bits >> 31) & 0x1) as u16;
92 let exp32 = (bits >> 23) & 0xff;
93 let mant32 = bits & 0x7f_ffff;
94
95 if exp32 == 0xff {
97 if mant32 == 0 {
98 return (sign << 15) | 0x7c00;
100 }
101 let mant16 = ((mant32 >> 13) | 0x200) as u16;
106 return (sign << 15) | 0x7c00 | mant16;
107 }
108
109 if exp32 == 0 {
111 return sign << 15;
112 }
113
114 let exp_unbiased: i32 = exp32 as i32 - 127;
116
117 if exp_unbiased > 15 {
119 return (sign << 15) | 0x7c00;
120 }
121
122 if exp_unbiased < -14 {
125 if exp_unbiased < -24 {
126 return sign << 15;
127 }
128 let shift = (1 - 14 - exp_unbiased) as u32; let mant_with_lead = mant32 | 0x80_0000;
134 let drop_bits = 13 + shift;
135 let mant16_pre = mant_with_lead >> drop_bits;
136 let half = 1u32 << (drop_bits - 1);
138 let mask = (1u32 << drop_bits) - 1;
139 let dropped = mant_with_lead & mask;
140 let round_up = dropped > half || (dropped == half && (mant16_pre & 1) == 1);
141 let mant16 = mant16_pre + u32::from(round_up);
142 return (sign << 15) | (mant16 as u16);
143 }
144
145 let exp16 = (exp_unbiased + 15) as u16;
147 let mant16_pre = mant32 >> 13;
148 let drop_mask = 0x1fffu32;
150 let half = 0x1000u32;
151 let dropped = mant32 & drop_mask;
152 let round_up = dropped > half || (dropped == half && (mant16_pre & 1) == 1);
153 let mant16 = mant16_pre + u32::from(round_up);
154 let packed = (u32::from(exp16) << 10) + mant16;
158 if packed >= 0x7c00 {
159 return (sign << 15) | 0x7c00;
161 }
162 #[allow(clippy::cast_possible_truncation)]
163 let packed_u16 = packed as u16;
164 (sign << 15) | packed_u16
165}
166
167#[must_use]
170pub fn f16_to_f32_bits(bits: u16) -> u32 {
171 let sign = u32::from(bits >> 15) & 0x1;
172 let exp16 = u32::from((bits >> 10) & 0x1f);
173 let mant16 = u32::from(bits & 0x3ff);
174
175 if exp16 == 0x1f {
177 if mant16 == 0 {
178 return (sign << 31) | 0x7f80_0000;
179 }
180 return (sign << 31) | 0x7f80_0000 | (mant16 << 13);
183 }
184
185 if exp16 == 0 && mant16 == 0 {
187 return sign << 31;
188 }
189
190 if exp16 == 0 {
192 let mut m = mant16;
194 let mut e: i32 = -14;
195 while (m & 0x400) == 0 {
196 m <<= 1;
197 e -= 1;
198 }
199 m &= 0x3ff; let exp32 = ((e + 127) as u32) & 0xff;
201 return (sign << 31) | (exp32 << 23) | (m << 13);
202 }
203
204 let exp_unbiased = exp16 as i32 - 15;
206 let exp32 = (exp_unbiased + 127) as u32;
207 (sign << 31) | (exp32 << 23) | (mant16 << 13)
208}
209
210#[must_use]
237pub fn half_l2_distance_sq_asymmetric(a: &HalfVector, q: &[f32]) -> f32 {
238 if a.dim() != q.len() {
239 return f32::INFINITY;
240 }
241 #[cfg(target_arch = "aarch64")]
242 {
243 let n = a.dim();
244 if n >= 8 && n.is_multiple_of(8) {
245 return unsafe { half_l2_distance_sq_asymmetric_neon(a, q) };
248 }
249 }
250 half_l2_distance_sq_asymmetric_scalar(a, q)
251}
252
253#[must_use]
255pub fn half_inner_product_asymmetric(a: &HalfVector, q: &[f32]) -> f32 {
256 if a.dim() != q.len() {
257 return f32::INFINITY;
258 }
259 #[cfg(target_arch = "aarch64")]
260 {
261 let n = a.dim();
262 if n >= 8 && n.is_multiple_of(8) {
263 return -unsafe { half_dot_asymmetric_neon(a, q) };
265 }
266 }
267 -half_dot_asymmetric_scalar(a, q)
268}
269
270#[must_use]
273pub fn half_cosine_distance_asymmetric(a: &HalfVector, q: &[f32]) -> f32 {
274 if a.dim() != q.len() {
275 return f32::INFINITY;
276 }
277 let (dot, na, nq);
278 #[cfg(target_arch = "aarch64")]
279 {
280 let n = a.dim();
281 if n >= 8 && n.is_multiple_of(8) {
282 let (d, a2, q2) = unsafe { half_cosine_accumulators_asymmetric_neon(a, q) };
284 dot = d;
285 na = a2;
286 nq = q2;
287 } else {
288 let (d, a2, q2) = half_cosine_accumulators_asymmetric_scalar(a, q);
289 dot = d;
290 na = a2;
291 nq = q2;
292 }
293 }
294 #[cfg(not(target_arch = "aarch64"))]
295 {
296 let (d, a2, q2) = half_cosine_accumulators_asymmetric_scalar(a, q);
297 dot = d;
298 na = a2;
299 nq = q2;
300 }
301 if na == 0.0 || nq == 0.0 {
302 return f32::INFINITY;
303 }
304 1.0 - dot / (sqrt_finite(na) * sqrt_finite(nq))
305}
306
307#[must_use]
309pub fn half_l2_distance_sq(a: &HalfVector, b: &HalfVector) -> f32 {
310 if a.dim() != b.dim() {
311 return f32::INFINITY;
312 }
313 #[cfg(target_arch = "aarch64")]
314 {
315 let n = a.dim();
316 if n >= 8 && n.is_multiple_of(8) {
317 return unsafe { half_l2_distance_sq_symmetric_neon(a, b) };
319 }
320 }
321 half_l2_distance_sq_symmetric_scalar(a, b)
322}
323
324fn half_l2_distance_sq_asymmetric_scalar(a: &HalfVector, q: &[f32]) -> f32 {
328 let mut acc: f32 = 0.0;
329 let mut i = 0usize;
330 while i + 2 <= a.bytes.len() {
331 let bits = u16::from_le_bytes([a.bytes[i], a.bytes[i + 1]]);
332 let xa = f32::from_bits(f16_to_f32_bits(bits));
333 let d = xa - q[i / 2];
334 acc += d * d;
335 i += 2;
336 }
337 acc
338}
339
340fn half_dot_asymmetric_scalar(a: &HalfVector, q: &[f32]) -> f32 {
341 let mut dot: f32 = 0.0;
342 let mut i = 0usize;
343 while i + 2 <= a.bytes.len() {
344 let bits = u16::from_le_bytes([a.bytes[i], a.bytes[i + 1]]);
345 let xa = f32::from_bits(f16_to_f32_bits(bits));
346 dot += xa * q[i / 2];
347 i += 2;
348 }
349 dot
350}
351
352fn half_cosine_accumulators_asymmetric_scalar(a: &HalfVector, q: &[f32]) -> (f32, f32, f32) {
353 let (mut dot, mut na, mut nq) = (0.0_f32, 0.0_f32, 0.0_f32);
354 let mut i = 0usize;
355 while i + 2 <= a.bytes.len() {
356 let bits = u16::from_le_bytes([a.bytes[i], a.bytes[i + 1]]);
357 let xa = f32::from_bits(f16_to_f32_bits(bits));
358 let qx = q[i / 2];
359 dot += xa * qx;
360 na += xa * xa;
361 nq += qx * qx;
362 i += 2;
363 }
364 (dot, na, nq)
365}
366
367fn half_l2_distance_sq_symmetric_scalar(a: &HalfVector, b: &HalfVector) -> f32 {
368 let mut acc: f32 = 0.0;
369 let mut i = 0usize;
370 while i + 2 <= a.bytes.len() {
371 let av = u16::from_le_bytes([a.bytes[i], a.bytes[i + 1]]);
372 let bv = u16::from_le_bytes([b.bytes[i], b.bytes[i + 1]]);
373 let xa = f32::from_bits(f16_to_f32_bits(av));
374 let xb = f32::from_bits(f16_to_f32_bits(bv));
375 let d = xa - xb;
376 acc += d * d;
377 i += 2;
378 }
379 acc
380}
381
382fn sqrt_finite(x: f32) -> f32 {
383 if x <= 0.0 {
384 return 0.0;
385 }
386 let mut y = if x >= 1.0 { x * 0.5 } else { (x + 1.0) * 0.5 };
387 for _ in 0..6 {
388 y = 0.5 * (y + x / y);
389 }
390 y
391}
392
393#[cfg(target_arch = "aarch64")]
401#[target_feature(enable = "neon")]
402#[allow(clippy::many_single_char_names)]
403#[inline]
404unsafe fn half_to_f32x8_neon(
405 h: core::arch::aarch64::uint16x8_t,
406) -> [core::arch::aarch64::float32x4_t; 2] {
407 use core::arch::aarch64::{
408 vaddq_u32, vandq_u32, vbslq_u32, vceqq_u32, vdupq_n_u32, vget_high_u16, vget_low_u16,
409 vmovl_u16, vorrq_u32, vreinterpretq_f32_u32, vshlq_n_u32, vshrq_n_u32,
410 };
411 let lo = vmovl_u16(vget_low_u16(h));
413 let hi = vmovl_u16(vget_high_u16(h));
414
415 let convert = |w: core::arch::aarch64::uint32x4_t| -> core::arch::aarch64::float32x4_t {
419 let sign = vshlq_n_u32::<16>(vandq_u32(w, vdupq_n_u32(0x8000)));
420 let mant = vandq_u32(w, vdupq_n_u32(0x3ff));
421 let exp = vandq_u32(vshrq_n_u32::<10>(w), vdupq_n_u32(0x1f));
422 let mant_f32 = vshlq_n_u32::<13>(mant);
423 let exp_plus_bias = vaddq_u32(exp, vdupq_n_u32(112));
424 let exp_f32_shifted = vshlq_n_u32::<23>(exp_plus_bias);
425 let normal = vorrq_u32(vorrq_u32(sign, exp_f32_shifted), mant_f32);
426 let inf_nan = vorrq_u32(vorrq_u32(sign, vdupq_n_u32(0x7f80_0000)), mant_f32);
427 let is_inf_nan = vceqq_u32(exp, vdupq_n_u32(0x1f));
428 let is_zero_or_subnormal = vceqq_u32(exp, vdupq_n_u32(0));
429 let result = vbslq_u32(is_inf_nan, inf_nan, normal);
430 let result = vbslq_u32(is_zero_or_subnormal, sign, result);
431 vreinterpretq_f32_u32(result)
432 };
433
434 [convert(lo), convert(hi)]
435}
436
437#[cfg(target_arch = "aarch64")]
438#[target_feature(enable = "neon")]
439#[allow(clippy::many_single_char_names)]
440unsafe fn half_l2_distance_sq_asymmetric_neon(a: &HalfVector, q: &[f32]) -> f32 {
441 use core::arch::aarch64::{
442 float32x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vld1q_u8,
443 vreinterpretq_u16_u8, vsubq_f32,
444 };
445 unsafe {
446 let zero: float32x4_t = vdupq_n_f32(0.0);
447 let mut acc0 = zero;
448 let mut acc1 = zero;
449 let n = a.dim();
450 let mut i = 0usize;
451 while i + 8 <= n {
452 let h = vreinterpretq_u16_u8(vld1q_u8(a.bytes.as_ptr().add(i * 2)));
459 let [xa0, xa1] = half_to_f32x8_neon(h);
460 let q0 = vld1q_f32(q.as_ptr().add(i));
461 let q1 = vld1q_f32(q.as_ptr().add(i + 4));
462 let d0 = vsubq_f32(xa0, q0);
463 let d1 = vsubq_f32(xa1, q1);
464 acc0 = vfmaq_f32(acc0, d0, d0);
465 acc1 = vfmaq_f32(acc1, d1, d1);
466 i += 8;
467 }
468 vaddvq_f32(vaddq_f32(acc0, acc1))
469 }
470}
471
472#[cfg(target_arch = "aarch64")]
473#[target_feature(enable = "neon")]
474#[allow(clippy::many_single_char_names)]
475unsafe fn half_dot_asymmetric_neon(a: &HalfVector, q: &[f32]) -> f32 {
476 use core::arch::aarch64::{
477 float32x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vld1q_u8,
478 vreinterpretq_u16_u8,
479 };
480 unsafe {
481 let zero: float32x4_t = vdupq_n_f32(0.0);
482 let mut acc0 = zero;
483 let mut acc1 = zero;
484 let n = a.dim();
485 let mut i = 0usize;
486 while i + 8 <= n {
487 let h = vreinterpretq_u16_u8(vld1q_u8(a.bytes.as_ptr().add(i * 2)));
493 let [xa0, xa1] = half_to_f32x8_neon(h);
494 acc0 = vfmaq_f32(acc0, xa0, vld1q_f32(q.as_ptr().add(i)));
495 acc1 = vfmaq_f32(acc1, xa1, vld1q_f32(q.as_ptr().add(i + 4)));
496 i += 8;
497 }
498 vaddvq_f32(vaddq_f32(acc0, acc1))
499 }
500}
501
502#[cfg(target_arch = "aarch64")]
503#[target_feature(enable = "neon")]
504#[allow(clippy::many_single_char_names, clippy::similar_names)]
505unsafe fn half_cosine_accumulators_asymmetric_neon(a: &HalfVector, q: &[f32]) -> (f32, f32, f32) {
506 use core::arch::aarch64::{
507 float32x4_t, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vld1q_u8, vreinterpretq_u16_u8,
508 };
509 unsafe {
510 let zero: float32x4_t = vdupq_n_f32(0.0);
511 let mut acc_dot = zero;
512 let mut acc_na = zero;
513 let mut acc_nq = zero;
514 let n = a.dim();
515 let mut i = 0usize;
516 while i + 8 <= n {
517 let h = vreinterpretq_u16_u8(vld1q_u8(a.bytes.as_ptr().add(i * 2)));
523 let [xa0, xa1] = half_to_f32x8_neon(h);
524 let q0 = vld1q_f32(q.as_ptr().add(i));
525 let q1 = vld1q_f32(q.as_ptr().add(i + 4));
526 acc_dot = vfmaq_f32(acc_dot, xa0, q0);
527 acc_dot = vfmaq_f32(acc_dot, xa1, q1);
528 acc_na = vfmaq_f32(acc_na, xa0, xa0);
529 acc_na = vfmaq_f32(acc_na, xa1, xa1);
530 acc_nq = vfmaq_f32(acc_nq, q0, q0);
531 acc_nq = vfmaq_f32(acc_nq, q1, q1);
532 i += 8;
533 }
534 (vaddvq_f32(acc_dot), vaddvq_f32(acc_na), vaddvq_f32(acc_nq))
535 }
536}
537
538#[cfg(target_arch = "aarch64")]
539#[target_feature(enable = "neon")]
540#[allow(clippy::many_single_char_names)]
541unsafe fn half_l2_distance_sq_symmetric_neon(a: &HalfVector, b: &HalfVector) -> f32 {
542 use core::arch::aarch64::{
543 float32x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_u8, vreinterpretq_u16_u8,
544 vsubq_f32,
545 };
546 unsafe {
547 let zero: float32x4_t = vdupq_n_f32(0.0);
548 let mut acc0 = zero;
549 let mut acc1 = zero;
550 let n = a.dim();
551 let mut i = 0usize;
552 while i + 8 <= n {
553 let ha = vreinterpretq_u16_u8(vld1q_u8(a.bytes.as_ptr().add(i * 2)));
554 let hb = vreinterpretq_u16_u8(vld1q_u8(b.bytes.as_ptr().add(i * 2)));
555 let [xa0, xa1] = half_to_f32x8_neon(ha);
556 let [xb0, xb1] = half_to_f32x8_neon(hb);
557 let d0 = vsubq_f32(xa0, xb0);
558 let d1 = vsubq_f32(xa1, xb1);
559 acc0 = vfmaq_f32(acc0, d0, d0);
560 acc1 = vfmaq_f32(acc1, d1, d1);
561 i += 8;
562 }
563 vaddvq_f32(vaddq_f32(acc0, acc1))
564 }
565}
566
567#[cfg(test)]
568#[allow(
569 clippy::float_cmp,
570 clippy::approx_constant,
571 clippy::suboptimal_flops,
572 clippy::unreadable_literal
573)]
574mod tests {
575 use super::*;
576
577 fn f32_eq_bits(a: f32, b: f32) -> bool {
578 if a.is_nan() && b.is_nan() {
580 return true;
581 }
582 a.to_bits() == b.to_bits()
583 }
584
585 #[test]
586 fn f16_roundtrip_representable_values() {
587 let cases: &[f32] = &[
589 0.0,
590 -0.0,
591 1.0,
592 -1.0,
593 0.5,
594 -0.5,
595 0.25,
596 2.0,
597 4.0,
598 1.5,
599 -1.5,
600 65504.0, -65504.0,
602 1.0 / 16384.0, ];
604 for &x in cases {
605 let bits = f16_from_f32_bits(x.to_bits());
606 let y = f32::from_bits(f16_to_f32_bits(bits));
607 assert!(f32_eq_bits(x, y), "expected {x} == {y} (bits {bits:#x})");
608 }
609 }
610
611 #[test]
612 fn f16_roundtrip_inf_and_nan() {
613 let inf = f32::INFINITY;
614 let neg_inf = f32::NEG_INFINITY;
615 assert_eq!(
616 f16_to_f32_bits(f16_from_f32_bits(inf.to_bits())),
617 inf.to_bits()
618 );
619 assert_eq!(
620 f16_to_f32_bits(f16_from_f32_bits(neg_inf.to_bits())),
621 neg_inf.to_bits()
622 );
623 let nan = f32::NAN;
624 let nan_back = f32::from_bits(f16_to_f32_bits(f16_from_f32_bits(nan.to_bits())));
625 assert!(nan_back.is_nan(), "NaN should round-trip as NaN");
626 }
627
628 #[test]
629 fn f16_overflow_saturates_to_inf() {
630 let huge = 1e30_f32;
632 let half_bits = f16_from_f32_bits(huge.to_bits());
633 assert_eq!(half_bits, 0x7c00, "huge positive → +∞");
634 let half_back = f32::from_bits(f16_to_f32_bits(half_bits));
635 assert_eq!(half_back, f32::INFINITY);
636 }
637
638 #[test]
639 fn f16_underflow_flushes_to_zero() {
640 let tiny = 1.0e-30_f32;
642 let half_bits = f16_from_f32_bits(tiny.to_bits());
643 assert_eq!(
644 half_bits & 0x7fff,
645 0,
646 "tiny positive → +0 (got {half_bits:#x})"
647 );
648 }
649
650 #[test]
651 fn f16_codec_roundtrip_finite_normals_bounded_error() {
652 let cases: &[f32] = &[
657 0.1,
658 0.333,
659 1.0 / 7.0,
660 3.14159,
661 100.0,
662 12345.0,
663 -0.1,
664 -3.14159,
665 ];
666 for &x in cases {
667 let bits = f16_from_f32_bits(x.to_bits());
668 let y = f32::from_bits(f16_to_f32_bits(bits));
669 let rel = (x - y).abs() / x.abs();
670 assert!(rel < 1e-3, "x={x} y={y} rel_err={rel} (bits {bits:#x})");
671 }
672 }
673
674 #[test]
675 fn half_vector_from_to_f32_slice() {
676 let v = alloc::vec![0.0_f32, 0.25, 0.5, 1.0, -1.0];
677 let h = HalfVector::from_f32_slice(&v);
678 assert_eq!(h.dim(), 5);
679 let back = h.to_f32_vec();
680 assert_eq!(back, v);
681 }
682
683 #[test]
684 fn half_vector_empty() {
685 let h = HalfVector::from_f32_slice(&[]);
686 assert_eq!(h.dim(), 0);
687 assert!(h.bytes.is_empty());
688 let back = h.to_f32_vec();
689 assert!(back.is_empty());
690 }
691
692 #[allow(clippy::cast_precision_loss)]
700 fn random_normal_vec(seed: u64, dim: usize) -> alloc::vec::Vec<f32> {
701 let mut state = seed | 1;
702 let mut out = alloc::vec::Vec::with_capacity(dim);
703 for _ in 0..dim {
704 state = state
705 .wrapping_mul(6_364_136_223_846_793_005)
706 .wrapping_add(1);
707 let u = ((state >> 32) & 0x00FF_FFFF) as f32 / (0x80_0000_u32 as f32);
711 out.push(2.0 * u - 1.0);
714 }
715 out
716 }
717
718 #[cfg(target_arch = "aarch64")]
719 #[test]
720 #[allow(clippy::cast_precision_loss)]
721 fn half_l2_asymmetric_neon_matches_scalar() {
722 for &d in &[8_usize, 16, 32, 64, 128, 256, 512, 1024] {
728 for trial in 0..8_u64 {
729 let v = random_normal_vec(0xA5A5_F160_F160_0001 ^ trial ^ (d as u64), d);
730 let q = random_normal_vec(0xC0FE_F160_F160_0002 ^ trial ^ (d as u64), d);
731 let h = HalfVector::from_f32_slice(&v);
732 let scalar = half_l2_distance_sq_asymmetric_scalar(&h, &q);
733 let neon = unsafe { half_l2_distance_sq_asymmetric_neon(&h, &q) };
734 let tol = (scalar.abs().max(1e-6)) * 1e-4 + (d as f32) * 1e-5;
735 assert!(
736 (scalar - neon).abs() <= tol,
737 "L2 asym dim={d} trial={trial}: scalar={scalar} neon={neon}"
738 );
739 }
740 }
741 }
742
743 #[cfg(target_arch = "aarch64")]
744 #[test]
745 #[allow(clippy::cast_precision_loss)]
746 fn half_dot_asymmetric_neon_matches_scalar() {
747 for &d in &[8_usize, 16, 32, 64, 128, 256, 512, 1024] {
748 for trial in 0..8_u64 {
749 let v = random_normal_vec(0xBEEF_F160_F160_0003 ^ trial ^ (d as u64), d);
750 let q = random_normal_vec(0xDEAD_F160_F160_0004 ^ trial ^ (d as u64), d);
751 let h = HalfVector::from_f32_slice(&v);
752 let scalar = half_dot_asymmetric_scalar(&h, &q);
753 let neon = unsafe { half_dot_asymmetric_neon(&h, &q) };
754 let tol = (scalar.abs().max(1e-6)) * 1e-4 + (d as f32) * 1e-5;
755 assert!(
756 (scalar - neon).abs() <= tol,
757 "dot dim={d} trial={trial}: scalar={scalar} neon={neon}"
758 );
759 }
760 }
761 }
762
763 #[cfg(target_arch = "aarch64")]
764 #[test]
765 #[allow(clippy::similar_names, clippy::cast_precision_loss)]
766 fn half_cosine_accumulators_neon_matches_scalar() {
767 for &d in &[8_usize, 16, 32, 64, 128, 256, 512, 1024] {
768 for trial in 0..8_u64 {
769 let v = random_normal_vec(0xC051_F160_F160_0005 ^ trial ^ (d as u64), d);
770 let q = random_normal_vec(0xF00D_F160_F160_0006 ^ trial ^ (d as u64), d);
771 let h = HalfVector::from_f32_slice(&v);
772 let (dot_s, na_s, nq_s) = half_cosine_accumulators_asymmetric_scalar(&h, &q);
773 let (dot_n, na_n, nq_n) =
774 unsafe { half_cosine_accumulators_asymmetric_neon(&h, &q) };
775 let tol = |x: f32| (x.abs().max(1e-6)) * 1e-4 + (d as f32) * 1e-5;
776 assert!(
777 (dot_s - dot_n).abs() <= tol(dot_s),
778 "cos dot dim={d}: scalar={dot_s} neon={dot_n}"
779 );
780 assert!(
781 (na_s - na_n).abs() <= tol(na_s),
782 "cos na dim={d}: scalar={na_s} neon={na_n}"
783 );
784 assert!(
785 (nq_s - nq_n).abs() <= tol(nq_s),
786 "cos nq dim={d}: scalar={nq_s} neon={nq_n}"
787 );
788 }
789 }
790 }
791
792 #[cfg(target_arch = "aarch64")]
793 #[test]
794 #[allow(clippy::cast_precision_loss)]
795 fn half_l2_symmetric_neon_matches_scalar() {
796 for &d in &[8_usize, 16, 32, 64, 128, 256, 512, 1024] {
797 for trial in 0..8_u64 {
798 let va = random_normal_vec(0x1234_F160_F160_0007 ^ trial ^ (d as u64), d);
799 let vb = random_normal_vec(0x5678_F160_F160_0008 ^ trial ^ (d as u64), d);
800 let ha = HalfVector::from_f32_slice(&va);
801 let hb = HalfVector::from_f32_slice(&vb);
802 let scalar = half_l2_distance_sq_symmetric_scalar(&ha, &hb);
803 let neon = unsafe { half_l2_distance_sq_symmetric_neon(&ha, &hb) };
804 let tol = (scalar.abs().max(1e-6)) * 1e-4 + (d as f32) * 1e-5;
805 assert!(
806 (scalar - neon).abs() <= tol,
807 "L2 sym dim={d}: scalar={scalar} neon={neon}"
808 );
809 }
810 }
811 }
812}