Skip to main content

BIAS_ADD

Constant BIAS_ADD 

Source
pub const BIAS_ADD: &str = "\
struct Dims { N: u32, cols: u32 }
var<push_constant> dims: Dims;

@group(0) @binding(0) var<storage, read> z: array<f32>;
@group(0) @binding(1) var<storage, read> bias: array<f32>;
@group(0) @binding(2) var<storage, read_write> out: array<f32>;

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    if i >= dims.N { return; }
    out[i] = z[i] + bias[i % dims.cols];
}";
Expand description

Bias add: out[i] = z[i] + bias[i % cols].

Push constants: struct Dims { N: u32, cols: u32 } (8 bytes) Workgroup size: 256 — dispatch N invocations (N = rows * cols) Bindings:

  • @binding(0) z: array<f32> (read) — input matrix [rows, cols]
  • @binding(1) bias: array<f32> (read) — bias vector [cols]
  • @binding(2) out: array<f32> (read_write) — output [rows, cols]