Constant GAMMA_MATRIX_SHADER
Source pub const GAMMA_MATRIX_SHADER: &str = r#"
struct AtomPos {
x: f32, y: f32, z: f32, _pad: f32,
};
struct Params {
n_atoms: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
};
@group(0) @binding(0) var<storage, read> eta: array<f32>;
@group(0) @binding(1) var<storage, read> positions: array<AtomPos>;
@group(0) @binding(2) var<uniform> params: Params;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;
@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let a = gid.x;
let b = gid.y;
let n = params.n_atoms;
if (a >= n || b >= n) {
return;
}
if (a == b) {
output[a * n + b] = eta[a];
return;
}
let pa = positions[a];
let pb = positions[b];
let dx = pa.x - pb.x;
let dy = pa.y - pb.y;
let dz = pa.z - pb.z;
let r = sqrt(dx * dx + dy * dy + dz * dz);
output[a * n + b] = 1.0 / r;
}
"#;