pub const BIAS_ADD: &str = "\
struct Dims { N: u32, cols: u32 }
var<push_constant> dims: Dims;
@group(0) @binding(0) var<storage, read> z: array<f32>;
@group(0) @binding(1) var<storage, read> bias: array<f32>;
@group(0) @binding(2) var<storage, read_write> out: array<f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if i >= dims.N { return; }
out[i] = z[i] + bias[i % dims.cols];
}";Expand description
Bias add: out[i] = z[i] + bias[i % cols].
Push constants: struct Dims { N: u32, cols: u32 } (8 bytes)
Workgroup size: 256 — dispatch N invocations (N = rows * cols)
Bindings:
@binding(0)z: array<f32>(read) — input matrix[rows, cols]@binding(1)bias: array<f32>(read) — bias vector[cols]@binding(2)out: array<f32>(read_write) — output[rows, cols]