Skip to main content

warp_types/
shuffle.rs

1//! Shuffle operations and permutation algebra.
2//!
3//! Shuffles let lanes exchange values within a warp. This module provides:
4//!
5//! 1. **Type-safe shuffle traits** — enforce correct return types
6//!    (shuffle → PerLane, ballot → Uniform, reduce → T)
7//! 2. **`Warp<All>`-restricted shuffles** — shuffle methods only on full warps
8//! 3. **Permutation algebra** — XOR/Rotate/Compose with group-theoretic properties
9
10use crate::active_set::{ActiveSet, All};
11use crate::data::{LaneId, PerLane, Uniform};
12use crate::gpu::GpuShuffle;
13use crate::warp::Warp;
14use crate::GpuValue;
15use core::marker::PhantomData;
16
17/// Result of a warp ballot operation.
18///
19/// A ballot collects a predicate from all lanes into a bitmask.
20/// The result is Uniform because every lane gets the same bitmask.
21///
22/// The mask is `u64` — covers both NVIDIA 32-lane (upper 32 bits zero)
23/// and AMD 64-lane wavefronts.
24#[must_use = "BallotResult carries lane vote data — dropping discards the ballot"]
25#[derive(Clone, Copy, Debug)]
26#[repr(transparent)]
27pub struct BallotResult {
28    mask: Uniform<u64>,
29}
30
31impl BallotResult {
32    /// Create a ballot result from a uniform mask.
33    pub fn from_mask(mask: Uniform<u64>) -> Self {
34        BallotResult { mask }
35    }
36
37    /// Create from a 32-bit mask (NVIDIA compatibility).
38    pub fn from_mask_u32(mask: Uniform<u32>) -> Self {
39        BallotResult {
40            mask: Uniform::from_const(mask.get() as u64),
41        }
42    }
43
44    pub fn mask(self) -> Uniform<u64> {
45        self.mask
46    }
47
48    /// Get the lower 32 bits (NVIDIA compatibility).
49    pub fn mask_u32(self) -> Uniform<u32> {
50        Uniform::from_const(self.mask.get() as u32)
51    }
52
53    pub fn lane_voted(self, lane: LaneId) -> Uniform<bool> {
54        let id = lane.get();
55        if id >= crate::WARP_SIZE as u8 {
56            return Uniform::from_const(false);
57        }
58        Uniform::from_const((self.mask.get() & (1u64 << id)) != 0)
59    }
60
61    pub fn popcount(self) -> Uniform<u32> {
62        Uniform::from_const(self.mask.get().count_ones())
63    }
64
65    pub fn first_lane(self) -> Option<LaneId> {
66        let tz = self.mask.get().trailing_zeros();
67        if tz < crate::WARP_SIZE {
68            Some(LaneId::new(tz as u8))
69        } else {
70            None
71        }
72    }
73}
74
75// ============================================================================
76// Shuffle safety marker (for error message improvement)
77// ============================================================================
78
79/// Marker trait for warp types that support shuffle operations.
80///
81/// Currently only `Warp<All>` implements this. `Tile<N>` has its own shuffle
82/// methods (all tile lanes are active by construction) but does not implement
83/// this trait. If you get an error mentioning this trait, it means you're
84/// trying to shuffle on a diverged warp — merge back to `Warp<All>` first.
85#[diagnostic::on_unimplemented(
86    message = "shuffle requires all lanes active, but `{Self}` may have inactive lanes",
87    label = "this warp may be diverged — shuffle needs Warp<All>",
88    note = "after diverge_even_odd(), call merge(evens, odds) to get Warp<All> back, then shuffle"
89)]
90pub trait ShuffleSafe {}
91
92impl ShuffleSafe for Warp<All> {}
93
94// ============================================================================
95// Shuffle XOR within a sub-warp (§3.3 SHUFFLE-WITHIN typing rule)
96// ============================================================================
97
98/// Check that an XOR shuffle mask preserves an active set.
99///
100/// Returns `true` if for every active lane `i` (bit set in `active_mask`),
101/// lane `i ^ xor_mask` is also active. This means the XOR permutation
102/// maps the active set to itself — no lane reads from an inactive partner.
103///
104/// The check works by computing the XOR-permuted bitmask and verifying it
105/// equals the original: bit `j` of the permuted mask is set iff bit
106/// `(j ^ xor_mask)` is set in the original.
107#[inline]
108fn xor_mask_preserves_active_set(active_mask: u64, xor_mask: u32) -> bool {
109    let ws = crate::WARP_SIZE;
110    let xor = xor_mask & (ws - 1); // 5 bits for 32-lane, 6 bits for 64-lane
111    let mut permuted = 0u64;
112    let mut j = 0u32;
113    while j < ws {
114        if active_mask & (1u64 << (j ^ xor)) != 0 {
115            permuted |= 1u64 << j;
116        }
117        j += 1;
118    }
119    permuted == active_mask
120}
121
122impl<S: ActiveSet> Warp<S> {
123    /// Shuffle XOR within a sub-warp, when the mask preserves the active set.
124    ///
125    /// This implements the §3.3 SHUFFLE-WITHIN typing rule: an XOR shuffle
126    /// is safe on `Warp<S>` (not just `Warp<All>`) when the XOR mask maps
127    /// every active lane to another active lane. Formally: for every lane `i`
128    /// in `S`, lane `(i ^ mask)` is also in `S`.
129    ///
130    /// # Examples
131    ///
132    /// ```
133    /// use warp_types::*;
134    ///
135    /// let warp: Warp<All> = Warp::kernel_entry();
136    /// let (evens, odds) = warp.diverge_even_odd();
137    ///
138    /// let data = PerLane::new(42i32);
139    ///
140    /// // XOR mask 2 on Even lanes: lane 0↔2, 4↔6, etc. — stays within Even.
141    /// let _shuffled = evens.shuffle_xor_within(data, 2);
142    ///
143    /// // XOR mask 1 would map even→odd — panics!
144    /// // evens.shuffle_xor_within(data, 1); // panic
145    /// #
146    /// # drop(odds); // suppress must_use
147    /// ```
148    ///
149    /// # Panics
150    ///
151    /// Panics if `mask` does not preserve `S`, i.e., there exists an active
152    /// lane `i` where lane `(i ^ mask)` is not active.
153    pub fn shuffle_xor_within<T: GpuValue + GpuShuffle>(
154        &self,
155        data: PerLane<T>,
156        mask: u32,
157    ) -> PerLane<T> {
158        assert!(
159            xor_mask_preserves_active_set(S::MASK, mask),
160            "shuffle_xor_within: XOR mask {} does not preserve active set {} (mask={:#018X})",
161            mask,
162            S::NAME,
163            S::MASK,
164        );
165        PerLane::new(data.get().gpu_shfl_xor(mask))
166    }
167}
168
169// ============================================================================
170// Shuffle operations restricted to Warp<All>
171// ============================================================================
172
173impl Warp<All> {
174    /// Shuffle XOR: each lane exchanges with lane (id ^ mask).
175    ///
176    /// **ONLY AVAILABLE ON `Warp<All>`** — diverged warps cannot shuffle.
177    ///
178    /// On GPU: emits `shfl.sync.bfly.b32` via inline assembly.
179    /// On CPU: returns the input value (single-thread identity).
180    pub fn shuffle_xor<T: GpuValue + crate::gpu::GpuShuffle>(
181        &self,
182        data: PerLane<T>,
183        mask: u32,
184    ) -> PerLane<T> {
185        PerLane::new(data.get().gpu_shfl_xor(mask))
186    }
187
188    /// Shuffle down: lane\[i\] reads from lane\[i+delta\].
189    ///
190    /// On GPU: emits `shfl.sync.down.b32`.
191    /// On CPU: returns input (identity).
192    pub fn shuffle_down<T: GpuValue + crate::gpu::GpuShuffle>(
193        &self,
194        data: PerLane<T>,
195        delta: u32,
196    ) -> PerLane<T> {
197        PerLane::new(data.get().gpu_shfl_down(delta))
198    }
199
200    /// Sum reduction across all lanes.
201    ///
202    /// Returns `Uniform<T>` because a full-warp reduction produces the same
203    /// result in every lane.
204    ///
205    /// On GPU: butterfly reduction using log2(WARP_SIZE) shuffle-XOR + add steps.
206    /// On CPU: returns val × WARP_SIZE (butterfly doubling via identity shuffle).
207    ///
208    /// **Overflow note:** GPU hardware wraps on integer overflow (two's complement,
209    /// verified on RTX 4000 Ada — see `reproduce/gpu_semantics_test.cu`). This
210    /// function uses `+` (Rust's `Add` trait), which panics in debug on overflow.
211    /// For GPU-faithful wrapping semantics, use [`reduce_sum_wrapping`](Self::reduce_sum_wrapping).
212    pub fn reduce_sum<T: GpuValue + crate::gpu::GpuShuffle + core::ops::Add<Output = T>>(
213        &self,
214        data: PerLane<T>,
215    ) -> Uniform<T> {
216        let mut val = data.get();
217        #[cfg(feature = "warp64")]
218        {
219            val = val + val.gpu_shfl_xor(32);
220        }
221        val = val + val.gpu_shfl_xor(16);
222        val = val + val.gpu_shfl_xor(8);
223        val = val + val.gpu_shfl_xor(4);
224        val = val + val.gpu_shfl_xor(2);
225        val = val + val.gpu_shfl_xor(1);
226        Uniform::from_const(val)
227    }
228
229    /// Wrapping butterfly reduce-sum matching GPU hardware overflow semantics.
230    ///
231    /// GPU integer arithmetic wraps on overflow (two's complement, no trap).
232    /// This variant uses `wrapping_add` to match that behavior exactly.
233    /// Hardware-verified on RTX 4000 Ada — see `reproduce/gpu_semantics_test.cu`.
234    pub fn reduce_sum_wrapping_i32(&self, data: PerLane<i32>) -> Uniform<i32> {
235        let mut val = data.get();
236        #[cfg(feature = "warp64")]
237        {
238            val = val.wrapping_add(val.gpu_shfl_xor(32));
239        }
240        val = val.wrapping_add(val.gpu_shfl_xor(16));
241        val = val.wrapping_add(val.gpu_shfl_xor(8));
242        val = val.wrapping_add(val.gpu_shfl_xor(4));
243        val = val.wrapping_add(val.gpu_shfl_xor(2));
244        val = val.wrapping_add(val.gpu_shfl_xor(1));
245        Uniform::from_const(val)
246    }
247
248    /// Wrapping reduce-sum for `u32` — GPU hardware overflow semantics.
249    pub fn reduce_sum_wrapping_u32(&self, data: PerLane<u32>) -> Uniform<u32> {
250        let mut val = data.get();
251        #[cfg(feature = "warp64")]
252        {
253            val = val.wrapping_add(val.gpu_shfl_xor(32));
254        }
255        val = val.wrapping_add(val.gpu_shfl_xor(16));
256        val = val.wrapping_add(val.gpu_shfl_xor(8));
257        val = val.wrapping_add(val.gpu_shfl_xor(4));
258        val = val.wrapping_add(val.gpu_shfl_xor(2));
259        val = val.wrapping_add(val.gpu_shfl_xor(1));
260        Uniform::from_const(val)
261    }
262
263    /// Wrapping reduce-sum for `i64` — GPU hardware overflow semantics.
264    pub fn reduce_sum_wrapping_i64(&self, data: PerLane<i64>) -> Uniform<i64> {
265        let mut val = data.get();
266        #[cfg(feature = "warp64")]
267        {
268            val = val.wrapping_add(val.gpu_shfl_xor(32));
269        }
270        val = val.wrapping_add(val.gpu_shfl_xor(16));
271        val = val.wrapping_add(val.gpu_shfl_xor(8));
272        val = val.wrapping_add(val.gpu_shfl_xor(4));
273        val = val.wrapping_add(val.gpu_shfl_xor(2));
274        val = val.wrapping_add(val.gpu_shfl_xor(1));
275        Uniform::from_const(val)
276    }
277
278    /// Wrapping reduce-sum for `u64` — GPU hardware overflow semantics.
279    pub fn reduce_sum_wrapping_u64(&self, data: PerLane<u64>) -> Uniform<u64> {
280        let mut val = data.get();
281        #[cfg(feature = "warp64")]
282        {
283            val = val.wrapping_add(val.gpu_shfl_xor(32));
284        }
285        val = val.wrapping_add(val.gpu_shfl_xor(16));
286        val = val.wrapping_add(val.gpu_shfl_xor(8));
287        val = val.wrapping_add(val.gpu_shfl_xor(4));
288        val = val.wrapping_add(val.gpu_shfl_xor(2));
289        val = val.wrapping_add(val.gpu_shfl_xor(1));
290        Uniform::from_const(val)
291    }
292
293    /// Warp ballot: collect a predicate from all lanes into a bitmask.
294    ///
295    /// Every lane gets the same bitmask — the result is `Uniform<u64>`.
296    /// Requires `Warp<All>` because reading predicates from inactive lanes
297    /// is undefined behavior.
298    ///
299    /// On GPU (nvptx64): calls `vote.sync.ballot.b32` via PTX inline asm.
300    /// On CPU: returns mask with bit 0 set if predicate is true (single-thread identity).
301    pub fn ballot(&self, predicate: PerLane<bool>) -> BallotResult {
302        #[cfg(target_arch = "nvptx64")]
303        {
304            let mask = crate::gpu::ballot_sync(0xFFFFFFFF, predicate.get()) as u64;
305            BallotResult::from_mask(Uniform::from_const(mask))
306        }
307        #[cfg(not(target_arch = "nvptx64"))]
308        {
309            // CPU emulation: single thread, so ballot = predicate in lane 0
310            let mask = if predicate.get() { 1u64 } else { 0u64 };
311            BallotResult::from_mask(Uniform::from_const(mask))
312        }
313    }
314
315    /// Broadcast: all lanes get the same value.
316    pub fn broadcast<T: GpuValue>(&self, value: T) -> PerLane<T> {
317        PerLane::new(value)
318    }
319
320    /// Shuffle XOR on a raw scalar — convenience that skips PerLane wrapping.
321    ///
322    /// Equivalent to `self.shuffle_xor(PerLane::new(val), mask).get()` but
323    /// avoids the verbosity of wrapping/unwrapping for the common case.
324    pub fn shuffle_xor_raw<T: GpuValue + crate::gpu::GpuShuffle>(&self, val: T, mask: u32) -> T {
325        val.gpu_shfl_xor(mask)
326    }
327
328    /// Shuffle down on a raw scalar — convenience that skips PerLane wrapping.
329    pub fn shuffle_down_raw<T: GpuValue + crate::gpu::GpuShuffle>(&self, val: T, delta: u32) -> T {
330        val.gpu_shfl_down(delta)
331    }
332}
333
334// ============================================================================
335// Permutation algebra (from shuffle duality research)
336// ============================================================================
337
338/// A permutation on lane indices [0, WARP_SIZE).
339pub trait Permutation: Copy + Clone {
340    /// Where does lane `i` send its value?
341    fn forward(i: u32) -> u32;
342    /// Where does lane `i` receive from? Invariant: `inverse(forward(i)) == i`.
343    fn inverse(i: u32) -> u32;
344    /// Is this permutation its own inverse (involution)?
345    fn is_self_dual() -> bool {
346        (0..crate::WARP_SIZE).all(|i| Self::forward(i) == Self::inverse(i))
347    }
348}
349
350/// The dual (inverse) of a permutation.
351pub trait HasDual: Permutation {
352    type Dual: Permutation;
353}
354
355/// XOR shuffle: lane i exchanges with lane i ⊕ mask.
356///
357/// XOR shuffles are involutions (self-dual) and form the group (Z₂)^log₂(WARP_SIZE).
358#[derive(Copy, Clone, Debug)]
359pub struct Xor<const MASK: u32>;
360
361impl<const MASK: u32> Permutation for Xor<MASK> {
362    fn forward(i: u32) -> u32 {
363        (i ^ MASK) & (crate::WARP_SIZE - 1)
364    }
365    fn inverse(i: u32) -> u32 {
366        (i ^ MASK) & (crate::WARP_SIZE - 1)
367    }
368    fn is_self_dual() -> bool {
369        true
370    }
371}
372
373impl<const MASK: u32> HasDual for Xor<MASK> {
374    type Dual = Xor<MASK>;
375}
376
377/// Rotate down: lane i receives from lane (i + delta) mod WARP_SIZE.
378///
379/// **Not the same as CUDA `__shfl_down_sync`**: this is a modular rotation
380/// (wraps around), whereas CUDA's `__shfl_down_sync` clamps (out-of-range
381/// lanes read their own value). Data flows from higher-numbered lanes to
382/// lower. `forward(i)` returns the *destination* of lane i's value
383/// (lane i - delta), while `inverse(i)` returns lane i's *source* (lane i + delta).
384#[derive(Copy, Clone, Debug)]
385pub struct RotateDown<const DELTA: u32>;
386
387/// Rotate up: lane i receives from lane (i - delta) mod WARP_SIZE.
388///
389/// Dual of `RotateDown`. Data flows from lower-numbered lanes to higher.
390#[derive(Copy, Clone, Debug)]
391pub struct RotateUp<const DELTA: u32>;
392
393impl<const DELTA: u32> Permutation for RotateDown<DELTA> {
394    fn forward(i: u32) -> u32 {
395        let mask = crate::WARP_SIZE - 1;
396        (i + crate::WARP_SIZE - (DELTA & mask)) & mask
397    }
398    fn inverse(i: u32) -> u32 {
399        let mask = crate::WARP_SIZE - 1;
400        (i + (DELTA & mask)) & mask
401    }
402    fn is_self_dual() -> bool {
403        let mask = crate::WARP_SIZE - 1;
404        (DELTA & mask) == 0 || (DELTA & mask) == crate::WARP_SIZE / 2
405    }
406}
407
408impl<const DELTA: u32> Permutation for RotateUp<DELTA> {
409    fn forward(i: u32) -> u32 {
410        let mask = crate::WARP_SIZE - 1;
411        (i + (DELTA & mask)) & mask
412    }
413    fn inverse(i: u32) -> u32 {
414        let mask = crate::WARP_SIZE - 1;
415        (i + crate::WARP_SIZE - (DELTA & mask)) & mask
416    }
417    fn is_self_dual() -> bool {
418        let mask = crate::WARP_SIZE - 1;
419        (DELTA & mask) == 0 || (DELTA & mask) == crate::WARP_SIZE / 2
420    }
421}
422
423impl<const DELTA: u32> HasDual for RotateDown<DELTA> {
424    type Dual = RotateUp<DELTA>;
425}
426
427impl<const DELTA: u32> HasDual for RotateUp<DELTA> {
428    type Dual = RotateDown<DELTA>;
429}
430
431/// Identity permutation.
432#[derive(Copy, Clone, Debug)]
433pub struct Identity;
434
435impl Permutation for Identity {
436    fn forward(i: u32) -> u32 {
437        i & (crate::WARP_SIZE - 1)
438    }
439    fn inverse(i: u32) -> u32 {
440        i & (crate::WARP_SIZE - 1)
441    }
442    fn is_self_dual() -> bool {
443        true
444    }
445}
446
447impl HasDual for Identity {
448    type Dual = Identity;
449}
450
451/// Composition of two permutations: apply P1 then P2.
452#[derive(Copy, Clone, Debug)]
453pub struct Compose<P1: Permutation, P2: Permutation>(PhantomData<(P1, P2)>);
454
455impl<P1: Permutation, P2: Permutation> Permutation for Compose<P1, P2> {
456    fn forward(i: u32) -> u32 {
457        P2::forward(P1::forward(i))
458    }
459    fn inverse(i: u32) -> u32 {
460        P1::inverse(P2::inverse(i))
461    }
462}
463
464impl<P1: Permutation + HasDual, P2: Permutation + HasDual> HasDual for Compose<P1, P2> {
465    type Dual = Compose<P2::Dual, P1::Dual>;
466}
467
468// Butterfly network type aliases
469pub type ButterflyStage0 = Xor<1>;
470pub type ButterflyStage1 = Xor<2>;
471pub type ButterflyStage2 = Xor<4>;
472pub type ButterflyStage3 = Xor<8>;
473pub type ButterflyStage4 = Xor<16>;
474
475/// 32-lane full butterfly: XOR stages 1|2|4|8|16.
476#[cfg(not(feature = "warp64"))]
477pub type FullButterfly = Compose<
478    Compose<Compose<Compose<ButterflyStage0, ButterflyStage1>, ButterflyStage2>, ButterflyStage3>,
479    ButterflyStage4,
480>;
481
482/// 64-lane butterfly adds XOR<32> as the 6th stage.
483#[cfg(feature = "warp64")]
484pub type ButterflyStage5 = Xor<32>;
485
486/// 64-lane full butterfly: XOR stages 1|2|4|8|16|32.
487#[cfg(feature = "warp64")]
488pub type FullButterfly = Compose<
489    Compose<
490        Compose<
491            Compose<Compose<ButterflyStage0, ButterflyStage1>, ButterflyStage2>,
492            ButterflyStage3,
493        >,
494        ButterflyStage4,
495    >,
496    ButterflyStage5,
497>;
498
499/// Apply a permutation to an array of values.
500#[cfg(not(feature = "warp64"))]
501pub fn shuffle_by<T: Copy, P: Permutation>(values: [T; 32], _perm: P) -> [T; 32] {
502    let mut result = values;
503    for (i, slot) in result.iter_mut().enumerate() {
504        let src = (P::inverse(i as u32) & (crate::WARP_SIZE - 1)) as usize;
505        *slot = values[src];
506    }
507    result
508}
509
510/// Apply a permutation to an array of values.
511#[cfg(feature = "warp64")]
512pub fn shuffle_by<T: Copy, P: Permutation>(values: [T; 64], _perm: P) -> [T; 64] {
513    let mut result = values;
514    for (i, slot) in result.iter_mut().enumerate() {
515        let src = (P::inverse(i as u32) & (crate::WARP_SIZE - 1)) as usize;
516        *slot = values[src];
517    }
518    result
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    #[test]
526    fn test_ballot_result_empty_mask() {
527        let result = BallotResult {
528            mask: Uniform::from_const(0),
529        };
530        assert_eq!(result.first_lane(), None);
531        assert_eq!(result.popcount().get(), 0);
532        assert!(!result.lane_voted(LaneId::new(0)).get());
533        assert!(!result.lane_voted(LaneId::new(31)).get());
534    }
535
536    #[test]
537    fn test_ballot_result() {
538        let result = BallotResult {
539            mask: Uniform::from_const(0b1010_1010),
540        };
541        assert!(!result.lane_voted(LaneId::new(0)).get());
542        assert!(result.lane_voted(LaneId::new(1)).get());
543        assert_eq!(result.popcount().get(), 4);
544        assert_eq!(result.first_lane(), Some(LaneId::new(1)));
545    }
546
547    #[test]
548    fn test_shuffle_only_on_all() {
549        let all: Warp<All> = Warp::new();
550        let data = PerLane::new(42i32);
551        let _shuffled = all.shuffle_xor(data, 1);
552        let _reduced = all.reduce_sum(data);
553    }
554
555    #[test]
556    fn test_shuffle_64bit_types() {
557        let all: Warp<All> = Warp::new();
558
559        // i64: two-pass shuffle on GPU, identity on CPU
560        let data_i64 = PerLane::new(0x0000_0001_0000_0002_i64);
561        let shuffled_i64 = all.shuffle_xor(data_i64, 1);
562        assert_eq!(shuffled_i64.get(), 0x0000_0001_0000_0002_i64);
563
564        // u64
565        let data_u64 = PerLane::new(u64::MAX);
566        let shuffled_u64 = all.shuffle_xor(data_u64, 1);
567        assert_eq!(shuffled_u64.get(), u64::MAX);
568
569        // f64: bit-preserving two-pass
570        #[allow(clippy::approx_constant)]
571        let data_f64 = PerLane::new(3.14159_f64);
572        let shuffled_f64 = all.shuffle_xor(data_f64, 1);
573        #[allow(clippy::approx_constant)]
574        let expected_f64 = 3.14159_f64;
575        assert_eq!(shuffled_f64.get(), expected_f64);
576
577        // Reduction works on 64-bit
578        let ones_i64 = PerLane::new(1_i64);
579        let sum = all.reduce_sum(ones_i64);
580        assert_eq!(sum.get(), crate::WARP_SIZE as i64);
581    }
582
583    #[test]
584    fn test_reduce_sum_wrapping_i32() {
585        let all: Warp<All> = Warp::new();
586        let data = PerLane::new(i32::MAX);
587        let result = all.reduce_sum_wrapping_i32(data);
588        let mut expected = i32::MAX;
589        let stages = crate::WARP_SIZE.trailing_zeros();
590        for _ in 0..stages {
591            expected = expected.wrapping_add(expected);
592        }
593        assert_eq!(result.get(), expected);
594    }
595
596    #[test]
597    fn test_reduce_sum_wrapping_u32() {
598        let all: Warp<All> = Warp::new();
599        let data = PerLane::new(u32::MAX);
600        let result = all.reduce_sum_wrapping_u32(data);
601        let mut expected = u32::MAX;
602        let stages = crate::WARP_SIZE.trailing_zeros();
603        for _ in 0..stages {
604            expected = expected.wrapping_add(expected);
605        }
606        assert_eq!(result.get(), expected);
607    }
608
609    #[test]
610    fn test_xor_self_dual() {
611        assert!(Xor::<5>::is_self_dual());
612        let ws = crate::WARP_SIZE;
613        let mask_bits = ws - 1;
614        for mask in 0..ws {
615            for lane in 0..ws {
616                let after_two = (((lane ^ mask) & mask_bits) ^ mask) & mask_bits;
617                assert_eq!(after_two, lane);
618            }
619        }
620    }
621
622    #[test]
623    fn test_rotate_duality() {
624        for lane in 0..crate::WARP_SIZE {
625            let down_then_up = RotateUp::<1>::forward(RotateDown::<1>::forward(lane));
626            assert_eq!(down_then_up, lane);
627        }
628    }
629
630    #[test]
631    fn test_shuffle_roundtrip() {
632        let original: [i32; crate::WARP_SIZE as usize] = core::array::from_fn(|i| i as i32);
633        let shuffled = shuffle_by(original, Xor::<5>);
634        let unshuffled = shuffle_by(shuffled, Xor::<5>);
635        assert_eq!(unshuffled, original);
636    }
637
638    #[test]
639    fn test_butterfly_permutation() {
640        // Full butterfly: XOR with all stages = WARP_SIZE-1, so maps i → i ^ (WARP_SIZE-1)
641        let ws = crate::WARP_SIZE;
642        for i in 0..ws {
643            assert_eq!(FullButterfly::forward(i), i ^ (ws - 1));
644        }
645    }
646
647    #[test]
648    fn test_compose_associative() {
649        for i in 0..crate::WARP_SIZE {
650            let ab_c = Compose::<Compose<Xor<3>, Xor<5>>, Xor<7>>::forward(i);
651            let a_bc = Compose::<Xor<3>, Compose<Xor<5>, Xor<7>>>::forward(i);
652            assert_eq!(ab_c, a_bc);
653        }
654    }
655
656    // ========================================================================
657    // shuffle_xor_within tests (§3.3 SHUFFLE-WITHIN typing rule)
658    // ========================================================================
659
660    #[test]
661    fn test_xor_mask_preserves_active_set_all() {
662        // All lanes active: any mask preserves the set.
663        for mask in 0..crate::WARP_SIZE {
664            assert!(
665                xor_mask_preserves_active_set(crate::active_set::All::MASK, mask),
666                "All should accept mask {mask}"
667            );
668        }
669    }
670
671    #[test]
672    fn test_xor_mask_preserves_even() {
673        use crate::active_set::Even;
674        // Even: mask 2 maps 0↔2, 4↔6, etc. — stays within Even.
675        assert!(xor_mask_preserves_active_set(Even::MASK, 2));
676        assert!(xor_mask_preserves_active_set(Even::MASK, 4));
677        assert!(xor_mask_preserves_active_set(Even::MASK, 6));
678        // Even: mask 0 is identity — always preserves.
679        assert!(xor_mask_preserves_active_set(Even::MASK, 0));
680        // Even: mask 1 maps even→odd — does NOT preserve.
681        assert!(!xor_mask_preserves_active_set(Even::MASK, 1));
682        assert!(!xor_mask_preserves_active_set(Even::MASK, 3));
683        assert!(!xor_mask_preserves_active_set(Even::MASK, 5));
684    }
685
686    #[test]
687    fn test_xor_mask_preserves_odd() {
688        use crate::active_set::Odd;
689        // Odd: even masks preserve (same group structure as Even).
690        assert!(xor_mask_preserves_active_set(Odd::MASK, 2));
691        assert!(xor_mask_preserves_active_set(Odd::MASK, 4));
692        // Odd: odd masks do NOT preserve.
693        assert!(!xor_mask_preserves_active_set(Odd::MASK, 1));
694        assert!(!xor_mask_preserves_active_set(Odd::MASK, 3));
695    }
696
697    #[test]
698    fn test_xor_mask_preserves_low_half() {
699        use crate::active_set::LowHalf;
700        let half = crate::WARP_SIZE / 2;
701        // LowHalf: lanes 0..half-1. Masks < half stay within LowHalf.
702        for mask in 0..half {
703            assert!(
704                xor_mask_preserves_active_set(LowHalf::MASK, mask),
705                "LowHalf should accept mask {mask}"
706            );
707        }
708        // Mask = half maps lane 0 outside LowHalf — does NOT preserve.
709        assert!(!xor_mask_preserves_active_set(LowHalf::MASK, half));
710        assert!(!xor_mask_preserves_active_set(LowHalf::MASK, half + 1));
711    }
712
713    #[test]
714    fn test_xor_mask_preserves_high_half() {
715        use crate::active_set::HighHalf;
716        let half = crate::WARP_SIZE / 2;
717        // HighHalf: lanes half..WARP_SIZE-1. Masks < half stay within HighHalf.
718        for mask in 0..half {
719            assert!(
720                xor_mask_preserves_active_set(HighHalf::MASK, mask),
721                "HighHalf should accept mask {mask}"
722            );
723        }
724        // Mask = half maps lane half→0 (outside HighHalf) — does NOT preserve.
725        assert!(!xor_mask_preserves_active_set(HighHalf::MASK, half));
726    }
727
728    #[test]
729    fn test_xor_mask_preserves_even_low() {
730        use crate::active_set::EvenLow;
731        let half = crate::WARP_SIZE / 2;
732        // EvenLow: even lanes in 0..half-1. Must be even AND < half.
733        assert!(xor_mask_preserves_active_set(EvenLow::MASK, 2));
734        assert!(xor_mask_preserves_active_set(EvenLow::MASK, 4));
735        assert!(xor_mask_preserves_active_set(EvenLow::MASK, 6));
736        // Mask 1 would go even→odd — fails.
737        assert!(!xor_mask_preserves_active_set(EvenLow::MASK, 1));
738        // Mask = half would go low→high — fails.
739        assert!(!xor_mask_preserves_active_set(EvenLow::MASK, half));
740    }
741
742    #[test]
743    fn test_shuffle_xor_within_on_warp_all() {
744        // Warp<All> should accept any mask via shuffle_xor_within.
745        let warp: Warp<All> = Warp::new();
746        let data = PerLane::new(42i32);
747        for mask in [0, 1, 2, 5, 16, 31] {
748            let result = warp.shuffle_xor_within(data, mask);
749            // CPU identity: shuffle returns input.
750            assert_eq!(result.get(), 42);
751        }
752    }
753
754    #[test]
755    fn test_shuffle_xor_within_on_even() {
756        use crate::active_set::Even;
757        let warp: Warp<Even> = Warp::new();
758        let data = PerLane::new(99i32);
759        // Even masks (2, 4, 6) should succeed.
760        let r = warp.shuffle_xor_within(data, 2);
761        assert_eq!(r.get(), 99); // CPU identity
762        let r = warp.shuffle_xor_within(data, 4);
763        assert_eq!(r.get(), 99);
764    }
765
766    #[test]
767    #[should_panic(expected = "does not preserve active set")]
768    fn test_shuffle_xor_within_even_rejects_odd_mask() {
769        use crate::active_set::Even;
770        let warp: Warp<Even> = Warp::new();
771        let data = PerLane::new(99i32);
772        // Mask 1 maps even→odd — should panic.
773        let _ = warp.shuffle_xor_within(data, 1);
774    }
775
776    #[test]
777    #[should_panic(expected = "does not preserve active set")]
778    fn test_shuffle_xor_within_low_half_rejects_high_mask() {
779        use crate::active_set::LowHalf;
780        let warp: Warp<LowHalf> = Warp::new();
781        let data = PerLane::new(7i32);
782        // Mask = half maps low→high — should panic.
783        let _ = warp.shuffle_xor_within(data, crate::WARP_SIZE / 2);
784    }
785
786    #[test]
787    fn test_shuffle_xor_within_simwarp_even_mask2() {
788        // Verify real lane exchange using SimWarp.
789        // Even lanes: {0, 2, 4, 6, 8, ...}. XOR mask 2: 0↔2, 4↔6, 8↔10, ...
790        use crate::simwarp::SimWarp;
791
792        let sw = SimWarp::<i32>::new(|i| i as i32 * 10);
793        let shuffled = sw.shuffle_xor(2);
794
795        // Lane 0 gets lane 2's value, lane 2 gets lane 0's value.
796        assert_eq!(shuffled.lane(0), 20); // was 0, now 2*10
797        assert_eq!(shuffled.lane(2), 0); // was 20, now 0*10
798        assert_eq!(shuffled.lane(4), 60); // was 40, now 6*10
799        assert_eq!(shuffled.lane(6), 40); // was 60, now 4*10
800
801        // Verify the preservation property: for Even lanes and mask 2,
802        // every even lane's partner (lane ^ 2) is also even.
803        for lane in (0..crate::WARP_SIZE).step_by(2) {
804            let partner = lane ^ 2;
805            assert_eq!(
806                partner % 2,
807                0,
808                "lane {lane}'s partner {partner} should be even"
809            );
810        }
811    }
812
813    #[test]
814    fn test_shuffle_xor_within_simwarp_odd_mask2() {
815        // Odd lanes: {1, 3, 5, 7, ...}. XOR mask 2: 1↔3, 5↔7, 9↔11, ...
816        use crate::simwarp::SimWarp;
817
818        let sw = SimWarp::<i32>::new(|i| i as i32);
819        let shuffled = sw.shuffle_xor(2);
820
821        // Lane 1 gets lane 3's value, lane 3 gets lane 1's value.
822        assert_eq!(shuffled.lane(1), 3);
823        assert_eq!(shuffled.lane(3), 1);
824        assert_eq!(shuffled.lane(5), 7);
825        assert_eq!(shuffled.lane(7), 5);
826
827        // Verify: for Odd lanes and mask 2, every odd partner is odd.
828        for lane in (1..crate::WARP_SIZE).step_by(2) {
829            let partner = lane ^ 2;
830            assert_ne!(
831                partner % 2,
832                0,
833                "lane {lane}'s partner {partner} should be odd"
834            );
835        }
836    }
837
838    #[test]
839    fn test_shuffle_xor_within_simwarp_low_half() {
840        // LowHalf: lanes 0..15. XOR mask 8: 0↔8, 1↔9, 2↔10, ..., 7↔15.
841        use crate::simwarp::SimWarp;
842
843        let sw = SimWarp::<i32>::new(|i| i as i32 * 3);
844        let shuffled = sw.shuffle_xor(8);
845
846        // All partners stay within 0..15.
847        assert_eq!(shuffled.lane(0), 24); // lane 8's value: 8*3
848        assert_eq!(shuffled.lane(8), 0); // lane 0's value: 0*3
849        assert_eq!(shuffled.lane(7), 45); // lane 15's value: 15*3
850        assert_eq!(shuffled.lane(15), 21); // lane 7's value: 7*3
851
852        for lane in 0..16u32 {
853            let partner = lane ^ 8;
854            assert!(
855                partner < 16,
856                "lane {lane}'s partner {partner} should be in LowHalf"
857            );
858        }
859    }
860
861    #[test]
862    fn test_shuffle_xor_within_after_diverge() {
863        // End-to-end: diverge, shuffle within sub-warp, merge.
864        let warp: Warp<All> = Warp::kernel_entry();
865        let data = PerLane::new(42i32);
866
867        let (evens, odds) = warp.diverge_even_odd();
868
869        // Both sub-warps can shuffle with even masks.
870        let _even_shuffled = evens.shuffle_xor_within(data, 2);
871        let _odd_shuffled = odds.shuffle_xor_within(data, 4);
872
873        // Merge back to All.
874        let _merged: Warp<All> = crate::merge::merge(evens, odds);
875    }
876}