Skip to main content

warp_types/
shuffle.rs

1//! Shuffle operations and permutation algebra.
2//!
3//! Shuffles let lanes exchange values within a warp. This module provides:
4//!
5//! 1. **Type-safe shuffle traits** — enforce correct return types
6//!    (shuffle → PerLane, ballot → Uniform, reduce → T)
7//! 2. **`Warp<All>`-restricted shuffles** — shuffle methods only on full warps
8//! 3. **Permutation algebra** — XOR/Rotate/Compose with group-theoretic properties
9
10use crate::active_set::All;
11use crate::data::{LaneId, PerLane, Uniform};
12use crate::warp::Warp;
13use crate::GpuValue;
14use core::marker::PhantomData;
15
16/// Result of a warp ballot operation.
17///
18/// A ballot collects a predicate from all lanes into a bitmask.
19/// The result is Uniform because every lane gets the same bitmask.
20///
21/// **Note:** The mask is `u32`, matching NVIDIA's `__ballot_sync` return type.
22/// For AMD 64-lane wavefronts, a `BallotResult64` variant would be needed.
23/// Lanes >= 32 always return `false` from `lane_voted()`.
24#[derive(Clone, Copy, Debug)]
25pub struct BallotResult {
26    mask: Uniform<u32>,
27}
28
29impl BallotResult {
30    /// Create a ballot result from a uniform mask.
31    pub fn from_mask(mask: Uniform<u32>) -> Self {
32        BallotResult { mask }
33    }
34
35    pub fn mask(self) -> Uniform<u32> {
36        self.mask
37    }
38
39    pub fn lane_voted(self, lane: LaneId) -> Uniform<bool> {
40        let id = lane.get();
41        // Guard: u32 mask only covers lanes 0..31. Lanes >= 32 are outside
42        // the ballot scope (would need BallotResult64 for AMD wavefronts).
43        if id >= 32 {
44            return Uniform::from_const(false);
45        }
46        Uniform::from_const((self.mask.get() & (1u32 << id)) != 0)
47    }
48
49    pub fn popcount(self) -> Uniform<u32> {
50        Uniform::from_const(self.mask.get().count_ones())
51    }
52
53    pub fn first_lane(self) -> Option<LaneId> {
54        let tz = self.mask.get().trailing_zeros();
55        if tz < 32 {
56            Some(LaneId::new(tz as u8))
57        } else {
58            None
59        }
60    }
61}
62
63// ============================================================================
64// Shuffle safety marker (for error message improvement)
65// ============================================================================
66
67/// Marker trait for warp types that support shuffle operations.
68///
69/// Currently only `Warp<All>` implements this. `Tile<N>` has its own shuffle
70/// methods (all tile lanes are active by construction) but does not implement
71/// this trait. If you get an error mentioning this trait, it means you're
72/// trying to shuffle on a diverged warp — merge back to `Warp<All>` first.
73#[diagnostic::on_unimplemented(
74    message = "shuffle requires all lanes active, but `{Self}` may have inactive lanes",
75    label = "this warp may be diverged — shuffle needs Warp<All>",
76    note = "after diverge_even_odd(), call merge(evens, odds) to get Warp<All> back, then shuffle"
77)]
78pub trait ShuffleSafe {}
79
80impl ShuffleSafe for Warp<All> {}
81
82// ============================================================================
83// Shuffle operations restricted to Warp<All>
84// ============================================================================
85
86impl Warp<All> {
87    /// Shuffle XOR: each lane exchanges with lane (id ^ mask).
88    ///
89    /// **ONLY AVAILABLE ON `Warp<All>`** — diverged warps cannot shuffle.
90    ///
91    /// On GPU: emits `shfl.sync.bfly.b32` via inline assembly.
92    /// On CPU: returns the input value (single-thread identity).
93    pub fn shuffle_xor<T: GpuValue + crate::gpu::GpuShuffle>(
94        &self,
95        data: PerLane<T>,
96        mask: u32,
97    ) -> PerLane<T> {
98        PerLane::new(data.get().gpu_shfl_xor(mask))
99    }
100
101    /// Shuffle down: lane\[i\] reads from lane\[i+delta\].
102    ///
103    /// On GPU: emits `shfl.sync.down.b32`.
104    /// On CPU: returns input (identity).
105    pub fn shuffle_down<T: GpuValue + crate::gpu::GpuShuffle>(
106        &self,
107        data: PerLane<T>,
108        delta: u32,
109    ) -> PerLane<T> {
110        PerLane::new(data.get().gpu_shfl_down(delta))
111    }
112
113    /// Sum reduction across all lanes.
114    ///
115    /// Returns `Uniform<T>` because a full-warp reduction produces the same
116    /// result in every lane.
117    ///
118    /// On GPU: butterfly reduction using 5 shuffle-XOR + add steps.
119    /// On CPU: returns val × 32 (butterfly doubling via identity shuffle).
120    pub fn reduce_sum<T: GpuValue + crate::gpu::GpuShuffle + core::ops::Add<Output = T>>(
121        &self,
122        data: PerLane<T>,
123    ) -> Uniform<T> {
124        let mut val = data.get();
125        val = val + val.gpu_shfl_xor(16);
126        val = val + val.gpu_shfl_xor(8);
127        val = val + val.gpu_shfl_xor(4);
128        val = val + val.gpu_shfl_xor(2);
129        val = val + val.gpu_shfl_xor(1);
130        Uniform::from_const(val)
131    }
132
133    /// Warp ballot: collect a predicate from all lanes into a bitmask.
134    ///
135    /// Every lane gets the same bitmask — the result is `Uniform<u32>`.
136    /// Requires `Warp<All>` because reading predicates from inactive lanes
137    /// is undefined behavior.
138    ///
139    /// **Note:** Currently CPU-emulation only on all targets. A GPU codepath
140    /// using `gpu::ballot_sync` requires `#[cfg(target_arch = "nvptx64")]` gating.
141    /// On CPU: returns mask with bit 0 set if predicate is true (single-thread identity).
142    pub fn ballot(&self, predicate: PerLane<bool>) -> BallotResult {
143        // CPU emulation: single thread, so ballot = predicate in lane 0
144        let mask = if predicate.get() { 1u32 } else { 0u32 };
145        BallotResult::from_mask(Uniform::from_const(mask))
146    }
147
148    /// Broadcast: all lanes get the same value.
149    pub fn broadcast<T: GpuValue>(&self, value: T) -> PerLane<T> {
150        PerLane::new(value)
151    }
152
153    /// Shuffle XOR on a raw scalar — convenience that skips PerLane wrapping.
154    ///
155    /// Equivalent to `self.shuffle_xor(PerLane::new(val), mask).get()` but
156    /// avoids the verbosity of wrapping/unwrapping for the common case.
157    pub fn shuffle_xor_raw<T: GpuValue + crate::gpu::GpuShuffle>(&self, val: T, mask: u32) -> T {
158        val.gpu_shfl_xor(mask)
159    }
160
161    /// Shuffle down on a raw scalar — convenience that skips PerLane wrapping.
162    pub fn shuffle_down_raw<T: GpuValue + crate::gpu::GpuShuffle>(&self, val: T, delta: u32) -> T {
163        val.gpu_shfl_down(delta)
164    }
165}
166
167// ============================================================================
168// Permutation algebra (from shuffle duality research)
169// ============================================================================
170
171/// A permutation on lane indices [0, 32).
172pub trait Permutation: Copy + Clone {
173    /// Where does lane `i` send its value?
174    fn forward(i: u32) -> u32;
175    /// Where does lane `i` receive from? Invariant: `inverse(forward(i)) == i`.
176    fn inverse(i: u32) -> u32;
177    /// Is this permutation its own inverse (involution)?
178    fn is_self_dual() -> bool {
179        (0..32).all(|i| Self::forward(i) == Self::inverse(i))
180    }
181}
182
183/// The dual (inverse) of a permutation.
184pub trait HasDual: Permutation {
185    type Dual: Permutation;
186}
187
188/// XOR shuffle: lane i exchanges with lane i ⊕ mask.
189///
190/// XOR shuffles are involutions (self-dual) and form the group (Z₂)⁵.
191#[derive(Copy, Clone, Debug)]
192pub struct Xor<const MASK: u32>;
193
194impl<const MASK: u32> Permutation for Xor<MASK> {
195    fn forward(i: u32) -> u32 {
196        (i ^ MASK) & 0x1F
197    }
198    fn inverse(i: u32) -> u32 {
199        (i ^ MASK) & 0x1F
200    }
201    fn is_self_dual() -> bool {
202        true
203    }
204}
205
206impl<const MASK: u32> HasDual for Xor<MASK> {
207    type Dual = Xor<MASK>;
208}
209
210/// Rotate down: lane i receives from lane (i + delta) mod 32.
211///
212/// Consistent with CUDA `__shfl_down_sync`: data flows from higher-numbered
213/// lanes to lower. `forward(i)` returns the *destination* of lane i's value
214/// (lane i - delta), while `inverse(i)` returns lane i's *source* (lane i + delta).
215#[derive(Copy, Clone, Debug)]
216pub struct RotateDown<const DELTA: u32>;
217
218/// Rotate up: lane i receives from lane (i - delta) mod 32.
219///
220/// Dual of `RotateDown`. Data flows from lower-numbered lanes to higher.
221#[derive(Copy, Clone, Debug)]
222pub struct RotateUp<const DELTA: u32>;
223
224impl<const DELTA: u32> Permutation for RotateDown<DELTA> {
225    fn forward(i: u32) -> u32 {
226        (i + 32 - (DELTA & 0x1F)) & 0x1F
227    }
228    fn inverse(i: u32) -> u32 {
229        (i + (DELTA & 0x1F)) & 0x1F
230    }
231    fn is_self_dual() -> bool {
232        (DELTA & 0x1F) == 0 || (DELTA & 0x1F) == 16
233    }
234}
235
236impl<const DELTA: u32> Permutation for RotateUp<DELTA> {
237    fn forward(i: u32) -> u32 {
238        (i + (DELTA & 0x1F)) & 0x1F
239    }
240    fn inverse(i: u32) -> u32 {
241        (i + 32 - (DELTA & 0x1F)) & 0x1F
242    }
243    fn is_self_dual() -> bool {
244        (DELTA & 0x1F) == 0 || (DELTA & 0x1F) == 16
245    }
246}
247
248impl<const DELTA: u32> HasDual for RotateDown<DELTA> {
249    type Dual = RotateUp<DELTA>;
250}
251
252impl<const DELTA: u32> HasDual for RotateUp<DELTA> {
253    type Dual = RotateDown<DELTA>;
254}
255
256/// Identity permutation.
257#[derive(Copy, Clone, Debug)]
258pub struct Identity;
259
260impl Permutation for Identity {
261    fn forward(i: u32) -> u32 {
262        i & 0x1F
263    }
264    fn inverse(i: u32) -> u32 {
265        i & 0x1F
266    }
267    fn is_self_dual() -> bool {
268        true
269    }
270}
271
272impl HasDual for Identity {
273    type Dual = Identity;
274}
275
276/// Composition of two permutations: apply P1 then P2.
277#[derive(Copy, Clone, Debug)]
278pub struct Compose<P1: Permutation, P2: Permutation>(PhantomData<(P1, P2)>);
279
280impl<P1: Permutation, P2: Permutation> Permutation for Compose<P1, P2> {
281    fn forward(i: u32) -> u32 {
282        P2::forward(P1::forward(i))
283    }
284    fn inverse(i: u32) -> u32 {
285        P1::inverse(P2::inverse(i))
286    }
287}
288
289impl<P1: Permutation + HasDual, P2: Permutation + HasDual> HasDual for Compose<P1, P2> {
290    type Dual = Compose<P2::Dual, P1::Dual>;
291}
292
293// Butterfly network type aliases
294pub type ButterflyStage0 = Xor<1>;
295pub type ButterflyStage1 = Xor<2>;
296pub type ButterflyStage2 = Xor<4>;
297pub type ButterflyStage3 = Xor<8>;
298pub type ButterflyStage4 = Xor<16>;
299
300pub type FullButterfly = Compose<
301    Compose<Compose<Compose<ButterflyStage0, ButterflyStage1>, ButterflyStage2>, ButterflyStage3>,
302    ButterflyStage4,
303>;
304
305/// Apply a permutation to an array of values.
306pub fn shuffle_by<T: Copy, P: Permutation>(values: [T; 32], _perm: P) -> [T; 32] {
307    let mut result = values;
308    for (i, slot) in result.iter_mut().enumerate() {
309        let src = P::inverse(i as u32) as usize;
310        *slot = values[src];
311    }
312    result
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_ballot_result_empty_mask() {
321        let result = BallotResult {
322            mask: Uniform::from_const(0),
323        };
324        assert_eq!(result.first_lane(), None);
325        assert_eq!(result.popcount().get(), 0);
326        assert!(!result.lane_voted(LaneId::new(0)).get());
327        assert!(!result.lane_voted(LaneId::new(31)).get());
328    }
329
330    #[test]
331    fn test_ballot_result() {
332        let result = BallotResult {
333            mask: Uniform::from_const(0b1010_1010),
334        };
335        assert!(!result.lane_voted(LaneId::new(0)).get());
336        assert!(result.lane_voted(LaneId::new(1)).get());
337        assert_eq!(result.popcount().get(), 4);
338        assert_eq!(result.first_lane(), Some(LaneId::new(1)));
339    }
340
341    #[test]
342    fn test_shuffle_only_on_all() {
343        let all: Warp<All> = Warp::new();
344        let data = PerLane::new(42i32);
345        let _shuffled = all.shuffle_xor(data, 1);
346        let _reduced = all.reduce_sum(data);
347    }
348
349    #[test]
350    fn test_shuffle_64bit_types() {
351        let all: Warp<All> = Warp::new();
352
353        // i64: two-pass shuffle on GPU, identity on CPU
354        let data_i64 = PerLane::new(0x0000_0001_0000_0002_i64);
355        let shuffled_i64 = all.shuffle_xor(data_i64, 1);
356        assert_eq!(shuffled_i64.get(), 0x0000_0001_0000_0002_i64);
357
358        // u64
359        let data_u64 = PerLane::new(u64::MAX);
360        let shuffled_u64 = all.shuffle_xor(data_u64, 1);
361        assert_eq!(shuffled_u64.get(), u64::MAX);
362
363        // f64: bit-preserving two-pass
364        #[allow(clippy::approx_constant)]
365        let data_f64 = PerLane::new(3.14159_f64);
366        let shuffled_f64 = all.shuffle_xor(data_f64, 1);
367        #[allow(clippy::approx_constant)]
368        let expected_f64 = 3.14159_f64;
369        assert_eq!(shuffled_f64.get(), expected_f64);
370
371        // Reduction works on 64-bit
372        let ones_i64 = PerLane::new(1_i64);
373        let sum = all.reduce_sum(ones_i64);
374        assert_eq!(sum.get(), 32_i64); // 1 + 1 + ... (5 XOR stages)
375    }
376
377    #[test]
378    fn test_xor_self_dual() {
379        assert!(Xor::<5>::is_self_dual());
380        for mask in 0..32u32 {
381            for lane in 0..32u32 {
382                let after_two = (((lane ^ mask) & 0x1F) ^ mask) & 0x1F;
383                assert_eq!(after_two, lane);
384            }
385        }
386    }
387
388    #[test]
389    fn test_rotate_duality() {
390        for lane in 0..32u32 {
391            let down_then_up = RotateUp::<1>::forward(RotateDown::<1>::forward(lane));
392            assert_eq!(down_then_up, lane);
393        }
394    }
395
396    #[test]
397    fn test_shuffle_roundtrip() {
398        let original: [i32; 32] = core::array::from_fn(|i| i as i32);
399        let shuffled = shuffle_by(original, Xor::<5>);
400        let unshuffled = shuffle_by(shuffled, Xor::<5>);
401        assert_eq!(unshuffled, original);
402    }
403
404    #[test]
405    fn test_butterfly_permutation() {
406        // Full butterfly: XOR with 1|2|4|8|16 = 31, so maps i → i ^ 31
407        for i in 0..32u32 {
408            assert_eq!(FullButterfly::forward(i), i ^ 31);
409        }
410    }
411
412    #[test]
413    fn test_compose_associative() {
414        for i in 0..32u32 {
415            let ab_c = Compose::<Compose<Xor<3>, Xor<5>>, Xor<7>>::forward(i);
416            let a_bc = Compose::<Xor<3>, Compose<Xor<5>, Xor<7>>>::forward(i);
417            assert_eq!(ab_c, a_bc);
418        }
419    }
420}