MSE_KERNEL

Constant MSE_KERNEL 

Source
pub const MSE_KERNEL: &str = r#"
    extern "C" __global__ void mse_kernel(
        const float* y_true,
        const float* ypred,
        float* result,
        int n
    ) {
        int idx = blockIdx.x * blockDim.x + threadIdx.x;
        int stride = blockDim.x * gridDim.x;

        __shared__ float sdata[256];

        float sum = 0.0f;
        for (int i = idx; i < n; i += stride) {
            float diff = y_true[i] - ypred[i];
            sum += diff * diff;
        }

        sdata[threadIdx.x] = sum;
        __syncthreads();

        // Parallel reduction
        for (int s = blockDim.x / 2; s > 0; s >>= 1) {
            if (threadIdx.x < s) {
                sdata[threadIdx.x] += sdata[threadIdx.x + s];
            }
            __syncthreads();
        }

        if (threadIdx.x == 0) {
            atomicAdd(result, sdata[0] / n);
        }
    }
    "#;