pub const TANH: &str = "\
struct Dims { N: u32 }
var<push_constant> dims: Dims;
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> out: array<f32>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if i >= dims.N { return; }
out[i] = tanh(input[i]);
}";Expand description
Tanh activation: out[i] = tanh(in[i]).
Push constants: struct Dims { N: u32 } (4 bytes)
Workgroup size: 256 — dispatch N invocations
Bindings:
@binding(0)input: array<f32>(read)@binding(1)out: array<f32>(read_write)