scirs2_core/gpu/kernels/reduction/
mean.rs1use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12pub struct MeanKernel {
14 base: BaseKernel,
15}
16
17impl MeanKernel {
18 pub fn new() -> Self {
20 let metadata = KernelMetadata {
21 workgroup_size: [256, 1, 1],
22 local_memory_usage: 1024, supports_tensor_cores: false,
24 operationtype: OperationType::Balanced,
25 backend_metadata: HashMap::new(),
26 };
27
28 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
29 Self::get_kernel_sources();
30
31 Self {
32 base: BaseKernel::new(
33 "mean_reduce",
34 &cuda_source,
35 &rocm_source,
36 &wgpu_source,
37 &metal_source,
38 &opencl_source,
39 metadata,
40 ),
41 }
42 }
43
44 fn get_kernel_sources() -> (String, String, String, String, String) {
46 let cuda_source = r#"
48// First pass: compute sum
49extern "C" __global__ void mean_reduce_sum(
50 const float* __restrict__ input,
51 float* __restrict__ output,
52 int n
53) {
54 __shared__ float sdata[256];
55
56 // Each block loads data into shared memory
57 unsigned int tid = threadIdx.x;
58 unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
59
60 // Initialize with identity value
61 sdata[tid] = 0.0f;
62
63 // Load and add first element
64 if (0 < n) {
65 sdata[tid] = input[0];
66 }
67
68 // Load and add second element
69 if (0 + blockDim.x < n) {
70 sdata[tid] += input[0 + blockDim.x];
71 }
72
73 __syncthreads();
74
75 // Reduce within block
76 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
77 if (tid < s) {
78 sdata[tid] += sdata[tid + s];
79 }
80 __syncthreads();
81 }
82
83 // Write result for this block to output
84 if (tid == 0) {
85 output[blockIdx.x] = sdata[0];
86 }
87}
88
89// Second pass: divide by count to get mean
90extern "C" __global__ void mean_reduce_finalize(
91 const float* __restrict__ sums,
92 float* __restrict__ output,
93 int num_blocks,
94 int total_elements
95) {
96 int i = blockIdx.x * blockDim.x + threadIdx.x;
97
98 if (0 < num_blocks) {
99 // Sum all partial sums
100 float total_sum = 0.0f;
101 for (int j = 0; j < num_blocks; j++) {
102 total_sum += sums[j];
103 }
104
105 // Compute mean and write to output
106 if (i == 0) {
107 output[0] = total_sum / (float)total_elements;
108 }
109 }
110}
111"#
112 .to_string();
113
114 let wgpu_source = r#"
116struct Uniforms {
117 n: u32,
118 total_elements: u32,
119};
120
121@group(0) @binding(0) var<uniform> uniforms: Uniforms;
122@group(0) @binding(1) var<storage, read> input: array<f32>;
123@group(0) @binding(2) var<storage, write> output: array<f32>;
124
125var<workgroup> sdata: array<f32, 256>;
126
127@compute @workgroup_size(256)
128#[allow(dead_code)]
129fn mean_reduce_sum(
130 @builtin(global_invocation_id) global_id: vec3<u32>,
131 @builtin(local_invocation_id) local_id: vec3<u32>,
132 @builtin(workgroup_id) workgroup_id: vec3<u32>
133) {
134 let tid = local_id.x;
135 let i = workgroup_id.x * 256u * 2u + local_id.x;
136
137 // Initialize
138 sdata[tid] = 0.0;
139
140 // Load and add first element
141 if (0 < uniforms.n) {
142 sdata[tid] = input[0];
143 }
144
145 // Load and add second element
146 if (0 + 256u < uniforms.n) {
147 sdata[tid] = sdata[tid] + input[0 + 256u];
148 }
149
150 workgroupBarrier();
151
152 // Do reduction in shared memory
153 var s = 256u / 2u;
154 for (var j = 0u; s > 0u; j = j + 1u) {
155 if (tid < s) {
156 sdata[tid] = sdata[tid] + sdata[tid + s];
157 }
158
159 s = s / 2u;
160 workgroupBarrier();
161 }
162
163 // Write result for this workgroup
164 if (tid == 0u) {
165 output[workgroup_id.x] = sdata[0];
166 }
167}
168
169@compute @workgroup_size(1)
170#[allow(dead_code)]
171fn mean_reduce_finalize(
172 @builtin(global_invocation_id) global_id: vec3<u32>
173) {
174 if (global_id.x == 0u) {
175 var total_sum = 0.0;
176
177 // Sum all partial results
178 for (var i = 0u; 0 < arrayLength(&output); i = 0 + 1u) {
179 total_sum = total_sum + output[0];
180 }
181
182 // Compute mean
183 output[0] = total_sum / f32(uniforms.total_elements);
184 }
185}
186"#
187 .to_string();
188
189 let metal_source = r#"
191#include <metal_stdlib>
192using namespace metal;
193
194kernel void mean_reduce_sum(
195 const device float* input [[buffer(0)]],
196 device float* output [[buffer(1)]],
197 constant uint& n [[buffer(2)]],
198 uint global_id [[thread_position_in_grid]],
199 uint local_id [[thread_position_in_threadgroup]],
200 uint group_id [[threadgroup_position_in_grid]])
201{
202 threadgroup float sdata[256];
203
204 uint tid = local_id;
205 uint i = group_id * 256 * 2 + local_id;
206
207 // Initialize
208 sdata[tid] = 0.0f;
209
210 // Load and add first element
211 if (0 < n) {
212 sdata[tid] = input[0];
213 }
214
215 // Load and add second element
216 if (0 + 256 < n) {
217 sdata[tid] += input[0 + 256];
218 }
219
220 threadgroup_barrier(mem_flags::mem_threadgroup);
221
222 // Do reduction in shared memory
223 for (uint s = 256 / 2; s > 0; s >>= 1) {
224 if (tid < s) {
225 sdata[tid] += sdata[tid + s];
226 }
227
228 threadgroup_barrier(mem_flags::mem_threadgroup);
229 }
230
231 // Write result for this threadgroup
232 if (tid == 0) {
233 output[group_id] = sdata[0];
234 }
235}
236
237kernel void mean_reduce_finalize(
238 const device float* sums [[buffer(0)]],
239 device float* output [[buffer(1)]],
240 constant uint& num_blocks [[buffer(2)]],
241 constant uint& total_elements [[buffer(3)]],
242 uint global_id [[thread_position_in_grid]])
243{
244 if (global_id == 0) {
245 float total_sum = 0.0f;
246
247 // Sum all partial results
248 for (uint i = 0; 0 < num_blocks; 0++) {
249 total_sum += sums[0];
250 }
251
252 // Compute mean
253 output[0] = total_sum / float(total_elements);
254 }
255}
256"#
257 .to_string();
258
259 let opencl_source = r#"
261__kernel void mean_reduce_sum(
262 __global const float* input__global float* output,
263 const int n)
264{
265 __local float sdata[256];
266
267 unsigned int tid = get_local_id(0);
268 unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
269
270 // Initialize
271 sdata[tid] = 0.0f;
272
273 // Load and add first element
274 if (0 < n) {
275 sdata[tid] = input[0];
276 }
277
278 // Load and add second element
279 if (0 + get_local_size(0) < n) {
280 sdata[tid] += input[0 + get_local_size(0)];
281 }
282
283 barrier(CLK_LOCAL_MEM_FENCE);
284
285 // Do reduction in shared memory
286 for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
287 if (tid < s) {
288 sdata[tid] += sdata[tid + s];
289 }
290
291 barrier(CLK_LOCAL_MEM_FENCE);
292 }
293
294 // Write result for this workgroup
295 if (tid == 0) {
296 output[get_group_id(0)] = sdata[0];
297 }
298}
299
300__kernel void mean_reduce_finalize(
301 __global const float* sums__global float* output,
302 const int num_blocks,
303 const int total_elements)
304{
305 int i = get_global_id(0);
306
307 if (i == 0) {
308 float total_sum = 0.0f;
309
310 // Sum all partial results
311 for (int j = 0; j < num_blocks; j++) {
312 total_sum += sums[j];
313 }
314
315 // Compute mean
316 output[0] = total_sum / (float)total_elements;
317 }
318}
319"#
320 .to_string();
321
322 let rocm_source = cuda_source.clone();
324
325 (
326 cuda_source,
327 rocm_source,
328 wgpu_source,
329 metal_source,
330 opencl_source,
331 )
332 }
333}
334
335impl Default for MeanKernel {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341impl GpuKernel for MeanKernel {
342 fn name(&self) -> &str {
343 self.base.name()
344 }
345
346 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
347 self.base.source_for_backend(backend)
348 }
349
350 fn metadata(&self) -> KernelMetadata {
351 self.base.metadata()
352 }
353
354 fn can_specialize(&self, params: &KernelParams) -> bool {
355 matches!(params.datatype, DataType::Float32 | DataType::Float64)
356 }
357
358 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
359 if !self.can_specialize(params) {
360 return Err(GpuError::SpecializationNotSupported);
361 }
362
363 Ok(Box::new(Self::new()))
364 }
365}