Skip to main content

warp_types/
platform.rs

1//! Platform abstraction for CPU/GPU unified targeting
2//!
3//! This module defines the `Platform` trait that allows the same algorithm
4//! code to run on either CPU SIMD or GPU warps.
5//!
6//! # Key Insight
7//!
8//! CPU SIMD (AVX-512, NEON) and GPU warps are structurally similar:
9//! - Both have "lanes" executing in lockstep
10//! - Both have shuffle/permute operations
11//! - Both benefit from uniform/varying tracking
12//!
13//! The `Platform` trait abstracts over these, enabling:
14//! - Same algorithm code for both targets
15//! - Fat binaries with CPU and GPU implementations
16//! - Runtime dispatch based on data size or availability
17
18use crate::GpuValue;
19
20/// Platform-specific SIMD vector type
21pub trait SimdVector<T: GpuValue>: Copy {
22    /// Number of lanes in this vector
23    const WIDTH: usize;
24
25    /// Create a vector with all lanes set to the same value
26    fn splat(value: T) -> Self;
27
28    /// Extract value from a specific lane
29    fn extract(self, lane: usize) -> T;
30
31    /// Insert value into a specific lane
32    fn insert(self, lane: usize, value: T) -> Self;
33}
34
35/// A computation platform (CPU SIMD or GPU warp)
36///
37/// This trait abstracts over the execution model, allowing algorithms
38/// to be written once and compiled for multiple targets.
39pub trait Platform: Copy + 'static {
40    /// Number of parallel lanes
41    const WIDTH: usize;
42
43    /// Platform name for debugging
44    const NAME: &'static str;
45
46    /// The vector type for this platform
47    type Vector<T: GpuValue>: SimdVector<T>;
48
49    /// Mask type for predicated operations
50    type Mask: Copy;
51
52    // === Core Operations ===
53
54    /// Broadcast a scalar to all lanes
55    fn broadcast<T: GpuValue>(value: T) -> Self::Vector<T> {
56        Self::Vector::splat(value)
57    }
58
59    /// Shuffle: each lane reads from source\[indices\[lane\]\]
60    fn shuffle<T: GpuValue>(source: Self::Vector<T>, indices: Self::Vector<u32>)
61        -> Self::Vector<T>;
62
63    /// Shuffle down: lane i reads from lane i+delta (with wrapping or clamping)
64    fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T>;
65
66    /// Shuffle XOR: lane i reads from lane i^mask
67    fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T>;
68
69    // === Reductions ===
70
71    /// Sum all lanes, result available in all lanes (uniform)
72    fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T;
73
74    /// Maximum across all lanes
75    fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T;
76
77    /// Minimum across all lanes
78    fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T;
79
80    // === Predicates ===
81
82    /// Ballot: collect per-lane bools into a mask
83    fn ballot(predicates: Self::Vector<bool>) -> Self::Mask;
84
85    /// All lanes true?
86    fn all(predicates: Self::Vector<bool>) -> bool;
87
88    /// Any lane true?
89    fn any(predicates: Self::Vector<bool>) -> bool;
90
91    /// Population count of a mask
92    fn mask_popcount(mask: Self::Mask) -> u32;
93}
94
95// ============================================================================
96// CPU SIMD Implementation (Portable)
97// ============================================================================
98
99/// Portable CPU SIMD platform
100///
101/// This uses scalar emulation for portability. In a real implementation,
102/// this would use std::simd or platform-specific intrinsics (AVX-512, NEON).
103#[derive(Copy, Clone, Debug)]
104pub struct CpuSimd<const WIDTH: usize>;
105
106/// Portable vector type (array-based)
107#[derive(Copy, Clone, Debug)]
108pub struct PortableVector<T: GpuValue, const WIDTH: usize> {
109    data: [T; WIDTH],
110}
111
112impl<T: GpuValue, const WIDTH: usize> SimdVector<T> for PortableVector<T, WIDTH> {
113    const WIDTH: usize = WIDTH;
114
115    fn splat(value: T) -> Self {
116        PortableVector {
117            data: [value; WIDTH],
118        }
119    }
120
121    fn extract(self, lane: usize) -> T {
122        assert!(lane < WIDTH, "extract: lane {lane} >= WIDTH {WIDTH}");
123        self.data[lane]
124    }
125
126    fn insert(self, lane: usize, value: T) -> Self {
127        assert!(lane < WIDTH, "insert: lane {lane} >= WIDTH {WIDTH}");
128        let mut result = self;
129        result.data[lane] = value;
130        result
131    }
132}
133
134impl<T: GpuValue, const WIDTH: usize> Default for PortableVector<T, WIDTH> {
135    fn default() -> Self {
136        PortableVector {
137            data: [T::default(); WIDTH],
138        }
139    }
140}
141
142impl<const WIDTH: usize> Platform for CpuSimd<WIDTH>
143where
144    [(); WIDTH]: Sized,
145{
146    const WIDTH: usize = WIDTH;
147    const NAME: &'static str = "CpuSimd";
148
149    type Vector<T: GpuValue> = PortableVector<T, WIDTH>;
150    type Mask = u64;
151
152    fn shuffle<T: GpuValue>(
153        source: Self::Vector<T>,
154        indices: Self::Vector<u32>,
155    ) -> Self::Vector<T> {
156        let mut result = PortableVector::default();
157        for i in 0..WIDTH {
158            let src_idx = indices.data[i] as usize % WIDTH;
159            result.data[i] = source.data[src_idx];
160        }
161        result
162    }
163
164    fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T> {
165        let mut result = PortableVector::default();
166        for i in 0..WIDTH {
167            // Clamp: lanes where i+delta >= WIDTH read their own value (GPU semantics).
168            let src_idx = if i + delta < WIDTH { i + delta } else { i };
169            result.data[i] = source.data[src_idx];
170        }
171        result
172    }
173
174    fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T> {
175        let mut result = PortableVector::default();
176        for i in 0..WIDTH {
177            let src_idx = (i ^ mask) % WIDTH;
178            result.data[i] = source.data[src_idx];
179        }
180        result
181    }
182
183    fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T {
184        const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
185        values.data.into_iter().reduce(|a, b| a + b).unwrap()
186    }
187
188    fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
189        const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
190        values.data.into_iter().max().unwrap()
191    }
192
193    fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
194        const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
195        values.data.into_iter().min().unwrap()
196    }
197
198    fn ballot(predicates: Self::Vector<bool>) -> Self::Mask {
199        // Mask is u64, so WIDTH must not exceed 64 bits.
200        const {
201            assert!(
202                WIDTH <= 64,
203                "CpuSimd<WIDTH>: ballot requires WIDTH <= 64 (u64 mask)"
204            )
205        };
206        let mut mask = 0u64;
207        for i in 0..WIDTH {
208            if predicates.data[i] {
209                mask |= 1u64 << i;
210            }
211        }
212        mask
213    }
214
215    fn all(predicates: Self::Vector<bool>) -> bool {
216        predicates.data.iter().all(|&b| b)
217    }
218
219    fn any(predicates: Self::Vector<bool>) -> bool {
220        predicates.data.iter().any(|&b| b)
221    }
222
223    fn mask_popcount(mask: Self::Mask) -> u32 {
224        mask.count_ones()
225    }
226}
227
228// ============================================================================
229// GPU Warp Implementation (Placeholder)
230// ============================================================================
231
232/// GPU warp platform (32 lanes for NVIDIA)
233///
234/// In a real implementation, this would lower to PTX intrinsics:
235/// - shuffle → __shfl_sync
236/// - reduce_sum → __reduce_add_sync or butterfly reduction
237/// - ballot → __ballot_sync
238#[derive(Copy, Clone, Debug)]
239pub struct GpuWarp32;
240
241/// GPU warp platform (64 lanes for AMD).
242///
243/// **Placeholder:** Does not implement `Platform` yet. AMD amdgcn inline
244/// assembly support is not stable in Rust. When available, this will mirror
245/// `GpuWarp32` with `WIDTH = 64` and `Mask = u64`.
246#[derive(Copy, Clone, Debug)]
247pub struct GpuWarp64;
248
249// For now, GPU platforms are just CPU emulation
250// In a real compiler, these would emit different IR
251
252impl Platform for GpuWarp32 {
253    const WIDTH: usize = 32;
254    const NAME: &'static str = "GpuWarp32";
255
256    type Vector<T: GpuValue> = PortableVector<T, 32>;
257    type Mask = u32;
258
259    fn shuffle<T: GpuValue>(
260        source: Self::Vector<T>,
261        indices: Self::Vector<u32>,
262    ) -> Self::Vector<T> {
263        // GPU shfl.sync.idx wraps: src_lane & 0x1F (hardware-verified on RTX 4000 Ada).
264        let mut result = PortableVector::default();
265        for i in 0..32 {
266            let src_idx = indices.data[i] as usize % 32;
267            result.data[i] = source.data[src_idx];
268        }
269        result
270    }
271
272    fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T> {
273        // GPU shfl.sync.down clamps: lanes where lane + delta >= WIDTH read
274        // their own value (not wrapped).  CpuSimd also clamps (same behavior).
275        let mut result = PortableVector::default();
276        for i in 0..32 {
277            let src_idx = i + delta;
278            result.data[i] = if src_idx < 32 {
279                source.data[src_idx]
280            } else {
281                source.data[i]
282            };
283        }
284        result
285    }
286
287    fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T> {
288        CpuSimd::<32>::shuffle_xor(source, mask)
289    }
290
291    fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T {
292        CpuSimd::<32>::reduce_sum(values)
293    }
294
295    fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
296        CpuSimd::<32>::reduce_max(values)
297    }
298
299    fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
300        CpuSimd::<32>::reduce_min(values)
301    }
302
303    fn ballot(predicates: Self::Vector<bool>) -> Self::Mask {
304        CpuSimd::<32>::ballot(predicates) as u32
305    }
306
307    fn all(predicates: Self::Vector<bool>) -> bool {
308        CpuSimd::<32>::all(predicates)
309    }
310
311    fn any(predicates: Self::Vector<bool>) -> bool {
312        CpuSimd::<32>::any(predicates)
313    }
314
315    fn mask_popcount(mask: Self::Mask) -> u32 {
316        mask.count_ones()
317    }
318}
319
320// ============================================================================
321// Generic Algorithms (using PortableVector)
322// ============================================================================
323
324/// Parallel reduction using butterfly pattern
325///
326/// This shows the key insight: the same algorithm works for CPU SIMD and GPU warps
327/// because both support shuffle_xor. The actual implementation would be
328/// specialized per platform for optimal codegen.
329pub fn butterfly_reduce_sum<const WIDTH: usize, T>(values: PortableVector<T, WIDTH>) -> T
330where
331    T: GpuValue + core::ops::Add<Output = T>,
332{
333    const {
334        assert!(
335            WIDTH.is_power_of_two(),
336            "butterfly_reduce_sum requires power-of-2 WIDTH"
337        )
338    };
339    let mut v = values;
340    let mut stride = 1;
341    while stride < WIDTH {
342        // XOR shuffle swaps elements at distance `stride`
343        let mut shuffled: PortableVector<T, WIDTH> = PortableVector::default();
344        for i in 0..WIDTH {
345            shuffled.data[i] = v.data[i ^ stride];
346        }
347        // Add corresponding elements
348        for i in 0..WIDTH {
349            v.data[i] = v.data[i] + shuffled.data[i];
350        }
351        stride *= 2;
352    }
353    v.data[0]
354}
355
356/// Prefix sum (inclusive scan)
357pub fn prefix_sum<const WIDTH: usize, T>(
358    values: PortableVector<T, WIDTH>,
359) -> PortableVector<T, WIDTH>
360where
361    T: GpuValue + core::ops::Add<Output = T>,
362{
363    let mut v = values;
364    let mut stride = 1;
365    while stride < WIDTH {
366        let mut result = v;
367        for i in stride..WIDTH {
368            result.data[i] = v.data[i] + v.data[i - stride];
369        }
370        v = result;
371        stride *= 2;
372    }
373    v
374}
375
376// ============================================================================
377// Tests
378// ============================================================================
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_cpu_simd_broadcast() {
386        let v = CpuSimd::<8>::broadcast(42i32);
387        for i in 0..8 {
388            assert_eq!(v.extract(i), 42);
389        }
390    }
391
392    #[test]
393    fn test_cpu_simd_shuffle_xor() {
394        // Create vector [0, 1, 2, 3, 4, 5, 6, 7]
395        let mut v = PortableVector::<i32, 8>::default();
396        for i in 0..8 {
397            v = v.insert(i, i as i32);
398        }
399
400        // XOR with 1 swaps adjacent pairs: [1, 0, 3, 2, 5, 4, 7, 6]
401        let shuffled = CpuSimd::<8>::shuffle_xor(v, 1);
402        assert_eq!(shuffled.extract(0), 1);
403        assert_eq!(shuffled.extract(1), 0);
404        assert_eq!(shuffled.extract(2), 3);
405        assert_eq!(shuffled.extract(3), 2);
406    }
407
408    #[test]
409    fn test_cpu_simd_reduce_sum() {
410        let mut v = PortableVector::<i32, 8>::default();
411        for i in 0..8 {
412            v = v.insert(i, (i + 1) as i32);
413        }
414        // Sum of 1..=8 = 36
415        assert_eq!(CpuSimd::<8>::reduce_sum(v), 36);
416    }
417
418    #[test]
419    fn test_butterfly_reduce() {
420        let mut v = PortableVector::<i32, 8>::default();
421        for i in 0..8 {
422            v = v.insert(i, (i + 1) as i32);
423        }
424        let sum = butterfly_reduce_sum::<8, i32>(v);
425        assert_eq!(sum, 36);
426    }
427
428    #[test]
429    fn test_ballot() {
430        // Odd lanes are true
431        let mut predicates = PortableVector::<bool, 8>::default();
432        for i in 0..8 {
433            predicates = predicates.insert(i, i % 2 == 1);
434        }
435        let mask = CpuSimd::<8>::ballot(predicates);
436        // Binary: 10101010 = 0xAA
437        assert_eq!(mask, 0b10101010);
438        assert_eq!(CpuSimd::<8>::mask_popcount(mask), 4);
439    }
440
441    #[test]
442    fn test_gpu_warp32_emulation() {
443        let v = GpuWarp32::broadcast(7i32);
444        assert_eq!(v.extract(0), 7);
445        assert_eq!(v.extract(31), 7);
446
447        let mut values = PortableVector::<i32, 32>::default();
448        for i in 0..32 {
449            values = values.insert(i, 1);
450        }
451        assert_eq!(GpuWarp32::reduce_sum(values), 32);
452    }
453
454    #[test]
455    fn test_prefix_sum() {
456        let mut v = PortableVector::<i32, 8>::default();
457        for i in 0..8 {
458            v = v.insert(i, 1); // All ones
459        }
460        let result = prefix_sum::<8, i32>(v);
461        // Should be [1, 2, 3, 4, 5, 6, 7, 8]
462        for i in 0..8 {
463            assert_eq!(result.extract(i), (i + 1) as i32);
464        }
465    }
466}