scirs2_metrics/optimization/gpu_kernels/
kernels.rs1#![allow(clippy::too_many_arguments)]
7#![allow(dead_code)]
8
9pub mod cuda_kernels {
11 pub const MSE_KERNEL: &str = r#"
12 extern "C" __global__ void mse_kernel(
13 const float* y_true,
14 const float* ypred,
15 float* result,
16 int n
17 ) {
18 int idx = blockIdx.x * blockDim.x + threadIdx.x;
19 int stride = blockDim.x * gridDim.x;
20
21 __shared__ float sdata[256];
22
23 float sum = 0.0f;
24 for (int i = idx; i < n; i += stride) {
25 float diff = y_true[i] - ypred[i];
26 sum += diff * diff;
27 }
28
29 sdata[threadIdx.x] = sum;
30 __syncthreads();
31
32 // Parallel reduction
33 for (int s = blockDim.x / 2; s > 0; s >>= 1) {
34 if (threadIdx.x < s) {
35 sdata[threadIdx.x] += sdata[threadIdx.x + s];
36 }
37 __syncthreads();
38 }
39
40 if (threadIdx.x == 0) {
41 atomicAdd(result, sdata[0] / n);
42 }
43 }
44 "#;
45
46 pub const MAE_KERNEL: &str = r#"
47 extern "C" __global__ void mae_kernel(
48 const float* y_true,
49 const float* ypred,
50 float* result,
51 int n
52 ) {
53 int idx = blockIdx.x * blockDim.x + threadIdx.x;
54 int stride = blockDim.x * gridDim.x;
55
56 __shared__ float sdata[256];
57
58 float sum = 0.0f;
59 for (int i = idx; i < n; i += stride) {
60 sum += fabsf(y_true[i] - ypred[i]);
61 }
62
63 sdata[threadIdx.x] = sum;
64 __syncthreads();
65
66 for (int s = blockDim.x / 2; s > 0; s >>= 1) {
67 if (threadIdx.x < s) {
68 sdata[threadIdx.x] += sdata[threadIdx.x + s];
69 }
70 __syncthreads();
71 }
72
73 if (threadIdx.x == 0) {
74 atomicAdd(result, sdata[0] / n);
75 }
76 }
77 "#;
78
79 pub const R2_KERNEL: &str = r#"
80 extern "C" __global__ void r2_kernel(
81 const float* y_true,
82 const float* ypred,
83 float* ss_res,
84 float* ss_tot,
85 float mean_true,
86 int n
87 ) {
88 int idx = blockIdx.x * blockDim.x + threadIdx.x;
89 int stride = blockDim.x * gridDim.x;
90
91 __shared__ float sdata_res[256];
92 __shared__ float sdata_tot[256];
93
94 float sum_res = 0.0f;
95 float sum_tot = 0.0f;
96
97 for (int i = idx; i < n; i += stride) {
98 float diff = y_true[i] - ypred[i];
99 sum_res += diff * diff;
100
101 float diff_mean = y_true[i] - mean_true;
102 sum_tot += diff_mean * diff_mean;
103 }
104
105 sdata_res[threadIdx.x] = sum_res;
106 sdata_tot[threadIdx.x] = sum_tot;
107 __syncthreads();
108
109 for (int s = blockDim.x / 2; s > 0; s >>= 1) {
110 if (threadIdx.x < s) {
111 sdata_res[threadIdx.x] += sdata_res[threadIdx.x + s];
112 sdata_tot[threadIdx.x] += sdata_tot[threadIdx.x + s];
113 }
114 __syncthreads();
115 }
116
117 if (threadIdx.x == 0) {
118 atomicAdd(ss_res, sdata_res[0]);
119 atomicAdd(ss_tot, sdata_tot[0]);
120 }
121 }
122 "#;
123
124 pub const PRECISION_RECALL_KERNEL: &str = r#"
125 extern "C" __global__ void precision_recall_kernel(
126 const float* y_true,
127 const float* ypred,
128 float* tp,
129 float* fp,
130 float* fn_ptr,
131 float threshold,
132 int n
133 ) {
134 int idx = blockIdx.x * blockDim.x + threadIdx.x;
135 int stride = blockDim.x * gridDim.x;
136
137 __shared__ float sdata_tp[256];
138 __shared__ float sdata_fp[256];
139 __shared__ float sdata_fn[256];
140
141 float local_tp = 0.0f;
142 float local_fp = 0.0f;
143 float local_fn = 0.0f;
144
145 for (int i = idx; i < n; i += stride) {
146 float pred = ypred[i] > threshold ? 1.0f : 0.0f;
147 float truth = y_true[i];
148
149 if (pred == 1.0f && truth == 1.0f) local_tp += 1.0f;
150 else if (pred == 1.0f && truth == 0.0f) local_fp += 1.0f;
151 else if (pred == 0.0f && truth == 1.0f) local_fn += 1.0f;
152 }
153
154 sdata_tp[threadIdx.x] = local_tp;
155 sdata_fp[threadIdx.x] = local_fp;
156 sdata_fn[threadIdx.x] = local_fn;
157 __syncthreads();
158
159 for (int s = blockDim.x / 2; s > 0; s >>= 1) {
160 if (threadIdx.x < s) {
161 sdata_tp[threadIdx.x] += sdata_tp[threadIdx.x + s];
162 sdata_fp[threadIdx.x] += sdata_fp[threadIdx.x + s];
163 sdata_fn[threadIdx.x] += sdata_fn[threadIdx.x + s];
164 }
165 __syncthreads();
166 }
167
168 if (threadIdx.x == 0) {
169 atomicAdd(tp, sdata_tp[0]);
170 atomicAdd(fp, sdata_fp[0]);
171 atomicAdd(fn_ptr, sdata_fn[0]);
172 }
173 }
174 "#;
175}
176
177pub mod opencl_kernels {
179 pub const MSE_KERNEL: &str = r#"
180 __kernel void mse_kernel(
181 __global const float* y_true,
182 __global const float* ypred,
183 __global float* result,
184 const int n
185 ) {
186 int idx = get_global_id(0);
187 int stride = get_global_size(0);
188
189 __local float sdata[256];
190 int lid = get_local_id(0);
191
192 float sum = 0.0f;
193 for (int i = idx; i < n; i += stride) {
194 float diff = y_true[i] - ypred[i];
195 sum += diff * diff;
196 }
197
198 sdata[lid] = sum;
199 barrier(CLK_LOCAL_MEM_FENCE);
200
201 for (int s = get_local_size(0) / 2; s > 0; s >>= 1) {
202 if (lid < s) {
203 sdata[lid] += sdata[lid + s];
204 }
205 barrier(CLK_LOCAL_MEM_FENCE);
206 }
207
208 if (lid == 0) {
209 atomic_add_global(result, sdata[0] / n);
210 }
211 }
212 "#;
213
214 pub const MAE_KERNEL: &str = r#"
215 __kernel void mae_kernel(
216 __global const float* y_true,
217 __global const float* ypred,
218 __global float* result,
219 const int n
220 ) {
221 int idx = get_global_id(0);
222 int stride = get_global_size(0);
223
224 __local float sdata[256];
225 int lid = get_local_id(0);
226
227 float sum = 0.0f;
228 for (int i = idx; i < n; i += stride) {
229 sum += fabs(y_true[i] - ypred[i]);
230 }
231
232 sdata[lid] = sum;
233 barrier(CLK_LOCAL_MEM_FENCE);
234
235 for (int s = get_local_size(0) / 2; s > 0; s >>= 1) {
236 if (lid < s) {
237 sdata[lid] += sdata[lid + s];
238 }
239 barrier(CLK_LOCAL_MEM_FENCE);
240 }
241
242 if (lid == 0) {
243 atomic_add_global(result, sdata[0] / n);
244 }
245 }
246 "#;
247}
248
249pub mod metal_kernels {
251 pub const MSE_KERNEL: &str = r#"
252 #include <metal_stdlib>
253 using namespace metal;
254
255 kernel void mse_kernel(
256 device const float* y_true [[buffer(0)]],
257 device const float* ypred [[buffer(1)]],
258 device float* result [[buffer(2)]],
259 constant uint& n [[buffer(3)]],
260 uint id [[thread_position_in_grid]],
261 uint threads_per_grid [[threads_per_grid]]
262 ) {
263 threadgroup float sdata[256];
264 uint lid = threadgroup_position_in_grid;
265
266 float sum = 0.0;
267 for (uint i = id; i < n; i += threads_per_grid) {
268 float diff = y_true[i] - ypred[i];
269 sum += diff * diff;
270 }
271
272 sdata[lid] = sum;
273 threadgroup_barrier(mem_flags::mem_threadgroup);
274
275 for (uint s = 128; s > 0; s >>= 1) {
276 if (lid < s) {
277 sdata[lid] += sdata[lid + s];
278 }
279 threadgroup_barrier(mem_flags::mem_threadgroup);
280 }
281
282 if (lid == 0) {
283 atomic_fetch_add_explicit(result, sdata[0] / n, memory_order_relaxed);
284 }
285 }
286 "#;
287
288 pub const MAE_KERNEL: &str = r#"
289 #include <metal_stdlib>
290 using namespace metal;
291
292 kernel void mae_kernel(
293 device const float* y_true [[buffer(0)]],
294 device const float* ypred [[buffer(1)]],
295 device float* result [[buffer(2)]],
296 constant uint& n [[buffer(3)]],
297 uint id [[thread_position_in_grid]],
298 uint threads_per_grid [[threads_per_grid]]
299 ) {
300 threadgroup float sdata[256];
301 uint lid = threadgroup_position_in_grid;
302
303 float sum = 0.0;
304 for (uint i = id; i < n; i += threads_per_grid) {
305 sum += abs(y_true[i] - ypred[i]);
306 }
307
308 sdata[lid] = sum;
309 threadgroup_barrier(mem_flags::mem_threadgroup);
310
311 for (uint s = 128; s > 0; s >>= 1) {
312 if (lid < s) {
313 sdata[lid] += sdata[lid + s];
314 }
315 threadgroup_barrier(mem_flags::mem_threadgroup);
316 }
317
318 if (lid == 0) {
319 atomic_fetch_add_explicit(result, sdata[0] / n, memory_order_relaxed);
320 }
321 }
322 "#;
323}
324
325pub mod vulkan_kernels {
327 pub const MSE_SPIRV: &[u8] = &[
328 0x03, 0x02, 0x23,
331 0x07, ];
334
335 pub const MAE_SPIRV: &[u8] = &[
336 0x03, 0x02, 0x23,
338 0x07, ];
341
342 pub const MSE_GLSL_SOURCE: &str = r#"
343 #version 450
344
345 layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
346
347 layout(set = 0, binding = 0) buffer YTrue {
348 float y_true[];
349 };
350
351 layout(set = 0, binding = 1) buffer YPred {
352 float ypred[];
353 };
354
355 layout(set = 0, binding = 2) buffer Result {
356 float result[];
357 };
358
359 layout(push_constant) uniform PushConstants {
360 uint n;
361 } pc;
362
363 shared float sdata[256];
364
365 void main() {
366 uint idx = gl_GlobalInvocationID.x;
367 uint stride = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
368 uint lid = gl_LocalInvocationID.x;
369
370 float sum = 0.0;
371 for (uint i = idx; i < pc.n; i += stride) {
372 float diff = y_true[i] - ypred[i];
373 sum += diff * diff;
374 }
375
376 sdata[lid] = sum;
377 barrier();
378
379 for (uint s = gl_WorkGroupSize.x / 2; s > 0; s >>= 1) {
380 if (lid < s) {
381 sdata[lid] += sdata[lid + s];
382 }
383 barrier();
384 }
385
386 if (lid == 0) {
387 atomicAdd(result[0], sdata[0] / pc.n);
388 }
389 }
390 "#;
391
392 pub const MAE_GLSL_SOURCE: &str = r#"
393 #version 450
394
395 layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
396
397 layout(set = 0, binding = 0) buffer YTrue {
398 float y_true[];
399 };
400
401 layout(set = 0, binding = 1) buffer YPred {
402 float ypred[];
403 };
404
405 layout(set = 0, binding = 2) buffer Result {
406 float result[];
407 };
408
409 layout(push_constant) uniform PushConstants {
410 uint n;
411 } pc;
412
413 shared float sdata[256];
414
415 void main() {
416 uint idx = gl_GlobalInvocationID.x;
417 uint stride = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
418 uint lid = gl_LocalInvocationID.x;
419
420 float sum = 0.0;
421 for (uint i = idx; i < pc.n; i += stride) {
422 sum += abs(y_true[i] - ypred[i]);
423 }
424
425 sdata[lid] = sum;
426 barrier();
427
428 for (uint s = gl_WorkGroupSize.x / 2; s > 0; s >>= 1) {
429 if (lid < s) {
430 sdata[lid] += sdata[lid + s];
431 }
432 barrier();
433 }
434
435 if (lid == 0) {
436 atomicAdd(result[0], sdata[0] / pc.n);
437 }
438 }
439 "#;
440}