1use crate::GpuValue;
19
20pub trait SimdVector<T: GpuValue>: Copy {
22 const WIDTH: usize;
24
25 fn splat(value: T) -> Self;
27
28 fn extract(self, lane: usize) -> T;
30
31 fn insert(self, lane: usize, value: T) -> Self;
33}
34
35pub trait Platform: Copy + 'static {
40 const WIDTH: usize;
42
43 const NAME: &'static str;
45
46 type Vector<T: GpuValue>: SimdVector<T>;
48
49 type Mask: Copy;
51
52 fn broadcast<T: GpuValue>(value: T) -> Self::Vector<T> {
56 Self::Vector::splat(value)
57 }
58
59 fn shuffle<T: GpuValue>(source: Self::Vector<T>, indices: Self::Vector<u32>)
61 -> Self::Vector<T>;
62
63 fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T>;
65
66 fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T>;
68
69 fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T;
73
74 fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T;
76
77 fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T;
79
80 fn ballot(predicates: Self::Vector<bool>) -> Self::Mask;
84
85 fn all(predicates: Self::Vector<bool>) -> bool;
87
88 fn any(predicates: Self::Vector<bool>) -> bool;
90
91 fn mask_popcount(mask: Self::Mask) -> u32;
93}
94
95#[derive(Copy, Clone, Debug)]
104pub struct CpuSimd<const WIDTH: usize>;
105
106#[derive(Copy, Clone, Debug)]
108pub struct PortableVector<T: GpuValue, const WIDTH: usize> {
109 data: [T; WIDTH],
110}
111
112impl<T: GpuValue, const WIDTH: usize> SimdVector<T> for PortableVector<T, WIDTH> {
113 const WIDTH: usize = WIDTH;
114
115 fn splat(value: T) -> Self {
116 PortableVector {
117 data: [value; WIDTH],
118 }
119 }
120
121 fn extract(self, lane: usize) -> T {
122 assert!(lane < WIDTH, "extract: lane {lane} >= WIDTH {WIDTH}");
123 self.data[lane]
124 }
125
126 fn insert(self, lane: usize, value: T) -> Self {
127 assert!(lane < WIDTH, "insert: lane {lane} >= WIDTH {WIDTH}");
128 let mut result = self;
129 result.data[lane] = value;
130 result
131 }
132}
133
134impl<T: GpuValue, const WIDTH: usize> Default for PortableVector<T, WIDTH> {
135 fn default() -> Self {
136 PortableVector {
137 data: [T::default(); WIDTH],
138 }
139 }
140}
141
142impl<const WIDTH: usize> Platform for CpuSimd<WIDTH>
143where
144 [(); WIDTH]: Sized,
145{
146 const WIDTH: usize = WIDTH;
147 const NAME: &'static str = "CpuSimd";
148
149 type Vector<T: GpuValue> = PortableVector<T, WIDTH>;
150 type Mask = u64;
151
152 fn shuffle<T: GpuValue>(
153 source: Self::Vector<T>,
154 indices: Self::Vector<u32>,
155 ) -> Self::Vector<T> {
156 let mut result = PortableVector::default();
157 for i in 0..WIDTH {
158 let src_idx = indices.data[i] as usize % WIDTH;
159 result.data[i] = source.data[src_idx];
160 }
161 result
162 }
163
164 fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T> {
165 let mut result = PortableVector::default();
166 for i in 0..WIDTH {
167 let src_idx = if i + delta < WIDTH { i + delta } else { i };
169 result.data[i] = source.data[src_idx];
170 }
171 result
172 }
173
174 fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T> {
175 let mut result = PortableVector::default();
176 for i in 0..WIDTH {
177 let src_idx = (i ^ mask) % WIDTH;
178 result.data[i] = source.data[src_idx];
179 }
180 result
181 }
182
183 fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T {
184 const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
185 values.data.into_iter().reduce(|a, b| a + b).unwrap()
186 }
187
188 fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
189 const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
190 values.data.into_iter().max().unwrap()
191 }
192
193 fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
194 const { assert!(WIDTH > 0, "CpuSimd<WIDTH>: reduce requires WIDTH > 0") };
195 values.data.into_iter().min().unwrap()
196 }
197
198 fn ballot(predicates: Self::Vector<bool>) -> Self::Mask {
199 const {
201 assert!(
202 WIDTH <= 64,
203 "CpuSimd<WIDTH>: ballot requires WIDTH <= 64 (u64 mask)"
204 )
205 };
206 let mut mask = 0u64;
207 for i in 0..WIDTH {
208 if predicates.data[i] {
209 mask |= 1u64 << i;
210 }
211 }
212 mask
213 }
214
215 fn all(predicates: Self::Vector<bool>) -> bool {
216 predicates.data.iter().all(|&b| b)
217 }
218
219 fn any(predicates: Self::Vector<bool>) -> bool {
220 predicates.data.iter().any(|&b| b)
221 }
222
223 fn mask_popcount(mask: Self::Mask) -> u32 {
224 mask.count_ones()
225 }
226}
227
228#[derive(Copy, Clone, Debug)]
239pub struct GpuWarp32;
240
241#[derive(Copy, Clone, Debug)]
247pub struct GpuWarp64;
248
249impl Platform for GpuWarp32 {
253 const WIDTH: usize = 32;
254 const NAME: &'static str = "GpuWarp32";
255
256 type Vector<T: GpuValue> = PortableVector<T, 32>;
257 type Mask = u32;
258
259 fn shuffle<T: GpuValue>(
260 source: Self::Vector<T>,
261 indices: Self::Vector<u32>,
262 ) -> Self::Vector<T> {
263 let mut result = PortableVector::default();
265 for i in 0..32 {
266 let src_idx = indices.data[i] as usize % 32;
267 result.data[i] = source.data[src_idx];
268 }
269 result
270 }
271
272 fn shuffle_down<T: GpuValue>(source: Self::Vector<T>, delta: usize) -> Self::Vector<T> {
273 let mut result = PortableVector::default();
276 for i in 0..32 {
277 let src_idx = i + delta;
278 result.data[i] = if src_idx < 32 {
279 source.data[src_idx]
280 } else {
281 source.data[i]
282 };
283 }
284 result
285 }
286
287 fn shuffle_xor<T: GpuValue>(source: Self::Vector<T>, mask: usize) -> Self::Vector<T> {
288 CpuSimd::<32>::shuffle_xor(source, mask)
289 }
290
291 fn reduce_sum<T: GpuValue + core::ops::Add<Output = T>>(values: Self::Vector<T>) -> T {
292 CpuSimd::<32>::reduce_sum(values)
293 }
294
295 fn reduce_max<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
296 CpuSimd::<32>::reduce_max(values)
297 }
298
299 fn reduce_min<T: GpuValue + Ord>(values: Self::Vector<T>) -> T {
300 CpuSimd::<32>::reduce_min(values)
301 }
302
303 fn ballot(predicates: Self::Vector<bool>) -> Self::Mask {
304 CpuSimd::<32>::ballot(predicates) as u32
305 }
306
307 fn all(predicates: Self::Vector<bool>) -> bool {
308 CpuSimd::<32>::all(predicates)
309 }
310
311 fn any(predicates: Self::Vector<bool>) -> bool {
312 CpuSimd::<32>::any(predicates)
313 }
314
315 fn mask_popcount(mask: Self::Mask) -> u32 {
316 mask.count_ones()
317 }
318}
319
320pub fn butterfly_reduce_sum<const WIDTH: usize, T>(values: PortableVector<T, WIDTH>) -> T
330where
331 T: GpuValue + core::ops::Add<Output = T>,
332{
333 const {
334 assert!(
335 WIDTH.is_power_of_two(),
336 "butterfly_reduce_sum requires power-of-2 WIDTH"
337 )
338 };
339 let mut v = values;
340 let mut stride = 1;
341 while stride < WIDTH {
342 let mut shuffled: PortableVector<T, WIDTH> = PortableVector::default();
344 for i in 0..WIDTH {
345 shuffled.data[i] = v.data[i ^ stride];
346 }
347 for i in 0..WIDTH {
349 v.data[i] = v.data[i] + shuffled.data[i];
350 }
351 stride *= 2;
352 }
353 v.data[0]
354}
355
356pub fn prefix_sum<const WIDTH: usize, T>(
358 values: PortableVector<T, WIDTH>,
359) -> PortableVector<T, WIDTH>
360where
361 T: GpuValue + core::ops::Add<Output = T>,
362{
363 let mut v = values;
364 let mut stride = 1;
365 while stride < WIDTH {
366 let mut result = v;
367 for i in stride..WIDTH {
368 result.data[i] = v.data[i] + v.data[i - stride];
369 }
370 v = result;
371 stride *= 2;
372 }
373 v
374}
375
376#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_cpu_simd_broadcast() {
386 let v = CpuSimd::<8>::broadcast(42i32);
387 for i in 0..8 {
388 assert_eq!(v.extract(i), 42);
389 }
390 }
391
392 #[test]
393 fn test_cpu_simd_shuffle_xor() {
394 let mut v = PortableVector::<i32, 8>::default();
396 for i in 0..8 {
397 v = v.insert(i, i as i32);
398 }
399
400 let shuffled = CpuSimd::<8>::shuffle_xor(v, 1);
402 assert_eq!(shuffled.extract(0), 1);
403 assert_eq!(shuffled.extract(1), 0);
404 assert_eq!(shuffled.extract(2), 3);
405 assert_eq!(shuffled.extract(3), 2);
406 }
407
408 #[test]
409 fn test_cpu_simd_reduce_sum() {
410 let mut v = PortableVector::<i32, 8>::default();
411 for i in 0..8 {
412 v = v.insert(i, (i + 1) as i32);
413 }
414 assert_eq!(CpuSimd::<8>::reduce_sum(v), 36);
416 }
417
418 #[test]
419 fn test_butterfly_reduce() {
420 let mut v = PortableVector::<i32, 8>::default();
421 for i in 0..8 {
422 v = v.insert(i, (i + 1) as i32);
423 }
424 let sum = butterfly_reduce_sum::<8, i32>(v);
425 assert_eq!(sum, 36);
426 }
427
428 #[test]
429 fn test_ballot() {
430 let mut predicates = PortableVector::<bool, 8>::default();
432 for i in 0..8 {
433 predicates = predicates.insert(i, i % 2 == 1);
434 }
435 let mask = CpuSimd::<8>::ballot(predicates);
436 assert_eq!(mask, 0b10101010);
438 assert_eq!(CpuSimd::<8>::mask_popcount(mask), 4);
439 }
440
441 #[test]
442 fn test_gpu_warp32_emulation() {
443 let v = GpuWarp32::broadcast(7i32);
444 assert_eq!(v.extract(0), 7);
445 assert_eq!(v.extract(31), 7);
446
447 let mut values = PortableVector::<i32, 32>::default();
448 for i in 0..32 {
449 values = values.insert(i, 1);
450 }
451 assert_eq!(GpuWarp32::reduce_sum(values), 32);
452 }
453
454 #[test]
455 fn test_prefix_sum() {
456 let mut v = PortableVector::<i32, 8>::default();
457 for i in 0..8 {
458 v = v.insert(i, 1); }
460 let result = prefix_sum::<8, i32>(v);
461 for i in 0..8 {
463 assert_eq!(result.extract(i), (i + 1) as i32);
464 }
465 }
466}