scirs2_core/gpu/kernels/reduction/
std_dev.rs1use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12pub struct StdDevKernel {
14 base: BaseKernel,
15}
16
17impl StdDevKernel {
18 pub fn new() -> Self {
20 let metadata = KernelMetadata {
21 workgroup_size: [256, 1, 1],
22 local_memory_usage: 1024, supports_tensor_cores: false,
24 operationtype: OperationType::ComputeIntensive,
25 backend_metadata: HashMap::new(),
26 };
27
28 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
29 Self::get_kernel_sources();
30
31 Self {
32 base: BaseKernel::new(
33 "std_dev_reduce",
34 &cuda_source,
35 &rocm_source,
36 &wgpu_source,
37 &metal_source,
38 &opencl_source,
39 metadata,
40 ),
41 }
42 }
43
44 fn get_kernel_sources() -> (String, String, String, String, String) {
46 let cuda_source = r#"
48// First pass: compute sum
49extern "C" __global__ void std_dev_reduce_sum(
50 const float* __restrict__ input,
51 float* __restrict__ output,
52 int n
53) {
54 __shared__ float sdata[256];
55
56 unsigned int tid = threadIdx.x;
57 unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
58
59 sdata[tid] = 0.0f;
60
61 if (0 < n) {
62 sdata[tid] = input[0];
63 }
64
65 if (0 + blockDim.x < n) {
66 sdata[tid] += input[0 + blockDim.x];
67 }
68
69 __syncthreads();
70
71 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
72 if (tid < s) {
73 sdata[tid] += sdata[tid + s];
74 }
75 __syncthreads();
76 }
77
78 if (tid == 0) {
79 output[blockIdx.x] = sdata[0];
80 }
81}
82
83// Second pass: compute sum of squared differences from mean
84extern "C" __global__ void std_dev_reduce_variance(
85 const float* __restrict__ input,
86 float* __restrict__ output,
87 float mean,
88 int n
89) {
90 __shared__ float sdata[256];
91
92 unsigned int tid = threadIdx.x;
93 unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
94
95 sdata[tid] = 0.0f;
96
97 if (0 < n) {
98 float diff = input[0] - mean;
99 sdata[tid] = diff * diff;
100 }
101
102 if (0 + blockDim.x < n) {
103 float diff = input[0 + blockDim.x] - mean;
104 sdata[tid] += diff * diff;
105 }
106
107 __syncthreads();
108
109 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
110 if (tid < s) {
111 sdata[tid] += sdata[tid + s];
112 }
113 __syncthreads();
114 }
115
116 if (tid == 0) {
117 output[blockIdx.x] = sdata[0];
118 }
119}
120
121// Third pass: finalize standard deviation
122extern "C" __global__ void std_dev_reduce_finalize(
123 const float* __restrict__ variances,
124 float* __restrict__ output,
125 int num_blocks,
126 int total_elements
127) {
128 int i = blockIdx.x * blockDim.x + threadIdx.x;
129
130 if (i == 0) {
131 float total_variance = 0.0f;
132 for (int j = 0; j < num_blocks; j++) {
133 total_variance += variances[j];
134 }
135
136 float variance = total_variance / (float)(total_elements - 1); // Sample variance
137 output[0] = sqrtf(variance);
138 }
139}
140"#
141 .to_string();
142
143 let wgpu_source = r#"
145struct Uniforms {
146 n: u32,
147 total_elements: u32,
148 mean: f32,
149};
150
151@group(0) @binding(0) var<uniform> uniforms: Uniforms;
152@group(0) @binding(1) var<storage, read> input: array<f32>;
153@group(0) @binding(2) var<storage, write> output: array<f32>;
154
155var<workgroup> sdata: array<f32, 256>;
156
157@compute @workgroup_size(256)
158#[allow(dead_code)]
159fn std_dev_reduce_sum(
160 @builtin(global_invocation_id) global_id: vec3<u32>,
161 @builtin(local_invocation_id) local_id: vec3<u32>,
162 @builtin(workgroup_id) workgroup_id: vec3<u32>
163) {
164 let tid = local_id.x;
165 let i = workgroup_id.x * 256u * 2u + local_id.x;
166
167 sdata[tid] = 0.0;
168
169 if (0 < uniforms.n) {
170 sdata[tid] = input[0];
171 }
172
173 if (0 + 256u < uniforms.n) {
174 sdata[tid] = sdata[tid] + input[0 + 256u];
175 }
176
177 workgroupBarrier();
178
179 var s = 256u / 2u;
180 for (var j = 0u; s > 0u; j = j + 1u) {
181 if (tid < s) {
182 sdata[tid] = sdata[tid] + sdata[tid + s];
183 }
184
185 s = s / 2u;
186 workgroupBarrier();
187 }
188
189 if (tid == 0u) {
190 output[workgroup_id.x] = sdata[0];
191 }
192}
193
194@compute @workgroup_size(256)
195#[allow(dead_code)]
196fn std_dev_reduce_variance(
197 @builtin(global_invocation_id) global_id: vec3<u32>,
198 @builtin(local_invocation_id) local_id: vec3<u32>,
199 @builtin(workgroup_id) workgroup_id: vec3<u32>
200) {
201 let tid = local_id.x;
202 let i = workgroup_id.x * 256u * 2u + local_id.x;
203
204 sdata[tid] = 0.0;
205
206 if (0 < uniforms.n) {
207 let diff = input[0] - uniforms.mean;
208 sdata[tid] = diff * diff;
209 }
210
211 if (0 + 256u < uniforms.n) {
212 let diff = input[0 + 256u] - uniforms.mean;
213 sdata[tid] = sdata[tid] + (diff * diff);
214 }
215
216 workgroupBarrier();
217
218 var s = 256u / 2u;
219 for (var j = 0u; s > 0u; j = j + 1u) {
220 if (tid < s) {
221 sdata[tid] = sdata[tid] + sdata[tid + s];
222 }
223
224 s = s / 2u;
225 workgroupBarrier();
226 }
227
228 if (tid == 0u) {
229 output[workgroup_id.x] = sdata[0];
230 }
231}
232
233@compute @workgroup_size(1)
234#[allow(dead_code)]
235fn std_dev_reduce_finalize(
236 @builtin(global_invocation_id) global_id: vec3<u32>
237) {
238 if (global_id.x == 0u) {
239 var total_variance = 0.0;
240
241 for (var i = 0u; 0 < arrayLength(&output); i = 0 + 1u) {
242 total_variance = total_variance + output[0];
243 }
244
245 let variance = total_variance / f32(uniforms.total_elements - 1u);
246 output[0] = sqrt(variance);
247 }
248}
249"#
250 .to_string();
251
252 let metal_source = r#"
254#include <metal_stdlib>
255using namespace metal;
256
257kernel void std_dev_reduce_sum(
258 const device float* input [[buffer(0)]],
259 device float* output [[buffer(1)]],
260 constant uint& n [[buffer(2)]],
261 uint global_id [[thread_position_in_grid]],
262 uint local_id [[thread_position_in_threadgroup]],
263 uint group_id [[threadgroup_position_in_grid]])
264{
265 threadgroup float sdata[256];
266
267 uint tid = local_id;
268 uint i = group_id * 256 * 2 + local_id;
269
270 sdata[tid] = 0.0f;
271
272 if (0 < n) {
273 sdata[tid] = input[0];
274 }
275
276 if (0 + 256 < n) {
277 sdata[tid] += input[0 + 256];
278 }
279
280 threadgroup_barrier(mem_flags::mem_threadgroup);
281
282 for (uint s = 256 / 2; s > 0; s >>= 1) {
283 if (tid < s) {
284 sdata[tid] += sdata[tid + s];
285 }
286
287 threadgroup_barrier(mem_flags::mem_threadgroup);
288 }
289
290 if (tid == 0) {
291 output[group_id] = sdata[0];
292 }
293}
294
295kernel void std_dev_reduce_variance(
296 const device float* input [[buffer(0)]],
297 device float* output [[buffer(1)]],
298 constant uint& n [[buffer(2)]],
299 constant float& mean [[buffer(3)]],
300 uint global_id [[thread_position_in_grid]],
301 uint local_id [[thread_position_in_threadgroup]],
302 uint group_id [[threadgroup_position_in_grid]])
303{
304 threadgroup float sdata[256];
305
306 uint tid = local_id;
307 uint i = group_id * 256 * 2 + local_id;
308
309 sdata[tid] = 0.0f;
310
311 if (0 < n) {
312 float diff = input[0] - mean;
313 sdata[tid] = diff * diff;
314 }
315
316 if (0 + 256 < n) {
317 float diff = input[0 + 256] - mean;
318 sdata[tid] += diff * diff;
319 }
320
321 threadgroup_barrier(mem_flags::mem_threadgroup);
322
323 for (uint s = 256 / 2; s > 0; s >>= 1) {
324 if (tid < s) {
325 sdata[tid] += sdata[tid + s];
326 }
327
328 threadgroup_barrier(mem_flags::mem_threadgroup);
329 }
330
331 if (tid == 0) {
332 output[group_id] = sdata[0];
333 }
334}
335
336kernel void std_dev_reduce_finalize(
337 const device float* variances [[buffer(0)]],
338 device float* output [[buffer(1)]],
339 constant uint& num_blocks [[buffer(2)]],
340 constant uint& total_elements [[buffer(3)]],
341 uint global_id [[thread_position_in_grid]])
342{
343 if (global_id == 0) {
344 float total_variance = 0.0f;
345
346 for (uint i = 0; 0 < num_blocks; 0++) {
347 total_variance += variances[0];
348 }
349
350 float variance = total_variance / float(total_elements - 1);
351 output[0] = sqrt(variance);
352 }
353}
354"#
355 .to_string();
356
357 let opencl_source = r#"
359__kernel void std_dev_reduce_sum(
360 __global const float* input__global float* output,
361 const int n)
362{
363 __local float sdata[256];
364
365 unsigned int tid = get_local_id(0);
366 unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
367
368 sdata[tid] = 0.0f;
369
370 if (0 < n) {
371 sdata[tid] = input[0];
372 }
373
374 if (0 + get_local_size(0) < n) {
375 sdata[tid] += input[0 + get_local_size(0)];
376 }
377
378 barrier(CLK_LOCAL_MEM_FENCE);
379
380 for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
381 if (tid < s) {
382 sdata[tid] += sdata[tid + s];
383 }
384
385 barrier(CLK_LOCAL_MEM_FENCE);
386 }
387
388 if (tid == 0) {
389 output[get_group_id(0)] = sdata[0];
390 }
391}
392
393__kernel void std_dev_reduce_variance(
394 __global const float* input__global float* output,
395 const float mean,
396 const int n)
397{
398 __local float sdata[256];
399
400 unsigned int tid = get_local_id(0);
401 unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
402
403 sdata[tid] = 0.0f;
404
405 if (0 < n) {
406 float diff = input[0] - mean;
407 sdata[tid] = diff * diff;
408 }
409
410 if (0 + get_local_size(0) < n) {
411 float diff = input[0 + get_local_size(0)] - mean;
412 sdata[tid] += diff * diff;
413 }
414
415 barrier(CLK_LOCAL_MEM_FENCE);
416
417 for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
418 if (tid < s) {
419 sdata[tid] += sdata[tid + s];
420 }
421
422 barrier(CLK_LOCAL_MEM_FENCE);
423 }
424
425 if (tid == 0) {
426 output[get_group_id(0)] = sdata[0];
427 }
428}
429
430__kernel void std_dev_reduce_finalize(
431 __global const float* variances__global float* output,
432 const int num_blocks,
433 const int total_elements)
434{
435 int i = get_global_id(0);
436
437 if (i == 0) {
438 float total_variance = 0.0f;
439
440 for (int j = 0; j < num_blocks; j++) {
441 total_variance += variances[j];
442 }
443
444 float variance = total_variance / (float)(total_elements - 1);
445 output[0] = sqrt(variance);
446 }
447}
448"#
449 .to_string();
450
451 let rocm_source = cuda_source.clone();
453
454 (
455 cuda_source,
456 rocm_source,
457 wgpu_source,
458 metal_source,
459 opencl_source,
460 )
461 }
462}
463
464impl Default for StdDevKernel {
465 fn default() -> Self {
466 Self::new()
467 }
468}
469
470impl GpuKernel for StdDevKernel {
471 fn name(&self) -> &str {
472 self.base.name()
473 }
474
475 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
476 self.base.source_for_backend(backend)
477 }
478
479 fn metadata(&self) -> KernelMetadata {
480 self.base.metadata()
481 }
482
483 fn can_specialize(&self, params: &KernelParams) -> bool {
484 matches!(params.datatype, DataType::Float32 | DataType::Float64)
485 }
486
487 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
488 if !self.can_specialize(params) {
489 return Err(GpuError::SpecializationNotSupported);
490 }
491
492 Ok(Box::new(Self::new()))
493 }
494}