Skip to main content

GAMMA_MATRIX_SHADER

Constant GAMMA_MATRIX_SHADER 

Source
pub const GAMMA_MATRIX_SHADER: &str = r#"
struct AtomPos {
    x: f32, y: f32, z: f32, _pad: f32,
};

struct Params {
    n_atoms: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
};

@group(0) @binding(0) var<storage, read> eta: array<f32>;
@group(0) @binding(1) var<storage, read> positions: array<AtomPos>;
@group(0) @binding(2) var<uniform> params: Params;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;

@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let a = gid.x;
    let b = gid.y;
    let n = params.n_atoms;

    if (a >= n || b >= n) {
        return;
    }

    if (a == b) {
        output[a * n + b] = eta[a];
        return;
    }

    let pa = positions[a];
    let pb = positions[b];
    let dx = pa.x - pb.x;
    let dy = pa.y - pb.y;
    let dz = pa.z - pb.z;
    let r = sqrt(dx * dx + dy * dy + dz * dz);

    output[a * n + b] = 1.0 / r;
}
"#;