1use crate::GpuValue;
19
20pub trait SimdVector<T: GpuValue>: Copy {
22 const WIDTH: usize;
24
25 fn splat(value: T) -> Self;
27
28 fn extract(self, lane: usize) -> T;
30
31 fn insert(self, lane: usize, value: T) -> Self;
33}
34
35pub trait Platform: Copy + 'static {
40 const WIDTH: usize;
42
43 const NAME: &'static str;
45
46 type Vector<T: GpuValue>: SimdVector<T>;
48
49 type Mask: Copy;
51
52 fn broadcast<T: GpuValue>(value: T) -> Self::Vector<T> {
56 Self::Vector::splat(value)
57 }
58
59 fn shuffle<T: GpuValue>(source: Self::Vector<T>, indices: Self::Vector<u32>)
61 -> Self::Vector<T>;
62
63 fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T>;
65
66 fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T>;
68
69 fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T;
73
74 fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T;
76
77 fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T;
79
80 fn ballot(predicates: Self::Vector<bool>) -> Self::Mask;
84
85 fn all(predicates: Self::Vector<bool>) -> bool;
87
88 fn any(predicates: Self::Vector<bool>) -> bool;
90
91 fn mask_popcount(mask: Self::Mask) -> u32;
93}
94
95#[derive(Copy, Clone, Debug)]
104pub struct CpuSimd<const WIDTH: usize>;
105
106#[derive(Copy, Clone, Debug)]
108pub struct PortableVector<T: GpuValue, const WIDTH: usize> {
109 data: [T; WIDTH],
110}
111
112impl<T: GpuValue, const WIDTH: usize> SimdVector<T> for PortableVector<T, WIDTH> {
113 const WIDTH: usize = WIDTH;
114
115 fn splat(value: T) -> Self {
116 PortableVector {
117 data: [value; WIDTH],
118 }
119 }
120
121 fn extract(self, lane: usize) -> T {
122 debug_assert!(lane < WIDTH, "extract: lane {lane} >= WIDTH {WIDTH}");
123 self.data[lane]
124 }
125
126 fn insert(self, lane: usize, value: T) -> Self {
127 debug_assert!(lane < WIDTH, "insert: lane {lane} >= WIDTH {WIDTH}");
128 let mut result = self;
129 result.data[lane] = value;
130 result
131 }
132}
133
134impl<T: GpuValue, const WIDTH: usize> Default for PortableVector<T, WIDTH> {
135 fn default() -> Self {
136 PortableVector {
137 data: [T::default(); WIDTH],
138 }
139 }
140}
141
142impl<const WIDTH: usize> Platform for CpuSimd<WIDTH>
143where
144 [(); WIDTH]: Sized,
145{
146 const WIDTH: usize = WIDTH;
147 const NAME: &'static str = "CpuSimd";
148
149 type Vector<T: GpuValue> = PortableVector<T, WIDTH>;
150 type Mask = u64;
151
152 fn shuffle<T: GpuValue>(
153 source: Self::Vector<T>,
154 indices: Self::Vector<u32>,
155 ) -> Self::Vector<T> {
156 let mut result = PortableVector::default();
157 for i in 0..WIDTH {
158 let src_idx = indices.data[i] as usize % WIDTH;
159 result.data[i] = source.data[src_idx];
160 }
161 result
162 }
163
164 fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T> {
165 let mut result = PortableVector::default();
166 for i in 0..WIDTH {
167 let src_idx = (i + delta) % WIDTH;
168 result.data[i] = source.data[src_idx];
169 }
170 result
171 }
172
173 fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T> {
174 let mut result = PortableVector::default();
175 for i in 0..WIDTH {
176 let src_idx = (i ^ mask) % WIDTH;
177 result.data[i] = source.data[src_idx];
178 }
179 result
180 }
181
182 fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T {
183 const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
184 values.data.into_iter().reduce(|a, b| a + b).unwrap()
185 }
186
187 fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
188 const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
189 values.data.into_iter().max().unwrap()
190 }
191
192 fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
193 const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
194 values.data.into_iter().min().unwrap()
195 }
196
197 fn ballot(predicates: Self::Vector<bool>) -> Self::Mask {
198 const {
200 assert!(
201 WIDTH <= 64,
202 "CpuSimd<WIDTH>: ballot requires WIDTH <= 64 (u64 mask)"
203 )
204 };
205 let mut mask = 0u64;
206 for i in 0..WIDTH {
207 if predicates.data[i] {
208 mask |= 1u64 << i;
209 }
210 }
211 mask
212 }
213
214 fn all(predicates: Self::Vector<bool>) -> bool {
215 predicates.data.iter().all(|&b| b)
216 }
217
218 fn any(predicates: Self::Vector<bool>) -> bool {
219 predicates.data.iter().any(|&b| b)
220 }
221
222 fn mask_popcount(mask: Self::Mask) -> u32 {
223 mask.count_ones()
224 }
225}
226
227#[derive(Copy, Clone, Debug)]
238pub struct GpuWarp32;
239
240#[derive(Copy, Clone, Debug)]
246pub struct GpuWarp64;
247
248impl Platform for GpuWarp32 {
252 const WIDTH: usize = 32;
253 const NAME: &'static str = "GpuWarp32";
254
255 type Vector<T: GpuValue> = PortableVector<T, 32>;
256 type Mask = u32;
257
258 fn shuffle<T: GpuValue>(
259 source: Self::Vector<T>,
260 indices: Self::Vector<u32>,
261 ) -> Self::Vector<T> {
262 let mut result = PortableVector::default();
264 for i in 0..32 {
265 let src_idx = indices.data[i] as usize;
266 result.data[i] = if src_idx < 32 {
267 source.data[src_idx]
268 } else {
269 source.data[i]
270 };
271 }
272 result
273 }
274
275 fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T> {
276 let mut result = PortableVector::default();
279 for i in 0..32 {
280 let src_idx = i + delta;
281 result.data[i] = if src_idx < 32 {
282 source.data[src_idx]
283 } else {
284 source.data[i]
285 };
286 }
287 result
288 }
289
290 fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T> {
291 CpuSimd::<32>::shuffle_xor(source, mask)
292 }
293
294 fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T {
295 CpuSimd::<32>::reduce_sum(values)
296 }
297
298 fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
299 CpuSimd::<32>::reduce_max(values)
300 }
301
302 fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
303 CpuSimd::<32>::reduce_min(values)
304 }
305
306 fn ballot(predicates: Self::Vector<bool>) -> Self::Mask {
307 CpuSimd::<32>::ballot(predicates) as u32
308 }
309
310 fn all(predicates: Self::Vector<bool>) -> bool {
311 CpuSimd::<32>::all(predicates)
312 }
313
314 fn any(predicates: Self::Vector<bool>) -> bool {
315 CpuSimd::<32>::any(predicates)
316 }
317
318 fn mask_popcount(mask: Self::Mask) -> u32 {
319 mask.count_ones()
320 }
321}
322
323pub fn butterfly_reduce_sum<const WIDTH: usize, T>(values: PortableVector<T, WIDTH>) -> T
333where
334 T: GpuValue + core::ops::Add<Output = T>,
335{
336 const {
337 assert!(
338 WIDTH.is_power_of_two(),
339 "butterfly_reduce_sum requires power-of-2 WIDTH"
340 )
341 };
342 let mut v = values;
343 let mut stride = 1;
344 while stride < WIDTH {
345 let mut shuffled: PortableVector<T, WIDTH> = PortableVector::default();
347 for i in 0..WIDTH {
348 shuffled.data[i] = v.data[i ^ stride];
349 }
350 for i in 0..WIDTH {
352 v.data[i] = v.data[i] + shuffled.data[i];
353 }
354 stride *= 2;
355 }
356 v.data[0]
357}
358
359pub fn prefix_sum<const WIDTH: usize, T>(
361 values: PortableVector<T, WIDTH>,
362) -> PortableVector<T, WIDTH>
363where
364 T: GpuValue + core::ops::Add<Output = T>,
365{
366 let mut v = values;
367 let mut stride = 1;
368 while stride < WIDTH {
369 let mut result = v;
370 for i in stride..WIDTH {
371 result.data[i] = v.data[i] + v.data[i - stride];
372 }
373 v = result;
374 stride *= 2;
375 }
376 v
377}
378
379#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn test_cpu_simd_broadcast() {
389 let v = CpuSimd::<8>::broadcast(42i32);
390 for i in 0..8 {
391 assert_eq!(v.extract(i), 42);
392 }
393 }
394
395 #[test]
396 fn test_cpu_simd_shuffle_xor() {
397 let mut v = PortableVector::<i32, 8>::default();
399 for i in 0..8 {
400 v = v.insert(i, i as i32);
401 }
402
403 let shuffled = CpuSimd::<8>::shuffle_xor(v, 1);
405 assert_eq!(shuffled.extract(0), 1);
406 assert_eq!(shuffled.extract(1), 0);
407 assert_eq!(shuffled.extract(2), 3);
408 assert_eq!(shuffled.extract(3), 2);
409 }
410
411 #[test]
412 fn test_cpu_simd_reduce_sum() {
413 let mut v = PortableVector::<i32, 8>::default();
414 for i in 0..8 {
415 v = v.insert(i, (i + 1) as i32);
416 }
417 assert_eq!(CpuSimd::<8>::reduce_sum(v), 36);
419 }
420
421 #[test]
422 fn test_butterfly_reduce() {
423 let mut v = PortableVector::<i32, 8>::default();
424 for i in 0..8 {
425 v = v.insert(i, (i + 1) as i32);
426 }
427 let sum = butterfly_reduce_sum::<8, i32>(v);
428 assert_eq!(sum, 36);
429 }
430
431 #[test]
432 fn test_ballot() {
433 let mut predicates = PortableVector::<bool, 8>::default();
435 for i in 0..8 {
436 predicates = predicates.insert(i, i % 2 == 1);
437 }
438 let mask = CpuSimd::<8>::ballot(predicates);
439 assert_eq!(mask, 0b10101010);
441 assert_eq!(CpuSimd::<8>::mask_popcount(mask), 4);
442 }
443
444 #[test]
445 fn test_gpu_warp32_emulation() {
446 let v = GpuWarp32::broadcast(7i32);
447 assert_eq!(v.extract(0), 7);
448 assert_eq!(v.extract(31), 7);
449
450 let mut values = PortableVector::<i32, 32>::default();
451 for i in 0..32 {
452 values = values.insert(i, 1);
453 }
454 assert_eq!(GpuWarp32::reduce_sum(values), 32);
455 }
456
457 #[test]
458 fn test_prefix_sum() {
459 let mut v = PortableVector::<i32, 8>::default();
460 for i in 0..8 {
461 v = v.insert(i, 1); }
463 let result = prefix_sum::<8, i32>(v);
464 for i in 0..8 {
466 assert_eq!(result.extract(i), (i + 1) as i32);
467 }
468 }
469}