Skip to main content

warp_types/
platform.rs

1//! Platform abstraction for CPU/GPU unified targeting
2//!
3//! This module defines the `Platform` trait that allows the same algorithm
4//! code to run on either CPU SIMD or GPU warps.
5//!
6//! # Key Insight
7//!
8//! CPU SIMD (AVX-512, NEON) and GPU warps are structurally similar:
9//! - Both have "lanes" executing in lockstep
10//! - Both have shuffle/permute operations
11//! - Both benefit from uniform/varying tracking
12//!
13//! The `Platform` trait abstracts over these, enabling:
14//! - Same algorithm code for both targets
15//! - Fat binaries with CPU and GPU implementations
16//! - Runtime dispatch based on data size or availability
17
18use crate::GpuValue;
19
20/// Platform-specific SIMD vector type
21pub trait SimdVector<T: GpuValue>: Copy {
22    /// Number of lanes in this vector
23    const WIDTH: usize;
24
25    /// Create a vector with all lanes set to the same value
26    fn splat(value: T) -> Self;
27
28    /// Extract value from a specific lane
29    fn extract(self, lane: usize) -> T;
30
31    /// Insert value into a specific lane
32    fn insert(self, lane: usize, value: T) -> Self;
33}
34
35/// A computation platform (CPU SIMD or GPU warp)
36///
37/// This trait abstracts over the execution model, allowing algorithms
38/// to be written once and compiled for multiple targets.
39pub trait Platform: Copy + 'static {
40    /// Number of parallel lanes
41    const WIDTH: usize;
42
43    /// Platform name for debugging
44    const NAME: &'static str;
45
46    /// The vector type for this platform
47    type Vector<T: GpuValue>: SimdVector<T>;
48
49    /// Mask type for predicated operations
50    type Mask: Copy;
51
52    // === Core Operations ===
53
54    /// Broadcast a scalar to all lanes
55    fn broadcast<T: GpuValue>(value: T) -> Self::Vector<T> {
56        Self::Vector::splat(value)
57    }
58
59    /// Shuffle: each lane reads from source\[indices\[lane\]\]
60    fn shuffle<T: GpuValue>(source: Self::Vector<T>, indices: Self::Vector<u32>)
61        -> Self::Vector<T>;
62
63    /// Shuffle down: lane i reads from lane i+delta (with wrapping or clamping)
64    fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T>;
65
66    /// Shuffle XOR: lane i reads from lane i^mask
67    fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T>;
68
69    // === Reductions ===
70
71    /// Sum all lanes, result available in all lanes (uniform)
72    fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T;
73
74    /// Maximum across all lanes
75    fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T;
76
77    /// Minimum across all lanes
78    fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T;
79
80    // === Predicates ===
81
82    /// Ballot: collect per-lane bools into a mask
83    fn ballot(predicates: Self::Vector<bool>) -> Self::Mask;
84
85    /// All lanes true?
86    fn all(predicates: Self::Vector<bool>) -> bool;
87
88    /// Any lane true?
89    fn any(predicates: Self::Vector<bool>) -> bool;
90
91    /// Population count of a mask
92    fn mask_popcount(mask: Self::Mask) -> u32;
93}
94
95// ============================================================================
96// CPU SIMD Implementation (Portable)
97// ============================================================================
98
99/// Portable CPU SIMD platform
100///
101/// This uses scalar emulation for portability. In a real implementation,
102/// this would use std::simd or platform-specific intrinsics (AVX-512, NEON).
103#[derive(Copy, Clone, Debug)]
104pub struct CpuSimd<const WIDTH: usize>;
105
106/// Portable vector type (array-based)
107#[derive(Copy, Clone, Debug)]
108pub struct PortableVector<T: GpuValue, const WIDTH: usize> {
109    data: [T; WIDTH],
110}
111
112impl<T: GpuValue, const WIDTH: usize> SimdVector<T> for PortableVector<T, WIDTH> {
113    const WIDTH: usize = WIDTH;
114
115    fn splat(value: T) -> Self {
116        PortableVector {
117            data: [value; WIDTH],
118        }
119    }
120
121    fn extract(self, lane: usize) -> T {
122        debug_assert!(lane < WIDTH, "extract: lane {lane} >= WIDTH {WIDTH}");
123        self.data[lane]
124    }
125
126    fn insert(self, lane: usize, value: T) -> Self {
127        debug_assert!(lane < WIDTH, "insert: lane {lane} >= WIDTH {WIDTH}");
128        let mut result = self;
129        result.data[lane] = value;
130        result
131    }
132}
133
134impl<T: GpuValue, const WIDTH: usize> Default for PortableVector<T, WIDTH> {
135    fn default() -> Self {
136        PortableVector {
137            data: [T::default(); WIDTH],
138        }
139    }
140}
141
142impl<const WIDTH: usize> Platform for CpuSimd<WIDTH>
143where
144    [(); WIDTH]: Sized,
145{
146    const WIDTH: usize = WIDTH;
147    const NAME: &'static str = "CpuSimd";
148
149    type Vector<T: GpuValue> = PortableVector<T, WIDTH>;
150    type Mask = u64;
151
152    fn shuffle<T: GpuValue>(
153        source: Self::Vector<T>,
154        indices: Self::Vector<u32>,
155    ) -> Self::Vector<T> {
156        let mut result = PortableVector::default();
157        for i in 0..WIDTH {
158            let src_idx = indices.data[i] as usize % WIDTH;
159            result.data[i] = source.data[src_idx];
160        }
161        result
162    }
163
164    fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T> {
165        let mut result = PortableVector::default();
166        for i in 0..WIDTH {
167            let src_idx = (i + delta) % WIDTH;
168            result.data[i] = source.data[src_idx];
169        }
170        result
171    }
172
173    fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T> {
174        let mut result = PortableVector::default();
175        for i in 0..WIDTH {
176            let src_idx = (i ^ mask) % WIDTH;
177            result.data[i] = source.data[src_idx];
178        }
179        result
180    }
181
182    fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T {
183        const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
184        values.data.into_iter().reduce(|a, b| a + b).unwrap()
185    }
186
187    fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
188        const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
189        values.data.into_iter().max().unwrap()
190    }
191
192    fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
193        const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
194        values.data.into_iter().min().unwrap()
195    }
196
197    fn ballot(predicates: Self::Vector<bool>) -> Self::Mask {
198        // Mask is u64, so WIDTH must not exceed 64 bits.
199        const {
200            assert!(
201                WIDTH <= 64,
202                "CpuSimd<WIDTH>: ballot requires WIDTH <= 64 (u64 mask)"
203            )
204        };
205        let mut mask = 0u64;
206        for i in 0..WIDTH {
207            if predicates.data[i] {
208                mask |= 1u64 << i;
209            }
210        }
211        mask
212    }
213
214    fn all(predicates: Self::Vector<bool>) -> bool {
215        predicates.data.iter().all(|&b| b)
216    }
217
218    fn any(predicates: Self::Vector<bool>) -> bool {
219        predicates.data.iter().any(|&b| b)
220    }
221
222    fn mask_popcount(mask: Self::Mask) -> u32 {
223        mask.count_ones()
224    }
225}
226
227// ============================================================================
228// GPU Warp Implementation (Placeholder)
229// ============================================================================
230
231/// GPU warp platform (32 lanes for NVIDIA)
232///
233/// In a real implementation, this would lower to PTX intrinsics:
234/// - shuffle → __shfl_sync
235/// - reduce_sum → __reduce_add_sync or butterfly reduction
236/// - ballot → __ballot_sync
237#[derive(Copy, Clone, Debug)]
238pub struct GpuWarp32;
239
240/// GPU warp platform (64 lanes for AMD).
241///
242/// **Placeholder:** Does not implement `Platform` yet. AMD amdgcn inline
243/// assembly support is not stable in Rust. When available, this will mirror
244/// `GpuWarp32` with `WIDTH = 64` and `Mask = u64`.
245#[derive(Copy, Clone, Debug)]
246pub struct GpuWarp64;
247
248// For now, GPU platforms are just CPU emulation
249// In a real compiler, these would emit different IR
250
251impl Platform for GpuWarp32 {
252    const WIDTH: usize = 32;
253    const NAME: &'static str = "GpuWarp32";
254
255    type Vector<T: GpuValue> = PortableVector<T, 32>;
256    type Mask = u32;
257
258    fn shuffle<T: GpuValue>(
259        source: Self::Vector<T>,
260        indices: Self::Vector<u32>,
261    ) -> Self::Vector<T> {
262        // GPU shfl.sync.idx clamps: OOB indices read own lane (not % WIDTH).
263        let mut result = PortableVector::default();
264        for i in 0..32 {
265            let src_idx = indices.data[i] as usize;
266            result.data[i] = if src_idx < 32 {
267                source.data[src_idx]
268            } else {
269                source.data[i]
270            };
271        }
272        result
273    }
274
275    fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T> {
276        // GPU shfl.sync.down clamps: lanes where lane + delta >= WIDTH read
277        // their own value (not wrapped).  CpuSimd wraps, so override here.
278        let mut result = PortableVector::default();
279        for i in 0..32 {
280            let src_idx = i + delta;
281            result.data[i] = if src_idx < 32 {
282                source.data[src_idx]
283            } else {
284                source.data[i]
285            };
286        }
287        result
288    }
289
290    fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T> {
291        CpuSimd::<32>::shuffle_xor(source, mask)
292    }
293
294    fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T {
295        CpuSimd::<32>::reduce_sum(values)
296    }
297
298    fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
299        CpuSimd::<32>::reduce_max(values)
300    }
301
302    fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
303        CpuSimd::<32>::reduce_min(values)
304    }
305
306    fn ballot(predicates: Self::Vector<bool>) -> Self::Mask {
307        CpuSimd::<32>::ballot(predicates) as u32
308    }
309
310    fn all(predicates: Self::Vector<bool>) -> bool {
311        CpuSimd::<32>::all(predicates)
312    }
313
314    fn any(predicates: Self::Vector<bool>) -> bool {
315        CpuSimd::<32>::any(predicates)
316    }
317
318    fn mask_popcount(mask: Self::Mask) -> u32 {
319        mask.count_ones()
320    }
321}
322
323// ============================================================================
324// Generic Algorithms (using PortableVector)
325// ============================================================================
326
327/// Parallel reduction using butterfly pattern
328///
329/// This shows the key insight: the same algorithm works for CPU SIMD and GPU warps
330/// because both support shuffle_xor. The actual implementation would be
331/// specialized per platform for optimal codegen.
332pub fn butterfly_reduce_sum<const WIDTH: usize, T>(values: PortableVector<T, WIDTH>) -> T
333where
334    T: GpuValue + core::ops::Add<Output = T>,
335{
336    const {
337        assert!(
338            WIDTH.is_power_of_two(),
339            "butterfly_reduce_sum requires power-of-2 WIDTH"
340        )
341    };
342    let mut v = values;
343    let mut stride = 1;
344    while stride < WIDTH {
345        // XOR shuffle swaps elements at distance `stride`
346        let mut shuffled: PortableVector<T, WIDTH> = PortableVector::default();
347        for i in 0..WIDTH {
348            shuffled.data[i] = v.data[i ^ stride];
349        }
350        // Add corresponding elements
351        for i in 0..WIDTH {
352            v.data[i] = v.data[i] + shuffled.data[i];
353        }
354        stride *= 2;
355    }
356    v.data[0]
357}
358
359/// Prefix sum (inclusive scan)
360pub fn prefix_sum<const WIDTH: usize, T>(
361    values: PortableVector<T, WIDTH>,
362) -> PortableVector<T, WIDTH>
363where
364    T: GpuValue + core::ops::Add<Output = T>,
365{
366    let mut v = values;
367    let mut stride = 1;
368    while stride < WIDTH {
369        let mut result = v;
370        for i in stride..WIDTH {
371            result.data[i] = v.data[i] + v.data[i - stride];
372        }
373        v = result;
374        stride *= 2;
375    }
376    v
377}
378
379// ============================================================================
380// Tests
381// ============================================================================
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_cpu_simd_broadcast() {
389        let v = CpuSimd::<8>::broadcast(42i32);
390        for i in 0..8 {
391            assert_eq!(v.extract(i), 42);
392        }
393    }
394
395    #[test]
396    fn test_cpu_simd_shuffle_xor() {
397        // Create vector [0, 1, 2, 3, 4, 5, 6, 7]
398        let mut v = PortableVector::<i32, 8>::default();
399        for i in 0..8 {
400            v = v.insert(i, i as i32);
401        }
402
403        // XOR with 1 swaps adjacent pairs: [1, 0, 3, 2, 5, 4, 7, 6]
404        let shuffled = CpuSimd::<8>::shuffle_xor(v, 1);
405        assert_eq!(shuffled.extract(0), 1);
406        assert_eq!(shuffled.extract(1), 0);
407        assert_eq!(shuffled.extract(2), 3);
408        assert_eq!(shuffled.extract(3), 2);
409    }
410
411    #[test]
412    fn test_cpu_simd_reduce_sum() {
413        let mut v = PortableVector::<i32, 8>::default();
414        for i in 0..8 {
415            v = v.insert(i, (i + 1) as i32);
416        }
417        // Sum of 1..=8 = 36
418        assert_eq!(CpuSimd::<8>::reduce_sum(v), 36);
419    }
420
421    #[test]
422    fn test_butterfly_reduce() {
423        let mut v = PortableVector::<i32, 8>::default();
424        for i in 0..8 {
425            v = v.insert(i, (i + 1) as i32);
426        }
427        let sum = butterfly_reduce_sum::<8, i32>(v);
428        assert_eq!(sum, 36);
429    }
430
431    #[test]
432    fn test_ballot() {
433        // Odd lanes are true
434        let mut predicates = PortableVector::<bool, 8>::default();
435        for i in 0..8 {
436            predicates = predicates.insert(i, i % 2 == 1);
437        }
438        let mask = CpuSimd::<8>::ballot(predicates);
439        // Binary: 10101010 = 0xAA
440        assert_eq!(mask, 0b10101010);
441        assert_eq!(CpuSimd::<8>::mask_popcount(mask), 4);
442    }
443
444    #[test]
445    fn test_gpu_warp32_emulation() {
446        let v = GpuWarp32::broadcast(7i32);
447        assert_eq!(v.extract(0), 7);
448        assert_eq!(v.extract(31), 7);
449
450        let mut values = PortableVector::<i32, 32>::default();
451        for i in 0..32 {
452            values = values.insert(i, 1);
453        }
454        assert_eq!(GpuWarp32::reduce_sum(values), 32);
455    }
456
457    #[test]
458    fn test_prefix_sum() {
459        let mut v = PortableVector::<i32, 8>::default();
460        for i in 0..8 {
461            v = v.insert(i, 1); // All ones
462        }
463        let result = prefix_sum::<8, i32>(v);
464        // Should be [1, 2, 3, 4, 5, 6, 7, 8]
465        for i in 0..8 {
466            assert_eq!(result.extract(i), (i + 1) as i32);
467        }
468    }
469}