Skip to main content

ESP_GRID_SHADER

Constant ESP_GRID_SHADER 

Source
pub const ESP_GRID_SHADER: &str = r#"
struct AtomPos {
    x: f32, y: f32, z: f32, _pad: f32,
};

struct GridParams {
    origin_x: f32, origin_y: f32, origin_z: f32,
    spacing: f32,
    dims_x: u32, dims_y: u32, dims_z: u32,
    n_atoms: u32,
};

@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
@group(0) @binding(1) var<storage, read> charges: array<f32>;
@group(0) @binding(2) var<uniform> params: GridParams;
@group(0) @binding(3) var<storage, read_write> output: array<f32>;

@compute @workgroup_size(8, 8, 4)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let ix = gid.x;
    let iy = gid.y;
    let iz = gid.z;

    if (ix >= params.dims_x || iy >= params.dims_y || iz >= params.dims_z) {
        return;
    }

    let rx = params.origin_x + f32(ix) * params.spacing;
    let ry = params.origin_y + f32(iy) * params.spacing;
    let rz = params.origin_z + f32(iz) * params.spacing;
    let flat_idx = ix * params.dims_y * params.dims_z + iy * params.dims_z + iz;

    var phi: f32 = 0.0;
    for (var atom: u32 = 0u; atom < params.n_atoms; atom = atom + 1u) {
        let pos = positions[atom];
        let dx = rx - pos.x;
        let dy = ry - pos.y;
        let dz = rz - pos.z;
        let dist = sqrt(dx * dx + dy * dy + dz * dz);
        if (dist < 0.01) { continue; }
        phi += charges[atom] / (dist * 1.88972599);
    }
    output[flat_idx] = phi;
}
"#;