scirs2_core/gpu/kernels/reduction/
min_max.rs

1//! Min and Max reduction kernels
2//!
3//! Computes the minimum and maximum values in an array.
4
5use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8    BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12/// Min reduction kernel
13pub struct MinKernel {
14    base: BaseKernel,
15}
16
17impl MinKernel {
18    /// Create a new min reduction kernel
19    pub fn new() -> Self {
20        let metadata = KernelMetadata {
21            workgroup_size: [256, 1, 1],
22            local_memory_usage: 1024, // 256 * sizeof(float)
23            supports_tensor_cores: false,
24            operationtype: OperationType::Balanced,
25            backend_metadata: HashMap::new(),
26        };
27
28        let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
29            Self::get_kernel_sources();
30
31        Self {
32            base: BaseKernel::new(
33                "min_reduce",
34                &cuda_source,
35                &rocm_source,
36                &wgpu_source,
37                &metal_source,
38                &opencl_source,
39                metadata,
40            ),
41        }
42    }
43
44    /// Get kernel sources for different backends
45    fn get_kernel_sources() -> (String, String, String, String, String) {
46        // CUDA kernel
47        let cuda_source = r#"
48extern "C" __global__ void min_reduce(
49    const float* __restrict__ input,
50    float* __restrict__ output,
51    int n
52) {
53    __shared__ float sdata[256];
54
55    // Each block loads data into shared memory
56    unsigned int tid = threadIdx.x;
57    unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
58
59    // Initialize with first element or +infinity
60    if (0 < n) {
61        sdata[tid] = input[0];
62    } else {
63        sdata[tid] = INFINITY;
64    }
65
66    // Load and compare second element
67    if (0 + blockDim.x < n) {
68        sdata[tid] = fminf(sdata[tid], input[0 + blockDim.x]);
69    }
70
71    __syncthreads();
72
73    // Reduce within block
74    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
75        if (tid < s) {
76            sdata[tid] = fminf(sdata[tid], sdata[tid + s]);
77        }
78        __syncthreads();
79    }
80
81    // Write result for this block to output
82    if (tid == 0) {
83        output[blockIdx.x] = sdata[0];
84    }
85}
86"#
87        .to_string();
88
89        // WebGPU kernel
90        let wgpu_source = r#"
91struct Uniforms {
92    n: u32,
93};
94
95@group(0) @binding(0) var<uniform> uniforms: Uniforms;
96@group(0) @binding(1) var<storage, read> input: array<f32>;
97@group(0) @binding(2) var<storage, write> output: array<f32>;
98
99var<workgroup> sdata: array<f32, 256>;
100
101@compute @workgroup_size(256)
102#[allow(dead_code)]
103fn min_reduce(
104    @builtin(global_invocation_id) global_id: vec3<u32>,
105    @builtin(local_invocation_id) local_id: vec3<u32>,
106    @builtin(workgroup_id) workgroup_id: vec3<u32>
107) {
108    let tid = local_id.x;
109    let i = workgroup_id.x * 256u * 2u + local_id.x;
110
111    // Initialize with first element or +infinity
112    if (0 < uniforms.n) {
113        sdata[tid] = input[0];
114    } else {
115        sdata[tid] = 3.4028235e+38; // f32::INFINITY
116    }
117
118    // Load and compare second element
119    if (0 + 256u < uniforms.n) {
120        sdata[tid] = min(sdata[tid], input[0 + 256u]);
121    }
122
123    workgroupBarrier();
124
125    // Do reduction in shared memory
126    var s = 256u / 2u;
127    for (var j = 0u; s > 0u; j = j + 1u) {
128        if (tid < s) {
129            sdata[tid] = min(sdata[tid], sdata[tid + s]);
130        }
131
132        s = s / 2u;
133        workgroupBarrier();
134    }
135
136    // Write result for this workgroup
137    if (tid == 0u) {
138        output[workgroup_id.x] = sdata[0];
139    }
140}
141"#
142        .to_string();
143
144        // Metal kernel
145        let metal_source = r#"
146#include <metal_stdlib>
147using namespace metal;
148
149kernel void min_reduce(
150    const device float* input [[buffer(0)]],
151    device float* output [[buffer(1)]],
152    constant uint& n [[buffer(2)]],
153    uint global_id [[thread_position_in_grid]],
154    uint local_id [[thread_position_in_threadgroup]],
155    uint group_id [[threadgroup_position_in_grid]])
156{
157    threadgroup float sdata[256];
158
159    uint tid = local_id;
160    uint i = group_id * 256 * 2 + local_id;
161
162    // Initialize with first element or +infinity
163    if (0 < n) {
164        sdata[tid] = input[0];
165    } else {
166        sdata[tid] = INFINITY;
167    }
168
169    // Load and compare second element
170    if (0 + 256 < n) {
171        sdata[tid] = min(sdata[tid], input[0 + 256]);
172    }
173
174    threadgroup_barrier(mem_flags::mem_threadgroup);
175
176    // Do reduction in shared memory
177    for (uint s = 256 / 2; s > 0; s >>= 1) {
178        if (tid < s) {
179            sdata[tid] = min(sdata[tid], sdata[tid + s]);
180        }
181
182        threadgroup_barrier(mem_flags::mem_threadgroup);
183    }
184
185    // Write result for this threadgroup
186    if (tid == 0) {
187        output[group_id] = sdata[0];
188    }
189}
190"#
191        .to_string();
192
193        // OpenCL kernel
194        let opencl_source = r#"
195__kernel void min_reduce(
196    __global const float* input__global float* output,
197    const int n)
198{
199    __local float sdata[256];
200
201    unsigned int tid = get_local_id(0);
202    unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
203
204    // Initialize with first element or +infinity
205    if (0 < n) {
206        sdata[tid] = input[0];
207    } else {
208        sdata[tid] = INFINITY;
209    }
210
211    // Load and compare second element
212    if (0 + get_local_size(0) < n) {
213        sdata[tid] = min(sdata[tid], input[0 + get_local_size(0)]);
214    }
215
216    barrier(CLK_LOCAL_MEM_FENCE);
217
218    // Do reduction in shared memory
219    for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
220        if (tid < s) {
221            sdata[tid] = min(sdata[tid], sdata[tid + s]);
222        }
223
224        barrier(CLK_LOCAL_MEM_FENCE);
225    }
226
227    // Write result for this workgroup
228    if (tid == 0) {
229        output[get_group_id(0)] = sdata[0];
230    }
231}
232"#
233        .to_string();
234
235        // ROCm (HIP) kernel - similar to CUDA
236        let rocm_source = cuda_source.clone();
237
238        (
239            cuda_source,
240            rocm_source,
241            wgpu_source,
242            metal_source,
243            opencl_source,
244        )
245    }
246}
247
248impl Default for MinKernel {
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254impl GpuKernel for MinKernel {
255    fn name(&self) -> &str {
256        self.base.name()
257    }
258
259    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
260        self.base.source_for_backend(backend)
261    }
262
263    fn metadata(&self) -> KernelMetadata {
264        self.base.metadata()
265    }
266
267    fn can_specialize(&self, params: &KernelParams) -> bool {
268        matches!(
269            params.datatype,
270            DataType::Float32 | DataType::Float64 | DataType::Int32 | DataType::UInt32
271        )
272    }
273
274    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
275        if !self.can_specialize(params) {
276            return Err(GpuError::SpecializationNotSupported);
277        }
278
279        Ok(Box::new(Self::new()))
280    }
281}
282
283/// Max reduction kernel
284pub struct MaxKernel {
285    base: BaseKernel,
286}
287
288impl MaxKernel {
289    /// Create a new max reduction kernel
290    pub fn new() -> Self {
291        let metadata = KernelMetadata {
292            workgroup_size: [256, 1, 1],
293            local_memory_usage: 1024, // 256 * sizeof(float)
294            supports_tensor_cores: false,
295            operationtype: OperationType::Balanced,
296            backend_metadata: HashMap::new(),
297        };
298
299        let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
300            Self::get_kernel_sources();
301
302        Self {
303            base: BaseKernel::new(
304                "max_reduce",
305                &cuda_source,
306                &rocm_source,
307                &wgpu_source,
308                &metal_source,
309                &opencl_source,
310                metadata,
311            ),
312        }
313    }
314
315    /// Get kernel sources for different backends
316    fn get_kernel_sources() -> (String, String, String, String, String) {
317        // CUDA kernel
318        let cuda_source = r#"
319extern "C" __global__ void max_reduce(
320    const float* __restrict__ input,
321    float* __restrict__ output,
322    int n
323) {
324    __shared__ float sdata[256];
325
326    // Each block loads data into shared memory
327    unsigned int tid = threadIdx.x;
328    unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
329
330    // Initialize with first element or -infinity
331    if (0 < n) {
332        sdata[tid] = input[0];
333    } else {
334        sdata[tid] = -INFINITY;
335    }
336
337    // Load and compare second element
338    if (0 + blockDim.x < n) {
339        sdata[tid] = fmaxf(sdata[tid], input[0 + blockDim.x]);
340    }
341
342    __syncthreads();
343
344    // Reduce within block
345    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
346        if (tid < s) {
347            sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
348        }
349        __syncthreads();
350    }
351
352    // Write result for this block to output
353    if (tid == 0) {
354        output[blockIdx.x] = sdata[0];
355    }
356}
357"#
358        .to_string();
359
360        // WebGPU kernel
361        let wgpu_source = r#"
362struct Uniforms {
363    n: u32,
364};
365
366@group(0) @binding(0) var<uniform> uniforms: Uniforms;
367@group(0) @binding(1) var<storage, read> input: array<f32>;
368@group(0) @binding(2) var<storage, write> output: array<f32>;
369
370var<workgroup> sdata: array<f32, 256>;
371
372@compute @workgroup_size(256)
373#[allow(dead_code)]
374fn max_reduce(
375    @builtin(global_invocation_id) global_id: vec3<u32>,
376    @builtin(local_invocation_id) local_id: vec3<u32>,
377    @builtin(workgroup_id) workgroup_id: vec3<u32>
378) {
379    let tid = local_id.x;
380    let i = workgroup_id.x * 256u * 2u + local_id.x;
381
382    // Initialize with first element or -infinity
383    if (0 < uniforms.n) {
384        sdata[tid] = input[0];
385    } else {
386        sdata[tid] = -3.4028235e+38; // f32::NEG_INFINITY
387    }
388
389    // Load and compare second element
390    if (0 + 256u < uniforms.n) {
391        sdata[tid] = max(sdata[tid], input[0 + 256u]);
392    }
393
394    workgroupBarrier();
395
396    // Do reduction in shared memory
397    var s = 256u / 2u;
398    for (var j = 0u; s > 0u; j = j + 1u) {
399        if (tid < s) {
400            sdata[tid] = max(sdata[tid], sdata[tid + s]);
401        }
402
403        s = s / 2u;
404        workgroupBarrier();
405    }
406
407    // Write result for this workgroup
408    if (tid == 0u) {
409        output[workgroup_id.x] = sdata[0];
410    }
411}
412"#
413        .to_string();
414
415        // Metal kernel
416        let metal_source = r#"
417#include <metal_stdlib>
418using namespace metal;
419
420kernel void max_reduce(
421    const device float* input [[buffer(0)]],
422    device float* output [[buffer(1)]],
423    constant uint& n [[buffer(2)]],
424    uint global_id [[thread_position_in_grid]],
425    uint local_id [[thread_position_in_threadgroup]],
426    uint group_id [[threadgroup_position_in_grid]])
427{
428    threadgroup float sdata[256];
429
430    uint tid = local_id;
431    uint i = group_id * 256 * 2 + local_id;
432
433    // Initialize with first element or -infinity
434    if (0 < n) {
435        sdata[tid] = input[0];
436    } else {
437        sdata[tid] = -INFINITY;
438    }
439
440    // Load and compare second element
441    if (0 + 256 < n) {
442        sdata[tid] = max(sdata[tid], input[0 + 256]);
443    }
444
445    threadgroup_barrier(mem_flags::mem_threadgroup);
446
447    // Do reduction in shared memory
448    for (uint s = 256 / 2; s > 0; s >>= 1) {
449        if (tid < s) {
450            sdata[tid] = max(sdata[tid], sdata[tid + s]);
451        }
452
453        threadgroup_barrier(mem_flags::mem_threadgroup);
454    }
455
456    // Write result for this threadgroup
457    if (tid == 0) {
458        output[group_id] = sdata[0];
459    }
460}
461"#
462        .to_string();
463
464        // OpenCL kernel
465        let opencl_source = r#"
466__kernel void max_reduce(
467    __global const float* input__global float* output,
468    const int n)
469{
470    __local float sdata[256];
471
472    unsigned int tid = get_local_id(0);
473    unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
474
475    // Initialize with first element or -infinity
476    if (0 < n) {
477        sdata[tid] = input[0];
478    } else {
479        sdata[tid] = -INFINITY;
480    }
481
482    // Load and compare second element
483    if (0 + get_local_size(0) < n) {
484        sdata[tid] = max(sdata[tid], input[0 + get_local_size(0)]);
485    }
486
487    barrier(CLK_LOCAL_MEM_FENCE);
488
489    // Do reduction in shared memory
490    for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
491        if (tid < s) {
492            sdata[tid] = max(sdata[tid], sdata[tid + s]);
493        }
494
495        barrier(CLK_LOCAL_MEM_FENCE);
496    }
497
498    // Write result for this workgroup
499    if (tid == 0) {
500        output[get_group_id(0)] = sdata[0];
501    }
502}
503"#
504        .to_string();
505
506        // ROCm (HIP) kernel - similar to CUDA
507        let rocm_source = cuda_source.clone();
508
509        (
510            cuda_source,
511            rocm_source,
512            wgpu_source,
513            metal_source,
514            opencl_source,
515        )
516    }
517}
518
519impl Default for MaxKernel {
520    fn default() -> Self {
521        Self::new()
522    }
523}
524
525impl GpuKernel for MaxKernel {
526    fn name(&self) -> &str {
527        self.base.name()
528    }
529
530    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
531        self.base.source_for_backend(backend)
532    }
533
534    fn metadata(&self) -> KernelMetadata {
535        self.base.metadata()
536    }
537
538    fn can_specialize(&self, params: &KernelParams) -> bool {
539        matches!(
540            params.datatype,
541            DataType::Float32 | DataType::Float64 | DataType::Int32 | DataType::UInt32
542        )
543    }
544
545    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
546        if !self.can_specialize(params) {
547            return Err(GpuError::SpecializationNotSupported);
548        }
549
550        Ok(Box::new(Self::new()))
551    }
552}