Skip to main content

tang_expr/
codegen.rs

1//! Multi-dialect compute shader code generation.
2//!
3//! Generates compute kernels in WGSL, MSL (Metal Shading Language),
4//! CUDA C, and plain C from the same expression graph. The expression
5//! syntax is nearly identical across dialects — only kernel boilerplate
6//! and a few edge cases (select, literal suffixes) differ.
7
8use std::fmt::Write;
9
10use crate::graph::ExprGraph;
11use crate::node::{ExprId, Node};
12
13/// Target shader dialect.
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15pub enum Dialect {
16    /// WebGPU Shading Language.
17    Wgsl,
18    /// Metal Shading Language.
19    Msl,
20    /// CUDA C (compiled via NVRTC).
21    Cuda,
22    /// Plain C (for CPU fallback / verification).
23    C,
24}
25
26/// A generated compute kernel.
27pub struct ComputeKernel {
28    /// Complete kernel source code.
29    pub source: String,
30    /// Number of input values per work item.
31    pub n_inputs: usize,
32    /// Number of output values per work item.
33    pub n_outputs: usize,
34    /// Workgroup / threadgroup / block size.
35    pub workgroup_size: u32,
36    /// Which dialect this kernel was generated for.
37    pub dialect: Dialect,
38    /// Entry point function name.
39    pub entry_point: &'static str,
40}
41
42impl ExprGraph {
43    /// Generate a compute kernel for the given dialect.
44    ///
45    /// Each work item reads `n_inputs` f32 values and writes `outputs.len()` f32 values.
46    /// Shared subexpressions are computed once per thread.
47    pub fn to_kernel(&self, outputs: &[ExprId], n_inputs: usize, dialect: Dialect) -> ComputeKernel {
48        let workgroup_size = 256u32;
49        let n_outputs = outputs.len();
50        let live = self.live_set(outputs);
51        let max_id = if live.is_empty() {
52            0
53        } else {
54            *live.iter().max().unwrap()
55        };
56
57        let mut src = String::with_capacity(2048);
58
59        match dialect {
60            Dialect::Wgsl => emit_wgsl(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id, workgroup_size),
61            Dialect::Msl => emit_msl(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id),
62            Dialect::Cuda => emit_cuda(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id),
63            Dialect::C => emit_c(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id),
64        }
65
66        ComputeKernel {
67            source: src,
68            n_inputs,
69            n_outputs,
70            workgroup_size,
71            dialect,
72            entry_point: "k0",
73        }
74    }
75}
76
77// ---------------------------------------------------------------------------
78// Shared body emission
79// ---------------------------------------------------------------------------
80
81/// Emit SSA evaluation lines shared by all dialects.
82///
83/// `decl` is the variable declaration prefix: `"let"` for WGSL, `"float"` for MSL/CUDA/C.
84/// `lit_suffix` is appended to float literals: `""` for WGSL, `"f"` for MSL/CUDA/C.
85fn emit_body(
86    src: &mut String,
87    graph: &ExprGraph,
88    outputs: &[ExprId],
89    n_inputs: usize,
90    n_outputs: usize,
91    live: &std::collections::HashSet<usize>,
92    max_id: usize,
93    indent: &str,
94    decl: &str,
95    lit_suffix: &str,
96    thread_id: &str,
97    dialect: Dialect,
98) {
99    // Load inputs
100    if n_inputs > 0 {
101        let base = format!("{thread_id} * {n_inputs}u");
102        let base_wgsl = format!("{thread_id} * {n_inputs}");
103        for i in 0..n_inputs {
104            match dialect {
105                Dialect::Wgsl => {
106                    writeln!(src, "{indent}{decl} x{i} = inputs[{base} + {i}u];").unwrap();
107                }
108                Dialect::Msl => {
109                    writeln!(src, "{indent}{decl} x{i} = inputs[{base_wgsl} + {i}];").unwrap();
110                }
111                Dialect::Cuda => {
112                    writeln!(src, "{indent}{decl} x{i} = in{i}[{thread_id}];").unwrap();
113                }
114                Dialect::C => {
115                    writeln!(src, "{indent}{decl} x{i} = inputs[i * {n_inputs} + {i}];").unwrap();
116                }
117            }
118        }
119        writeln!(src).unwrap();
120    }
121
122    // Evaluate in topological order (SSA form)
123    for i in 0..=max_id {
124        if !live.contains(&i) {
125            continue;
126        }
127        let node = graph.node(ExprId(i as u32));
128        match node {
129            Node::Var(_) | Node::Lit(_) => continue,
130            _ => {}
131        }
132        let rhs = expr_str(graph, node, lit_suffix, dialect);
133        writeln!(src, "{indent}{decl} t{i} = {rhs};").unwrap();
134    }
135    writeln!(src).unwrap();
136
137    // Store outputs
138    if n_outputs > 0 {
139        for (k, out) in outputs.iter().enumerate() {
140            let val = ref_str(graph, *out, lit_suffix);
141            match dialect {
142                Dialect::Wgsl => {
143                    let base = format!("{thread_id} * {n_outputs}u");
144                    writeln!(src, "{indent}outputs[{base} + {k}u] = {val};").unwrap();
145                }
146                Dialect::Msl => {
147                    let base = format!("{thread_id} * {n_outputs}");
148                    writeln!(src, "{indent}outputs[{base} + {k}] = {val};").unwrap();
149                }
150                Dialect::Cuda => {
151                    let base = format!("{thread_id} * {n_outputs}");
152                    writeln!(src, "{indent}outputs[{base} + {k}] = {val};").unwrap();
153                }
154                Dialect::C => {
155                    writeln!(src, "{indent}outputs[i * {n_outputs} + {k}] = {val};").unwrap();
156                }
157            }
158        }
159    }
160}
161
162/// Generate expression string for a node.
163fn expr_str(graph: &ExprGraph, node: Node, suffix: &str, dialect: Dialect) -> String {
164    match node {
165        Node::Var(n) => format!("x{n}"),
166        Node::Lit(bits) => format_literal(f64::from_bits(bits), suffix),
167        Node::Add(a, b) => format!("({} + {})", ref_str(graph, a, suffix), ref_str(graph, b, suffix)),
168        Node::Mul(a, b) => format!("({} * {})", ref_str(graph, a, suffix), ref_str(graph, b, suffix)),
169        Node::Neg(a) => format!("(-{})", ref_str(graph, a, suffix)),
170        Node::Recip(a) => format!("(1.0{suffix} / {})", ref_str(graph, a, suffix)),
171        Node::Sqrt(a) => format!("sqrt({})", ref_str(graph, a, suffix)),
172        Node::Sin(a) => format!("sin({})", ref_str(graph, a, suffix)),
173        Node::Atan2(y, x) => format!("atan2({}, {})", ref_str(graph, y, suffix), ref_str(graph, x, suffix)),
174        Node::Exp2(a) => format!("exp2({})", ref_str(graph, a, suffix)),
175        Node::Log2(a) => format!("log2({})", ref_str(graph, a, suffix)),
176        Node::Select(c, a, b) => {
177            match dialect {
178                Dialect::Wgsl => {
179                    // WGSL: select(false_val, true_val, cond)
180                    format!(
181                        "select({}, {}, {} > 0.0)",
182                        ref_str(graph, b, suffix),
183                        ref_str(graph, a, suffix),
184                        ref_str(graph, c, suffix),
185                    )
186                }
187                _ => {
188                    // MSL/CUDA/C: ternary
189                    format!(
190                        "({} > 0.0{suffix} ? {} : {})",
191                        ref_str(graph, c, suffix),
192                        ref_str(graph, a, suffix),
193                        ref_str(graph, b, suffix),
194                    )
195                }
196            }
197        }
198    }
199}
200
201/// Reference a node inline: Var → x{n}, Lit → literal, others → t{index}.
202fn ref_str(graph: &ExprGraph, id: ExprId, suffix: &str) -> String {
203    match graph.node(id) {
204        Node::Var(n) => format!("x{n}"),
205        Node::Lit(bits) => format_literal(f64::from_bits(bits), suffix),
206        _ => format!("t{}", id.0),
207    }
208}
209
210/// Format f64 as a float literal with optional suffix.
211fn format_literal(v: f64, suffix: &str) -> String {
212    let base = if v == 0.0 {
213        "0.0".to_string()
214    } else if v == 1.0 {
215        "1.0".to_string()
216    } else if v == -1.0 {
217        "-1.0".to_string()
218    } else if v == 2.0 {
219        "2.0".to_string()
220    } else {
221        let s = format!("{v}");
222        if s.contains('.') || s.contains('e') || s.contains('E') {
223            s
224        } else {
225            format!("{s}.0")
226        }
227    };
228    format!("{base}{suffix}")
229}
230
231// ---------------------------------------------------------------------------
232// WGSL dialect
233// ---------------------------------------------------------------------------
234
235fn emit_wgsl(
236    src: &mut String,
237    graph: &ExprGraph,
238    outputs: &[ExprId],
239    n_inputs: usize,
240    n_outputs: usize,
241    live: &std::collections::HashSet<usize>,
242    max_id: usize,
243    workgroup_size: u32,
244) {
245    writeln!(src, "// Auto-generated by tang-expr").unwrap();
246    writeln!(src).unwrap();
247
248    // Params struct
249    writeln!(src, "struct Params {{").unwrap();
250    writeln!(src, "    count: u32,").unwrap();
251    writeln!(src, "    _pad1: u32,").unwrap();
252    writeln!(src, "    _pad2: u32,").unwrap();
253    writeln!(src, "    _pad3: u32,").unwrap();
254    writeln!(src, "}}").unwrap();
255    writeln!(src).unwrap();
256
257    // Bindings
258    writeln!(src, "@group(0) @binding(0) var<storage, read> inputs: array<f32>;").unwrap();
259    writeln!(src, "@group(0) @binding(1) var<storage, read_write> outputs: array<f32>;").unwrap();
260    writeln!(src, "@group(0) @binding(2) var<uniform> params: Params;").unwrap();
261    writeln!(src).unwrap();
262
263    // Entry point
264    writeln!(src, "@compute @workgroup_size({workgroup_size})").unwrap();
265    writeln!(src, "fn k0(@builtin(global_invocation_id) gid: vec3<u32>) {{").unwrap();
266    writeln!(src, "    let idx = gid.x;").unwrap();
267    writeln!(src, "    if (idx >= params.count) {{ return; }}").unwrap();
268    writeln!(src).unwrap();
269
270    emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, "    ", "let", "", "idx", Dialect::Wgsl);
271
272    writeln!(src, "}}").unwrap();
273}
274
275// ---------------------------------------------------------------------------
276// MSL dialect
277// ---------------------------------------------------------------------------
278
279fn emit_msl(
280    src: &mut String,
281    graph: &ExprGraph,
282    outputs: &[ExprId],
283    n_inputs: usize,
284    n_outputs: usize,
285    live: &std::collections::HashSet<usize>,
286    max_id: usize,
287) {
288    writeln!(src, "// Auto-generated by tang-expr").unwrap();
289    writeln!(src, "#include <metal_stdlib>").unwrap();
290    writeln!(src, "using namespace metal;").unwrap();
291    writeln!(src).unwrap();
292
293    write!(src, "kernel void k0(").unwrap();
294    writeln!(src, "device const float* inputs [[buffer(0)]],").unwrap();
295    writeln!(src, "    device float* outputs [[buffer(1)]],").unwrap();
296    writeln!(src, "    device const uint& count [[buffer(2)]],").unwrap();
297    writeln!(src, "    uint gid [[thread_position_in_grid]]) {{").unwrap();
298    writeln!(src, "    if (gid >= count) {{ return; }}").unwrap();
299    writeln!(src).unwrap();
300
301    emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, "    ", "float", "f", "gid", Dialect::Msl);
302
303    writeln!(src, "}}").unwrap();
304}
305
306// ---------------------------------------------------------------------------
307// CUDA dialect
308// ---------------------------------------------------------------------------
309
310fn emit_cuda(
311    src: &mut String,
312    graph: &ExprGraph,
313    outputs: &[ExprId],
314    n_inputs: usize,
315    n_outputs: usize,
316    live: &std::collections::HashSet<usize>,
317    max_id: usize,
318) {
319    writeln!(src, "// Auto-generated by tang-expr").unwrap();
320    // NVRTC provides all math builtins (expf, sqrtf, fmaxf, etc.) — no #include needed.
321    writeln!(src).unwrap();
322
323    // Separate input pointers: in0, in1, ... instead of one interleaved buffer
324    write!(src, "extern \"C\" __global__ void k0(").unwrap();
325    for i in 0..n_inputs {
326        writeln!(src, "const float* __restrict__ in{i},").unwrap();
327        write!(src, "    ").unwrap();
328    }
329    writeln!(src, "float* __restrict__ outputs,").unwrap();
330    writeln!(src, "    const unsigned int count) {{").unwrap();
331    writeln!(src, "    unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;").unwrap();
332    writeln!(src, "    if (gid >= count) {{ return; }}").unwrap();
333    writeln!(src).unwrap();
334
335    emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, "    ", "float", "f", "gid", Dialect::Cuda);
336
337    writeln!(src, "}}").unwrap();
338}
339
340// ---------------------------------------------------------------------------
341// C dialect
342// ---------------------------------------------------------------------------
343
344fn emit_c(
345    src: &mut String,
346    graph: &ExprGraph,
347    outputs: &[ExprId],
348    n_inputs: usize,
349    n_outputs: usize,
350    live: &std::collections::HashSet<usize>,
351    max_id: usize,
352) {
353    writeln!(src, "// Auto-generated by tang-expr").unwrap();
354    writeln!(src, "#include <math.h>").unwrap();
355    writeln!(src).unwrap();
356
357    write!(src, "void k0(").unwrap();
358    writeln!(src, "const float* inputs,").unwrap();
359    writeln!(src, "    float* outputs,").unwrap();
360    writeln!(src, "    int count) {{").unwrap();
361    writeln!(src, "    for (int i = 0; i < count; i++) {{").unwrap();
362
363    emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, "        ", "float", "f", "i", Dialect::C);
364
365    writeln!(src, "    }}").unwrap();
366    writeln!(src, "}}").unwrap();
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::graph::ExprGraph;
373
374    #[test]
375    fn wgsl_matches_original() {
376        // Verify that to_kernel(Wgsl) produces output matching to_wgsl()
377        let mut g = ExprGraph::new();
378        let x = g.var(0);
379        let y = g.var(1);
380        let xx = g.mul(x, x);
381        let yy = g.mul(y, y);
382        let sum = g.add(xx, yy);
383        let dist = g.sqrt(sum);
384
385        let old = g.to_wgsl(&[dist], 2);
386        let new = g.to_kernel(&[dist], 2, Dialect::Wgsl);
387
388        assert_eq!(new.n_inputs, old.n_inputs);
389        assert_eq!(new.n_outputs, old.n_outputs);
390        assert_eq!(new.workgroup_size, old.workgroup_size);
391        // Both should contain core elements
392        assert!(new.source.contains("@compute"));
393        assert!(new.source.contains("sqrt("));
394    }
395
396    #[test]
397    fn msl_entry_point() {
398        let mut g = ExprGraph::new();
399        let x = g.var(0);
400        let y = g.var(1);
401        let sum = g.add(x, y);
402        let kernel = g.to_kernel(&[sum], 2, Dialect::Msl);
403        assert!(kernel.source.contains("kernel void k0("));
404        assert!(kernel.source.contains("thread_position_in_grid"));
405        assert!(kernel.source.contains("#include <metal_stdlib>"));
406        assert_eq!(kernel.entry_point, "k0");
407    }
408
409    #[test]
410    fn cuda_entry_point() {
411        let mut g = ExprGraph::new();
412        let x = g.var(0);
413        let y = g.var(1);
414        let prod = g.mul(x, y);
415        let kernel = g.to_kernel(&[prod], 2, Dialect::Cuda);
416        assert!(kernel.source.contains("extern \"C\" __global__ void k0("));
417        assert!(kernel.source.contains("blockIdx.x * blockDim.x + threadIdx.x"));
418        // Separate input pointers instead of interleaved buffer
419        assert!(kernel.source.contains("const float* __restrict__ in0,"));
420        assert!(kernel.source.contains("const float* __restrict__ in1,"));
421        assert!(!kernel.source.contains("inputs["));
422        // Loads from separate pointers
423        assert!(kernel.source.contains("in0[gid]"));
424        assert!(kernel.source.contains("in1[gid]"));
425    }
426
427    #[test]
428    fn c_loop() {
429        let mut g = ExprGraph::new();
430        let x = g.var(0);
431        let s = g.sin(x);
432        let kernel = g.to_kernel(&[s], 1, Dialect::C);
433        assert!(kernel.source.contains("for (int i = 0; i < count; i++)"));
434        assert!(kernel.source.contains("sin("));
435    }
436
437    #[test]
438    fn msl_select_ternary() {
439        let mut g = ExprGraph::new();
440        let x = g.var(0);
441        let a = g.lit(3.0);
442        let b = g.lit(7.0);
443        let s = g.select(x, a, b);
444        let kernel = g.to_kernel(&[s], 1, Dialect::Msl);
445        // MSL should use ternary, not select()
446        assert!(kernel.source.contains("?"));
447        assert!(!kernel.source.contains("select("));
448    }
449
450    #[test]
451    fn wgsl_select_builtin() {
452        let mut g = ExprGraph::new();
453        let x = g.var(0);
454        let a = g.lit(3.0);
455        let b = g.lit(7.0);
456        let s = g.select(x, a, b);
457        let kernel = g.to_kernel(&[s], 1, Dialect::Wgsl);
458        assert!(kernel.source.contains("select("));
459    }
460
461    #[test]
462    fn msl_literal_suffix() {
463        let mut g = ExprGraph::new();
464        let x = g.var(0);
465        let c = g.lit(3.14);
466        let prod = g.mul(x, c);
467        let kernel = g.to_kernel(&[prod], 1, Dialect::Msl);
468        assert!(kernel.source.contains("3.14f"));
469    }
470
471    #[test]
472    fn multiple_outputs_all_dialects() {
473        let mut g = ExprGraph::new();
474        let x = g.var(0);
475        let y = g.var(1);
476        let sum = g.add(x, y);
477        let prod = g.mul(x, y);
478
479        for dialect in [Dialect::Wgsl, Dialect::Msl, Dialect::Cuda, Dialect::C] {
480            let kernel = g.to_kernel(&[sum, prod], 2, dialect);
481            assert_eq!(kernel.n_outputs, 2);
482            assert_eq!(kernel.n_inputs, 2);
483        }
484    }
485
486    #[test]
487    fn full_pipeline_all_dialects() {
488        let mut g = ExprGraph::new();
489        let x = g.var(0);
490        let xx = g.mul(x, x);
491        let dx = g.diff(xx, 0);
492        let dx = g.simplify(dx);
493
494        for dialect in [Dialect::Wgsl, Dialect::Msl, Dialect::Cuda, Dialect::C] {
495            let kernel = g.to_kernel(&[xx, dx], 1, dialect);
496            assert_eq!(kernel.n_outputs, 2);
497            assert!(!kernel.source.is_empty());
498        }
499    }
500}