R2_KERNEL

Constant R2_KERNEL 

Source
pub const R2_KERNEL: &str = r#"
    extern "C" __global__ void r2_kernel(
        const float* y_true,
        const float* ypred,
        float* ss_res,
        float* ss_tot,
        float mean_true,
        int n
    ) {
        int idx = blockIdx.x * blockDim.x + threadIdx.x;
        int stride = blockDim.x * gridDim.x;

        __shared__ float sdata_res[256];
        __shared__ float sdata_tot[256];

        float sum_res = 0.0f;
        float sum_tot = 0.0f;

        for (int i = idx; i < n; i += stride) {
            float diff = y_true[i] - ypred[i];
            sum_res += diff * diff;

            float diff_mean = y_true[i] - mean_true;
            sum_tot += diff_mean * diff_mean;
        }

        sdata_res[threadIdx.x] = sum_res;
        sdata_tot[threadIdx.x] = sum_tot;
        __syncthreads();

        for (int s = blockDim.x / 2; s > 0; s >>= 1) {
            if (threadIdx.x < s) {
                sdata_res[threadIdx.x] += sdata_res[threadIdx.x + s];
                sdata_tot[threadIdx.x] += sdata_tot[threadIdx.x + s];
            }
            __syncthreads();
        }

        if (threadIdx.x == 0) {
            atomicAdd(ss_res, sdata_res[0]);
            atomicAdd(ss_tot, sdata_tot[0]);
        }
    }
    "#;