Skip to main content

tang_expr/
wgsl.rs

1//! WGSL compute shader code generation.
2
3use std::fmt::Write;
4
5use crate::graph::ExprGraph;
6use crate::node::{ExprId, Node};
7
8/// A generated WGSL compute shader.
9pub struct WgslKernel {
10    /// Complete WGSL shader source.
11    pub source: String,
12    /// Number of input values per work item.
13    pub n_inputs: usize,
14    /// Number of output values per work item.
15    pub n_outputs: usize,
16    /// Workgroup size (default: 256).
17    pub workgroup_size: u32,
18}
19
20impl ExprGraph {
21    /// Generate a WGSL compute shader that evaluates expressions in parallel.
22    ///
23    /// Each work item reads `n_inputs` values and writes `outputs.len()` values.
24    /// The generated shader uses f32 (GPU native). Shared subexpressions are
25    /// computed once per thread.
26    ///
27    /// The caller handles device/pipeline/dispatch (no wgpu dependency here).
28    pub fn to_wgsl(&self, outputs: &[ExprId], n_inputs: usize) -> WgslKernel {
29        let workgroup_size = 256u32;
30        let n_outputs = outputs.len();
31
32        // Find all live nodes (shared with codegen.rs and compile.rs)
33        let live = self.live_set(outputs);
34        let max_id = if live.is_empty() {
35            0
36        } else {
37            *live.iter().max().unwrap()
38        };
39
40        let mut src = String::with_capacity(2048);
41
42        // Header
43        writeln!(src, "// Auto-generated by tang-expr").unwrap();
44        writeln!(src).unwrap();
45
46        // Params struct
47        writeln!(src, "struct Params {{").unwrap();
48        writeln!(src, "    count: u32,").unwrap();
49        writeln!(src, "    _pad1: u32,").unwrap();
50        writeln!(src, "    _pad2: u32,").unwrap();
51        writeln!(src, "    _pad3: u32,").unwrap();
52        writeln!(src, "}}").unwrap();
53        writeln!(src).unwrap();
54
55        // Bindings
56        writeln!(
57            src,
58            "@group(0) @binding(0) var<storage, read> inputs: array<f32>;"
59        )
60        .unwrap();
61        writeln!(
62            src,
63            "@group(0) @binding(1) var<storage, read_write> outputs: array<f32>;"
64        )
65        .unwrap();
66        writeln!(src, "@group(0) @binding(2) var<uniform> params: Params;").unwrap();
67        writeln!(src).unwrap();
68
69        // Entry point
70        writeln!(src, "@compute @workgroup_size({workgroup_size})").unwrap();
71        writeln!(
72            src,
73            "fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{"
74        )
75        .unwrap();
76        writeln!(src, "    let idx = gid.x;").unwrap();
77        writeln!(src, "    if (idx >= params.count) {{ return; }}").unwrap();
78        writeln!(src).unwrap();
79
80        // Load inputs
81        if n_inputs > 0 {
82            writeln!(src, "    let base_in = idx * {n_inputs}u;").unwrap();
83            for i in 0..n_inputs {
84                writeln!(src, "    let x{i} = inputs[base_in + {i}u];").unwrap();
85            }
86            writeln!(src).unwrap();
87        }
88
89        // Evaluate in topological order (SSA form)
90        for i in 0..=max_id {
91            if !live.contains(&i) {
92                continue;
93            }
94            let node = self.node(ExprId(i as u32));
95            // Skip Var and Lit nodes that are used inline
96            match node {
97                Node::Var(_) | Node::Lit(_) => continue,
98                _ => {}
99            }
100            let rhs = self.wgsl_expr(node);
101            writeln!(src, "    let t{i} = {rhs};").unwrap();
102        }
103        writeln!(src).unwrap();
104
105        // Store outputs
106        if n_outputs > 0 {
107            writeln!(src, "    let base_out = idx * {n_outputs}u;").unwrap();
108            for (k, out) in outputs.iter().enumerate() {
109                let val = self.wgsl_ref(*out);
110                writeln!(src, "    outputs[base_out + {k}u] = {val};").unwrap();
111            }
112        }
113
114        writeln!(src, "}}").unwrap();
115
116        WgslKernel {
117            source: src,
118            n_inputs,
119            n_outputs,
120            workgroup_size,
121        }
122    }
123
124    /// Generate WGSL expression for a node.
125    fn wgsl_expr(&self, node: Node) -> String {
126        match node {
127            Node::Var(n) => format!("x{n}"),
128            Node::Lit(bits) => {
129                let v = f64::from_bits(bits);
130                format_f32_literal(v)
131            }
132            Node::Add(a, b) => {
133                format!("({} + {})", self.wgsl_ref(a), self.wgsl_ref(b))
134            }
135            Node::Mul(a, b) => {
136                format!("({} * {})", self.wgsl_ref(a), self.wgsl_ref(b))
137            }
138            Node::Neg(a) => format!("(-{})", self.wgsl_ref(a)),
139            Node::Recip(a) => format!("(1.0 / {})", self.wgsl_ref(a)),
140            Node::Sqrt(a) => format!("sqrt({})", self.wgsl_ref(a)),
141            Node::Sin(a) => format!("sin({})", self.wgsl_ref(a)),
142            Node::Atan2(y, x) => {
143                format!("atan2({}, {})", self.wgsl_ref(y), self.wgsl_ref(x))
144            }
145            Node::Exp2(a) => format!("exp2({})", self.wgsl_ref(a)),
146            Node::Log2(a) => format!("log2({})", self.wgsl_ref(a)),
147            Node::Select(c, a, b) => {
148                // WGSL select(false_val, true_val, cond) — false value FIRST
149                format!(
150                    "select({}, {}, {} > 0.0)",
151                    self.wgsl_ref(b),
152                    self.wgsl_ref(a),
153                    self.wgsl_ref(c)
154                )
155            }
156        }
157    }
158
159    /// Reference a node: Var → x{n}, Lit → literal, others → t{index}.
160    fn wgsl_ref(&self, id: ExprId) -> String {
161        match self.node(id) {
162            Node::Var(n) => format!("x{n}"),
163            Node::Lit(bits) => {
164                let v = f64::from_bits(bits);
165                format_f32_literal(v)
166            }
167            _ => format!("t{}", id.0),
168        }
169    }
170}
171
172/// Format an f64 as an f32 WGSL literal.
173fn format_f32_literal(v: f64) -> String {
174    if v == 0.0 {
175        "0.0".to_string()
176    } else if v == 1.0 {
177        "1.0".to_string()
178    } else if v == -1.0 {
179        "-1.0".to_string()
180    } else if v == 2.0 {
181        "2.0".to_string()
182    } else {
183        // Ensure the literal has a decimal point for WGSL
184        let s = format!("{v}");
185        if s.contains('.') || s.contains('e') || s.contains('E') {
186            s
187        } else {
188            format!("{s}.0")
189        }
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use crate::graph::ExprGraph;
196
197    #[test]
198    fn wgsl_basic() {
199        let mut g = ExprGraph::new();
200        let x = g.var(0);
201        let y = g.var(1);
202        let xx = g.mul(x, x);
203        let yy = g.mul(y, y);
204        let sum = g.add(xx, yy);
205        let dist = g.sqrt(sum);
206
207        let kernel = g.to_wgsl(&[dist], 2);
208        assert!(kernel.source.contains("@compute"));
209        assert!(kernel.source.contains("@workgroup_size(256)"));
210        assert!(kernel.source.contains("let x0 = inputs[base_in + 0u];"));
211        assert!(kernel.source.contains("let x1 = inputs[base_in + 1u];"));
212        assert!(kernel.source.contains("sqrt("));
213        assert_eq!(kernel.n_inputs, 2);
214        assert_eq!(kernel.n_outputs, 1);
215        assert_eq!(kernel.workgroup_size, 256);
216    }
217
218    #[test]
219    fn wgsl_multiple_outputs() {
220        let mut g = ExprGraph::new();
221        let x = g.var(0);
222        let y = g.var(1);
223        let sum = g.add(x, y);
224        let prod = g.mul(x, y);
225
226        let kernel = g.to_wgsl(&[sum, prod], 2);
227        assert_eq!(kernel.n_outputs, 2);
228        assert!(kernel.source.contains("let base_out = idx * 2u;"));
229        assert!(kernel.source.contains("outputs[base_out + 0u]"));
230        assert!(kernel.source.contains("outputs[base_out + 1u]"));
231    }
232
233    #[test]
234    fn wgsl_sin() {
235        let mut g = ExprGraph::new();
236        let x = g.var(0);
237        let s = g.sin(x);
238
239        let kernel = g.to_wgsl(&[s], 1);
240        assert!(kernel.source.contains("sin(x0)"));
241    }
242
243    #[test]
244    fn wgsl_lit_inline() {
245        let mut g = ExprGraph::new();
246        let x = g.var(0);
247        let c = g.lit(3.14);
248        let prod = g.mul(x, c);
249
250        let kernel = g.to_wgsl(&[prod], 1);
251        // Literal should be inlined, not assigned to a t variable
252        assert!(kernel.source.contains("3.14"));
253    }
254
255    #[test]
256    fn wgsl_select() {
257        let mut g = ExprGraph::new();
258        let x = g.var(0);
259        let a = g.lit(3.0);
260        let b = g.lit(7.0);
261        let s = g.select(x, a, b);
262
263        let kernel = g.to_wgsl(&[s], 1);
264        // WGSL select(false_val, true_val, cond)
265        assert!(kernel.source.contains("select("));
266        assert!(kernel.source.contains("> 0.0)"));
267    }
268
269    #[test]
270    fn wgsl_full_pipeline() {
271        // Build a small expression, differentiate, simplify, compile to WGSL
272        let mut g = ExprGraph::new();
273        let x = g.var(0);
274        let xx = g.mul(x, x);
275        let dx = g.diff(xx, 0);
276        let dx = g.simplify(dx);
277
278        let kernel = g.to_wgsl(&[xx, dx], 1);
279        assert_eq!(kernel.n_outputs, 2);
280        assert!(kernel.source.contains("@compute"));
281    }
282}