Skip to main content

simd_popcnt/
lib.rs

1//! # simd-popcnt
2//!
3//! Count the number of 1 bits (bit population count, a.k.a. Hamming weight) in
4//! an array as quickly as possible using specialized CPU instructions: POPCNT,
5//! AVX2 and AVX512 on x86/x86-64, and NEON and SVE on AArch64. The fastest
6//! instruction set the CPU supports is detected once at runtime and cached; on
7//! every other architecture the count falls back to [`u64::count_ones`], which
8//! the compiler lowers to a hardware popcount instruction wherever one exists.
9//!
10//! The crate is portable by default and thread-safe. It has no external crate
11//! dependencies and needs the Rust standard library only for runtime SIMD
12//! dispatch (CPU feature detection); it is otherwise `no_std`.
13//!
14//! This is an AI-assisted Rust port of the [libpopcnt C/C++ library](https://github.com/kimwalisch/libpopcnt).
15//!
16//! ## Usage
17//!
18//! [`popcnt`] counts the 1 bits in a byte slice; the [`PopcntExt`] trait adds a
19//! `.popcnt()` method to slices, arrays and `Vec`s of every built-in integer
20//! type.
21//!
22//! ```
23//! use simd_popcnt::{popcnt, PopcntExt};
24//!
25//! assert_eq!(popcnt(&[0xFF, 0x0F]), 12);
26//! assert_eq!([u64::MAX, 0x0F0F_0F0F_0F0F_0F0F].popcnt(), 96);
27//! ```
28//!
29//! ## Performance
30//!
31//! For the fastest possible code, compile with `RUSTFLAGS="-C target-cpu=native"`.
32//! This selects the best SIMD path at compile time and removes the runtime
33//! dispatch entirely.
34
35// Enable the SVE intrinsics only when the build probe confirmed they compile and
36// the SVE code is actually built (compile-time SVE path or the `std` dispatcher).
37#![cfg_attr(
38    all(simd_popcnt_have_sve, any(target_feature = "sve", feature = "std")),
39    feature(stdarch_aarch64_sve)
40)]
41// The crate uses `std` only for runtime CPU feature detection, which is compiled
42// only when the `std` feature is on and no SIMD path was already selected at
43// compile time. Whenever that code is absent — feature off, `-C target-cpu=native`,
44// or a non-x86/AArch64 target — the crate is `no_std`. `not(test)` keeps `std` for
45// the unit tests.
46#![cfg_attr(
47    not(any(
48        test,
49        all(
50            feature = "std",
51            any(target_arch = "x86", target_arch = "x86_64"),
52            not(any(target_feature = "avx2", target_feature = "avx512vpopcntdq")),
53        ),
54        all(
55            feature = "std",
56            target_arch = "aarch64",
57            simd_popcnt_have_sve,
58            not(target_feature = "sve"),
59        ),
60    )),
61    no_std
62)]
63
64#[cfg(target_arch = "aarch64")]
65use core::arch::aarch64::*;
66// A scalar-only `no_std` build uses none of these x86 intrinsics.
67#[cfg(target_arch = "x86")]
68#[allow(unused_imports)]
69use core::arch::x86::*;
70#[cfg(target_arch = "x86_64")]
71#[allow(unused_imports)]
72use core::arch::x86_64::*;
73#[cfg(all(
74    target_arch = "aarch64",
75    simd_popcnt_have_sve,
76    feature = "std",
77    not(target_feature = "sve")
78))]
79use std::arch::is_aarch64_feature_detected;
80
81/// Counts the number of one bits (population count) in `bytes`.
82///
83/// Dispatches to the fastest implementation for the running CPU: SIMD where
84/// available, a scalar fallback otherwise.
85///
86/// To count the bits in a slice of a wider integer type (`&[u64]`, `&[u32]`, …),
87/// use the [`PopcntExt::popcnt`] method rather than converting to bytes by hand.
88///
89/// # Examples
90///
91/// ```
92/// assert_eq!(simd_popcnt::popcnt(&[]), 0);
93/// assert_eq!(simd_popcnt::popcnt(&[0xFF, 0x0F]), 12);
94/// ```
95#[must_use]
96#[inline]
97pub fn popcnt(bytes: &[u8]) -> u64 {
98    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
99    {
100        popcnt_x86(bytes)
101    }
102
103    #[cfg(target_arch = "aarch64")]
104    {
105        popcnt_aarch64(bytes)
106    }
107
108    #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
109    {
110        popcnt_scalar(bytes)
111    }
112}
113
114// ────────────────────────────────────────────────────────────────────────────
115// Ergonomic extension trait for integer slices
116// ────────────────────────────────────────────────────────────────────────────
117
118/// Adds a [`popcnt`](PopcntExt::popcnt) method to slices of the built-in integer
119/// types, counting their bits without a manual byte cast. Implemented for slices,
120/// arrays and `Vec`s of `u8`/`u16`/`u32`/`u64`/`u128`/`usize` and their signed
121/// counterparts; bring it into scope with `use simd_popcnt::PopcntExt;`.
122///
123/// ```
124/// use simd_popcnt::PopcntExt;
125///
126/// let words: &[u64] = &[u64::MAX, 0x0F0F_0F0F_0F0F_0F0F];
127/// assert_eq!(words.popcnt(), 64 + 32);
128/// assert_eq!(vec![1u32, 2, 3].popcnt(), 4);
129/// ```
130pub trait PopcntExt {
131    /// Count the total number of 1 bits across all elements of the slice.
132    #[must_use]
133    fn popcnt(&self) -> u64;
134}
135
136/// Implement [`PopcntExt`] for `[$t]` by viewing the slice as bytes and
137/// delegating to [`popcnt`]. Population count is byte-order independent, so this
138/// is correct on both little- and big-endian targets.
139macro_rules! impl_popcnt_ext {
140    ($($t:ty),+ $(,)?) => {$(
141        impl PopcntExt for [$t] {
142            #[inline]
143            fn popcnt(&self) -> u64 {
144                // SAFETY: `$t` is a plain integer (no padding, every bit pattern
145                // valid) and `u8` is always 1-aligned, so the slice is a valid
146                // `&[u8]` of `size_of_val` bytes.
147                let bytes = unsafe {
148                    core::slice::from_raw_parts(
149                        self.as_ptr().cast::<u8>(),
150                        core::mem::size_of_val(self),
151                    )
152                };
153                popcnt(bytes)
154            }
155        }
156    )+};
157}
158
159impl_popcnt_ext!(
160    u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize
161);
162
163// ────────────────────────────────────────────────────────────────────────────
164// Portable scalar fallbacks (available on every architecture)
165// ────────────────────────────────────────────────────────────────────────────
166
167/// Packs the trailing `rem.len()` (0..=7) bytes into a zero-padded `u64` for
168/// counting. Uses native byte order (the popcount is order-independent), which
169/// avoids a byte swap on big-endian targets.
170#[inline]
171fn tail_u64(rem: &[u8]) -> u64 {
172    let mut buf = [0u8; 8];
173    buf[..rem.len()].copy_from_slice(rem);
174    u64::from_ne_bytes(buf)
175}
176
177/// Scalar population count loop, summing `count_ones()` over 8-byte chunks.
178/// `count_ones()` is inlined in release and lowers to the target's hardware
179/// popcount where available (x86 POPCNT, PowerPC popcntd, WebAssembly
180/// `i64.popcnt`, …), otherwise to an inline bit-twiddling sequence — never a
181/// library call. Shared with the POPCNT-`target_feature` variant below.
182macro_rules! popcnt_scalar_loop {
183    ($bytes:expr) => {{
184        let mut cnt = 0u64;
185        let (chunks, rem) = $bytes.as_chunks::<8>();
186        for chunk in chunks {
187            cnt += u64::from_ne_bytes(*chunk).count_ones() as u64;
188        }
189        if !rem.is_empty() {
190            cnt += tail_u64(rem).count_ones() as u64;
191        }
192        cnt
193    }};
194}
195
196/// Portable scalar population count via [`u64::count_ones`].
197#[allow(dead_code)]
198#[inline]
199fn popcnt_scalar(bytes: &[u8]) -> u64 {
200    popcnt_scalar_loop!(bytes)
201}
202
203// ════════════════════════════════════════════════════════════════════════════
204// x86 / x86-64
205// ════════════════════════════════════════════════════════════════════════════
206
207#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
208#[inline]
209fn popcnt_x86(bytes: &[u8]) -> u64 {
210    // Compile-time AVX512 path (e.g. with `-C target-cpu=native`).
211    #[cfg(target_feature = "avx512vpopcntdq")]
212    {
213        // AVX512 isn't worth its setup cost for tiny arrays.
214        if bytes.len() >= 40 {
215            unsafe { popcnt_avx512(bytes) }
216        } else {
217            popcnt_scalar_static(bytes)
218        }
219    }
220
221    // Compile-time AVX2 path.
222    #[cfg(all(target_feature = "avx2", not(target_feature = "avx512vpopcntdq")))]
223    {
224        let mut cnt = 0u64;
225        let mut rest = bytes;
226        // AVX2 only wins for arrays >= 512 bytes.
227        if bytes.len() >= 512 {
228            let n = bytes.len() / 32 * 32;
229            cnt += unsafe { popcnt_avx2(&bytes[..n]) };
230            rest = &bytes[n..];
231        }
232        cnt + popcnt_scalar_static(rest)
233    }
234
235    // No SIMD enabled at compile time: detect at runtime (needs `std`).
236    #[cfg(all(
237        not(any(target_feature = "avx2", target_feature = "avx512vpopcntdq")),
238        feature = "std"
239    ))]
240    {
241        popcnt_x86_runtime(bytes)
242    }
243
244    // No SIMD and no `std` for runtime detection: use the compile-time scalar path.
245    #[cfg(all(
246        not(any(target_feature = "avx2", target_feature = "avx512vpopcntdq")),
247        not(feature = "std")
248    ))]
249    {
250        popcnt_scalar_static(bytes)
251    }
252}
253
254/// Scalar count for the compile-time SIMD paths' small arrays and tails:
255/// hardware POPCNT when statically enabled, otherwise the integer fallback.
256#[cfg(all(
257    any(target_arch = "x86", target_arch = "x86_64"),
258    any(
259        target_feature = "avx2",
260        target_feature = "avx512vpopcntdq",
261        not(feature = "std")
262    )
263))]
264#[inline]
265fn popcnt_scalar_static(bytes: &[u8]) -> u64 {
266    #[cfg(target_feature = "popcnt")]
267    {
268        // SAFETY: `popcnt` is statically enabled for the whole crate.
269        unsafe { popcnt_scalar_hw(bytes) }
270    }
271    #[cfg(not(target_feature = "popcnt"))]
272    {
273        popcnt_scalar(bytes)
274    }
275}
276
277/// Cached runtime check for AVX-512F + AVX-512BW + AVX-512VPOPCNTDQ support.
278/// Reads three `is_x86_feature_detected!` results once and stores the combined
279/// outcome so subsequent calls only load a single atomic.
280#[cfg(all(
281    any(target_arch = "x86", target_arch = "x86_64"),
282    not(any(target_feature = "avx2", target_feature = "avx512vpopcntdq")),
283    feature = "std"
284))]
285#[inline]
286fn has_avx512() -> bool {
287    use core::sync::atomic::{AtomicI32, Ordering};
288    static HAS_AVX512: AtomicI32 = AtomicI32::new(-1);
289    let cached = HAS_AVX512.load(Ordering::Relaxed);
290    if cached != -1 {
291        return cached != 0;
292    }
293    let v = (is_x86_feature_detected!("avx512f")
294        && is_x86_feature_detected!("avx512bw")
295        && is_x86_feature_detected!("avx512vpopcntdq")) as i32;
296    HAS_AVX512.store(v, Ordering::Relaxed);
297    v != 0
298}
299
300/// Runtime dispatch using cached CPU feature detection. Only compiled when no
301/// SIMD feature is statically enabled (otherwise the compile-time paths run).
302#[cfg(all(
303    any(target_arch = "x86", target_arch = "x86_64"),
304    not(any(target_feature = "avx2", target_feature = "avx512vpopcntdq")),
305    feature = "std"
306))]
307#[inline]
308fn popcnt_x86_runtime(bytes: &[u8]) -> u64 {
309    // AVX512: not worth its setup cost below ~40 bytes, handles any length.
310    if bytes.len() >= 40 && has_avx512() {
311        return unsafe { popcnt_avx512(bytes) };
312    }
313
314    let mut cnt = 0u64;
315    let mut rest = bytes;
316
317    // AVX2 only wins for arrays >= 512 bytes.
318    if bytes.len() >= 512 && is_x86_feature_detected!("avx2") {
319        let n = bytes.len() / 32 * 32;
320        cnt += unsafe { popcnt_avx2(&bytes[..n]) };
321        rest = &bytes[n..];
322    }
323
324    // Scalar tail (or the whole array if AVX2 didn't fire). Dispatching on
325    // POPCNT is essential: outside a `#[target_feature(enable = "popcnt")]`
326    // function, `count_ones()` compiles to a software fallback even on
327    // POPCNT-capable CPUs.
328    cnt += if is_x86_feature_detected!("popcnt") {
329        unsafe { popcnt_scalar_hw(rest) }
330    } else {
331        popcnt_scalar(rest)
332    };
333
334    cnt
335}
336
337/// Scalar population count via the hardware POPCNT instruction. The
338/// `#[target_feature(enable = "popcnt")]` attribute is what lets `count_ones()`
339/// lower to a single `popcnt`; only call it once POPCNT support is confirmed.
340#[allow(dead_code)]
341#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
342#[target_feature(enable = "popcnt")]
343#[inline]
344fn popcnt_scalar_hw(bytes: &[u8]) -> u64 {
345    popcnt_scalar_loop!(bytes) // count_ones() lowers to popcntq here
346}
347
348// ── AVX2 ────────────────────────────────────────────────────────────────────
349
350/// Carry-save adder: returns the `(carry, sum)` bit-planes of `a + b + c`,
351/// computed across all lanes in parallel.
352#[cfg(all(
353    any(target_arch = "x86", target_arch = "x86_64"),
354    not(target_feature = "avx512vpopcntdq"),
355    any(target_feature = "avx2", feature = "std")
356))]
357#[target_feature(enable = "avx2")]
358#[inline]
359fn csa256(a: __m256i, b: __m256i, c: __m256i) -> (__m256i, __m256i) {
360    let u = _mm256_xor_si256(a, b);
361    let h = _mm256_or_si256(_mm256_and_si256(a, b), _mm256_and_si256(u, c));
362    let l = _mm256_xor_si256(u, c);
363    (h, l)
364}
365
366/// Per-byte population count of a 256-bit vector using the nibble lookup, then
367/// horizontal sum of each 8-byte lane via `_mm256_sad_epu8` (result in 4 u64s).
368#[cfg(all(
369    any(target_arch = "x86", target_arch = "x86_64"),
370    not(target_feature = "avx512vpopcntdq"),
371    any(target_feature = "avx2", feature = "std")
372))]
373#[target_feature(enable = "avx2")]
374#[inline]
375fn popcnt256(v: __m256i) -> __m256i {
376    let lookup1 = _mm256_setr_epi8(
377        4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7,
378        7, 8,
379    );
380    let lookup2 = _mm256_setr_epi8(
381        4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0, 4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1,
382        1, 0,
383    );
384    let low_mask = _mm256_set1_epi8(0x0f);
385    let lo = _mm256_and_si256(v, low_mask);
386    let hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), low_mask);
387    let popcnt1 = _mm256_shuffle_epi8(lookup1, lo);
388    let popcnt2 = _mm256_shuffle_epi8(lookup2, hi);
389    _mm256_sad_epu8(popcnt1, popcnt2)
390}
391
392/// AVX2 Harley-Seal population count (4th iteration), from "Faster Population
393/// Counts using AVX2 Instructions" by Lemire, Kurz and Muła (2016),
394/// <https://arxiv.org/abs/1611.07612>.
395///
396/// `bytes.len()` must be a multiple of 32.
397#[cfg(all(
398    any(target_arch = "x86", target_arch = "x86_64"),
399    not(target_feature = "avx512vpopcntdq"),
400    any(target_feature = "avx2", feature = "std")
401))]
402#[target_feature(enable = "avx2")]
403#[inline]
404// Hand-aligned: keep the 16-way CSA tree readable.
405#[rustfmt::skip]
406fn popcnt_avx2(bytes: &[u8]) -> u64 {
407    let zero = _mm256_setzero_si256();
408    let mut cnt = zero;
409    let mut ones = zero;
410    let mut twos = zero;
411    let mut fours = zero;
412    let mut eights = zero;
413    let mut twos_a;
414    let mut twos_b;
415    let mut fours_a;
416    let mut fours_b;
417    let mut eights_a;
418    let mut eights_b;
419    let mut sixteens;
420
421    // 16 vectors (512 bytes) per iteration.
422    let (blocks, tail) = bytes.as_chunks::<512>();
423    for chunk in blocks {
424        let p = chunk.as_ptr().cast::<__m256i>();
425        // SAFETY: `chunk` is 512 bytes, so all 16 loads (32 bytes each) are in bounds.
426        unsafe {
427            (twos_a, ones) = csa256(ones, _mm256_loadu_si256(p.add(0)), _mm256_loadu_si256(p.add(1)));
428            (twos_b, ones) = csa256(ones, _mm256_loadu_si256(p.add(2)), _mm256_loadu_si256(p.add(3)));
429            (fours_a, twos) = csa256(twos, twos_a, twos_b);
430            (twos_a, ones) = csa256(ones, _mm256_loadu_si256(p.add(4)), _mm256_loadu_si256(p.add(5)));
431            (twos_b, ones) = csa256(ones, _mm256_loadu_si256(p.add(6)), _mm256_loadu_si256(p.add(7)));
432            (fours_b, twos) = csa256(twos, twos_a, twos_b);
433            (eights_a, fours) = csa256(fours, fours_a, fours_b);
434            (twos_a, ones) = csa256(ones, _mm256_loadu_si256(p.add(8)), _mm256_loadu_si256(p.add(9)));
435            (twos_b, ones) = csa256(ones, _mm256_loadu_si256(p.add(10)), _mm256_loadu_si256(p.add(11)));
436            (fours_a, twos) = csa256(twos, twos_a, twos_b);
437            (twos_a, ones) = csa256(ones, _mm256_loadu_si256(p.add(12)), _mm256_loadu_si256(p.add(13)));
438            (twos_b, ones) = csa256(ones, _mm256_loadu_si256(p.add(14)), _mm256_loadu_si256(p.add(15)));
439            (fours_b, twos) = csa256(twos, twos_a, twos_b);
440            (eights_b, fours) = csa256(fours, fours_a, fours_b);
441            (sixteens, eights) = csa256(eights, eights_a, eights_b);
442            cnt = _mm256_add_epi64(cnt, popcnt256(sixteens));
443        }
444    }
445
446    cnt = _mm256_slli_epi64(cnt, 4);
447    cnt = _mm256_add_epi64(cnt, _mm256_slli_epi64(popcnt256(eights), 3));
448    cnt = _mm256_add_epi64(cnt, _mm256_slli_epi64(popcnt256(fours), 2));
449    cnt = _mm256_add_epi64(cnt, _mm256_slli_epi64(popcnt256(twos), 1));
450    cnt = _mm256_add_epi64(cnt, popcnt256(ones));
451
452    // Remaining whole 32-byte vectors.
453    let (vecs, _) = tail.as_chunks::<32>();
454    for chunk in vecs {
455        let v = unsafe { _mm256_loadu_si256(chunk.as_ptr().cast::<__m256i>()) };
456        cnt = _mm256_add_epi64(cnt, popcnt256(v));
457    }
458
459    // Sum the four 64-bit lanes.
460    // SAFETY: `__m256i` and `[u64; 4]` are both 32 bytes with no invalid bit patterns.
461    let lanes: [u64; 4] = unsafe { core::mem::transmute(cnt) };
462    lanes[0] + lanes[1] + lanes[2] + lanes[3]
463}
464
465// ── AVX512 ──────────────────────────────────────────────────────────────────
466
467/// AVX512-VPOPCNTDQ population count, handling any length: a 4×-unrolled
468/// 256-byte loop, then a 64-byte loop, then a masked load for the final
469/// 1..=63 bytes.
470#[cfg(all(
471    any(target_arch = "x86", target_arch = "x86_64"),
472    any(
473        all(not(target_feature = "avx2"), feature = "std"),
474        target_feature = "avx512vpopcntdq"
475    )
476))]
477#[target_feature(enable = "avx512f,avx512bw,avx512vpopcntdq")]
478#[inline]
479fn popcnt_avx512(bytes: &[u8]) -> u64 {
480    let mut cnt0 = _mm512_setzero_si512();
481
482    // 4× unrolled 64-byte loop (256 bytes per iteration). Four independent
483    // accumulators keep the popcount+add chains parallel (higher ILP).
484    let (blocks, tail256) = bytes.as_chunks::<256>();
485    if !blocks.is_empty() {
486        let mut cnt1 = _mm512_setzero_si512();
487        let mut cnt2 = _mm512_setzero_si512();
488        let mut cnt3 = _mm512_setzero_si512();
489        for chunk in blocks {
490            let p = chunk.as_ptr();
491            // SAFETY: `chunk` is 256 bytes, so the four 64-byte loads are in bounds.
492            unsafe {
493                let v0 = _mm512_loadu_si512(p.add(0).cast());
494                let v1 = _mm512_loadu_si512(p.add(64).cast());
495                let v2 = _mm512_loadu_si512(p.add(128).cast());
496                let v3 = _mm512_loadu_si512(p.add(192).cast());
497                cnt0 = _mm512_add_epi64(cnt0, _mm512_popcnt_epi64(v0));
498                cnt1 = _mm512_add_epi64(cnt1, _mm512_popcnt_epi64(v1));
499                cnt2 = _mm512_add_epi64(cnt2, _mm512_popcnt_epi64(v2));
500                cnt3 = _mm512_add_epi64(cnt3, _mm512_popcnt_epi64(v3));
501            }
502        }
503        cnt0 = _mm512_add_epi64(cnt0, cnt1);
504        cnt2 = _mm512_add_epi64(cnt2, cnt3);
505        cnt0 = _mm512_add_epi64(cnt0, cnt2);
506    }
507
508    // Remaining complete 64-byte blocks.
509    let (vecs, tail64) = tail256.as_chunks::<64>();
510    for chunk in vecs {
511        let v = unsafe { _mm512_loadu_si512(chunk.as_ptr().cast()) };
512        cnt0 = _mm512_add_epi64(cnt0, _mm512_popcnt_epi64(v));
513    }
514
515    // Masked load for the final 1..=63 bytes.
516    if !tail64.is_empty() {
517        let len = tail64.len();
518        let mask = (u64::MAX >> (64 - len)) as __mmask64;
519        // SAFETY: the mask selects only the `len` valid bytes; masked-off lanes
520        // are not accessed.
521        unsafe {
522            let v = _mm512_maskz_loadu_epi8(mask, tail64.as_ptr().cast());
523            cnt0 = _mm512_add_epi64(cnt0, _mm512_popcnt_epi64(v));
524        }
525    }
526
527    _mm512_reduce_add_epi64(cnt0) as u64
528}
529
530// ════════════════════════════════════════════════════════════════════════════
531// AArch64
532// ════════════════════════════════════════════════════════════════════════════
533
534#[cfg(target_arch = "aarch64")]
535#[inline]
536fn popcnt_aarch64(bytes: &[u8]) -> u64 {
537    // Compile-time SVE path.
538    #[cfg(all(target_feature = "sve", simd_popcnt_have_sve))]
539    {
540        unsafe { popcnt_arm_sve(bytes) }
541    }
542
543    // NEON baseline; `popcnt_neon` dispatches to SVE at runtime when available.
544    #[cfg(not(all(target_feature = "sve", simd_popcnt_have_sve)))]
545    {
546        popcnt_neon(bytes)
547    }
548}
549
550#[cfg(all(
551    target_arch = "aarch64",
552    not(all(target_feature = "sve", simd_popcnt_have_sve))
553))]
554#[inline]
555fn vpadalq(sum: uint64x2_t, t: uint8x16_t) -> uint64x2_t {
556    unsafe { vpadalq_u32(sum, vpaddlq_u16(vpaddlq_u8(t))) }
557}
558
559#[cfg(all(
560    target_arch = "aarch64",
561    not(all(target_feature = "sve", simd_popcnt_have_sve))
562))]
563#[inline]
564fn popcnt_neon(bytes: &[u8]) -> u64 {
565    #[cfg(all(simd_popcnt_have_sve, feature = "std"))]
566    if is_aarch64_feature_detected!("sve") {
567        return unsafe { popcnt_arm_sve(bytes) };
568    }
569
570    const CHUNK: usize = 64;
571    let mut cnt = 0u64;
572    let iters = bytes.len() / CHUNK;
573    let ptr = bytes.as_ptr();
574
575    if iters > 0 {
576        // SAFETY: `iters = len / 64`, so every load at `i * 64` (i < iters) reads
577        // 64 in-bounds bytes; the final store targets a local array.
578        unsafe {
579            let mut sum = vdupq_n_u64(0);
580            let zero = vdupq_n_u8(0);
581            let mut i = 0usize;
582
583            while i < iters {
584                let mut t0 = zero;
585                let mut t1 = zero;
586                let mut t2 = zero;
587                let mut t3 = zero;
588
589                // Accumulate at most 31 chunks before draining into `sum`:
590                // 31 × 8 bits = 248 ≤ 255 guarantees no u8 lane overflow.
591                let limit = (i + 31).min(iters);
592                while i < limit {
593                    // Plain contiguous load (`vld1q_u8_x4`), not the deinterleaving
594                    // `vld4q_u8`: population count is order-independent, so avoiding
595                    // the deinterleave saves the `tbl`/`mov` shuffles it compiles to.
596                    let input = vld1q_u8_x4(ptr.add(i * CHUNK));
597                    t0 = vaddq_u8(t0, vcntq_u8(input.0));
598                    t1 = vaddq_u8(t1, vcntq_u8(input.1));
599                    t2 = vaddq_u8(t2, vcntq_u8(input.2));
600                    t3 = vaddq_u8(t3, vcntq_u8(input.3));
601                    i += 1;
602                }
603
604                sum = vpadalq(sum, t0);
605                sum = vpadalq(sum, t1);
606                sum = vpadalq(sum, t2);
607                sum = vpadalq(sum, t3);
608            }
609
610            let mut tmp = [0u64; 2];
611            vst1q_u64(tmp.as_mut_ptr(), sum);
612            cnt += tmp[0] + tmp[1];
613        }
614    }
615
616    // Scalar tail. On AArch64 `count_ones()` always lowers to NEON `cnt`, so no
617    // POPCNT runtime check is needed here.
618    let rest = &bytes[iters * CHUNK..];
619    cnt += popcnt_scalar_loop!(rest);
620    cnt
621}
622
623// ── ARM SVE ─────────────────────────────────────────────────────────────────
624
625/// SVE population count: a 4×-unrolled main loop over full vectors, then a
626/// predicated tail loop that needs no separate scalar remainder.
627#[cfg(all(
628    target_arch = "aarch64",
629    simd_popcnt_have_sve,
630    any(target_feature = "sve", feature = "std")
631))]
632#[target_feature(enable = "sve")]
633#[inline]
634fn popcnt_arm_sve(bytes: &[u8]) -> u64 {
635    // SAFETY: the loop bound keeps each full load within `len`; the tail loop's
636    // predicate masks off any lanes past the end.
637    unsafe {
638        let mut i = 0usize;
639        let mut vcnt0 = svdup_n_u64(0);
640        let vl = svcntb() as usize; // SVE vector length in bytes (hardware-defined)
641        let ptr = bytes.as_ptr();
642        let len = bytes.len();
643
644        // 4× unrolled full-predicate loop. Four independent accumulators keep the
645        // count+add chains parallel (higher ILP).
646        if i + vl * 4 <= len {
647            let mut vcnt1 = svdup_n_u64(0);
648            let mut vcnt2 = svdup_n_u64(0);
649            let mut vcnt3 = svdup_n_u64(0);
650            loop {
651                let v0 = svreinterpret_u64_u8(svld1_u8(svptrue_b8(), ptr.add(i)));
652                let v1 = svreinterpret_u64_u8(svld1_u8(svptrue_b8(), ptr.add(i + vl)));
653                let v2 = svreinterpret_u64_u8(svld1_u8(svptrue_b8(), ptr.add(i + vl * 2)));
654                let v3 = svreinterpret_u64_u8(svld1_u8(svptrue_b8(), ptr.add(i + vl * 3)));
655                vcnt0 = svadd_u64_x(svptrue_b64(), vcnt0, svcnt_u64_x(svptrue_b64(), v0));
656                vcnt1 = svadd_u64_x(svptrue_b64(), vcnt1, svcnt_u64_x(svptrue_b64(), v1));
657                vcnt2 = svadd_u64_x(svptrue_b64(), vcnt2, svcnt_u64_x(svptrue_b64(), v2));
658                vcnt3 = svadd_u64_x(svptrue_b64(), vcnt3, svcnt_u64_x(svptrue_b64(), v3));
659                i += vl * 4;
660                if i + vl * 4 > len {
661                    break;
662                }
663            }
664            vcnt0 = svadd_u64_x(svptrue_b64(), vcnt0, vcnt1);
665            vcnt2 = svadd_u64_x(svptrue_b64(), vcnt2, vcnt3);
666            vcnt0 = svadd_u64_x(svptrue_b64(), vcnt0, vcnt2);
667        }
668
669        // Predicated tail: the load zero-fills inactive lanes, so no separate
670        // scalar remainder is needed.
671        let mut pg = svwhilelt_b8_u64(i as u64, len as u64);
672        while svptest_any(svptrue_b8(), pg) {
673            let v = svreinterpret_u64_u8(svld1_u8(pg, ptr.add(i)));
674            vcnt0 = svadd_u64_x(svptrue_b64(), vcnt0, svcnt_u64_x(svptrue_b64(), v));
675            i += vl;
676            pg = svwhilelt_b8_u64(i as u64, len as u64);
677        }
678
679        svaddv_u64(svptrue_b64(), vcnt0)
680    }
681}
682
683// ════════════════════════════════════════════════════════════════════════════
684// Tests
685// ════════════════════════════════════════════════════════════════════════════
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690
691    /// Reference implementation: count bits one byte at a time.
692    fn reference(bytes: &[u8]) -> u64 {
693        bytes.iter().map(|b| b.count_ones() as u64).sum()
694    }
695
696    /// Independent integer-only popcount oracle (does not use `count_ones`),
697    /// so the sweep cross-checks the crate against a different algorithm.
698    fn popcnt64_bitwise(x: u64) -> u64 {
699        const M1: u64 = 0x5555555555555555;
700        const M2: u64 = 0x3333333333333333;
701        const M4: u64 = 0x0F0F0F0F0F0F0F0F;
702        const H01: u64 = 0x0101010101010101;
703        let x = x - ((x >> 1) & M1);
704        let x = (x & M2) + ((x >> 2) & M2);
705        let x = (x + (x >> 4)) & M4;
706        x.wrapping_mul(H01) >> 56
707    }
708
709    #[test]
710    fn empty() {
711        assert_eq!(popcnt(&[]), 0);
712    }
713
714    #[test]
715    fn all_ones() {
716        for &size in &[
717            0, 1, 7, 8, 31, 32, 39, 40, 63, 64, 255, 256, 511, 512, 4095, 4096, 65537,
718        ] {
719            let bytes = vec![0xFFu8; size];
720            assert_eq!(popcnt(&bytes), size as u64 * 8, "size={size}");
721        }
722    }
723
724    #[test]
725    fn all_zeros() {
726        let bytes = vec![0u8; 65536];
727        assert_eq!(popcnt(&bytes), 0);
728    }
729
730    #[test]
731    fn single_bits() {
732        for bit in 0u64..64 {
733            let val = 1u64 << bit;
734            assert_eq!(popcnt(&val.to_le_bytes()), 1, "bit={bit}");
735        }
736    }
737
738    /// `PopcntExt::popcnt` on each integer width must equal the per-element
739    /// `count_ones()` sum (an oracle independent of the byte reinterpretation).
740    #[test]
741    fn ext_trait_widths() {
742        let u8s: &[u8] = &[0xFF, 0x0F, 0x00, 0xAB, 0x01];
743        assert_eq!(
744            u8s.popcnt(),
745            u8s.iter().map(|x| x.count_ones() as u64).sum()
746        );
747
748        let u16s: &[u16] = &[0xFFFF, 0x0F0F, 0x1234, 0];
749        assert_eq!(
750            u16s.popcnt(),
751            u16s.iter().map(|x| x.count_ones() as u64).sum()
752        );
753
754        let u32s: &[u32] = &[u32::MAX, 0, 0x8000_0001];
755        assert_eq!(
756            u32s.popcnt(),
757            u32s.iter().map(|x| x.count_ones() as u64).sum()
758        );
759
760        let u64s: &[u64] = &[u64::MAX, 0x0F0F_0F0F_0F0F_0F0F, 0];
761        assert_eq!(
762            u64s.popcnt(),
763            u64s.iter().map(|x| x.count_ones() as u64).sum()
764        );
765
766        // Signed types and arrays resolve through the same impls (the doc
767        // example covers `Vec`).
768        let i32s = [-1i32, 0, 1, i32::MIN];
769        assert_eq!(
770            i32s.popcnt(),
771            i32s.iter().map(|x| x.count_ones() as u64).sum()
772        );
773        assert_eq!([u128::MAX, 0].popcnt(), 128);
774    }
775
776    /// Sweep every boundary-relevant size against the byte-wise reference using
777    /// a deterministic pseudo-random fill (xorshift). Covers tail handling,
778    /// the AVX2/AVX512 thresholds and multiple Harley-Seal outer iterations.
779    #[test]
780    fn pseudorandom_all_sizes() {
781        let mut state: u64 = 0x9E37_79B9_7F4A_7C15;
782        let mut next = || {
783            state ^= state << 13;
784            state ^= state >> 7;
785            state ^= state << 17;
786            state
787        };
788
789        // Largest size + largest offset exercised below, plus margin. 4695
790        // bytes spans several 512-byte Harley-Seal iterations.
791        const MAX_SIZE: usize = 4695;
792        const MAX_OFF: usize = 7;
793        let mut bytes = vec![0u8; MAX_SIZE + MAX_OFF + 1];
794        for b in bytes.iter_mut() {
795            *b = (next() & 0xFF) as u8;
796        }
797
798        // Every size from 0 up through the AVX2/AVX512 active range, plus a few
799        // larger ones, exercised at multiple start offsets so alignment varies.
800        let sizes =
801            (0usize..=600).chain([1023, 1024, 1025, 2048, 4095, 4096, 4097, 4608, MAX_SIZE]);
802        for size in sizes {
803            for &off in &[0usize, 1, 3, MAX_OFF] {
804                let slice = &bytes[off..off + size];
805                assert_eq!(popcnt(slice), reference(slice), "size={size} off={off}");
806            }
807        }
808    }
809
810    /// Verify `popcnt()` of every suffix `bytes[i..]` against an independent
811    /// byte-wise reference, covering every length and a range of start
812    /// alignments in one sweep.
813    ///
814    /// Size defaults to 20_000 to keep `cargo test` fast — the sweep is O(n²) in
815    /// the work `popcnt` performs. Override with `SIMD_POPCNT_TEST_SIZE` for a
816    /// heavier run, e.g. `SIMD_POPCNT_TEST_SIZE=100000 cargo test --release suffix_sweep`.
817    #[test]
818    fn suffix_sweep() {
819        let size = std::env::var("SIMD_POPCNT_TEST_SIZE")
820            .ok()
821            .and_then(|s| s.parse::<usize>().ok())
822            .unwrap_or(20_000);
823
824        // All-ones array.
825        let ones = vec![0xFFu8; size];
826        check_all_suffixes(&ones);
827
828        // Deterministic pseudo-random array (fixed seed → reproducible failures).
829        let mut state: u64 = 0x2545_F491_4F6C_DD1D;
830        let mut bytes = vec![0u8; size];
831        for b in bytes.iter_mut() {
832            state ^= state << 13;
833            state ^= state >> 7;
834            state ^= state << 17;
835            *b = state as u8;
836        }
837        check_all_suffixes(&bytes);
838    }
839
840    /// Assert `popcnt(&bytes[i..])` for every `i` against an O(1) prefix-sum
841    /// reference, so only `popcnt` itself does O(n) work per suffix.
842    fn check_all_suffixes(bytes: &[u8]) {
843        let total: u64 = bytes.iter().map(|&b| popcnt64_bitwise(b as u64)).sum();
844        let mut prefix = 0u64; // popcount of bytes[..i]
845        for (i, &byte) in bytes.iter().enumerate() {
846            assert_eq!(popcnt(&bytes[i..]), total - prefix, "suffix at offset {i}");
847            prefix += popcnt64_bitwise(byte as u64);
848        }
849        // Empty suffix.
850        assert_eq!(popcnt(&bytes[bytes.len()..]), 0);
851    }
852}