1use crate::active_set::{sealed, All};
28use crate::data::PerLane;
29use crate::gpu::GpuShuffle;
30use crate::warp::Warp;
31use crate::GpuValue;
32use core::marker::PhantomData;
33
34#[must_use = "a Tile represents a partitioned warp — dropping it silently discards the partition"]
44pub struct Tile<const SIZE: usize> {
45 _phantom: PhantomData<()>,
46}
47
48pub trait ValidTileSize: sealed::Sealed {
53 const TILE_MASK: u32;
55}
56
57#[allow(private_interfaces)]
58impl sealed::Sealed for Tile<4> {
59 fn _sealed() -> sealed::SealToken {
60 sealed::SealToken
61 }
62}
63#[allow(private_interfaces)]
64impl sealed::Sealed for Tile<8> {
65 fn _sealed() -> sealed::SealToken {
66 sealed::SealToken
67 }
68}
69#[allow(private_interfaces)]
70impl sealed::Sealed for Tile<16> {
71 fn _sealed() -> sealed::SealToken {
72 sealed::SealToken
73 }
74}
75#[allow(private_interfaces)]
76impl sealed::Sealed for Tile<32> {
77 fn _sealed() -> sealed::SealToken {
78 sealed::SealToken
79 }
80}
81
82impl ValidTileSize for Tile<4> {
83 const TILE_MASK: u32 = 0xF; }
85impl ValidTileSize for Tile<8> {
86 const TILE_MASK: u32 = 0xFF; }
88impl ValidTileSize for Tile<16> {
89 const TILE_MASK: u32 = 0xFFFF; }
91impl ValidTileSize for Tile<32> {
92 const TILE_MASK: u32 = 0xFFFFFFFF; }
94
95impl Warp<All> {
96 pub fn tile<const SIZE: usize>(&self) -> Tile<SIZE>
123 where
124 Tile<SIZE>: ValidTileSize,
125 {
126 Tile {
127 _phantom: PhantomData,
128 }
129 }
130}
131
132impl<const SIZE: usize> Tile<SIZE>
133where
134 Tile<SIZE>: ValidTileSize,
135{
136 pub fn shuffle_xor<T: GpuValue + GpuShuffle>(&self, data: PerLane<T>, mask: u32) -> PerLane<T> {
146 assert!(
147 mask < SIZE as u32,
148 "shuffle_xor: mask {mask} >= tile SIZE {SIZE}"
149 );
150 PerLane::new(data.get().gpu_shfl_xor_width(mask, SIZE as u32))
151 }
152
153 pub fn shuffle_down<T: GpuValue + GpuShuffle>(
155 &self,
156 data: PerLane<T>,
157 delta: u32,
158 ) -> PerLane<T> {
159 assert!(
160 delta < SIZE as u32,
161 "shuffle_down: delta {delta} >= tile SIZE {SIZE}"
162 );
163 PerLane::new(data.get().gpu_shfl_down_width(delta, SIZE as u32))
164 }
165
166 pub fn reduce_sum<T: GpuValue + GpuShuffle + core::ops::Add<Output = T>>(
170 &self,
171 data: PerLane<T>,
172 ) -> T {
173 let mut val = data.get();
174 let mut stride = 1u32;
175 while stride < SIZE as u32 {
176 val = val + val.gpu_shfl_xor_width(stride, SIZE as u32);
177 stride *= 2;
178 }
179 val
180 }
181
182 #[deprecated(
190 note = "Not correct on any target — Hillis-Steele without lane_id guard. Use SimWarp for tested scan."
191 )]
192 pub fn inclusive_sum<T: GpuValue + GpuShuffle + core::ops::Add<Output = T>>(
193 &self,
194 data: PerLane<T>,
195 ) -> PerLane<T> {
196 let mut val = data.get();
197 let mut stride = 1u32;
198 while stride < SIZE as u32 {
199 let s = val.gpu_shfl_up_width(stride, SIZE as u32);
200 val = val + s;
201 stride *= 2;
202 }
203 PerLane::new(val)
204 }
205
206 pub const fn size(&self) -> usize {
208 SIZE
209 }
210}
211
212impl Tile<32> {
217 pub fn partition_16(&self) -> Tile<16> {
219 Tile {
220 _phantom: PhantomData,
221 }
222 }
223 pub fn partition_8(&self) -> Tile<8> {
225 Tile {
226 _phantom: PhantomData,
227 }
228 }
229 pub fn partition_4(&self) -> Tile<4> {
231 Tile {
232 _phantom: PhantomData,
233 }
234 }
235}
236
237impl Tile<16> {
238 pub fn partition_8(&self) -> Tile<8> {
240 Tile {
241 _phantom: PhantomData,
242 }
243 }
244 pub fn partition_4(&self) -> Tile<4> {
246 Tile {
247 _phantom: PhantomData,
248 }
249 }
250}
251
252impl Tile<8> {
253 pub fn partition_4(&self) -> Tile<4> {
255 Tile {
256 _phantom: PhantomData,
257 }
258 }
259}
260
261#[cfg(test)]
266mod tests {
267 use super::*;
268 use crate::data::PerLane;
269
270 #[test]
271 fn test_tile_from_warp() {
272 let warp: Warp<All> = Warp::kernel_entry();
273 let tile32: Tile<32> = warp.tile();
274 let tile16: Tile<16> = warp.tile();
275 let tile8: Tile<8> = warp.tile();
276 let tile4: Tile<4> = warp.tile();
277
278 assert_eq!(tile32.size(), 32);
279 assert_eq!(tile16.size(), 16);
280 assert_eq!(tile8.size(), 8);
281 assert_eq!(tile4.size(), 4);
282 }
283
284 #[test]
285 fn test_tile_shuffle() {
286 let warp: Warp<All> = Warp::kernel_entry();
287 let tile: Tile<16> = warp.tile();
288 let data = PerLane::new(42i32);
289
290 let result = tile.shuffle_xor(data, 1);
292 assert_eq!(result.get(), 42); }
294
295 #[test]
296 fn test_tile_reduce() {
297 let warp: Warp<All> = Warp::kernel_entry();
298 let tile: Tile<8> = warp.tile();
299 let data = PerLane::new(1i32);
300
301 let sum = tile.reduce_sum(data);
303 assert_eq!(sum, 8);
304 }
305
306 #[test]
307 fn test_tile_reduce_32() {
308 let warp: Warp<All> = Warp::kernel_entry();
309 let tile: Tile<32> = warp.tile();
310 let data = PerLane::new(1i32);
311 let sum = tile.reduce_sum(data);
312 assert_eq!(sum, 32);
313 }
314
315 #[test]
316 fn test_tile_reduce_4() {
317 let warp: Warp<All> = Warp::kernel_entry();
318 let tile: Tile<4> = warp.tile();
319 let data = PerLane::new(1i32);
320 let sum = tile.reduce_sum(data);
321 assert_eq!(sum, 4);
322 }
323
324 #[test]
325 fn test_tile_sub_partition() {
326 let warp: Warp<All> = Warp::kernel_entry();
327 let t32: Tile<32> = warp.tile();
328 let t16 = t32.partition_16();
329 let t8 = t16.partition_8();
330 let t4 = t8.partition_4();
331
332 assert_eq!(t16.size(), 16);
333 assert_eq!(t8.size(), 8);
334 assert_eq!(t4.size(), 4);
335 }
336
337 #[test]
338 fn test_tile_shuffle_64bit() {
339 let warp: Warp<All> = Warp::kernel_entry();
340 let tile: Tile<16> = warp.tile();
341 let data = PerLane::new(123456789_i64);
342
343 let result = tile.shuffle_xor(data, 1);
345 assert_eq!(result.get(), 123456789_i64);
346 }
347
348 #[test]
349 #[allow(deprecated)]
350 fn test_tile_inclusive_sum() {
351 let warp: Warp<All> = Warp::kernel_entry();
352 let tile: Tile<8> = warp.tile();
353 let data = PerLane::new(1i32);
354 let result = tile.inclusive_sum(data);
355 assert_eq!(result.get(), 8);
357 }
358}