scirs2_metrics/optimization/gpu_kernels/
kernels.rs

1//! GPU kernel source code for different compute backends
2//!
3//! This module contains kernel implementations for CUDA, OpenCL, Metal, and Vulkan
4//! optimized for metrics computation.
5
6#![allow(clippy::too_many_arguments)]
7#![allow(dead_code)]
8
9/// CUDA kernel source code for metrics computation
10pub mod cuda_kernels {
11    pub const MSE_KERNEL: &str = r#"
12    extern "C" __global__ void mse_kernel(
13        const float* y_true,
14        const float* ypred,
15        float* result,
16        int n
17    ) {
18        int idx = blockIdx.x * blockDim.x + threadIdx.x;
19        int stride = blockDim.x * gridDim.x;
20
21        __shared__ float sdata[256];
22
23        float sum = 0.0f;
24        for (int i = idx; i < n; i += stride) {
25            float diff = y_true[i] - ypred[i];
26            sum += diff * diff;
27        }
28
29        sdata[threadIdx.x] = sum;
30        __syncthreads();
31
32        // Parallel reduction
33        for (int s = blockDim.x / 2; s > 0; s >>= 1) {
34            if (threadIdx.x < s) {
35                sdata[threadIdx.x] += sdata[threadIdx.x + s];
36            }
37            __syncthreads();
38        }
39
40        if (threadIdx.x == 0) {
41            atomicAdd(result, sdata[0] / n);
42        }
43    }
44    "#;
45
46    pub const MAE_KERNEL: &str = r#"
47    extern "C" __global__ void mae_kernel(
48        const float* y_true,
49        const float* ypred,
50        float* result,
51        int n
52    ) {
53        int idx = blockIdx.x * blockDim.x + threadIdx.x;
54        int stride = blockDim.x * gridDim.x;
55
56        __shared__ float sdata[256];
57
58        float sum = 0.0f;
59        for (int i = idx; i < n; i += stride) {
60            sum += fabsf(y_true[i] - ypred[i]);
61        }
62
63        sdata[threadIdx.x] = sum;
64        __syncthreads();
65
66        for (int s = blockDim.x / 2; s > 0; s >>= 1) {
67            if (threadIdx.x < s) {
68                sdata[threadIdx.x] += sdata[threadIdx.x + s];
69            }
70            __syncthreads();
71        }
72
73        if (threadIdx.x == 0) {
74            atomicAdd(result, sdata[0] / n);
75        }
76    }
77    "#;
78
79    pub const R2_KERNEL: &str = r#"
80    extern "C" __global__ void r2_kernel(
81        const float* y_true,
82        const float* ypred,
83        float* ss_res,
84        float* ss_tot,
85        float mean_true,
86        int n
87    ) {
88        int idx = blockIdx.x * blockDim.x + threadIdx.x;
89        int stride = blockDim.x * gridDim.x;
90
91        __shared__ float sdata_res[256];
92        __shared__ float sdata_tot[256];
93
94        float sum_res = 0.0f;
95        float sum_tot = 0.0f;
96
97        for (int i = idx; i < n; i += stride) {
98            float diff = y_true[i] - ypred[i];
99            sum_res += diff * diff;
100
101            float diff_mean = y_true[i] - mean_true;
102            sum_tot += diff_mean * diff_mean;
103        }
104
105        sdata_res[threadIdx.x] = sum_res;
106        sdata_tot[threadIdx.x] = sum_tot;
107        __syncthreads();
108
109        for (int s = blockDim.x / 2; s > 0; s >>= 1) {
110            if (threadIdx.x < s) {
111                sdata_res[threadIdx.x] += sdata_res[threadIdx.x + s];
112                sdata_tot[threadIdx.x] += sdata_tot[threadIdx.x + s];
113            }
114            __syncthreads();
115        }
116
117        if (threadIdx.x == 0) {
118            atomicAdd(ss_res, sdata_res[0]);
119            atomicAdd(ss_tot, sdata_tot[0]);
120        }
121    }
122    "#;
123
124    pub const PRECISION_RECALL_KERNEL: &str = r#"
125    extern "C" __global__ void precision_recall_kernel(
126        const float* y_true,
127        const float* ypred,
128        float* tp,
129        float* fp,
130        float* fn_ptr,
131        float threshold,
132        int n
133    ) {
134        int idx = blockIdx.x * blockDim.x + threadIdx.x;
135        int stride = blockDim.x * gridDim.x;
136
137        __shared__ float sdata_tp[256];
138        __shared__ float sdata_fp[256];
139        __shared__ float sdata_fn[256];
140
141        float local_tp = 0.0f;
142        float local_fp = 0.0f;
143        float local_fn = 0.0f;
144
145        for (int i = idx; i < n; i += stride) {
146            float pred = ypred[i] > threshold ? 1.0f : 0.0f;
147            float truth = y_true[i];
148
149            if (pred == 1.0f && truth == 1.0f) local_tp += 1.0f;
150            else if (pred == 1.0f && truth == 0.0f) local_fp += 1.0f;
151            else if (pred == 0.0f && truth == 1.0f) local_fn += 1.0f;
152        }
153
154        sdata_tp[threadIdx.x] = local_tp;
155        sdata_fp[threadIdx.x] = local_fp;
156        sdata_fn[threadIdx.x] = local_fn;
157        __syncthreads();
158
159        for (int s = blockDim.x / 2; s > 0; s >>= 1) {
160            if (threadIdx.x < s) {
161                sdata_tp[threadIdx.x] += sdata_tp[threadIdx.x + s];
162                sdata_fp[threadIdx.x] += sdata_fp[threadIdx.x + s];
163                sdata_fn[threadIdx.x] += sdata_fn[threadIdx.x + s];
164            }
165            __syncthreads();
166        }
167
168        if (threadIdx.x == 0) {
169            atomicAdd(tp, sdata_tp[0]);
170            atomicAdd(fp, sdata_fp[0]);
171            atomicAdd(fn_ptr, sdata_fn[0]);
172        }
173    }
174    "#;
175}
176
177/// OpenCL kernel source code for metrics computation
178pub mod opencl_kernels {
179    pub const MSE_KERNEL: &str = r#"
180    __kernel void mse_kernel(
181        __global const float* y_true,
182        __global const float* ypred,
183        __global float* result,
184        const int n
185    ) {
186        int idx = get_global_id(0);
187        int stride = get_global_size(0);
188
189        __local float sdata[256];
190        int lid = get_local_id(0);
191
192        float sum = 0.0f;
193        for (int i = idx; i < n; i += stride) {
194            float diff = y_true[i] - ypred[i];
195            sum += diff * diff;
196        }
197
198        sdata[lid] = sum;
199        barrier(CLK_LOCAL_MEM_FENCE);
200
201        for (int s = get_local_size(0) / 2; s > 0; s >>= 1) {
202            if (lid < s) {
203                sdata[lid] += sdata[lid + s];
204            }
205            barrier(CLK_LOCAL_MEM_FENCE);
206        }
207
208        if (lid == 0) {
209            atomic_add_global(result, sdata[0] / n);
210        }
211    }
212    "#;
213
214    pub const MAE_KERNEL: &str = r#"
215    __kernel void mae_kernel(
216        __global const float* y_true,
217        __global const float* ypred,
218        __global float* result,
219        const int n
220    ) {
221        int idx = get_global_id(0);
222        int stride = get_global_size(0);
223
224        __local float sdata[256];
225        int lid = get_local_id(0);
226
227        float sum = 0.0f;
228        for (int i = idx; i < n; i += stride) {
229            sum += fabs(y_true[i] - ypred[i]);
230        }
231
232        sdata[lid] = sum;
233        barrier(CLK_LOCAL_MEM_FENCE);
234
235        for (int s = get_local_size(0) / 2; s > 0; s >>= 1) {
236            if (lid < s) {
237                sdata[lid] += sdata[lid + s];
238            }
239            barrier(CLK_LOCAL_MEM_FENCE);
240        }
241
242        if (lid == 0) {
243            atomic_add_global(result, sdata[0] / n);
244        }
245    }
246    "#;
247}
248
249/// Metal compute shader kernels for metrics computation
250pub mod metal_kernels {
251    pub const MSE_KERNEL: &str = r#"
252    #include <metal_stdlib>
253    using namespace metal;
254
255    kernel void mse_kernel(
256        device const float* y_true [[buffer(0)]],
257        device const float* ypred [[buffer(1)]],
258        device float* result [[buffer(2)]],
259        constant uint& n [[buffer(3)]],
260        uint id [[thread_position_in_grid]],
261        uint threads_per_grid [[threads_per_grid]]
262    ) {
263        threadgroup float sdata[256];
264        uint lid = threadgroup_position_in_grid;
265
266        float sum = 0.0;
267        for (uint i = id; i < n; i += threads_per_grid) {
268            float diff = y_true[i] - ypred[i];
269            sum += diff * diff;
270        }
271
272        sdata[lid] = sum;
273        threadgroup_barrier(mem_flags::mem_threadgroup);
274
275        for (uint s = 128; s > 0; s >>= 1) {
276            if (lid < s) {
277                sdata[lid] += sdata[lid + s];
278            }
279            threadgroup_barrier(mem_flags::mem_threadgroup);
280        }
281
282        if (lid == 0) {
283            atomic_fetch_add_explicit(result, sdata[0] / n, memory_order_relaxed);
284        }
285    }
286    "#;
287
288    pub const MAE_KERNEL: &str = r#"
289    #include <metal_stdlib>
290    using namespace metal;
291
292    kernel void mae_kernel(
293        device const float* y_true [[buffer(0)]],
294        device const float* ypred [[buffer(1)]],
295        device float* result [[buffer(2)]],
296        constant uint& n [[buffer(3)]],
297        uint id [[thread_position_in_grid]],
298        uint threads_per_grid [[threads_per_grid]]
299    ) {
300        threadgroup float sdata[256];
301        uint lid = threadgroup_position_in_grid;
302
303        float sum = 0.0;
304        for (uint i = id; i < n; i += threads_per_grid) {
305            sum += abs(y_true[i] - ypred[i]);
306        }
307
308        sdata[lid] = sum;
309        threadgroup_barrier(mem_flags::mem_threadgroup);
310
311        for (uint s = 128; s > 0; s >>= 1) {
312            if (lid < s) {
313                sdata[lid] += sdata[lid + s];
314            }
315            threadgroup_barrier(mem_flags::mem_threadgroup);
316        }
317
318        if (lid == 0) {
319            atomic_fetch_add_explicit(result, sdata[0] / n, memory_order_relaxed);
320        }
321    }
322    "#;
323}
324
325/// Vulkan SPIR-V compute shader kernels
326pub mod vulkan_kernels {
327    pub const MSE_SPIRV: &[u8] = &[
328        // SPIR-V bytecode for MSE kernel would go here
329        // This is a placeholder for the actual compiled SPIR-V
330        0x03, 0x02, 0x23,
331        0x07, // SPIR-V magic number
332             // ... actual SPIR-V bytecode would follow
333    ];
334
335    pub const MAE_SPIRV: &[u8] = &[
336        // SPIR-V bytecode for MAE kernel would go here
337        0x03, 0x02, 0x23,
338        0x07, // SPIR-V magic number
339             // ... actual SPIR-V bytecode would follow
340    ];
341
342    pub const MSE_GLSL_SOURCE: &str = r#"
343    #version 450
344
345    layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
346
347    layout(set = 0, binding = 0) buffer YTrue {
348        float y_true[];
349    };
350
351    layout(set = 0, binding = 1) buffer YPred {
352        float ypred[];
353    };
354
355    layout(set = 0, binding = 2) buffer Result {
356        float result[];
357    };
358
359    layout(push_constant) uniform PushConstants {
360        uint n;
361    } pc;
362
363    shared float sdata[256];
364
365    void main() {
366        uint idx = gl_GlobalInvocationID.x;
367        uint stride = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
368        uint lid = gl_LocalInvocationID.x;
369
370        float sum = 0.0;
371        for (uint i = idx; i < pc.n; i += stride) {
372            float diff = y_true[i] - ypred[i];
373            sum += diff * diff;
374        }
375
376        sdata[lid] = sum;
377        barrier();
378
379        for (uint s = gl_WorkGroupSize.x / 2; s > 0; s >>= 1) {
380            if (lid < s) {
381                sdata[lid] += sdata[lid + s];
382            }
383            barrier();
384        }
385
386        if (lid == 0) {
387            atomicAdd(result[0], sdata[0] / pc.n);
388        }
389    }
390    "#;
391
392    pub const MAE_GLSL_SOURCE: &str = r#"
393    #version 450
394
395    layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
396
397    layout(set = 0, binding = 0) buffer YTrue {
398        float y_true[];
399    };
400
401    layout(set = 0, binding = 1) buffer YPred {
402        float ypred[];
403    };
404
405    layout(set = 0, binding = 2) buffer Result {
406        float result[];
407    };
408
409    layout(push_constant) uniform PushConstants {
410        uint n;
411    } pc;
412
413    shared float sdata[256];
414
415    void main() {
416        uint idx = gl_GlobalInvocationID.x;
417        uint stride = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
418        uint lid = gl_LocalInvocationID.x;
419
420        float sum = 0.0;
421        for (uint i = idx; i < pc.n; i += stride) {
422            sum += abs(y_true[i] - ypred[i]);
423        }
424
425        sdata[lid] = sum;
426        barrier();
427
428        for (uint s = gl_WorkGroupSize.x / 2; s > 0; s >>= 1) {
429            if (lid < s) {
430                sdata[lid] += sdata[lid + s];
431            }
432            barrier();
433        }
434
435        if (lid == 0) {
436            atomicAdd(result[0], sdata[0] / pc.n);
437        }
438    }
439    "#;
440}