Skip to main content

DENSITY_GRID_SHADER

Constant DENSITY_GRID_SHADER 

Source
pub const DENSITY_GRID_SHADER: &str = r#"
struct BasisFunc {
    center_x: f32, center_y: f32, center_z: f32,
    lx: u32, ly: u32, lz: u32,
    n_primitives: u32,
    norm_coeff: f32,
};

struct GridParams {
    origin_x: f32, origin_y: f32, origin_z: f32,
    spacing: f32,
    dims_x: u32, dims_y: u32, dims_z: u32,
    n_basis: u32,
};

@group(0) @binding(0) var<storage, read> basis: array<BasisFunc>;
@group(0) @binding(1) var<storage, read> density: array<f32>;
@group(0) @binding(2) var<storage, read> primitives: array<vec2<f32>>;
@group(0) @binding(3) var<uniform> params: GridParams;
@group(0) @binding(4) var<storage, read_write> output: array<f32>;

fn phi_at(mu: u32, rx: f32, ry: f32, rz: f32) -> f32 {
    let bf = basis[mu];
    let dx = rx - bf.center_x;
    let dy = ry - bf.center_y;
    let dz = rz - bf.center_z;
    let r2 = dx * dx + dy * dy + dz * dz;

    var angular: f32 = 1.0;
    for (var i: u32 = 0u; i < bf.lx; i = i + 1u) { angular *= dx; }
    for (var i: u32 = 0u; i < bf.ly; i = i + 1u) { angular *= dy; }
    for (var i: u32 = 0u; i < bf.lz; i = i + 1u) { angular *= dz; }

    var radial: f32 = 0.0;
    for (var p: u32 = 0u; p < bf.n_primitives; p = p + 1u) {
        let prim = primitives[mu * 3u + p];
        radial += prim.y * exp(-prim.x * r2);
    }
    return bf.norm_coeff * angular * radial;
}

@compute @workgroup_size(8, 8, 4)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let ix = gid.x;
    let iy = gid.y;
    let iz = gid.z;

    if (ix >= params.dims_x || iy >= params.dims_y || iz >= params.dims_z) {
        return;
    }

    let rx = params.origin_x + f32(ix) * params.spacing;
    let ry = params.origin_y + f32(iy) * params.spacing;
    let rz = params.origin_z + f32(iz) * params.spacing;
    let flat_idx = ix * params.dims_y * params.dims_z + iy * params.dims_z + iz;

    var rho: f32 = 0.0;
    for (var mu: u32 = 0u; mu < params.n_basis; mu = mu + 1u) {
        let phi_mu = phi_at(mu, rx, ry, rz);
        if (abs(phi_mu) < 1e-7) { continue; }
        for (var nu: u32 = 0u; nu < params.n_basis; nu = nu + 1u) {
            let phi_nu = phi_at(nu, rx, ry, rz);
            rho += density[mu * params.n_basis + nu] * phi_mu * phi_nu;
        }
    }

    output[flat_idx] = rho;
}
"#;