scirs2_core/gpu/kernels/reduction/
sum.rs

1//! Sum reduction kernel
2//!
3//! Computes the sum of all elements in an array.
4
5use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8    BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12/// Sum reduction kernel
13pub struct SumKernel {
14    base: BaseKernel,
15}
16
17impl SumKernel {
18    /// Create a new sum reduction kernel
19    pub fn new() -> Self {
20        let metadata = KernelMetadata {
21            workgroup_size: [256, 1, 1],
22            local_memory_usage: 1024, // 256 * sizeof(float)
23            supports_tensor_cores: false,
24            operationtype: OperationType::Balanced,
25            backend_metadata: HashMap::new(),
26        };
27
28        let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
29            Self::get_kernel_sources();
30
31        Self {
32            base: BaseKernel::new(
33                "sum_reduce",
34                &cuda_source,
35                &rocm_source,
36                &wgpu_source,
37                &metal_source,
38                &opencl_source,
39                metadata,
40            ),
41        }
42    }
43
44    /// Get kernel sources for different backends
45    fn get_kernel_sources() -> (String, String, String, String, String) {
46        // CUDA kernel
47        let cuda_source = r#"
48extern "C" __global__ void sum_reduce(
49    const float* __restrict__ input,
50    float* __restrict__ output,
51    int n
52) {
53    __shared__ float sdata[256];
54
55    // Each block loads data into shared memory
56    unsigned int tid = threadIdx.x;
57    unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
58
59    // Initialize with identity value
60    sdata[tid] = 0.0f;
61
62    // Load and add first element
63    if (0 < n) {
64        sdata[tid] = input[0];
65    }
66
67    // Load and add second element
68    if (0 + blockDim.x < n) {
69        sdata[tid] += input[0 + blockDim.x];
70    }
71
72    __syncthreads();
73
74    // Reduce within block
75    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
76        if (tid < s) {
77            sdata[tid] += sdata[tid + s];
78        }
79        __syncthreads();
80    }
81
82    // Write result for this block to output
83    if (tid == 0) {
84        output[blockIdx.x] = sdata[0];
85    }
86}
87"#
88        .to_string();
89
90        // WebGPU kernel
91        let wgpu_source = r#"
92struct Uniforms {
93    n: u32,
94};
95
96@group(0) @binding(0) var<uniform> uniforms: Uniforms;
97@group(0) @binding(1) var<storage, read> input: array<f32>;
98@group(0) @binding(2) var<storage, write> output: array<f32>;
99
100var<workgroup> sdata: array<f32, 256>;
101
102@compute @workgroup_size(256)
103#[allow(dead_code)]
104fn sum_reduce(
105    @builtin(global_invocation_id) global_id: vec3<u32>,
106    @builtin(local_invocation_id) local_id: vec3<u32>,
107    @builtin(workgroup_id) workgroup_id: vec3<u32>
108) {
109    let tid = local_id.x;
110    let i = workgroup_id.x * 256u * 2u + local_id.x;
111
112    // Initialize
113    sdata[tid] = 0.0;
114
115    // Load and add first element
116    if (0 < uniforms.n) {
117        sdata[tid] = input[0];
118    }
119
120    // Load and add second element
121    if (0 + 256u < uniforms.n) {
122        sdata[tid] = sdata[tid] + input[0 + 256u];
123    }
124
125    workgroupBarrier();
126
127    // Do reduction in shared memory
128    var s = 256u / 2u;
129    for (var j = 0u; s > 0u; j = j + 1u) {
130        if (tid < s) {
131            sdata[tid] = sdata[tid] + sdata[tid + s];
132        }
133
134        s = s / 2u;
135        workgroupBarrier();
136    }
137
138    // Write result for this workgroup
139    if (tid == 0u) {
140        output[workgroup_id.x] = sdata[0];
141    }
142}
143"#
144        .to_string();
145
146        // Metal kernel
147        let metal_source = r#"
148#include <metal_stdlib>
149using namespace metal;
150
151kernel void sum_reduce(
152    const device float* input [[buffer(0)]],
153    device float* output [[buffer(1)]],
154    constant uint& n [[buffer(2)]],
155    uint global_id [[thread_position_in_grid]],
156    uint local_id [[thread_position_in_threadgroup]],
157    uint group_id [[threadgroup_position_in_grid]])
158{
159    threadgroup float sdata[256];
160
161    uint tid = local_id;
162    uint i = group_id * 256 * 2 + local_id;
163
164    // Initialize
165    sdata[tid] = 0.0f;
166
167    // Load and add first element
168    if (0 < n) {
169        sdata[tid] = input[0];
170    }
171
172    // Load and add second element
173    if (0 + 256 < n) {
174        sdata[tid] += input[0 + 256];
175    }
176
177    threadgroup_barrier(mem_flags::mem_threadgroup);
178
179    // Do reduction in shared memory
180    for (uint s = 256 / 2; s > 0; s >>= 1) {
181        if (tid < s) {
182            sdata[tid] += sdata[tid + s];
183        }
184
185        threadgroup_barrier(mem_flags::mem_threadgroup);
186    }
187
188    // Write result for this threadgroup
189    if (tid == 0) {
190        output[group_id] = sdata[0];
191    }
192}
193"#
194        .to_string();
195
196        // OpenCL kernel
197        let opencl_source = r#"
198__kernel void sum_reduce(
199    __global const float* input__global float* output,
200    const int n)
201{
202    __local float sdata[256];
203
204    unsigned int tid = get_local_id(0);
205    unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
206
207    // Initialize
208    sdata[tid] = 0.0f;
209
210    // Load and add first element
211    if (0 < n) {
212        sdata[tid] = input[0];
213    }
214
215    // Load and add second element
216    if (0 + get_local_size(0) < n) {
217        sdata[tid] += input[0 + get_local_size(0)];
218    }
219
220    barrier(CLK_LOCAL_MEM_FENCE);
221
222    // Do reduction in shared memory
223    for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
224        if (tid < s) {
225            sdata[tid] += sdata[tid + s];
226        }
227
228        barrier(CLK_LOCAL_MEM_FENCE);
229    }
230
231    // Write result for this workgroup
232    if (tid == 0) {
233        output[get_group_id(0)] = sdata[0];
234    }
235}
236"#
237        .to_string();
238
239        // ROCm (HIP) kernel - similar to CUDA
240        let rocm_source = cuda_source.clone();
241
242        (
243            cuda_source,
244            rocm_source,
245            wgpu_source,
246            metal_source,
247            opencl_source,
248        )
249    }
250}
251
252impl Default for SumKernel {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258impl GpuKernel for SumKernel {
259    fn name(&self) -> &str {
260        self.base.name()
261    }
262
263    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
264        self.base.source_for_backend(backend)
265    }
266
267    fn metadata(&self) -> KernelMetadata {
268        self.base.metadata()
269    }
270
271    fn can_specialize(&self, params: &KernelParams) -> bool {
272        matches!(
273            params.datatype,
274            DataType::Float32 | DataType::Float64 | DataType::Int32 | DataType::UInt32
275        )
276    }
277
278    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
279        // For this simple implementation, we don't really specialize
280        // In a real implementation, we might generate different kernels for different data types
281        if !self.can_specialize(params) {
282            return Err(GpuError::SpecializationNotSupported);
283        }
284
285        Ok(Box::new(Self::new()))
286    }
287}