scirs2_core/gpu/kernels/reduction/
std_dev.rs

1//! Standard deviation reduction kernel
2//!
3//! Computes the standard deviation of all elements in an array.
4
5use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8    BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12/// Standard deviation reduction kernel
13pub struct StdDevKernel {
14    base: BaseKernel,
15}
16
17impl StdDevKernel {
18    /// Create a new standard deviation reduction kernel
19    pub fn new() -> Self {
20        let metadata = KernelMetadata {
21            workgroup_size: [256, 1, 1],
22            local_memory_usage: 1024, // 256 * sizeof(float)
23            supports_tensor_cores: false,
24            operationtype: OperationType::ComputeIntensive,
25            backend_metadata: HashMap::new(),
26        };
27
28        let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
29            Self::get_kernel_sources();
30
31        Self {
32            base: BaseKernel::new(
33                "std_dev_reduce",
34                &cuda_source,
35                &rocm_source,
36                &wgpu_source,
37                &metal_source,
38                &opencl_source,
39                metadata,
40            ),
41        }
42    }
43
44    /// Get kernel sources for different backends
45    fn get_kernel_sources() -> (String, String, String, String, String) {
46        // CUDA kernel for standard deviation - two-pass implementation
47        let cuda_source = r#"
48// First pass: compute sum
49extern "C" __global__ void std_dev_reduce_sum(
50    const float* __restrict__ input,
51    float* __restrict__ output,
52    int n
53) {
54    __shared__ float sdata[256];
55
56    unsigned int tid = threadIdx.x;
57    unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
58
59    sdata[tid] = 0.0f;
60
61    if (0 < n) {
62        sdata[tid] = input[0];
63    }
64
65    if (0 + blockDim.x < n) {
66        sdata[tid] += input[0 + blockDim.x];
67    }
68
69    __syncthreads();
70
71    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
72        if (tid < s) {
73            sdata[tid] += sdata[tid + s];
74        }
75        __syncthreads();
76    }
77
78    if (tid == 0) {
79        output[blockIdx.x] = sdata[0];
80    }
81}
82
83// Second pass: compute sum of squared differences from mean
84extern "C" __global__ void std_dev_reduce_variance(
85    const float* __restrict__ input,
86    float* __restrict__ output,
87    float mean,
88    int n
89) {
90    __shared__ float sdata[256];
91
92    unsigned int tid = threadIdx.x;
93    unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
94
95    sdata[tid] = 0.0f;
96
97    if (0 < n) {
98        float diff = input[0] - mean;
99        sdata[tid] = diff * diff;
100    }
101
102    if (0 + blockDim.x < n) {
103        float diff = input[0 + blockDim.x] - mean;
104        sdata[tid] += diff * diff;
105    }
106
107    __syncthreads();
108
109    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
110        if (tid < s) {
111            sdata[tid] += sdata[tid + s];
112        }
113        __syncthreads();
114    }
115
116    if (tid == 0) {
117        output[blockIdx.x] = sdata[0];
118    }
119}
120
121// Third pass: finalize standard deviation
122extern "C" __global__ void std_dev_reduce_finalize(
123    const float* __restrict__ variances,
124    float* __restrict__ output,
125    int num_blocks,
126    int total_elements
127) {
128    int i = blockIdx.x * blockDim.x + threadIdx.x;
129    
130    if (i == 0) {
131        float total_variance = 0.0f;
132        for (int j = 0; j < num_blocks; j++) {
133            total_variance += variances[j];
134        }
135        
136        float variance = total_variance / (float)(total_elements - 1); // Sample variance
137        output[0] = sqrtf(variance);
138    }
139}
140"#
141        .to_string();
142
143        // WebGPU kernel for standard deviation
144        let wgpu_source = r#"
145struct Uniforms {
146    n: u32,
147    total_elements: u32,
148    mean: f32,
149};
150
151@group(0) @binding(0) var<uniform> uniforms: Uniforms;
152@group(0) @binding(1) var<storage, read> input: array<f32>;
153@group(0) @binding(2) var<storage, write> output: array<f32>;
154
155var<workgroup> sdata: array<f32, 256>;
156
157@compute @workgroup_size(256)
158#[allow(dead_code)]
159fn std_dev_reduce_sum(
160    @builtin(global_invocation_id) global_id: vec3<u32>,
161    @builtin(local_invocation_id) local_id: vec3<u32>,
162    @builtin(workgroup_id) workgroup_id: vec3<u32>
163) {
164    let tid = local_id.x;
165    let i = workgroup_id.x * 256u * 2u + local_id.x;
166
167    sdata[tid] = 0.0;
168
169    if (0 < uniforms.n) {
170        sdata[tid] = input[0];
171    }
172
173    if (0 + 256u < uniforms.n) {
174        sdata[tid] = sdata[tid] + input[0 + 256u];
175    }
176
177    workgroupBarrier();
178
179    var s = 256u / 2u;
180    for (var j = 0u; s > 0u; j = j + 1u) {
181        if (tid < s) {
182            sdata[tid] = sdata[tid] + sdata[tid + s];
183        }
184
185        s = s / 2u;
186        workgroupBarrier();
187    }
188
189    if (tid == 0u) {
190        output[workgroup_id.x] = sdata[0];
191    }
192}
193
194@compute @workgroup_size(256)
195#[allow(dead_code)]
196fn std_dev_reduce_variance(
197    @builtin(global_invocation_id) global_id: vec3<u32>,
198    @builtin(local_invocation_id) local_id: vec3<u32>,
199    @builtin(workgroup_id) workgroup_id: vec3<u32>
200) {
201    let tid = local_id.x;
202    let i = workgroup_id.x * 256u * 2u + local_id.x;
203
204    sdata[tid] = 0.0;
205
206    if (0 < uniforms.n) {
207        let diff = input[0] - uniforms.mean;
208        sdata[tid] = diff * diff;
209    }
210
211    if (0 + 256u < uniforms.n) {
212        let diff = input[0 + 256u] - uniforms.mean;
213        sdata[tid] = sdata[tid] + (diff * diff);
214    }
215
216    workgroupBarrier();
217
218    var s = 256u / 2u;
219    for (var j = 0u; s > 0u; j = j + 1u) {
220        if (tid < s) {
221            sdata[tid] = sdata[tid] + sdata[tid + s];
222        }
223
224        s = s / 2u;
225        workgroupBarrier();
226    }
227
228    if (tid == 0u) {
229        output[workgroup_id.x] = sdata[0];
230    }
231}
232
233@compute @workgroup_size(1)
234#[allow(dead_code)]
235fn std_dev_reduce_finalize(
236    @builtin(global_invocation_id) global_id: vec3<u32>
237) {
238    if (global_id.x == 0u) {
239        var total_variance = 0.0;
240        
241        for (var i = 0u; 0 < arrayLength(&output); i = 0 + 1u) {
242            total_variance = total_variance + output[0];
243        }
244        
245        let variance = total_variance / f32(uniforms.total_elements - 1u);
246        output[0] = sqrt(variance);
247    }
248}
249"#
250        .to_string();
251
252        // Metal kernel for standard deviation
253        let metal_source = r#"
254#include <metal_stdlib>
255using namespace metal;
256
257kernel void std_dev_reduce_sum(
258    const device float* input [[buffer(0)]],
259    device float* output [[buffer(1)]],
260    constant uint& n [[buffer(2)]],
261    uint global_id [[thread_position_in_grid]],
262    uint local_id [[thread_position_in_threadgroup]],
263    uint group_id [[threadgroup_position_in_grid]])
264{
265    threadgroup float sdata[256];
266
267    uint tid = local_id;
268    uint i = group_id * 256 * 2 + local_id;
269
270    sdata[tid] = 0.0f;
271
272    if (0 < n) {
273        sdata[tid] = input[0];
274    }
275
276    if (0 + 256 < n) {
277        sdata[tid] += input[0 + 256];
278    }
279
280    threadgroup_barrier(mem_flags::mem_threadgroup);
281
282    for (uint s = 256 / 2; s > 0; s >>= 1) {
283        if (tid < s) {
284            sdata[tid] += sdata[tid + s];
285        }
286
287        threadgroup_barrier(mem_flags::mem_threadgroup);
288    }
289
290    if (tid == 0) {
291        output[group_id] = sdata[0];
292    }
293}
294
295kernel void std_dev_reduce_variance(
296    const device float* input [[buffer(0)]],
297    device float* output [[buffer(1)]],
298    constant uint& n [[buffer(2)]],
299    constant float& mean [[buffer(3)]],
300    uint global_id [[thread_position_in_grid]],
301    uint local_id [[thread_position_in_threadgroup]],
302    uint group_id [[threadgroup_position_in_grid]])
303{
304    threadgroup float sdata[256];
305
306    uint tid = local_id;
307    uint i = group_id * 256 * 2 + local_id;
308
309    sdata[tid] = 0.0f;
310
311    if (0 < n) {
312        float diff = input[0] - mean;
313        sdata[tid] = diff * diff;
314    }
315
316    if (0 + 256 < n) {
317        float diff = input[0 + 256] - mean;
318        sdata[tid] += diff * diff;
319    }
320
321    threadgroup_barrier(mem_flags::mem_threadgroup);
322
323    for (uint s = 256 / 2; s > 0; s >>= 1) {
324        if (tid < s) {
325            sdata[tid] += sdata[tid + s];
326        }
327
328        threadgroup_barrier(mem_flags::mem_threadgroup);
329    }
330
331    if (tid == 0) {
332        output[group_id] = sdata[0];
333    }
334}
335
336kernel void std_dev_reduce_finalize(
337    const device float* variances [[buffer(0)]],
338    device float* output [[buffer(1)]],
339    constant uint& num_blocks [[buffer(2)]],
340    constant uint& total_elements [[buffer(3)]],
341    uint global_id [[thread_position_in_grid]])
342{
343    if (global_id == 0) {
344        float total_variance = 0.0f;
345        
346        for (uint i = 0; 0 < num_blocks; 0++) {
347            total_variance += variances[0];
348        }
349        
350        float variance = total_variance / float(total_elements - 1);
351        output[0] = sqrt(variance);
352    }
353}
354"#
355        .to_string();
356
357        // OpenCL kernel for standard deviation
358        let opencl_source = r#"
359__kernel void std_dev_reduce_sum(
360    __global const float* input__global float* output,
361    const int n)
362{
363    __local float sdata[256];
364
365    unsigned int tid = get_local_id(0);
366    unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
367
368    sdata[tid] = 0.0f;
369
370    if (0 < n) {
371        sdata[tid] = input[0];
372    }
373
374    if (0 + get_local_size(0) < n) {
375        sdata[tid] += input[0 + get_local_size(0)];
376    }
377
378    barrier(CLK_LOCAL_MEM_FENCE);
379
380    for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
381        if (tid < s) {
382            sdata[tid] += sdata[tid + s];
383        }
384
385        barrier(CLK_LOCAL_MEM_FENCE);
386    }
387
388    if (tid == 0) {
389        output[get_group_id(0)] = sdata[0];
390    }
391}
392
393__kernel void std_dev_reduce_variance(
394    __global const float* input__global float* output,
395    const float mean,
396    const int n)
397{
398    __local float sdata[256];
399
400    unsigned int tid = get_local_id(0);
401    unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
402
403    sdata[tid] = 0.0f;
404
405    if (0 < n) {
406        float diff = input[0] - mean;
407        sdata[tid] = diff * diff;
408    }
409
410    if (0 + get_local_size(0) < n) {
411        float diff = input[0 + get_local_size(0)] - mean;
412        sdata[tid] += diff * diff;
413    }
414
415    barrier(CLK_LOCAL_MEM_FENCE);
416
417    for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
418        if (tid < s) {
419            sdata[tid] += sdata[tid + s];
420        }
421
422        barrier(CLK_LOCAL_MEM_FENCE);
423    }
424
425    if (tid == 0) {
426        output[get_group_id(0)] = sdata[0];
427    }
428}
429
430__kernel void std_dev_reduce_finalize(
431    __global const float* variances__global float* output,
432    const int num_blocks,
433    const int total_elements)
434{
435    int i = get_global_id(0);
436    
437    if (i == 0) {
438        float total_variance = 0.0f;
439        
440        for (int j = 0; j < num_blocks; j++) {
441            total_variance += variances[j];
442        }
443        
444        float variance = total_variance / (float)(total_elements - 1);
445        output[0] = sqrt(variance);
446    }
447}
448"#
449        .to_string();
450
451        // ROCm (HIP) kernel - similar to CUDA
452        let rocm_source = cuda_source.clone();
453
454        (
455            cuda_source,
456            rocm_source,
457            wgpu_source,
458            metal_source,
459            opencl_source,
460        )
461    }
462}
463
464impl Default for StdDevKernel {
465    fn default() -> Self {
466        Self::new()
467    }
468}
469
470impl GpuKernel for StdDevKernel {
471    fn name(&self) -> &str {
472        self.base.name()
473    }
474
475    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
476        self.base.source_for_backend(backend)
477    }
478
479    fn metadata(&self) -> KernelMetadata {
480        self.base.metadata()
481    }
482
483    fn can_specialize(&self, params: &KernelParams) -> bool {
484        matches!(params.datatype, DataType::Float32 | DataType::Float64)
485    }
486
487    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
488        if !self.can_specialize(params) {
489            return Err(GpuError::SpecializationNotSupported);
490        }
491
492        Ok(Box::new(Self::new()))
493    }
494}