Skip to main content

warp_types/
tile.rs

1//! Cooperative Groups: thread block tiles with typed shuffle safety.
2//!
3//! NVIDIA Cooperative Groups (CUDA 9.0+) partition warps into tiles of
4//! 4, 8, 16, or 32 threads. Each tile supports collective operations
5//! (shuffle, ballot, reduce) within its own lanes.
6//!
7//! # Key Difference from Divergence
8//!
9//! Diverged sub-warps (`Warp<Even>`) have inactive lanes — shuffle is unsafe.
10//! Tiles are *partitions* — ALL threads within a tile are active by construction.
11//! Shuffle within a tile is always safe because every lane participates.
12//!
13//! ```text
14//! Warp (32 lanes)
15//! ├── Tile<16> (lanes 0-15)   ← shuffle safe within tile
16//! └── Tile<16> (lanes 16-31)  ← shuffle safe within tile
17//!     ├── Tile<8> (lanes 16-23)
18//!     └── Tile<8> (lanes 24-31)
19//! ```
20//!
21//! # Type System Guarantee
22//!
23//! `Tile<N>` has `shuffle_xor` for any N — because all N lanes participate.
24//! This is unlike `Warp<S>` where only `Warp<All>` has shuffle.
25//! The safety comes from the partition structure, not the active set.
26
27use crate::active_set::{sealed, All};
28use crate::data::PerLane;
29use crate::gpu::GpuShuffle;
30use crate::warp::Warp;
31use crate::GpuValue;
32use core::marker::PhantomData;
33
34/// A thread block tile of `SIZE` threads.
35///
36/// All threads within a tile are guaranteed active — shuffle is always safe.
37/// Created by partitioning a `Warp<All>` via `warp.tile::<N>()`.
38///
39/// # Supported Sizes
40///
41/// 4, 8, 16, 32 — matching NVIDIA's cooperative groups API.
42/// Only power-of-two sizes that divide 32 are valid.
43#[must_use = "a Tile represents a partitioned warp — dropping it silently discards the partition"]
44pub struct Tile<const SIZE: usize> {
45    _phantom: PhantomData<()>,
46}
47
48/// Marker trait for valid tile sizes (powers of 2 that divide 32).
49///
50/// Sealed — only implemented for Tile<4>, Tile<8>, Tile<16>, Tile<32>.
51/// External crates cannot implement this for arbitrary sizes.
52pub trait ValidTileSize: sealed::Sealed {
53    /// Mask for this tile within a warp (based on thread position).
54    const TILE_MASK: u32;
55}
56
57#[allow(private_interfaces)]
58impl sealed::Sealed for Tile<4> {
59    fn _sealed() -> sealed::SealToken {
60        sealed::SealToken
61    }
62}
63#[allow(private_interfaces)]
64impl sealed::Sealed for Tile<8> {
65    fn _sealed() -> sealed::SealToken {
66        sealed::SealToken
67    }
68}
69#[allow(private_interfaces)]
70impl sealed::Sealed for Tile<16> {
71    fn _sealed() -> sealed::SealToken {
72        sealed::SealToken
73    }
74}
75#[allow(private_interfaces)]
76impl sealed::Sealed for Tile<32> {
77    fn _sealed() -> sealed::SealToken {
78        sealed::SealToken
79    }
80}
81
82impl ValidTileSize for Tile<4> {
83    const TILE_MASK: u32 = 0xF; // 4 lanes
84}
85impl ValidTileSize for Tile<8> {
86    const TILE_MASK: u32 = 0xFF; // 8 lanes
87}
88impl ValidTileSize for Tile<16> {
89    const TILE_MASK: u32 = 0xFFFF; // 16 lanes
90}
91impl ValidTileSize for Tile<32> {
92    const TILE_MASK: u32 = 0xFFFFFFFF; // 32 lanes = full warp
93}
94
95impl Warp<All> {
96    /// Partition the warp into tiles of `SIZE` threads.
97    ///
98    /// Equivalent to `cg::tiled_partition<SIZE>(cg::this_thread_block())`.
99    ///
100    /// Each thread gets a `Tile<SIZE>` representing its local tile.
101    /// All tiles have exactly `SIZE` active lanes — shuffle is safe.
102    ///
103    /// ```
104    /// use warp_types::*;
105    /// use warp_types::tile::Tile;
106    ///
107    /// let warp: Warp<All> = Warp::kernel_entry();
108    /// let tile: Tile<16> = warp.tile();
109    /// // tile.shuffle_xor is available — all 16 lanes participate
110    /// let data = data::PerLane::new(42i32);
111    /// let _partner = tile.shuffle_xor(data, 1);
112    /// ```
113    ///
114    /// Tiles can only be created from `Warp<All>`:
115    ///
116    /// ```compile_fail
117    /// use warp_types::prelude::*;
118    /// let warp = Warp::kernel_entry();
119    /// let (evens, _odds) = warp.diverge_even_odd();
120    /// let _tile: Tile<16> = evens.tile(); // ERROR: method not found
121    /// ```
122    pub fn tile<const SIZE: usize>(&self) -> Tile<SIZE>
123    where
124        Tile<SIZE>: ValidTileSize,
125    {
126        Tile {
127            _phantom: PhantomData,
128        }
129    }
130}
131
132impl<const SIZE: usize> Tile<SIZE>
133where
134    Tile<SIZE>: ValidTileSize,
135{
136    /// Shuffle XOR within the tile.
137    ///
138    /// Each thread exchanges with the thread at `(thread_rank XOR mask)` within
139    /// the tile. Caller must ensure mask < SIZE (no automatic clamping).
140    ///
141    /// **Always safe**: all `SIZE` threads in the tile participate.
142    ///
143    /// On GPU: emits `shfl.sync.bfly.b32` with `c = ((32-SIZE)<<8)|0x1F`,
144    /// confining the shuffle to SIZE-lane segments.
145    pub fn shuffle_xor<T: GpuValue + GpuShuffle>(&self, data: PerLane<T>, mask: u32) -> PerLane<T> {
146        assert!(
147            mask < SIZE as u32,
148            "shuffle_xor: mask {mask} >= tile SIZE {SIZE}"
149        );
150        PerLane::new(data.get().gpu_shfl_xor_width(mask, SIZE as u32))
151    }
152
153    /// Shuffle down within the tile (confined to tile-sized segments).
154    pub fn shuffle_down<T: GpuValue + GpuShuffle>(
155        &self,
156        data: PerLane<T>,
157        delta: u32,
158    ) -> PerLane<T> {
159        assert!(
160            delta < SIZE as u32,
161            "shuffle_down: delta {delta} >= tile SIZE {SIZE}"
162        );
163        PerLane::new(data.get().gpu_shfl_down_width(delta, SIZE as u32))
164    }
165
166    /// Sum reduction across all tile lanes.
167    ///
168    /// Uses butterfly reduction with `log2(SIZE)` shuffle-XOR steps.
169    pub fn reduce_sum<T: GpuValue + GpuShuffle + core::ops::Add<Output = T>>(
170        &self,
171        data: PerLane<T>,
172    ) -> T {
173        let mut val = data.get();
174        let mut stride = 1u32;
175        while stride < SIZE as u32 {
176            val = val + val.gpu_shfl_xor_width(stride, SIZE as u32);
177            stride *= 2;
178        }
179        val
180    }
181
182    /// Inclusive prefix sum within the tile.
183    ///
184    /// **WARNING:** Not correct on any target. On CPU, `shfl_up` is identity,
185    /// so each stage doubles (result: val × SIZE). On GPU, lanes where
186    /// `lane_id < stride` get clamped (own value), doubling instead of
187    /// preserving. Needs `if lane_id >= stride` guard (requires `lane_id()`).
188    /// Retained for type-system demonstration.
189    #[deprecated(
190        note = "Not correct on any target — Hillis-Steele without lane_id guard. Use SimWarp for tested scan."
191    )]
192    pub fn inclusive_sum<T: GpuValue + GpuShuffle + core::ops::Add<Output = T>>(
193        &self,
194        data: PerLane<T>,
195    ) -> PerLane<T> {
196        let mut val = data.get();
197        let mut stride = 1u32;
198        while stride < SIZE as u32 {
199            let s = val.gpu_shfl_up_width(stride, SIZE as u32);
200            val = val + s;
201            stride *= 2;
202        }
203        PerLane::new(val)
204    }
205
206    /// Number of threads in this tile.
207    pub const fn size(&self) -> usize {
208        SIZE
209    }
210}
211
212// ============================================================================
213// Sub-partitioning: Tile<N> → Tile<N/2>, Tile<N/4>, etc.
214// ============================================================================
215
216impl Tile<32> {
217    /// Sub-partition into tiles of 16.
218    pub fn partition_16(&self) -> Tile<16> {
219        Tile {
220            _phantom: PhantomData,
221        }
222    }
223    /// Sub-partition into tiles of 8.
224    pub fn partition_8(&self) -> Tile<8> {
225        Tile {
226            _phantom: PhantomData,
227        }
228    }
229    /// Sub-partition into tiles of 4.
230    pub fn partition_4(&self) -> Tile<4> {
231        Tile {
232            _phantom: PhantomData,
233        }
234    }
235}
236
237impl Tile<16> {
238    /// Sub-partition into tiles of 8.
239    pub fn partition_8(&self) -> Tile<8> {
240        Tile {
241            _phantom: PhantomData,
242        }
243    }
244    /// Sub-partition into tiles of 4.
245    pub fn partition_4(&self) -> Tile<4> {
246        Tile {
247            _phantom: PhantomData,
248        }
249    }
250}
251
252impl Tile<8> {
253    /// Sub-partition into tiles of 4.
254    pub fn partition_4(&self) -> Tile<4> {
255        Tile {
256            _phantom: PhantomData,
257        }
258    }
259}
260
261// ============================================================================
262// Tests
263// ============================================================================
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::data::PerLane;
269
270    #[test]
271    fn test_tile_from_warp() {
272        let warp: Warp<All> = Warp::kernel_entry();
273        let tile32: Tile<32> = warp.tile();
274        let tile16: Tile<16> = warp.tile();
275        let tile8: Tile<8> = warp.tile();
276        let tile4: Tile<4> = warp.tile();
277
278        assert_eq!(tile32.size(), 32);
279        assert_eq!(tile16.size(), 16);
280        assert_eq!(tile8.size(), 8);
281        assert_eq!(tile4.size(), 4);
282    }
283
284    #[test]
285    fn test_tile_shuffle() {
286        let warp: Warp<All> = Warp::kernel_entry();
287        let tile: Tile<16> = warp.tile();
288        let data = PerLane::new(42i32);
289
290        // Shuffle within tile — always safe
291        let result = tile.shuffle_xor(data, 1);
292        assert_eq!(result.get(), 42); // CPU identity
293    }
294
295    #[test]
296    fn test_tile_reduce() {
297        let warp: Warp<All> = Warp::kernel_entry();
298        let tile: Tile<8> = warp.tile();
299        let data = PerLane::new(1i32);
300
301        // Reduce: 1 + 1 = 2, 2 + 2 = 4, 4 + 4 = 8 (3 stages for tile<8>)
302        let sum = tile.reduce_sum(data);
303        assert_eq!(sum, 8);
304    }
305
306    #[test]
307    fn test_tile_reduce_32() {
308        let warp: Warp<All> = Warp::kernel_entry();
309        let tile: Tile<32> = warp.tile();
310        let data = PerLane::new(1i32);
311        let sum = tile.reduce_sum(data);
312        assert_eq!(sum, 32);
313    }
314
315    #[test]
316    fn test_tile_reduce_4() {
317        let warp: Warp<All> = Warp::kernel_entry();
318        let tile: Tile<4> = warp.tile();
319        let data = PerLane::new(1i32);
320        let sum = tile.reduce_sum(data);
321        assert_eq!(sum, 4);
322    }
323
324    #[test]
325    fn test_tile_sub_partition() {
326        let warp: Warp<All> = Warp::kernel_entry();
327        let t32: Tile<32> = warp.tile();
328        let t16 = t32.partition_16();
329        let t8 = t16.partition_8();
330        let t4 = t8.partition_4();
331
332        assert_eq!(t16.size(), 16);
333        assert_eq!(t8.size(), 8);
334        assert_eq!(t4.size(), 4);
335    }
336
337    #[test]
338    fn test_tile_shuffle_64bit() {
339        let warp: Warp<All> = Warp::kernel_entry();
340        let tile: Tile<16> = warp.tile();
341        let data = PerLane::new(123456789_i64);
342
343        // 64-bit shuffle within tile — two-pass on GPU, identity on CPU
344        let result = tile.shuffle_xor(data, 1);
345        assert_eq!(result.get(), 123456789_i64);
346    }
347
348    #[test]
349    #[allow(deprecated)]
350    fn test_tile_inclusive_sum() {
351        let warp: Warp<All> = Warp::kernel_entry();
352        let tile: Tile<8> = warp.tile();
353        let data = PerLane::new(1i32);
354        let result = tile.inclusive_sum(data);
355        // CPU emulation: 1 + 1 = 2, 2 + 2 = 4, 4 + 4 = 8
356        assert_eq!(result.get(), 8);
357    }
358}