Constant MSE_KERNEL
Source pub const MSE_KERNEL: &str = r#"
extern "C" __global__ void mse_kernel(
const float* y_true,
const float* ypred,
float* result,
int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
__shared__ float sdata[256];
float sum = 0.0f;
for (int i = idx; i < n; i += stride) {
float diff = y_true[i] - ypred[i];
sum += diff * diff;
}
sdata[threadIdx.x] = sum;
__syncthreads();
// Parallel reduction
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
sdata[threadIdx.x] += sdata[threadIdx.x + s];
}
__syncthreads();
}
if (threadIdx.x == 0) {
atomicAdd(result, sdata[0] / n);
}
}
"#;