MAE_GLSL_SOURCE

Constant MAE_GLSL_SOURCE 

Source
pub const MAE_GLSL_SOURCE: &str = r#"
    #version 450

    layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

    layout(set = 0, binding = 0) buffer YTrue {
        float y_true[];
    };

    layout(set = 0, binding = 1) buffer YPred {
        float ypred[];
    };

    layout(set = 0, binding = 2) buffer Result {
        float result[];
    };

    layout(push_constant) uniform PushConstants {
        uint n;
    } pc;

    shared float sdata[256];

    void main() {
        uint idx = gl_GlobalInvocationID.x;
        uint stride = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
        uint lid = gl_LocalInvocationID.x;

        float sum = 0.0;
        for (uint i = idx; i < pc.n; i += stride) {
            sum += abs(y_true[i] - ypred[i]);
        }

        sdata[lid] = sum;
        barrier();

        for (uint s = gl_WorkGroupSize.x / 2; s > 0; s >>= 1) {
            if (lid < s) {
                sdata[lid] += sdata[lid + s];
            }
            barrier();
        }

        if (lid == 0) {
            atomicAdd(result[0], sdata[0] / pc.n);
        }
    }
    "#;