1use crate::active_set::All;
11use crate::data::{LaneId, PerLane, Uniform};
12use crate::warp::Warp;
13use crate::GpuValue;
14use core::marker::PhantomData;
15
16#[derive(Clone, Copy, Debug)]
25pub struct BallotResult {
26 mask: Uniform<u32>,
27}
28
29impl BallotResult {
30 pub fn from_mask(mask: Uniform<u32>) -> Self {
32 BallotResult { mask }
33 }
34
35 pub fn mask(self) -> Uniform<u32> {
36 self.mask
37 }
38
39 pub fn lane_voted(self, lane: LaneId) -> Uniform<bool> {
40 let id = lane.get();
41 if id >= 32 {
44 return Uniform::from_const(false);
45 }
46 Uniform::from_const((self.mask.get() & (1u32 << id)) != 0)
47 }
48
49 pub fn popcount(self) -> Uniform<u32> {
50 Uniform::from_const(self.mask.get().count_ones())
51 }
52
53 pub fn first_lane(self) -> Option<LaneId> {
54 let tz = self.mask.get().trailing_zeros();
55 if tz < 32 {
56 Some(LaneId::new(tz as u8))
57 } else {
58 None
59 }
60 }
61}
62
63#[diagnostic::on_unimplemented(
74 message = "shuffle requires all lanes active, but `{Self}` may have inactive lanes",
75 label = "this warp may be diverged — shuffle needs Warp<All>",
76 note = "after diverge_even_odd(), call merge(evens, odds) to get Warp<All> back, then shuffle"
77)]
78pub trait ShuffleSafe {}
79
80impl ShuffleSafe for Warp<All> {}
81
82impl Warp<All> {
87 pub fn shuffle_xor<T: GpuValue + crate::gpu::GpuShuffle>(
94 &self,
95 data: PerLane<T>,
96 mask: u32,
97 ) -> PerLane<T> {
98 PerLane::new(data.get().gpu_shfl_xor(mask))
99 }
100
101 pub fn shuffle_down<T: GpuValue + crate::gpu::GpuShuffle>(
106 &self,
107 data: PerLane<T>,
108 delta: u32,
109 ) -> PerLane<T> {
110 PerLane::new(data.get().gpu_shfl_down(delta))
111 }
112
113 pub fn reduce_sum<T: GpuValue + crate::gpu::GpuShuffle + core::ops::Add<Output = T>>(
121 &self,
122 data: PerLane<T>,
123 ) -> Uniform<T> {
124 let mut val = data.get();
125 val = val + val.gpu_shfl_xor(16);
126 val = val + val.gpu_shfl_xor(8);
127 val = val + val.gpu_shfl_xor(4);
128 val = val + val.gpu_shfl_xor(2);
129 val = val + val.gpu_shfl_xor(1);
130 Uniform::from_const(val)
131 }
132
133 pub fn ballot(&self, predicate: PerLane<bool>) -> BallotResult {
143 let mask = if predicate.get() { 1u32 } else { 0u32 };
145 BallotResult::from_mask(Uniform::from_const(mask))
146 }
147
148 pub fn broadcast<T: GpuValue>(&self, value: T) -> PerLane<T> {
150 PerLane::new(value)
151 }
152
153 pub fn shuffle_xor_raw<T: GpuValue + crate::gpu::GpuShuffle>(&self, val: T, mask: u32) -> T {
158 val.gpu_shfl_xor(mask)
159 }
160
161 pub fn shuffle_down_raw<T: GpuValue + crate::gpu::GpuShuffle>(&self, val: T, delta: u32) -> T {
163 val.gpu_shfl_down(delta)
164 }
165}
166
167pub trait Permutation: Copy + Clone {
173 fn forward(i: u32) -> u32;
175 fn inverse(i: u32) -> u32;
177 fn is_self_dual() -> bool {
179 (0..32).all(|i| Self::forward(i) == Self::inverse(i))
180 }
181}
182
183pub trait HasDual: Permutation {
185 type Dual: Permutation;
186}
187
188#[derive(Copy, Clone, Debug)]
192pub struct Xor<const MASK: u32>;
193
194impl<const MASK: u32> Permutation for Xor<MASK> {
195 fn forward(i: u32) -> u32 {
196 (i ^ MASK) & 0x1F
197 }
198 fn inverse(i: u32) -> u32 {
199 (i ^ MASK) & 0x1F
200 }
201 fn is_self_dual() -> bool {
202 true
203 }
204}
205
206impl<const MASK: u32> HasDual for Xor<MASK> {
207 type Dual = Xor<MASK>;
208}
209
210#[derive(Copy, Clone, Debug)]
216pub struct RotateDown<const DELTA: u32>;
217
218#[derive(Copy, Clone, Debug)]
222pub struct RotateUp<const DELTA: u32>;
223
224impl<const DELTA: u32> Permutation for RotateDown<DELTA> {
225 fn forward(i: u32) -> u32 {
226 (i + 32 - (DELTA & 0x1F)) & 0x1F
227 }
228 fn inverse(i: u32) -> u32 {
229 (i + (DELTA & 0x1F)) & 0x1F
230 }
231 fn is_self_dual() -> bool {
232 (DELTA & 0x1F) == 0 || (DELTA & 0x1F) == 16
233 }
234}
235
236impl<const DELTA: u32> Permutation for RotateUp<DELTA> {
237 fn forward(i: u32) -> u32 {
238 (i + (DELTA & 0x1F)) & 0x1F
239 }
240 fn inverse(i: u32) -> u32 {
241 (i + 32 - (DELTA & 0x1F)) & 0x1F
242 }
243 fn is_self_dual() -> bool {
244 (DELTA & 0x1F) == 0 || (DELTA & 0x1F) == 16
245 }
246}
247
248impl<const DELTA: u32> HasDual for RotateDown<DELTA> {
249 type Dual = RotateUp<DELTA>;
250}
251
252impl<const DELTA: u32> HasDual for RotateUp<DELTA> {
253 type Dual = RotateDown<DELTA>;
254}
255
256#[derive(Copy, Clone, Debug)]
258pub struct Identity;
259
260impl Permutation for Identity {
261 fn forward(i: u32) -> u32 {
262 i & 0x1F
263 }
264 fn inverse(i: u32) -> u32 {
265 i & 0x1F
266 }
267 fn is_self_dual() -> bool {
268 true
269 }
270}
271
272impl HasDual for Identity {
273 type Dual = Identity;
274}
275
276#[derive(Copy, Clone, Debug)]
278pub struct Compose<P1: Permutation, P2: Permutation>(PhantomData<(P1, P2)>);
279
280impl<P1: Permutation, P2: Permutation> Permutation for Compose<P1, P2> {
281 fn forward(i: u32) -> u32 {
282 P2::forward(P1::forward(i))
283 }
284 fn inverse(i: u32) -> u32 {
285 P1::inverse(P2::inverse(i))
286 }
287}
288
289impl<P1: Permutation + HasDual, P2: Permutation + HasDual> HasDual for Compose<P1, P2> {
290 type Dual = Compose<P2::Dual, P1::Dual>;
291}
292
293pub type ButterflyStage0 = Xor<1>;
295pub type ButterflyStage1 = Xor<2>;
296pub type ButterflyStage2 = Xor<4>;
297pub type ButterflyStage3 = Xor<8>;
298pub type ButterflyStage4 = Xor<16>;
299
300pub type FullButterfly = Compose<
301 Compose<Compose<Compose<ButterflyStage0, ButterflyStage1>, ButterflyStage2>, ButterflyStage3>,
302 ButterflyStage4,
303>;
304
305pub fn shuffle_by<T: Copy, P: Permutation>(values: [T; 32], _perm: P) -> [T; 32] {
307 let mut result = values;
308 for (i, slot) in result.iter_mut().enumerate() {
309 let src = P::inverse(i as u32) as usize;
310 *slot = values[src];
311 }
312 result
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_ballot_result_empty_mask() {
321 let result = BallotResult {
322 mask: Uniform::from_const(0),
323 };
324 assert_eq!(result.first_lane(), None);
325 assert_eq!(result.popcount().get(), 0);
326 assert!(!result.lane_voted(LaneId::new(0)).get());
327 assert!(!result.lane_voted(LaneId::new(31)).get());
328 }
329
330 #[test]
331 fn test_ballot_result() {
332 let result = BallotResult {
333 mask: Uniform::from_const(0b1010_1010),
334 };
335 assert!(!result.lane_voted(LaneId::new(0)).get());
336 assert!(result.lane_voted(LaneId::new(1)).get());
337 assert_eq!(result.popcount().get(), 4);
338 assert_eq!(result.first_lane(), Some(LaneId::new(1)));
339 }
340
341 #[test]
342 fn test_shuffle_only_on_all() {
343 let all: Warp<All> = Warp::new();
344 let data = PerLane::new(42i32);
345 let _shuffled = all.shuffle_xor(data, 1);
346 let _reduced = all.reduce_sum(data);
347 }
348
349 #[test]
350 fn test_shuffle_64bit_types() {
351 let all: Warp<All> = Warp::new();
352
353 let data_i64 = PerLane::new(0x0000_0001_0000_0002_i64);
355 let shuffled_i64 = all.shuffle_xor(data_i64, 1);
356 assert_eq!(shuffled_i64.get(), 0x0000_0001_0000_0002_i64);
357
358 let data_u64 = PerLane::new(u64::MAX);
360 let shuffled_u64 = all.shuffle_xor(data_u64, 1);
361 assert_eq!(shuffled_u64.get(), u64::MAX);
362
363 #[allow(clippy::approx_constant)]
365 let data_f64 = PerLane::new(3.14159_f64);
366 let shuffled_f64 = all.shuffle_xor(data_f64, 1);
367 #[allow(clippy::approx_constant)]
368 let expected_f64 = 3.14159_f64;
369 assert_eq!(shuffled_f64.get(), expected_f64);
370
371 let ones_i64 = PerLane::new(1_i64);
373 let sum = all.reduce_sum(ones_i64);
374 assert_eq!(sum.get(), 32_i64); }
376
377 #[test]
378 fn test_xor_self_dual() {
379 assert!(Xor::<5>::is_self_dual());
380 for mask in 0..32u32 {
381 for lane in 0..32u32 {
382 let after_two = (((lane ^ mask) & 0x1F) ^ mask) & 0x1F;
383 assert_eq!(after_two, lane);
384 }
385 }
386 }
387
388 #[test]
389 fn test_rotate_duality() {
390 for lane in 0..32u32 {
391 let down_then_up = RotateUp::<1>::forward(RotateDown::<1>::forward(lane));
392 assert_eq!(down_then_up, lane);
393 }
394 }
395
396 #[test]
397 fn test_shuffle_roundtrip() {
398 let original: [i32; 32] = core::array::from_fn(|i| i as i32);
399 let shuffled = shuffle_by(original, Xor::<5>);
400 let unshuffled = shuffle_by(shuffled, Xor::<5>);
401 assert_eq!(unshuffled, original);
402 }
403
404 #[test]
405 fn test_butterfly_permutation() {
406 for i in 0..32u32 {
408 assert_eq!(FullButterfly::forward(i), i ^ 31);
409 }
410 }
411
412 #[test]
413 fn test_compose_associative() {
414 for i in 0..32u32 {
415 let ab_c = Compose::<Compose<Xor<3>, Xor<5>>, Xor<7>>::forward(i);
416 let a_bc = Compose::<Xor<3>, Compose<Xor<5>, Xor<7>>>::forward(i);
417 assert_eq!(ab_c, a_bc);
418 }
419 }
420}