warp_types/tile.rs
1//! Cooperative Groups: thread block tiles with typed shuffle safety.
2//!
3//! NVIDIA Cooperative Groups (CUDA 9.0+) partition warps into tiles of
4//! 4, 8, 16, or 32 threads. Each tile supports collective operations
5//! (shuffle, ballot, reduce) within its own lanes.
6//!
7//! # Key Difference from Divergence
8//!
9//! Diverged sub-warps (`Warp<Even>`) have inactive lanes — shuffle is unsafe.
10//! Tiles are *partitions* — ALL threads within a tile are active by construction.
11//! Shuffle within a tile is always safe because every lane participates.
12//!
13//! ```text
14//! Warp (32 lanes)
15//! ├── Tile<16> (lanes 0-15) ← shuffle safe within tile
16//! └── Tile<16> (lanes 16-31) ← shuffle safe within tile
17//! ├── Tile<8> (lanes 16-23)
18//! └── Tile<8> (lanes 24-31)
19//! ```
20//!
21//! # Type System Guarantee
22//!
23//! `Tile<N>` has `shuffle_xor` for any N — because all N lanes participate.
24//! This is unlike `Warp<S>` where only `Warp<All>` has shuffle.
25//! The safety comes from the partition structure, not the active set.
26
27use crate::active_set::{sealed, All};
28use crate::data::PerLane;
29use crate::gpu::GpuShuffle;
30use crate::warp::Warp;
31use crate::GpuValue;
32use core::marker::PhantomData;
33
34/// A thread block tile of `SIZE` threads.
35///
36/// All threads within a tile are guaranteed active — shuffle is always safe.
37/// Created by partitioning a `Warp<All>` via `warp.tile::<N>()`.
38///
39/// # Supported Sizes
40///
41/// 4, 8, 16, 32 — matching NVIDIA's cooperative groups API.
42/// Only power-of-two sizes that divide 32 are valid.
43pub struct Tile<const SIZE: usize> {
44 _phantom: PhantomData<()>,
45}
46
47/// Marker trait for valid tile sizes (powers of 2 that divide 32).
48///
49/// Sealed — only implemented for Tile<4>, Tile<8>, Tile<16>, Tile<32>.
50/// External crates cannot implement this for arbitrary sizes.
51pub trait ValidTileSize: sealed::Sealed {
52 /// Mask for this tile within a warp (based on thread position).
53 const TILE_MASK: u32;
54}
55
56#[allow(private_interfaces)]
57impl sealed::Sealed for Tile<4> {
58 fn _sealed() -> sealed::SealToken {
59 sealed::SealToken
60 }
61}
62#[allow(private_interfaces)]
63impl sealed::Sealed for Tile<8> {
64 fn _sealed() -> sealed::SealToken {
65 sealed::SealToken
66 }
67}
68#[allow(private_interfaces)]
69impl sealed::Sealed for Tile<16> {
70 fn _sealed() -> sealed::SealToken {
71 sealed::SealToken
72 }
73}
74#[allow(private_interfaces)]
75impl sealed::Sealed for Tile<32> {
76 fn _sealed() -> sealed::SealToken {
77 sealed::SealToken
78 }
79}
80
81impl ValidTileSize for Tile<4> {
82 const TILE_MASK: u32 = 0xF; // 4 lanes
83}
84impl ValidTileSize for Tile<8> {
85 const TILE_MASK: u32 = 0xFF; // 8 lanes
86}
87impl ValidTileSize for Tile<16> {
88 const TILE_MASK: u32 = 0xFFFF; // 16 lanes
89}
90impl ValidTileSize for Tile<32> {
91 const TILE_MASK: u32 = 0xFFFFFFFF; // 32 lanes = full warp
92}
93
94impl Warp<All> {
95 /// Partition the warp into tiles of `SIZE` threads.
96 ///
97 /// Equivalent to `cg::tiled_partition<SIZE>(cg::this_thread_block())`.
98 ///
99 /// Each thread gets a `Tile<SIZE>` representing its local tile.
100 /// All tiles have exactly `SIZE` active lanes — shuffle is safe.
101 ///
102 /// ```
103 /// use warp_types::*;
104 /// use warp_types::tile::Tile;
105 ///
106 /// let warp: Warp<All> = Warp::kernel_entry();
107 /// let tile: Tile<16> = warp.tile();
108 /// // tile.shuffle_xor is available — all 16 lanes participate
109 /// let data = data::PerLane::new(42i32);
110 /// let _partner = tile.shuffle_xor(data, 1);
111 /// ```
112 ///
113 /// Tiles can only be created from `Warp<All>`:
114 ///
115 /// ```compile_fail
116 /// use warp_types::prelude::*;
117 /// let warp = Warp::kernel_entry();
118 /// let (evens, _odds) = warp.diverge_even_odd();
119 /// let _tile: Tile<16> = evens.tile(); // ERROR: method not found
120 /// ```
121 pub fn tile<const SIZE: usize>(&self) -> Tile<SIZE>
122 where
123 Tile<SIZE>: ValidTileSize,
124 {
125 Tile {
126 _phantom: PhantomData,
127 }
128 }
129}
130
131impl<const SIZE: usize> Tile<SIZE>
132where
133 Tile<SIZE>: ValidTileSize,
134{
135 /// Shuffle XOR within the tile.
136 ///
137 /// Each thread exchanges with the thread at `(thread_rank XOR mask)` within
138 /// the tile. Caller must ensure mask < SIZE (no automatic clamping).
139 ///
140 /// **Always safe**: all `SIZE` threads in the tile participate.
141 ///
142 /// On GPU: emits `shfl.sync.bfly.b32` with `c = ((32-SIZE)<<8)|0x1F`,
143 /// confining the shuffle to SIZE-lane segments.
144 pub fn shuffle_xor<T: GpuValue + GpuShuffle>(&self, data: PerLane<T>, mask: u32) -> PerLane<T> {
145 debug_assert!(
146 mask < SIZE as u32,
147 "shuffle_xor: mask {mask} >= tile SIZE {SIZE}"
148 );
149 PerLane::new(data.get().gpu_shfl_xor_width(mask, SIZE as u32))
150 }
151
152 /// Shuffle down within the tile (confined to tile-sized segments).
153 pub fn shuffle_down<T: GpuValue + GpuShuffle>(
154 &self,
155 data: PerLane<T>,
156 delta: u32,
157 ) -> PerLane<T> {
158 debug_assert!(
159 delta < SIZE as u32,
160 "shuffle_down: delta {delta} >= tile SIZE {SIZE}"
161 );
162 PerLane::new(data.get().gpu_shfl_down_width(delta, SIZE as u32))
163 }
164
165 /// Sum reduction across all tile lanes.
166 ///
167 /// Uses butterfly reduction with `log2(SIZE)` shuffle-XOR steps.
168 pub fn reduce_sum<T: GpuValue + GpuShuffle + core::ops::Add<Output = T>>(
169 &self,
170 data: PerLane<T>,
171 ) -> T {
172 let mut val = data.get();
173 let mut stride = 1u32;
174 while stride < SIZE as u32 {
175 val = val + val.gpu_shfl_xor_width(stride, SIZE as u32);
176 stride *= 2;
177 }
178 val
179 }
180
181 /// Inclusive prefix sum within the tile.
182 ///
183 /// **WARNING:** Not correct on any target. On CPU, `shfl_up` is identity,
184 /// so each stage doubles (result: val × SIZE). On GPU, lanes where
185 /// `lane_id < stride` get clamped (own value), doubling instead of
186 /// preserving. Needs `if lane_id >= stride` guard (requires `lane_id()`).
187 /// Retained for type-system demonstration.
188 pub fn inclusive_sum<T: GpuValue + GpuShuffle + core::ops::Add<Output = T>>(
189 &self,
190 data: PerLane<T>,
191 ) -> PerLane<T> {
192 let mut val = data.get();
193 let mut stride = 1u32;
194 while stride < SIZE as u32 {
195 let s = val.gpu_shfl_up_width(stride, SIZE as u32);
196 val = val + s;
197 stride *= 2;
198 }
199 PerLane::new(val)
200 }
201
202 /// Number of threads in this tile.
203 pub const fn size(&self) -> usize {
204 SIZE
205 }
206}
207
208// ============================================================================
209// Sub-partitioning: Tile<N> → Tile<N/2>, Tile<N/4>, etc.
210// ============================================================================
211
212impl Tile<32> {
213 /// Sub-partition into tiles of 16.
214 pub fn partition_16(&self) -> Tile<16> {
215 Tile {
216 _phantom: PhantomData,
217 }
218 }
219 /// Sub-partition into tiles of 8.
220 pub fn partition_8(&self) -> Tile<8> {
221 Tile {
222 _phantom: PhantomData,
223 }
224 }
225 /// Sub-partition into tiles of 4.
226 pub fn partition_4(&self) -> Tile<4> {
227 Tile {
228 _phantom: PhantomData,
229 }
230 }
231}
232
233impl Tile<16> {
234 /// Sub-partition into tiles of 8.
235 pub fn partition_8(&self) -> Tile<8> {
236 Tile {
237 _phantom: PhantomData,
238 }
239 }
240 /// Sub-partition into tiles of 4.
241 pub fn partition_4(&self) -> Tile<4> {
242 Tile {
243 _phantom: PhantomData,
244 }
245 }
246}
247
248impl Tile<8> {
249 /// Sub-partition into tiles of 4.
250 pub fn partition_4(&self) -> Tile<4> {
251 Tile {
252 _phantom: PhantomData,
253 }
254 }
255}
256
257// ============================================================================
258// Tests
259// ============================================================================
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::data::PerLane;
265
266 #[test]
267 fn test_tile_from_warp() {
268 let warp: Warp<All> = Warp::kernel_entry();
269 let tile32: Tile<32> = warp.tile();
270 let tile16: Tile<16> = warp.tile();
271 let tile8: Tile<8> = warp.tile();
272 let tile4: Tile<4> = warp.tile();
273
274 assert_eq!(tile32.size(), 32);
275 assert_eq!(tile16.size(), 16);
276 assert_eq!(tile8.size(), 8);
277 assert_eq!(tile4.size(), 4);
278 }
279
280 #[test]
281 fn test_tile_shuffle() {
282 let warp: Warp<All> = Warp::kernel_entry();
283 let tile: Tile<16> = warp.tile();
284 let data = PerLane::new(42i32);
285
286 // Shuffle within tile — always safe
287 let result = tile.shuffle_xor(data, 1);
288 assert_eq!(result.get(), 42); // CPU identity
289 }
290
291 #[test]
292 fn test_tile_reduce() {
293 let warp: Warp<All> = Warp::kernel_entry();
294 let tile: Tile<8> = warp.tile();
295 let data = PerLane::new(1i32);
296
297 // Reduce: 1 + 1 = 2, 2 + 2 = 4, 4 + 4 = 8 (3 stages for tile<8>)
298 let sum = tile.reduce_sum(data);
299 assert_eq!(sum, 8);
300 }
301
302 #[test]
303 fn test_tile_reduce_32() {
304 let warp: Warp<All> = Warp::kernel_entry();
305 let tile: Tile<32> = warp.tile();
306 let data = PerLane::new(1i32);
307 let sum = tile.reduce_sum(data);
308 assert_eq!(sum, 32);
309 }
310
311 #[test]
312 fn test_tile_reduce_4() {
313 let warp: Warp<All> = Warp::kernel_entry();
314 let tile: Tile<4> = warp.tile();
315 let data = PerLane::new(1i32);
316 let sum = tile.reduce_sum(data);
317 assert_eq!(sum, 4);
318 }
319
320 #[test]
321 fn test_tile_sub_partition() {
322 let warp: Warp<All> = Warp::kernel_entry();
323 let t32: Tile<32> = warp.tile();
324 let t16 = t32.partition_16();
325 let t8 = t16.partition_8();
326 let t4 = t8.partition_4();
327
328 assert_eq!(t16.size(), 16);
329 assert_eq!(t8.size(), 8);
330 assert_eq!(t4.size(), 4);
331 }
332
333 #[test]
334 fn test_tile_shuffle_64bit() {
335 let warp: Warp<All> = Warp::kernel_entry();
336 let tile: Tile<16> = warp.tile();
337 let data = PerLane::new(123456789_i64);
338
339 // 64-bit shuffle within tile — two-pass on GPU, identity on CPU
340 let result = tile.shuffle_xor(data, 1);
341 assert_eq!(result.get(), 123456789_i64);
342 }
343
344 #[test]
345 fn test_tile_inclusive_sum() {
346 let warp: Warp<All> = Warp::kernel_entry();
347 let tile: Tile<8> = warp.tile();
348 let data = PerLane::new(1i32);
349 let result = tile.inclusive_sum(data);
350 // CPU emulation: 1 + 1 = 2, 2 + 2 = 4, 4 + 4 = 8
351 assert_eq!(result.get(), 8);
352 }
353}