Skip to main content

warp_types/
tile.rs

1//! Cooperative Groups: thread block tiles with typed shuffle safety.
2//!
3//! NVIDIA Cooperative Groups (CUDA 9.0+) partition warps into tiles of
4//! 4, 8, 16, or 32 threads. Each tile supports collective operations
5//! (shuffle, ballot, reduce) within its own lanes.
6//!
7//! # Key Difference from Divergence
8//!
9//! Diverged sub-warps (`Warp<Even>`) have inactive lanes — shuffle is unsafe.
10//! Tiles are *partitions* — ALL threads within a tile are active by construction.
11//! Shuffle within a tile is always safe because every lane participates.
12//!
13//! ```text
14//! Warp (32 lanes)
15//! ├── Tile<16> (lanes 0-15)   ← shuffle safe within tile
16//! └── Tile<16> (lanes 16-31)  ← shuffle safe within tile
17//!     ├── Tile<8> (lanes 16-23)
18//!     └── Tile<8> (lanes 24-31)
19//! ```
20//!
21//! # Type System Guarantee
22//!
23//! `Tile<N>` has `shuffle_xor` for any N — because all N lanes participate.
24//! This is unlike `Warp<S>` where only `Warp<All>` has shuffle.
25//! The safety comes from the partition structure, not the active set.
26
27use crate::active_set::{sealed, All};
28use crate::data::PerLane;
29use crate::gpu::GpuShuffle;
30use crate::warp::Warp;
31use crate::GpuValue;
32use core::marker::PhantomData;
33
34/// A thread block tile of `SIZE` threads.
35///
36/// All threads within a tile are guaranteed active — shuffle is always safe.
37/// Created by partitioning a `Warp<All>` via `warp.tile::<N>()`.
38///
39/// # Supported Sizes
40///
41/// 4, 8, 16, 32 — matching NVIDIA's cooperative groups API.
42/// Only power-of-two sizes that divide 32 are valid.
43pub struct Tile<const SIZE: usize> {
44    _phantom: PhantomData<()>,
45}
46
47/// Marker trait for valid tile sizes (powers of 2 that divide 32).
48///
49/// Sealed — only implemented for Tile<4>, Tile<8>, Tile<16>, Tile<32>.
50/// External crates cannot implement this for arbitrary sizes.
51pub trait ValidTileSize: sealed::Sealed {
52    /// Mask for this tile within a warp (based on thread position).
53    const TILE_MASK: u32;
54}
55
56#[allow(private_interfaces)]
57impl sealed::Sealed for Tile<4> {
58    fn _sealed() -> sealed::SealToken {
59        sealed::SealToken
60    }
61}
62#[allow(private_interfaces)]
63impl sealed::Sealed for Tile<8> {
64    fn _sealed() -> sealed::SealToken {
65        sealed::SealToken
66    }
67}
68#[allow(private_interfaces)]
69impl sealed::Sealed for Tile<16> {
70    fn _sealed() -> sealed::SealToken {
71        sealed::SealToken
72    }
73}
74#[allow(private_interfaces)]
75impl sealed::Sealed for Tile<32> {
76    fn _sealed() -> sealed::SealToken {
77        sealed::SealToken
78    }
79}
80
81impl ValidTileSize for Tile<4> {
82    const TILE_MASK: u32 = 0xF; // 4 lanes
83}
84impl ValidTileSize for Tile<8> {
85    const TILE_MASK: u32 = 0xFF; // 8 lanes
86}
87impl ValidTileSize for Tile<16> {
88    const TILE_MASK: u32 = 0xFFFF; // 16 lanes
89}
90impl ValidTileSize for Tile<32> {
91    const TILE_MASK: u32 = 0xFFFFFFFF; // 32 lanes = full warp
92}
93
94impl Warp<All> {
95    /// Partition the warp into tiles of `SIZE` threads.
96    ///
97    /// Equivalent to `cg::tiled_partition<SIZE>(cg::this_thread_block())`.
98    ///
99    /// Each thread gets a `Tile<SIZE>` representing its local tile.
100    /// All tiles have exactly `SIZE` active lanes — shuffle is safe.
101    ///
102    /// ```
103    /// use warp_types::*;
104    /// use warp_types::tile::Tile;
105    ///
106    /// let warp: Warp<All> = Warp::kernel_entry();
107    /// let tile: Tile<16> = warp.tile();
108    /// // tile.shuffle_xor is available — all 16 lanes participate
109    /// let data = data::PerLane::new(42i32);
110    /// let _partner = tile.shuffle_xor(data, 1);
111    /// ```
112    ///
113    /// Tiles can only be created from `Warp<All>`:
114    ///
115    /// ```compile_fail
116    /// use warp_types::prelude::*;
117    /// let warp = Warp::kernel_entry();
118    /// let (evens, _odds) = warp.diverge_even_odd();
119    /// let _tile: Tile<16> = evens.tile(); // ERROR: method not found
120    /// ```
121    pub fn tile<const SIZE: usize>(&self) -> Tile<SIZE>
122    where
123        Tile<SIZE>: ValidTileSize,
124    {
125        Tile {
126            _phantom: PhantomData,
127        }
128    }
129}
130
131impl<const SIZE: usize> Tile<SIZE>
132where
133    Tile<SIZE>: ValidTileSize,
134{
135    /// Shuffle XOR within the tile.
136    ///
137    /// Each thread exchanges with the thread at `(thread_rank XOR mask)` within
138    /// the tile. Caller must ensure mask < SIZE (no automatic clamping).
139    ///
140    /// **Always safe**: all `SIZE` threads in the tile participate.
141    ///
142    /// On GPU: emits `shfl.sync.bfly.b32` with `c = ((32-SIZE)<<8)|0x1F`,
143    /// confining the shuffle to SIZE-lane segments.
144    pub fn shuffle_xor<T: GpuValue + GpuShuffle>(&self, data: PerLane<T>, mask: u32) -> PerLane<T> {
145        debug_assert!(
146            mask < SIZE as u32,
147            "shuffle_xor: mask {mask} >= tile SIZE {SIZE}"
148        );
149        PerLane::new(data.get().gpu_shfl_xor_width(mask, SIZE as u32))
150    }
151
152    /// Shuffle down within the tile (confined to tile-sized segments).
153    pub fn shuffle_down<T: GpuValue + GpuShuffle>(
154        &self,
155        data: PerLane<T>,
156        delta: u32,
157    ) -> PerLane<T> {
158        debug_assert!(
159            delta < SIZE as u32,
160            "shuffle_down: delta {delta} >= tile SIZE {SIZE}"
161        );
162        PerLane::new(data.get().gpu_shfl_down_width(delta, SIZE as u32))
163    }
164
165    /// Sum reduction across all tile lanes.
166    ///
167    /// Uses butterfly reduction with `log2(SIZE)` shuffle-XOR steps.
168    pub fn reduce_sum<T: GpuValue + GpuShuffle + core::ops::Add<Output = T>>(
169        &self,
170        data: PerLane<T>,
171    ) -> T {
172        let mut val = data.get();
173        let mut stride = 1u32;
174        while stride < SIZE as u32 {
175            val = val + val.gpu_shfl_xor_width(stride, SIZE as u32);
176            stride *= 2;
177        }
178        val
179    }
180
181    /// Inclusive prefix sum within the tile.
182    ///
183    /// **WARNING:** Not correct on any target. On CPU, `shfl_up` is identity,
184    /// so each stage doubles (result: val × SIZE). On GPU, lanes where
185    /// `lane_id < stride` get clamped (own value), doubling instead of
186    /// preserving. Needs `if lane_id >= stride` guard (requires `lane_id()`).
187    /// Retained for type-system demonstration.
188    pub fn inclusive_sum<T: GpuValue + GpuShuffle + core::ops::Add<Output = T>>(
189        &self,
190        data: PerLane<T>,
191    ) -> PerLane<T> {
192        let mut val = data.get();
193        let mut stride = 1u32;
194        while stride < SIZE as u32 {
195            let s = val.gpu_shfl_up_width(stride, SIZE as u32);
196            val = val + s;
197            stride *= 2;
198        }
199        PerLane::new(val)
200    }
201
202    /// Number of threads in this tile.
203    pub const fn size(&self) -> usize {
204        SIZE
205    }
206}
207
208// ============================================================================
209// Sub-partitioning: Tile<N> → Tile<N/2>, Tile<N/4>, etc.
210// ============================================================================
211
212impl Tile<32> {
213    /// Sub-partition into tiles of 16.
214    pub fn partition_16(&self) -> Tile<16> {
215        Tile {
216            _phantom: PhantomData,
217        }
218    }
219    /// Sub-partition into tiles of 8.
220    pub fn partition_8(&self) -> Tile<8> {
221        Tile {
222            _phantom: PhantomData,
223        }
224    }
225    /// Sub-partition into tiles of 4.
226    pub fn partition_4(&self) -> Tile<4> {
227        Tile {
228            _phantom: PhantomData,
229        }
230    }
231}
232
233impl Tile<16> {
234    /// Sub-partition into tiles of 8.
235    pub fn partition_8(&self) -> Tile<8> {
236        Tile {
237            _phantom: PhantomData,
238        }
239    }
240    /// Sub-partition into tiles of 4.
241    pub fn partition_4(&self) -> Tile<4> {
242        Tile {
243            _phantom: PhantomData,
244        }
245    }
246}
247
248impl Tile<8> {
249    /// Sub-partition into tiles of 4.
250    pub fn partition_4(&self) -> Tile<4> {
251        Tile {
252            _phantom: PhantomData,
253        }
254    }
255}
256
257// ============================================================================
258// Tests
259// ============================================================================
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::data::PerLane;
265
266    #[test]
267    fn test_tile_from_warp() {
268        let warp: Warp<All> = Warp::kernel_entry();
269        let tile32: Tile<32> = warp.tile();
270        let tile16: Tile<16> = warp.tile();
271        let tile8: Tile<8> = warp.tile();
272        let tile4: Tile<4> = warp.tile();
273
274        assert_eq!(tile32.size(), 32);
275        assert_eq!(tile16.size(), 16);
276        assert_eq!(tile8.size(), 8);
277        assert_eq!(tile4.size(), 4);
278    }
279
280    #[test]
281    fn test_tile_shuffle() {
282        let warp: Warp<All> = Warp::kernel_entry();
283        let tile: Tile<16> = warp.tile();
284        let data = PerLane::new(42i32);
285
286        // Shuffle within tile — always safe
287        let result = tile.shuffle_xor(data, 1);
288        assert_eq!(result.get(), 42); // CPU identity
289    }
290
291    #[test]
292    fn test_tile_reduce() {
293        let warp: Warp<All> = Warp::kernel_entry();
294        let tile: Tile<8> = warp.tile();
295        let data = PerLane::new(1i32);
296
297        // Reduce: 1 + 1 = 2, 2 + 2 = 4, 4 + 4 = 8 (3 stages for tile<8>)
298        let sum = tile.reduce_sum(data);
299        assert_eq!(sum, 8);
300    }
301
302    #[test]
303    fn test_tile_reduce_32() {
304        let warp: Warp<All> = Warp::kernel_entry();
305        let tile: Tile<32> = warp.tile();
306        let data = PerLane::new(1i32);
307        let sum = tile.reduce_sum(data);
308        assert_eq!(sum, 32);
309    }
310
311    #[test]
312    fn test_tile_reduce_4() {
313        let warp: Warp<All> = Warp::kernel_entry();
314        let tile: Tile<4> = warp.tile();
315        let data = PerLane::new(1i32);
316        let sum = tile.reduce_sum(data);
317        assert_eq!(sum, 4);
318    }
319
320    #[test]
321    fn test_tile_sub_partition() {
322        let warp: Warp<All> = Warp::kernel_entry();
323        let t32: Tile<32> = warp.tile();
324        let t16 = t32.partition_16();
325        let t8 = t16.partition_8();
326        let t4 = t8.partition_4();
327
328        assert_eq!(t16.size(), 16);
329        assert_eq!(t8.size(), 8);
330        assert_eq!(t4.size(), 4);
331    }
332
333    #[test]
334    fn test_tile_shuffle_64bit() {
335        let warp: Warp<All> = Warp::kernel_entry();
336        let tile: Tile<16> = warp.tile();
337        let data = PerLane::new(123456789_i64);
338
339        // 64-bit shuffle within tile — two-pass on GPU, identity on CPU
340        let result = tile.shuffle_xor(data, 1);
341        assert_eq!(result.get(), 123456789_i64);
342    }
343
344    #[test]
345    fn test_tile_inclusive_sum() {
346        let warp: Warp<All> = Warp::kernel_entry();
347        let tile: Tile<8> = warp.tile();
348        let data = PerLane::new(1i32);
349        let result = tile.inclusive_sum(data);
350        // CPU emulation: 1 + 1 = 2, 2 + 2 = 4, 4 + 4 = 8
351        assert_eq!(result.get(), 8);
352    }
353}