Constant R2_KERNEL
Source pub const R2_KERNEL: &str = r#"
extern "C" __global__ void r2_kernel(
const float* y_true,
const float* ypred,
float* ss_res,
float* ss_tot,
float mean_true,
int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
__shared__ float sdata_res[256];
__shared__ float sdata_tot[256];
float sum_res = 0.0f;
float sum_tot = 0.0f;
for (int i = idx; i < n; i += stride) {
float diff = y_true[i] - ypred[i];
sum_res += diff * diff;
float diff_mean = y_true[i] - mean_true;
sum_tot += diff_mean * diff_mean;
}
sdata_res[threadIdx.x] = sum_res;
sdata_tot[threadIdx.x] = sum_tot;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
sdata_res[threadIdx.x] += sdata_res[threadIdx.x + s];
sdata_tot[threadIdx.x] += sdata_tot[threadIdx.x + s];
}
__syncthreads();
}
if (threadIdx.x == 0) {
atomicAdd(ss_res, sdata_res[0]);
atomicAdd(ss_tot, sdata_tot[0]);
}
}
"#;