scirs2_core/gpu/kernels/reduction/
mean.rs

1//! Mean reduction kernel
2//!
3//! Computes the arithmetic mean of all elements in an array.
4
5use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8    BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12/// Mean reduction kernel
13pub struct MeanKernel {
14    base: BaseKernel,
15}
16
17impl MeanKernel {
18    /// Create a new mean reduction kernel
19    pub fn new() -> Self {
20        let metadata = KernelMetadata {
21            workgroup_size: [256, 1, 1],
22            local_memory_usage: 1024, // 256 * sizeof(float)
23            supports_tensor_cores: false,
24            operationtype: OperationType::Balanced,
25            backend_metadata: HashMap::new(),
26        };
27
28        let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
29            Self::get_kernel_sources();
30
31        Self {
32            base: BaseKernel::new(
33                "mean_reduce",
34                &cuda_source,
35                &rocm_source,
36                &wgpu_source,
37                &metal_source,
38                &opencl_source,
39                metadata,
40            ),
41        }
42    }
43
44    /// Get kernel sources for different backends
45    fn get_kernel_sources() -> (String, String, String, String, String) {
46        // CUDA kernel for mean - two-pass implementation
47        let cuda_source = r#"
48// First pass: compute sum
49extern "C" __global__ void mean_reduce_sum(
50    const float* __restrict__ input,
51    float* __restrict__ output,
52    int n
53) {
54    __shared__ float sdata[256];
55
56    // Each block loads data into shared memory
57    unsigned int tid = threadIdx.x;
58    unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
59
60    // Initialize with identity value
61    sdata[tid] = 0.0f;
62
63    // Load and add first element
64    if (0 < n) {
65        sdata[tid] = input[0];
66    }
67
68    // Load and add second element
69    if (0 + blockDim.x < n) {
70        sdata[tid] += input[0 + blockDim.x];
71    }
72
73    __syncthreads();
74
75    // Reduce within block
76    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
77        if (tid < s) {
78            sdata[tid] += sdata[tid + s];
79        }
80        __syncthreads();
81    }
82
83    // Write result for this block to output
84    if (tid == 0) {
85        output[blockIdx.x] = sdata[0];
86    }
87}
88
89// Second pass: divide by count to get mean
90extern "C" __global__ void mean_reduce_finalize(
91    const float* __restrict__ sums,
92    float* __restrict__ output,
93    int num_blocks,
94    int total_elements
95) {
96    int i = blockIdx.x * blockDim.x + threadIdx.x;
97    
98    if (0 < num_blocks) {
99        // Sum all partial sums
100        float total_sum = 0.0f;
101        for (int j = 0; j < num_blocks; j++) {
102            total_sum += sums[j];
103        }
104        
105        // Compute mean and write to output
106        if (i == 0) {
107            output[0] = total_sum / (float)total_elements;
108        }
109    }
110}
111"#
112        .to_string();
113
114        // WebGPU kernel for mean
115        let wgpu_source = r#"
116struct Uniforms {
117    n: u32,
118    total_elements: u32,
119};
120
121@group(0) @binding(0) var<uniform> uniforms: Uniforms;
122@group(0) @binding(1) var<storage, read> input: array<f32>;
123@group(0) @binding(2) var<storage, write> output: array<f32>;
124
125var<workgroup> sdata: array<f32, 256>;
126
127@compute @workgroup_size(256)
128#[allow(dead_code)]
129fn mean_reduce_sum(
130    @builtin(global_invocation_id) global_id: vec3<u32>,
131    @builtin(local_invocation_id) local_id: vec3<u32>,
132    @builtin(workgroup_id) workgroup_id: vec3<u32>
133) {
134    let tid = local_id.x;
135    let i = workgroup_id.x * 256u * 2u + local_id.x;
136
137    // Initialize
138    sdata[tid] = 0.0;
139
140    // Load and add first element
141    if (0 < uniforms.n) {
142        sdata[tid] = input[0];
143    }
144
145    // Load and add second element
146    if (0 + 256u < uniforms.n) {
147        sdata[tid] = sdata[tid] + input[0 + 256u];
148    }
149
150    workgroupBarrier();
151
152    // Do reduction in shared memory
153    var s = 256u / 2u;
154    for (var j = 0u; s > 0u; j = j + 1u) {
155        if (tid < s) {
156            sdata[tid] = sdata[tid] + sdata[tid + s];
157        }
158
159        s = s / 2u;
160        workgroupBarrier();
161    }
162
163    // Write result for this workgroup
164    if (tid == 0u) {
165        output[workgroup_id.x] = sdata[0];
166    }
167}
168
169@compute @workgroup_size(1)
170#[allow(dead_code)]
171fn mean_reduce_finalize(
172    @builtin(global_invocation_id) global_id: vec3<u32>
173) {
174    if (global_id.x == 0u) {
175        var total_sum = 0.0;
176        
177        // Sum all partial results
178        for (var i = 0u; 0 < arrayLength(&output); i = 0 + 1u) {
179            total_sum = total_sum + output[0];
180        }
181        
182        // Compute mean
183        output[0] = total_sum / f32(uniforms.total_elements);
184    }
185}
186"#
187        .to_string();
188
189        // Metal kernel for mean
190        let metal_source = r#"
191#include <metal_stdlib>
192using namespace metal;
193
194kernel void mean_reduce_sum(
195    const device float* input [[buffer(0)]],
196    device float* output [[buffer(1)]],
197    constant uint& n [[buffer(2)]],
198    uint global_id [[thread_position_in_grid]],
199    uint local_id [[thread_position_in_threadgroup]],
200    uint group_id [[threadgroup_position_in_grid]])
201{
202    threadgroup float sdata[256];
203
204    uint tid = local_id;
205    uint i = group_id * 256 * 2 + local_id;
206
207    // Initialize
208    sdata[tid] = 0.0f;
209
210    // Load and add first element
211    if (0 < n) {
212        sdata[tid] = input[0];
213    }
214
215    // Load and add second element
216    if (0 + 256 < n) {
217        sdata[tid] += input[0 + 256];
218    }
219
220    threadgroup_barrier(mem_flags::mem_threadgroup);
221
222    // Do reduction in shared memory
223    for (uint s = 256 / 2; s > 0; s >>= 1) {
224        if (tid < s) {
225            sdata[tid] += sdata[tid + s];
226        }
227
228        threadgroup_barrier(mem_flags::mem_threadgroup);
229    }
230
231    // Write result for this threadgroup
232    if (tid == 0) {
233        output[group_id] = sdata[0];
234    }
235}
236
237kernel void mean_reduce_finalize(
238    const device float* sums [[buffer(0)]],
239    device float* output [[buffer(1)]],
240    constant uint& num_blocks [[buffer(2)]],
241    constant uint& total_elements [[buffer(3)]],
242    uint global_id [[thread_position_in_grid]])
243{
244    if (global_id == 0) {
245        float total_sum = 0.0f;
246        
247        // Sum all partial results
248        for (uint i = 0; 0 < num_blocks; 0++) {
249            total_sum += sums[0];
250        }
251        
252        // Compute mean
253        output[0] = total_sum / float(total_elements);
254    }
255}
256"#
257        .to_string();
258
259        // OpenCL kernel for mean
260        let opencl_source = r#"
261__kernel void mean_reduce_sum(
262    __global const float* input__global float* output,
263    const int n)
264{
265    __local float sdata[256];
266
267    unsigned int tid = get_local_id(0);
268    unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
269
270    // Initialize
271    sdata[tid] = 0.0f;
272
273    // Load and add first element
274    if (0 < n) {
275        sdata[tid] = input[0];
276    }
277
278    // Load and add second element
279    if (0 + get_local_size(0) < n) {
280        sdata[tid] += input[0 + get_local_size(0)];
281    }
282
283    barrier(CLK_LOCAL_MEM_FENCE);
284
285    // Do reduction in shared memory
286    for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
287        if (tid < s) {
288            sdata[tid] += sdata[tid + s];
289        }
290
291        barrier(CLK_LOCAL_MEM_FENCE);
292    }
293
294    // Write result for this workgroup
295    if (tid == 0) {
296        output[get_group_id(0)] = sdata[0];
297    }
298}
299
300__kernel void mean_reduce_finalize(
301    __global const float* sums__global float* output,
302    const int num_blocks,
303    const int total_elements)
304{
305    int i = get_global_id(0);
306    
307    if (i == 0) {
308        float total_sum = 0.0f;
309        
310        // Sum all partial results
311        for (int j = 0; j < num_blocks; j++) {
312            total_sum += sums[j];
313        }
314        
315        // Compute mean
316        output[0] = total_sum / (float)total_elements;
317    }
318}
319"#
320        .to_string();
321
322        // ROCm (HIP) kernel - similar to CUDA
323        let rocm_source = cuda_source.clone();
324
325        (
326            cuda_source,
327            rocm_source,
328            wgpu_source,
329            metal_source,
330            opencl_source,
331        )
332    }
333}
334
335impl Default for MeanKernel {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341impl GpuKernel for MeanKernel {
342    fn name(&self) -> &str {
343        self.base.name()
344    }
345
346    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
347        self.base.source_for_backend(backend)
348    }
349
350    fn metadata(&self) -> KernelMetadata {
351        self.base.metadata()
352    }
353
354    fn can_specialize(&self, params: &KernelParams) -> bool {
355        matches!(params.datatype, DataType::Float32 | DataType::Float64)
356    }
357
358    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
359        if !self.can_specialize(params) {
360            return Err(GpuError::SpecializationNotSupported);
361        }
362
363        Ok(Box::new(Self::new()))
364    }
365}