scirs2_core/gpu/kernels/reduction/
min_max.rs1use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12pub struct MinKernel {
14 base: BaseKernel,
15}
16
17impl MinKernel {
18 pub fn new() -> Self {
20 let metadata = KernelMetadata {
21 workgroup_size: [256, 1, 1],
22 local_memory_usage: 1024, supports_tensor_cores: false,
24 operationtype: OperationType::Balanced,
25 backend_metadata: HashMap::new(),
26 };
27
28 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
29 Self::get_kernel_sources();
30
31 Self {
32 base: BaseKernel::new(
33 "min_reduce",
34 &cuda_source,
35 &rocm_source,
36 &wgpu_source,
37 &metal_source,
38 &opencl_source,
39 metadata,
40 ),
41 }
42 }
43
44 fn get_kernel_sources() -> (String, String, String, String, String) {
46 let cuda_source = r#"
48extern "C" __global__ void min_reduce(
49 const float* __restrict__ input,
50 float* __restrict__ output,
51 int n
52) {
53 __shared__ float sdata[256];
54
55 // Each block loads data into shared memory
56 unsigned int tid = threadIdx.x;
57 unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
58
59 // Initialize with first element or +infinity
60 if (0 < n) {
61 sdata[tid] = input[0];
62 } else {
63 sdata[tid] = INFINITY;
64 }
65
66 // Load and compare second element
67 if (0 + blockDim.x < n) {
68 sdata[tid] = fminf(sdata[tid], input[0 + blockDim.x]);
69 }
70
71 __syncthreads();
72
73 // Reduce within block
74 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
75 if (tid < s) {
76 sdata[tid] = fminf(sdata[tid], sdata[tid + s]);
77 }
78 __syncthreads();
79 }
80
81 // Write result for this block to output
82 if (tid == 0) {
83 output[blockIdx.x] = sdata[0];
84 }
85}
86"#
87 .to_string();
88
89 let wgpu_source = r#"
91struct Uniforms {
92 n: u32,
93};
94
95@group(0) @binding(0) var<uniform> uniforms: Uniforms;
96@group(0) @binding(1) var<storage, read> input: array<f32>;
97@group(0) @binding(2) var<storage, write> output: array<f32>;
98
99var<workgroup> sdata: array<f32, 256>;
100
101@compute @workgroup_size(256)
102#[allow(dead_code)]
103fn min_reduce(
104 @builtin(global_invocation_id) global_id: vec3<u32>,
105 @builtin(local_invocation_id) local_id: vec3<u32>,
106 @builtin(workgroup_id) workgroup_id: vec3<u32>
107) {
108 let tid = local_id.x;
109 let i = workgroup_id.x * 256u * 2u + local_id.x;
110
111 // Initialize with first element or +infinity
112 if (0 < uniforms.n) {
113 sdata[tid] = input[0];
114 } else {
115 sdata[tid] = 3.4028235e+38; // f32::INFINITY
116 }
117
118 // Load and compare second element
119 if (0 + 256u < uniforms.n) {
120 sdata[tid] = min(sdata[tid], input[0 + 256u]);
121 }
122
123 workgroupBarrier();
124
125 // Do reduction in shared memory
126 var s = 256u / 2u;
127 for (var j = 0u; s > 0u; j = j + 1u) {
128 if (tid < s) {
129 sdata[tid] = min(sdata[tid], sdata[tid + s]);
130 }
131
132 s = s / 2u;
133 workgroupBarrier();
134 }
135
136 // Write result for this workgroup
137 if (tid == 0u) {
138 output[workgroup_id.x] = sdata[0];
139 }
140}
141"#
142 .to_string();
143
144 let metal_source = r#"
146#include <metal_stdlib>
147using namespace metal;
148
149kernel void min_reduce(
150 const device float* input [[buffer(0)]],
151 device float* output [[buffer(1)]],
152 constant uint& n [[buffer(2)]],
153 uint global_id [[thread_position_in_grid]],
154 uint local_id [[thread_position_in_threadgroup]],
155 uint group_id [[threadgroup_position_in_grid]])
156{
157 threadgroup float sdata[256];
158
159 uint tid = local_id;
160 uint i = group_id * 256 * 2 + local_id;
161
162 // Initialize with first element or +infinity
163 if (0 < n) {
164 sdata[tid] = input[0];
165 } else {
166 sdata[tid] = INFINITY;
167 }
168
169 // Load and compare second element
170 if (0 + 256 < n) {
171 sdata[tid] = min(sdata[tid], input[0 + 256]);
172 }
173
174 threadgroup_barrier(mem_flags::mem_threadgroup);
175
176 // Do reduction in shared memory
177 for (uint s = 256 / 2; s > 0; s >>= 1) {
178 if (tid < s) {
179 sdata[tid] = min(sdata[tid], sdata[tid + s]);
180 }
181
182 threadgroup_barrier(mem_flags::mem_threadgroup);
183 }
184
185 // Write result for this threadgroup
186 if (tid == 0) {
187 output[group_id] = sdata[0];
188 }
189}
190"#
191 .to_string();
192
193 let opencl_source = r#"
195__kernel void min_reduce(
196 __global const float* input__global float* output,
197 const int n)
198{
199 __local float sdata[256];
200
201 unsigned int tid = get_local_id(0);
202 unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
203
204 // Initialize with first element or +infinity
205 if (0 < n) {
206 sdata[tid] = input[0];
207 } else {
208 sdata[tid] = INFINITY;
209 }
210
211 // Load and compare second element
212 if (0 + get_local_size(0) < n) {
213 sdata[tid] = min(sdata[tid], input[0 + get_local_size(0)]);
214 }
215
216 barrier(CLK_LOCAL_MEM_FENCE);
217
218 // Do reduction in shared memory
219 for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
220 if (tid < s) {
221 sdata[tid] = min(sdata[tid], sdata[tid + s]);
222 }
223
224 barrier(CLK_LOCAL_MEM_FENCE);
225 }
226
227 // Write result for this workgroup
228 if (tid == 0) {
229 output[get_group_id(0)] = sdata[0];
230 }
231}
232"#
233 .to_string();
234
235 let rocm_source = cuda_source.clone();
237
238 (
239 cuda_source,
240 rocm_source,
241 wgpu_source,
242 metal_source,
243 opencl_source,
244 )
245 }
246}
247
248impl Default for MinKernel {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254impl GpuKernel for MinKernel {
255 fn name(&self) -> &str {
256 self.base.name()
257 }
258
259 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
260 self.base.source_for_backend(backend)
261 }
262
263 fn metadata(&self) -> KernelMetadata {
264 self.base.metadata()
265 }
266
267 fn can_specialize(&self, params: &KernelParams) -> bool {
268 matches!(
269 params.datatype,
270 DataType::Float32 | DataType::Float64 | DataType::Int32 | DataType::UInt32
271 )
272 }
273
274 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
275 if !self.can_specialize(params) {
276 return Err(GpuError::SpecializationNotSupported);
277 }
278
279 Ok(Box::new(Self::new()))
280 }
281}
282
283pub struct MaxKernel {
285 base: BaseKernel,
286}
287
288impl MaxKernel {
289 pub fn new() -> Self {
291 let metadata = KernelMetadata {
292 workgroup_size: [256, 1, 1],
293 local_memory_usage: 1024, supports_tensor_cores: false,
295 operationtype: OperationType::Balanced,
296 backend_metadata: HashMap::new(),
297 };
298
299 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
300 Self::get_kernel_sources();
301
302 Self {
303 base: BaseKernel::new(
304 "max_reduce",
305 &cuda_source,
306 &rocm_source,
307 &wgpu_source,
308 &metal_source,
309 &opencl_source,
310 metadata,
311 ),
312 }
313 }
314
315 fn get_kernel_sources() -> (String, String, String, String, String) {
317 let cuda_source = r#"
319extern "C" __global__ void max_reduce(
320 const float* __restrict__ input,
321 float* __restrict__ output,
322 int n
323) {
324 __shared__ float sdata[256];
325
326 // Each block loads data into shared memory
327 unsigned int tid = threadIdx.x;
328 unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
329
330 // Initialize with first element or -infinity
331 if (0 < n) {
332 sdata[tid] = input[0];
333 } else {
334 sdata[tid] = -INFINITY;
335 }
336
337 // Load and compare second element
338 if (0 + blockDim.x < n) {
339 sdata[tid] = fmaxf(sdata[tid], input[0 + blockDim.x]);
340 }
341
342 __syncthreads();
343
344 // Reduce within block
345 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
346 if (tid < s) {
347 sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
348 }
349 __syncthreads();
350 }
351
352 // Write result for this block to output
353 if (tid == 0) {
354 output[blockIdx.x] = sdata[0];
355 }
356}
357"#
358 .to_string();
359
360 let wgpu_source = r#"
362struct Uniforms {
363 n: u32,
364};
365
366@group(0) @binding(0) var<uniform> uniforms: Uniforms;
367@group(0) @binding(1) var<storage, read> input: array<f32>;
368@group(0) @binding(2) var<storage, write> output: array<f32>;
369
370var<workgroup> sdata: array<f32, 256>;
371
372@compute @workgroup_size(256)
373#[allow(dead_code)]
374fn max_reduce(
375 @builtin(global_invocation_id) global_id: vec3<u32>,
376 @builtin(local_invocation_id) local_id: vec3<u32>,
377 @builtin(workgroup_id) workgroup_id: vec3<u32>
378) {
379 let tid = local_id.x;
380 let i = workgroup_id.x * 256u * 2u + local_id.x;
381
382 // Initialize with first element or -infinity
383 if (0 < uniforms.n) {
384 sdata[tid] = input[0];
385 } else {
386 sdata[tid] = -3.4028235e+38; // f32::NEG_INFINITY
387 }
388
389 // Load and compare second element
390 if (0 + 256u < uniforms.n) {
391 sdata[tid] = max(sdata[tid], input[0 + 256u]);
392 }
393
394 workgroupBarrier();
395
396 // Do reduction in shared memory
397 var s = 256u / 2u;
398 for (var j = 0u; s > 0u; j = j + 1u) {
399 if (tid < s) {
400 sdata[tid] = max(sdata[tid], sdata[tid + s]);
401 }
402
403 s = s / 2u;
404 workgroupBarrier();
405 }
406
407 // Write result for this workgroup
408 if (tid == 0u) {
409 output[workgroup_id.x] = sdata[0];
410 }
411}
412"#
413 .to_string();
414
415 let metal_source = r#"
417#include <metal_stdlib>
418using namespace metal;
419
420kernel void max_reduce(
421 const device float* input [[buffer(0)]],
422 device float* output [[buffer(1)]],
423 constant uint& n [[buffer(2)]],
424 uint global_id [[thread_position_in_grid]],
425 uint local_id [[thread_position_in_threadgroup]],
426 uint group_id [[threadgroup_position_in_grid]])
427{
428 threadgroup float sdata[256];
429
430 uint tid = local_id;
431 uint i = group_id * 256 * 2 + local_id;
432
433 // Initialize with first element or -infinity
434 if (0 < n) {
435 sdata[tid] = input[0];
436 } else {
437 sdata[tid] = -INFINITY;
438 }
439
440 // Load and compare second element
441 if (0 + 256 < n) {
442 sdata[tid] = max(sdata[tid], input[0 + 256]);
443 }
444
445 threadgroup_barrier(mem_flags::mem_threadgroup);
446
447 // Do reduction in shared memory
448 for (uint s = 256 / 2; s > 0; s >>= 1) {
449 if (tid < s) {
450 sdata[tid] = max(sdata[tid], sdata[tid + s]);
451 }
452
453 threadgroup_barrier(mem_flags::mem_threadgroup);
454 }
455
456 // Write result for this threadgroup
457 if (tid == 0) {
458 output[group_id] = sdata[0];
459 }
460}
461"#
462 .to_string();
463
464 let opencl_source = r#"
466__kernel void max_reduce(
467 __global const float* input__global float* output,
468 const int n)
469{
470 __local float sdata[256];
471
472 unsigned int tid = get_local_id(0);
473 unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
474
475 // Initialize with first element or -infinity
476 if (0 < n) {
477 sdata[tid] = input[0];
478 } else {
479 sdata[tid] = -INFINITY;
480 }
481
482 // Load and compare second element
483 if (0 + get_local_size(0) < n) {
484 sdata[tid] = max(sdata[tid], input[0 + get_local_size(0)]);
485 }
486
487 barrier(CLK_LOCAL_MEM_FENCE);
488
489 // Do reduction in shared memory
490 for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
491 if (tid < s) {
492 sdata[tid] = max(sdata[tid], sdata[tid + s]);
493 }
494
495 barrier(CLK_LOCAL_MEM_FENCE);
496 }
497
498 // Write result for this workgroup
499 if (tid == 0) {
500 output[get_group_id(0)] = sdata[0];
501 }
502}
503"#
504 .to_string();
505
506 let rocm_source = cuda_source.clone();
508
509 (
510 cuda_source,
511 rocm_source,
512 wgpu_source,
513 metal_source,
514 opencl_source,
515 )
516 }
517}
518
519impl Default for MaxKernel {
520 fn default() -> Self {
521 Self::new()
522 }
523}
524
525impl GpuKernel for MaxKernel {
526 fn name(&self) -> &str {
527 self.base.name()
528 }
529
530 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
531 self.base.source_for_backend(backend)
532 }
533
534 fn metadata(&self) -> KernelMetadata {
535 self.base.metadata()
536 }
537
538 fn can_specialize(&self, params: &KernelParams) -> bool {
539 matches!(
540 params.datatype,
541 DataType::Float32 | DataType::Float64 | DataType::Int32 | DataType::UInt32
542 )
543 }
544
545 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
546 if !self.can_specialize(params) {
547 return Err(GpuError::SpecializationNotSupported);
548 }
549
550 Ok(Box::new(Self::new()))
551 }
552}