Skip to main content

scry_gpu/
shaders.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Shared WGSL shader sources for reuse across crates.
3//!
4//! Each constant is a complete WGSL shader string ready to pass to
5//! [`Device::compile`](crate::Device::compile). Push constant layouts
6//! and workgroup sizes are documented per shader.
7
8/// Matrix multiplication shaders.
9///
10/// Each shader is available as a WGSL constant (for Vulkan dispatch) and,
11/// when the `cuda` feature is enabled, as a CUDA C constant (for NVRTC
12/// compilation via [`Device::compile_cuda`](crate::Device::compile_cuda)).
13///
14/// For CUDA matmul, prefer [`Device::cublas_matmul`](crate::Device::cublas_matmul)
15/// over custom kernels — cuBLAS reaches 80%+ peak throughput immediately.
16pub mod matmul {
17    /// Tiled matmul: 16x16 shared-memory tiles, 1 element per thread.
18    ///
19    /// **Push constants:** `struct Dims { M: u32, N: u32, K: u32 }` (12 bytes)
20    /// **Workgroup size:** `(16, 16)` — dispatch `[N.div_ceil(16), M.div_ceil(16), 1]`
21    /// **Shared memory:** 2 x 256 floats (2 KB)
22    pub const TILED_16X16: &str = "\
23struct Dims { M: u32, N: u32, K: u32 }
24var<push_constant> dims: Dims;
25
26@group(0) @binding(0) var<storage, read> A: array<f32>;
27@group(0) @binding(1) var<storage, read> B: array<f32>;
28@group(0) @binding(2) var<storage, read_write> C: array<f32>;
29
30var<workgroup> tile_a: array<f32, 256>;
31var<workgroup> tile_b: array<f32, 256>;
32
33@compute @workgroup_size(16, 16)
34fn main(
35    @builtin(local_invocation_id) lid: vec3<u32>,
36    @builtin(workgroup_id) wid: vec3<u32>,
37) {
38    let row = wid.y * 16u + lid.y;
39    let col = wid.x * 16u + lid.x;
40    let lr = lid.y;
41    let lc = lid.x;
42
43    var sum = 0.0;
44    let num_tiles = (dims.K + 15u) / 16u;
45
46    for (var t = 0u; t < num_tiles; t++) {
47        let a_col = t * 16u + lc;
48        if row < dims.M && a_col < dims.K {
49            tile_a[lr * 16u + lc] = A[row * dims.K + a_col];
50        } else {
51            tile_a[lr * 16u + lc] = 0.0;
52        }
53
54        let b_row = t * 16u + lr;
55        if b_row < dims.K && col < dims.N {
56            tile_b[lr * 16u + lc] = B[b_row * dims.N + col];
57        } else {
58            tile_b[lr * 16u + lc] = 0.0;
59        }
60
61        workgroupBarrier();
62
63        for (var k = 0u; k < 16u; k++) {
64            sum += tile_a[lr * 16u + k] * tile_b[k * 16u + lc];
65        }
66
67        workgroupBarrier();
68    }
69
70    if row < dims.M && col < dims.N {
71        C[row * dims.N + col] = sum;
72    }
73}";
74
75    /// CUDA C equivalent of [`TILED_16X16`].
76    ///
77    /// **Kernel signature:** `matmul_tiled_16x16(const float* A, const float* B, float* C, unsigned int M, unsigned int N, unsigned int K)`
78    /// **Block size:** `(16, 16)` — dispatch `[N.div_ceil(16), M.div_ceil(16), 1]`
79    #[cfg(feature = "cuda")]
80    pub const TILED_16X16_CUDA: &str = "\
81extern \"C\" __global__ void matmul_tiled_16x16(
82    const float* A, const float* B, float* C,
83    unsigned int M, unsigned int N, unsigned int K
84) {
85    __shared__ float tile_a[256];
86    __shared__ float tile_b[256];
87
88    unsigned int row = blockIdx.y * 16 + threadIdx.y;
89    unsigned int col = blockIdx.x * 16 + threadIdx.x;
90    unsigned int lr = threadIdx.y;
91    unsigned int lc = threadIdx.x;
92
93    float sum = 0.0f;
94    unsigned int num_tiles = (K + 15) / 16;
95
96    for (unsigned int t = 0; t < num_tiles; t++) {
97        unsigned int a_col = t * 16 + lc;
98        tile_a[lr * 16 + lc] = (row < M && a_col < K) ? A[row * K + a_col] : 0.0f;
99
100        unsigned int b_row = t * 16 + lr;
101        tile_b[lr * 16 + lc] = (b_row < K && col < N) ? B[b_row * N + col] : 0.0f;
102
103        __syncthreads();
104
105        for (unsigned int k = 0; k < 16; k++) {
106            sum += tile_a[lr * 16 + k] * tile_b[k * 16 + lc];
107        }
108
109        __syncthreads();
110    }
111
112    if (row < M && col < N) {
113        C[row * N + col] = sum;
114    }
115}";
116
117    /// Thread-coarsened matmul: 64x64 output tile, each thread computes 4x4.
118    ///
119    /// **Push constants:** `struct Dims { M: u32, N: u32, K: u32 }` (12 bytes)
120    /// **Workgroup size:** `(16, 16)` = 256 threads, each owns a 4x4 output block.
121    /// **Dispatch:** `[N.div_ceil(64), M.div_ceil(64), 1]`
122    /// **Shared memory:** A\[64x(16+1)\] + B\[16x64\] = ~8.5 KB (A padded to stride 17
123    /// to eliminate bank conflicts).
124    /// **Arithmetic intensity:** 16 FLOP/byte (4x over the simple tiled kernel).
125    pub const COARSE_64X64: &str = "\
126struct Dims { M: u32, N: u32, K: u32 }
127var<push_constant> dims: Dims;
128
129@group(0) @binding(0) var<storage, read> A: array<f32>;
130@group(0) @binding(1) var<storage, read> B: array<f32>;
131@group(0) @binding(2) var<storage, read_write> C: array<f32>;
132
133var<workgroup> sa: array<f32, 1088>;
134var<workgroup> sb: array<f32, 1024>;
135
136@compute @workgroup_size(16, 16)
137fn main(
138    @builtin(local_invocation_id) lid: vec3<u32>,
139    @builtin(local_invocation_index) li: u32,
140    @builtin(workgroup_id) wid: vec3<u32>,
141) {
142    let block_row = wid.y * 64u;
143    let block_col = wid.x * 64u;
144    let tr = lid.y * 4u;
145    let tc = lid.x * 4u;
146
147    var acc: array<f32, 16>;
148    for (var i = 0u; i < 16u; i++) { acc[i] = 0.0; }
149
150    let num_k_tiles = (dims.K + 15u) / 16u;
151
152    for (var kt = 0u; kt < num_k_tiles; kt++) {
153        // Load A tile [64x16] into padded layout (stride 17)
154        for (var x = 0u; x < 4u; x++) {
155            let flat = li * 4u + x;
156            let r = flat / 16u;
157            let c = flat % 16u;
158            let gr = block_row + r;
159            let gc = kt * 16u + c;
160            if gr < dims.M && gc < dims.K {
161                sa[r * 17u + c] = A[gr * dims.K + gc];
162            } else {
163                sa[r * 17u + c] = 0.0;
164            }
165        }
166
167        // Load B tile [16x64]
168        for (var x = 0u; x < 4u; x++) {
169            let flat = li * 4u + x;
170            let r = flat / 64u;
171            let c = flat % 64u;
172            let gr = kt * 16u + r;
173            let gc = block_col + c;
174            if gr < dims.K && gc < dims.N {
175                sb[flat] = B[gr * dims.N + gc];
176            } else {
177                sb[flat] = 0.0;
178            }
179        }
180
181        workgroupBarrier();
182
183        for (var k = 0u; k < 16u; k++) {
184            for (var i = 0u; i < 4u; i++) {
185                let a_val = sa[(tr + i) * 17u + k];
186                for (var j = 0u; j < 4u; j++) {
187                    acc[i * 4u + j] += a_val * sb[k * 64u + tc + j];
188                }
189            }
190        }
191
192        workgroupBarrier();
193    }
194
195    for (var i = 0u; i < 4u; i++) {
196        for (var j = 0u; j < 4u; j++) {
197            let gr = block_row + tr + i;
198            let gc = block_col + tc + j;
199            if gr < dims.M && gc < dims.N {
200                C[gr * dims.N + gc] = acc[i * 4u + j];
201            }
202        }
203    }
204}";
205    /// Thread-coarsened matmul: 128x128 tile, 8x8 per thread with vec4 accumulators.
206    ///
207    /// Uses 16 named `vec4<f32>` accumulator variables instead of `array<f32, 64>`
208    /// to avoid NVIDIA SPIR-V register spill (which triggers at `array<f32, 32+>`).
209    /// Vec4 loads from the B shared-memory tile halve load instruction count.
210    ///
211    /// **Push constants:** `struct Dims { M: u32, N: u32, K: u32 }` (12 bytes)
212    /// **Workgroup size:** `(16, 16)` = 256 threads, each owns an 8×8 output block.
213    /// **Dispatch:** `[N.div_ceil(128), M.div_ceil(128), 1]`
214    /// **Shared memory:** A\[128×(16+1)\] + B\[16×128\] ≈ 16.6 KB
215    /// **Arithmetic intensity:** 64 FLOP per (8+2) loads ≈ 6.4 FMA/load (3.2× over 4×4).
216    pub const COARSE_8X8: &str = "\
217struct Dims { M: u32, N: u32, K: u32 }
218var<push_constant> dims: Dims;
219
220@group(0) @binding(0) var<storage, read> A: array<f32>;
221@group(0) @binding(1) var<storage, read> B: array<f32>;
222@group(0) @binding(2) var<storage, read_write> C: array<f32>;
223
224var<workgroup> sa: array<f32, 2176>;
225var<workgroup> sb: array<f32, 2048>;
226
227fn store_row(gr: u32, gc: u32, lo: vec4<f32>, hi: vec4<f32>) {
228    if gr >= dims.M { return; }
229    let base = gr * dims.N + gc;
230    if gc < dims.N { C[base] = lo.x; }
231    if gc + 1u < dims.N { C[base + 1u] = lo.y; }
232    if gc + 2u < dims.N { C[base + 2u] = lo.z; }
233    if gc + 3u < dims.N { C[base + 3u] = lo.w; }
234    if gc + 4u < dims.N { C[base + 4u] = hi.x; }
235    if gc + 5u < dims.N { C[base + 5u] = hi.y; }
236    if gc + 6u < dims.N { C[base + 6u] = hi.z; }
237    if gc + 7u < dims.N { C[base + 7u] = hi.w; }
238}
239
240@compute @workgroup_size(16, 16)
241fn main(
242    @builtin(local_invocation_id) lid: vec3<u32>,
243    @builtin(local_invocation_index) li: u32,
244    @builtin(workgroup_id) wid: vec3<u32>,
245) {
246    let block_row = wid.y * 128u;
247    let block_col = wid.x * 128u;
248    let tr = lid.y * 8u;
249    let tc = lid.x * 8u;
250
251    // 16 named vec4 accumulators — avoids array-based register spill.
252    var r0l = vec4<f32>(0.0); var r0h = vec4<f32>(0.0);
253    var r1l = vec4<f32>(0.0); var r1h = vec4<f32>(0.0);
254    var r2l = vec4<f32>(0.0); var r2h = vec4<f32>(0.0);
255    var r3l = vec4<f32>(0.0); var r3h = vec4<f32>(0.0);
256    var r4l = vec4<f32>(0.0); var r4h = vec4<f32>(0.0);
257    var r5l = vec4<f32>(0.0); var r5h = vec4<f32>(0.0);
258    var r6l = vec4<f32>(0.0); var r6h = vec4<f32>(0.0);
259    var r7l = vec4<f32>(0.0); var r7h = vec4<f32>(0.0);
260
261    let num_k_tiles = (dims.K + 15u) / 16u;
262
263    for (var kt = 0u; kt < num_k_tiles; kt++) {
264        // Load A tile [128x16] — 2048 elements, 8 per thread, padded stride 17
265        for (var x = 0u; x < 8u; x++) {
266            let flat = li * 8u + x;
267            let r = flat / 16u;
268            let c = flat % 16u;
269            let gr = block_row + r;
270            let gc = kt * 16u + c;
271            if gr < dims.M && gc < dims.K {
272                sa[r * 17u + c] = A[gr * dims.K + gc];
273            } else {
274                sa[r * 17u + c] = 0.0;
275            }
276        }
277
278        // Load B tile [16x128] — 2048 elements, 8 per thread
279        for (var x = 0u; x < 8u; x++) {
280            let flat = li * 8u + x;
281            let r = flat / 128u;
282            let c = flat % 128u;
283            let gr = kt * 16u + r;
284            let gc = block_col + c;
285            if gr < dims.K && gc < dims.N {
286                sb[flat] = B[gr * dims.N + gc];
287            } else {
288                sb[flat] = 0.0;
289            }
290        }
291
292        workgroupBarrier();
293
294        // Inner loop: 8 a-scalar loads + 2 vec4 b-loads + 16 vec4 FMAs per k
295        for (var k = 0u; k < 16u; k++) {
296            let bk = k * 128u + tc;
297            let bl = vec4<f32>(sb[bk], sb[bk+1u], sb[bk+2u], sb[bk+3u]);
298            let bh = vec4<f32>(sb[bk+4u], sb[bk+5u], sb[bk+6u], sb[bk+7u]);
299
300            let a0 = sa[(tr    ) * 17u + k]; r0l += a0 * bl; r0h += a0 * bh;
301            let a1 = sa[(tr+1u) * 17u + k]; r1l += a1 * bl; r1h += a1 * bh;
302            let a2 = sa[(tr+2u) * 17u + k]; r2l += a2 * bl; r2h += a2 * bh;
303            let a3 = sa[(tr+3u) * 17u + k]; r3l += a3 * bl; r3h += a3 * bh;
304            let a4 = sa[(tr+4u) * 17u + k]; r4l += a4 * bl; r4h += a4 * bh;
305            let a5 = sa[(tr+5u) * 17u + k]; r5l += a5 * bl; r5h += a5 * bh;
306            let a6 = sa[(tr+6u) * 17u + k]; r6l += a6 * bl; r6h += a6 * bh;
307            let a7 = sa[(tr+7u) * 17u + k]; r7l += a7 * bl; r7h += a7 * bh;
308        }
309
310        workgroupBarrier();
311    }
312
313    let gc = block_col + tc;
314    store_row(block_row + tr,      gc, r0l, r0h);
315    store_row(block_row + tr + 1u, gc, r1l, r1h);
316    store_row(block_row + tr + 2u, gc, r2l, r2h);
317    store_row(block_row + tr + 3u, gc, r3l, r3h);
318    store_row(block_row + tr + 4u, gc, r4l, r4h);
319    store_row(block_row + tr + 5u, gc, r5l, r5h);
320    store_row(block_row + tr + 6u, gc, r6l, r6h);
321    store_row(block_row + tr + 7u, gc, r7l, r7h);
322}";
323}
324
325/// Element-wise activation and bias shaders.
326///
327/// All shaders use workgroup size 256 (1D) and take a push constant `N: u32`
328/// for bounds checking. Each thread processes one element.
329pub mod elementwise {
330    /// Bias add: `out[i] = z[i] + bias[i % cols]`.
331    ///
332    /// **Push constants:** `struct Dims { N: u32, cols: u32 }` (8 bytes)
333    /// **Workgroup size:** 256 — dispatch `N` invocations (N = rows * cols)
334    /// **Bindings:**
335    ///   - `@binding(0)` `z: array<f32>` (read) — input matrix `[rows, cols]`
336    ///   - `@binding(1)` `bias: array<f32>` (read) — bias vector `[cols]`
337    ///   - `@binding(2)` `out: array<f32>` (`read_write`) — output `[rows, cols]`
338    pub const BIAS_ADD: &str = "\
339struct Dims { N: u32, cols: u32 }
340var<push_constant> dims: Dims;
341
342@group(0) @binding(0) var<storage, read> z: array<f32>;
343@group(0) @binding(1) var<storage, read> bias: array<f32>;
344@group(0) @binding(2) var<storage, read_write> out: array<f32>;
345
346@compute @workgroup_size(256)
347fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
348    let i = gid.x;
349    if i >= dims.N { return; }
350    out[i] = z[i] + bias[i % dims.cols];
351}";
352
353    /// CUDA C equivalent of [`BIAS_ADD`].
354    #[cfg(feature = "cuda")]
355    pub const BIAS_ADD_CUDA: &str = "\
356extern \"C\" __global__ void bias_add(
357    const float* z, const float* bias, float* out,
358    unsigned int N, unsigned int cols
359) {
360    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
361    if (i >= N) return;
362    out[i] = z[i] + bias[i % cols];
363}";
364
365    /// `ReLU` activation: `out[i] = max(0, in[i])`.
366    ///
367    /// **Push constants:** `struct Dims { N: u32 }` (4 bytes)
368    /// **Workgroup size:** 256 — dispatch `N` invocations
369    /// **Bindings:**
370    ///   - `@binding(0)` `input: array<f32>` (read)
371    ///   - `@binding(1)` `out: array<f32>` (`read_write`)
372    pub const RELU: &str = "\
373struct Dims { N: u32 }
374var<push_constant> dims: Dims;
375
376@group(0) @binding(0) var<storage, read> input: array<f32>;
377@group(0) @binding(1) var<storage, read_write> out: array<f32>;
378
379@compute @workgroup_size(256)
380fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
381    let i = gid.x;
382    if i >= dims.N { return; }
383    out[i] = max(0.0, input[i]);
384}";
385
386    /// CUDA C equivalent of [`RELU`].
387    #[cfg(feature = "cuda")]
388    pub const RELU_CUDA: &str = "\
389extern \"C\" __global__ void relu(
390    const float* input, float* out,
391    unsigned int N
392) {
393    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
394    if (i >= N) return;
395    out[i] = fmaxf(0.0f, input[i]);
396}";
397
398    /// Tanh activation: `out[i] = tanh(in[i])`.
399    ///
400    /// **Push constants:** `struct Dims { N: u32 }` (4 bytes)
401    /// **Workgroup size:** 256 — dispatch `N` invocations
402    /// **Bindings:**
403    ///   - `@binding(0)` `input: array<f32>` (read)
404    ///   - `@binding(1)` `out: array<f32>` (`read_write`)
405    pub const TANH: &str = "\
406struct Dims { N: u32 }
407var<push_constant> dims: Dims;
408
409@group(0) @binding(0) var<storage, read> input: array<f32>;
410@group(0) @binding(1) var<storage, read_write> out: array<f32>;
411
412@compute @workgroup_size(256)
413fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
414    let i = gid.x;
415    if i >= dims.N { return; }
416    out[i] = tanh(input[i]);
417}";
418
419    /// CUDA C equivalent of [`TANH`].
420    #[cfg(feature = "cuda")]
421    pub const TANH_CUDA: &str = "\
422extern \"C\" __global__ void tanh_fwd(
423    const float* input, float* out,
424    unsigned int N
425) {
426    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
427    if (i >= N) return;
428    out[i] = tanhf(input[i]);
429}";
430
431    /// Sigmoid activation: `out[i] = 1 / (1 + exp(-in[i]))`.
432    ///
433    /// **Push constants:** `struct Dims { N: u32 }` (4 bytes)
434    /// **Workgroup size:** 256 — dispatch `N` invocations
435    /// **Bindings:**
436    ///   - `@binding(0)` `input: array<f32>` (read)
437    ///   - `@binding(1)` `out: array<f32>` (`read_write`)
438    pub const SIGMOID: &str = "\
439struct Dims { N: u32 }
440var<push_constant> dims: Dims;
441
442@group(0) @binding(0) var<storage, read> input: array<f32>;
443@group(0) @binding(1) var<storage, read_write> out: array<f32>;
444
445@compute @workgroup_size(256)
446fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
447    let i = gid.x;
448    if i >= dims.N { return; }
449    out[i] = 1.0 / (1.0 + exp(-input[i]));
450}";
451
452    /// CUDA C equivalent of [`SIGMOID`].
453    #[cfg(feature = "cuda")]
454    pub const SIGMOID_CUDA: &str = "\
455extern \"C\" __global__ void sigmoid(
456    const float* input, float* out,
457    unsigned int N
458) {
459    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
460    if (i >= N) return;
461    out[i] = 1.0f / (1.0f + expf(-input[i]));
462}";
463}
464
465/// Backward activation and utility shaders for backpropagation.
466///
467/// All shaders use workgroup size 256 (1D) and follow the same dispatch
468/// pattern as the [`elementwise`] forward shaders.
469pub mod backward {
470    /// `ReLU` backward: `out[i] = grad[i] * (z[i] > 0 ? 1 : 0)`.
471    ///
472    /// Uses the pre-activation value `z` (not the activated output).
473    ///
474    /// **Push constants:** `struct Dims { N: u32 }` (4 bytes)
475    /// **Workgroup size:** 256 — dispatch `N` invocations
476    /// **Bindings:**
477    ///   - `@binding(0)` `grad: array<f32>` (read) — upstream gradient
478    ///   - `@binding(1)` `z: array<f32>` (read) — pre-activation values
479    ///   - `@binding(2)` `out: array<f32>` (`read_write`) — output gradient
480    pub const RELU_BACKWARD: &str = "\
481struct Dims { N: u32 }
482var<push_constant> dims: Dims;
483
484@group(0) @binding(0) var<storage, read> grad: array<f32>;
485@group(0) @binding(1) var<storage, read> z: array<f32>;
486@group(0) @binding(2) var<storage, read_write> out: array<f32>;
487
488@compute @workgroup_size(256)
489fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
490    let i = gid.x;
491    if i >= dims.N { return; }
492    out[i] = select(0.0, grad[i], z[i] > 0.0);
493}";
494
495    /// CUDA C equivalent of [`RELU_BACKWARD`].
496    #[cfg(feature = "cuda")]
497    pub const RELU_BACKWARD_CUDA: &str = "\
498extern \"C\" __global__ void relu_backward(
499    const float* grad, const float* z, float* out,
500    unsigned int N
501) {
502    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
503    if (i >= N) return;
504    out[i] = z[i] > 0.0f ? grad[i] : 0.0f;
505}";
506
507    /// Sigmoid backward: `out[i] = grad[i] * a[i] * (1 - a[i])`.
508    ///
509    /// Uses the post-activation value `a = sigmoid(z)`.
510    ///
511    /// **Push constants:** `struct Dims { N: u32 }` (4 bytes)
512    /// **Workgroup size:** 256 — dispatch `N` invocations
513    /// **Bindings:**
514    ///   - `@binding(0)` `grad: array<f32>` (read) — upstream gradient
515    ///   - `@binding(1)` `activated: array<f32>` (read) — post-activation values
516    ///   - `@binding(2)` `out: array<f32>` (`read_write`) — output gradient
517    pub const SIGMOID_BACKWARD: &str = "\
518struct Dims { N: u32 }
519var<push_constant> dims: Dims;
520
521@group(0) @binding(0) var<storage, read> grad: array<f32>;
522@group(0) @binding(1) var<storage, read> activated: array<f32>;
523@group(0) @binding(2) var<storage, read_write> out: array<f32>;
524
525@compute @workgroup_size(256)
526fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
527    let i = gid.x;
528    if i >= dims.N { return; }
529    let a = activated[i];
530    out[i] = grad[i] * a * (1.0 - a);
531}";
532
533    /// CUDA C equivalent of [`SIGMOID_BACKWARD`].
534    #[cfg(feature = "cuda")]
535    pub const SIGMOID_BACKWARD_CUDA: &str = "\
536extern \"C\" __global__ void sigmoid_backward(
537    const float* grad, const float* activated, float* out,
538    unsigned int N
539) {
540    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
541    if (i >= N) return;
542    float a = activated[i];
543    out[i] = grad[i] * a * (1.0f - a);
544}";
545
546    /// Tanh backward: `out[i] = grad[i] * (1 - a[i]^2)`.
547    ///
548    /// Uses the post-activation value `a = tanh(z)`.
549    ///
550    /// **Push constants:** `struct Dims { N: u32 }` (4 bytes)
551    /// **Workgroup size:** 256 — dispatch `N` invocations
552    /// **Bindings:**
553    ///   - `@binding(0)` `grad: array<f32>` (read) — upstream gradient
554    ///   - `@binding(1)` `activated: array<f32>` (read) — post-activation values
555    ///   - `@binding(2)` `out: array<f32>` (`read_write`) — output gradient
556    pub const TANH_BACKWARD: &str = "\
557struct Dims { N: u32 }
558var<push_constant> dims: Dims;
559
560@group(0) @binding(0) var<storage, read> grad: array<f32>;
561@group(0) @binding(1) var<storage, read> activated: array<f32>;
562@group(0) @binding(2) var<storage, read_write> out: array<f32>;
563
564@compute @workgroup_size(256)
565fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
566    let i = gid.x;
567    if i >= dims.N { return; }
568    let a = activated[i];
569    out[i] = grad[i] * (1.0 - a * a);
570}";
571
572    /// CUDA C equivalent of [`TANH_BACKWARD`].
573    #[cfg(feature = "cuda")]
574    pub const TANH_BACKWARD_CUDA: &str = "\
575extern \"C\" __global__ void tanh_backward(
576    const float* grad, const float* activated, float* out,
577    unsigned int N
578) {
579    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
580    if (i >= N) return;
581    float a = activated[i];
582    out[i] = grad[i] * (1.0f - a * a);
583}";
584
585    /// Matrix transpose: `out[col * rows + row] = in[row * cols + col]`.
586    ///
587    /// Transposes a row-major `[rows, cols]` matrix to `[cols, rows]`.
588    /// Each thread handles one element.
589    ///
590    /// **Push constants:** `struct Dims { rows: u32, cols: u32 }` (8 bytes)
591    /// **Workgroup size:** 256 — dispatch `rows * cols` invocations
592    /// **Bindings:**
593    ///   - `@binding(0)` `input: array<f32>` (read)
594    ///   - `@binding(1)` `out: array<f32>` (`read_write`)
595    pub const TRANSPOSE: &str = "\
596struct Dims { rows: u32, cols: u32 }
597var<push_constant> dims: Dims;
598
599@group(0) @binding(0) var<storage, read> input: array<f32>;
600@group(0) @binding(1) var<storage, read_write> out: array<f32>;
601
602@compute @workgroup_size(256)
603fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
604    let i = gid.x;
605    let n = dims.rows * dims.cols;
606    if i >= n { return; }
607    let row = i / dims.cols;
608    let col = i % dims.cols;
609    out[col * dims.rows + row] = input[i];
610}";
611
612    /// CUDA C equivalent of [`TRANSPOSE`].
613    #[cfg(feature = "cuda")]
614    pub const TRANSPOSE_CUDA: &str = "\
615extern \"C\" __global__ void transpose_2d(
616    const float* input, float* out,
617    unsigned int rows, unsigned int cols
618) {
619    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
620    if (i >= rows * cols) return;
621    unsigned int row = i / cols;
622    unsigned int col = i % cols;
623    out[col * rows + row] = input[i];
624}";
625
626    /// Element-wise scale: `out[i] = in[i] * alpha`.
627    ///
628    /// **Push constants:** `struct Dims { N: u32, alpha: f32 }` (8 bytes)
629    /// **Workgroup size:** 256 — dispatch `N` invocations
630    /// **Bindings:**
631    ///   - `@binding(0)` `input: array<f32>` (read)
632    ///   - `@binding(1)` `out: array<f32>` (`read_write`)
633    pub const SCALE: &str = "\
634struct Dims { N: u32, alpha: f32 }
635var<push_constant> dims: Dims;
636
637@group(0) @binding(0) var<storage, read> input: array<f32>;
638@group(0) @binding(1) var<storage, read_write> out: array<f32>;
639
640@compute @workgroup_size(256)
641fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
642    let i = gid.x;
643    if i >= dims.N { return; }
644    out[i] = input[i] * dims.alpha;
645}";
646
647    /// CUDA C equivalent of [`SCALE`].
648    #[cfg(feature = "cuda")]
649    pub const SCALE_CUDA: &str = "\
650extern \"C\" __global__ void scale_fwd(
651    const float* input, float* out,
652    unsigned int N, float alpha
653) {
654    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
655    if (i >= N) return;
656    out[i] = input[i] * alpha;
657}";
658
659    /// Column-wise reduction: `out[j] = sum_i(in[i * cols + j]) * scale`.
660    ///
661    /// Sums over the row (batch) dimension for each column, then scales.
662    /// Used for bias gradient computation: `db = reduce_cols(delta, 1/batch)`.
663    ///
664    /// **Push constants:** `struct Dims { rows: u32, cols: u32, scale: f32 }` (12 bytes)
665    /// **Workgroup size:** 256 — dispatch `cols` invocations
666    /// **Bindings:**
667    ///   - `@binding(0)` `input: array<f32>` (read) — `[rows, cols]` matrix
668    ///   - `@binding(1)` `out: array<f32>` (`read_write`) — `[cols]` vector
669    pub const REDUCE_COLS: &str = "\
670struct Dims { rows: u32, cols: u32, scale: f32 }
671var<push_constant> dims: Dims;
672
673@group(0) @binding(0) var<storage, read> input: array<f32>;
674@group(0) @binding(1) var<storage, read_write> out: array<f32>;
675
676@compute @workgroup_size(256)
677fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
678    let j = gid.x;
679    if j >= dims.cols { return; }
680    var sum = 0.0;
681    for (var i = 0u; i < dims.rows; i++) {
682        sum += input[i * dims.cols + j];
683    }
684    out[j] = sum * dims.scale;
685}";
686
687    /// CUDA C equivalent of [`REDUCE_COLS`].
688    #[cfg(feature = "cuda")]
689    pub const REDUCE_COLS_CUDA: &str = "\
690extern \"C\" __global__ void reduce_cols(
691    const float* input, float* out,
692    unsigned int rows, unsigned int cols, float scale
693) {
694    unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;
695    if (j >= cols) return;
696    float sum = 0.0f;
697    for (unsigned int i = 0; i < rows; i++) {
698        sum += input[i * cols + j];
699    }
700    out[j] = sum * scale;
701}";
702}
703
704/// Pairwise distance shaders.
705pub mod distance {
706    /// Pairwise squared Euclidean distance.
707    ///
708    /// For `n_q` query points and `n_t` training points in `dim` dimensions,
709    /// computes the `n_q x n_t` distance matrix where:
710    ///   `D[i][j] = sum_d (Q[i*dim+d] - T[j*dim+d])^2`
711    ///
712    /// Each thread computes one (query, train) pair.
713    ///
714    /// **Push constants:** `struct Dims { n_q: u32, n_t: u32, dim: u32 }` (12 bytes)
715    /// **Workgroup size:** 256 (1D) — dispatch `n_q * n_t` invocations
716    /// **Bindings:**
717    ///   - `@binding(0)` `queries: array<f32>` (read)
718    ///   - `@binding(1)` `train: array<f32>` (read)
719    ///   - `@binding(2)` `dists: array<f32>` (`read_write`)
720    pub const PAIRWISE_EUCLIDEAN: &str = "\
721struct Dims { n_q: u32, n_t: u32, dim: u32 }
722var<push_constant> dims: Dims;
723
724@group(0) @binding(0) var<storage, read> queries: array<f32>;
725@group(0) @binding(1) var<storage, read> train: array<f32>;
726@group(0) @binding(2) var<storage, read_write> dists: array<f32>;
727
728@compute @workgroup_size(256)
729fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
730    let idx = gid.x;
731    let total = dims.n_q * dims.n_t;
732    if (idx >= total) {
733        return;
734    }
735
736    let i = idx / dims.n_t;
737    let j = idx % dims.n_t;
738
739    var sum: f32 = 0.0;
740    let q_base = i * dims.dim;
741    let t_base = j * dims.dim;
742
743    for (var d: u32 = 0u; d < dims.dim; d = d + 1u) {
744        let diff = queries[q_base + d] - train[t_base + d];
745        sum = sum + diff * diff;
746    }
747
748    dists[idx] = sum;
749}";
750
751    /// CUDA C equivalent of [`PAIRWISE_EUCLIDEAN`].
752    ///
753    /// **Kernel signature:** `pairwise_euclidean(const float* queries, const float* train, float* dists, unsigned int n_q, unsigned int n_t, unsigned int dim)`
754    /// **Block size:** `(256, 1, 1)` — dispatch `n_q * n_t` invocations
755    #[cfg(feature = "cuda")]
756    pub const PAIRWISE_EUCLIDEAN_CUDA: &str = "\
757extern \"C\" __global__ void pairwise_euclidean(
758    const float* queries, const float* train, float* dists,
759    unsigned int n_q, unsigned int n_t, unsigned int dim
760) {
761    unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
762    unsigned int total = n_q * n_t;
763    if (idx >= total) return;
764
765    unsigned int i = idx / n_t;
766    unsigned int j = idx % n_t;
767
768    float sum = 0.0f;
769    unsigned int q_base = i * dim;
770    unsigned int t_base = j * dim;
771
772    for (unsigned int d = 0; d < dim; d++) {
773        float diff = queries[q_base + d] - train[t_base + d];
774        sum += diff * diff;
775    }
776
777    dists[idx] = sum;
778}";
779}