1use std::fmt::Write;
9
10use crate::graph::ExprGraph;
11use crate::node::{ExprId, Node};
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15pub enum Dialect {
16 Wgsl,
18 Msl,
20 Cuda,
22 C,
24}
25
26pub struct ComputeKernel {
28 pub source: String,
30 pub n_inputs: usize,
32 pub n_outputs: usize,
34 pub workgroup_size: u32,
36 pub dialect: Dialect,
38 pub entry_point: &'static str,
40}
41
42impl ExprGraph {
43 pub fn to_kernel(&self, outputs: &[ExprId], n_inputs: usize, dialect: Dialect) -> ComputeKernel {
48 let workgroup_size = 256u32;
49 let n_outputs = outputs.len();
50 let live = self.live_set(outputs);
51 let max_id = if live.is_empty() {
52 0
53 } else {
54 *live.iter().max().unwrap()
55 };
56
57 let mut src = String::with_capacity(2048);
58
59 match dialect {
60 Dialect::Wgsl => emit_wgsl(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id, workgroup_size),
61 Dialect::Msl => emit_msl(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id),
62 Dialect::Cuda => emit_cuda(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id),
63 Dialect::C => emit_c(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id),
64 }
65
66 ComputeKernel {
67 source: src,
68 n_inputs,
69 n_outputs,
70 workgroup_size,
71 dialect,
72 entry_point: "k0",
73 }
74 }
75}
76
77fn emit_body(
86 src: &mut String,
87 graph: &ExprGraph,
88 outputs: &[ExprId],
89 n_inputs: usize,
90 n_outputs: usize,
91 live: &std::collections::HashSet<usize>,
92 max_id: usize,
93 indent: &str,
94 decl: &str,
95 lit_suffix: &str,
96 thread_id: &str,
97 dialect: Dialect,
98) {
99 if n_inputs > 0 {
101 let base = format!("{thread_id} * {n_inputs}u");
102 let base_wgsl = format!("{thread_id} * {n_inputs}");
103 for i in 0..n_inputs {
104 match dialect {
105 Dialect::Wgsl => {
106 writeln!(src, "{indent}{decl} x{i} = inputs[{base} + {i}u];").unwrap();
107 }
108 Dialect::Msl => {
109 writeln!(src, "{indent}{decl} x{i} = inputs[{base_wgsl} + {i}];").unwrap();
110 }
111 Dialect::Cuda => {
112 writeln!(src, "{indent}{decl} x{i} = in{i}[{thread_id}];").unwrap();
113 }
114 Dialect::C => {
115 writeln!(src, "{indent}{decl} x{i} = inputs[i * {n_inputs} + {i}];").unwrap();
116 }
117 }
118 }
119 writeln!(src).unwrap();
120 }
121
122 for i in 0..=max_id {
124 if !live.contains(&i) {
125 continue;
126 }
127 let node = graph.node(ExprId(i as u32));
128 match node {
129 Node::Var(_) | Node::Lit(_) => continue,
130 _ => {}
131 }
132 let rhs = expr_str(graph, node, lit_suffix, dialect);
133 writeln!(src, "{indent}{decl} t{i} = {rhs};").unwrap();
134 }
135 writeln!(src).unwrap();
136
137 if n_outputs > 0 {
139 for (k, out) in outputs.iter().enumerate() {
140 let val = ref_str(graph, *out, lit_suffix);
141 match dialect {
142 Dialect::Wgsl => {
143 let base = format!("{thread_id} * {n_outputs}u");
144 writeln!(src, "{indent}outputs[{base} + {k}u] = {val};").unwrap();
145 }
146 Dialect::Msl => {
147 let base = format!("{thread_id} * {n_outputs}");
148 writeln!(src, "{indent}outputs[{base} + {k}] = {val};").unwrap();
149 }
150 Dialect::Cuda => {
151 let base = format!("{thread_id} * {n_outputs}");
152 writeln!(src, "{indent}outputs[{base} + {k}] = {val};").unwrap();
153 }
154 Dialect::C => {
155 writeln!(src, "{indent}outputs[i * {n_outputs} + {k}] = {val};").unwrap();
156 }
157 }
158 }
159 }
160}
161
162fn expr_str(graph: &ExprGraph, node: Node, suffix: &str, dialect: Dialect) -> String {
164 match node {
165 Node::Var(n) => format!("x{n}"),
166 Node::Lit(bits) => format_literal(f64::from_bits(bits), suffix),
167 Node::Add(a, b) => format!("({} + {})", ref_str(graph, a, suffix), ref_str(graph, b, suffix)),
168 Node::Mul(a, b) => format!("({} * {})", ref_str(graph, a, suffix), ref_str(graph, b, suffix)),
169 Node::Neg(a) => format!("(-{})", ref_str(graph, a, suffix)),
170 Node::Recip(a) => format!("(1.0{suffix} / {})", ref_str(graph, a, suffix)),
171 Node::Sqrt(a) => format!("sqrt({})", ref_str(graph, a, suffix)),
172 Node::Sin(a) => format!("sin({})", ref_str(graph, a, suffix)),
173 Node::Atan2(y, x) => format!("atan2({}, {})", ref_str(graph, y, suffix), ref_str(graph, x, suffix)),
174 Node::Exp2(a) => format!("exp2({})", ref_str(graph, a, suffix)),
175 Node::Log2(a) => format!("log2({})", ref_str(graph, a, suffix)),
176 Node::Select(c, a, b) => {
177 match dialect {
178 Dialect::Wgsl => {
179 format!(
181 "select({}, {}, {} > 0.0)",
182 ref_str(graph, b, suffix),
183 ref_str(graph, a, suffix),
184 ref_str(graph, c, suffix),
185 )
186 }
187 _ => {
188 format!(
190 "({} > 0.0{suffix} ? {} : {})",
191 ref_str(graph, c, suffix),
192 ref_str(graph, a, suffix),
193 ref_str(graph, b, suffix),
194 )
195 }
196 }
197 }
198 }
199}
200
201fn ref_str(graph: &ExprGraph, id: ExprId, suffix: &str) -> String {
203 match graph.node(id) {
204 Node::Var(n) => format!("x{n}"),
205 Node::Lit(bits) => format_literal(f64::from_bits(bits), suffix),
206 _ => format!("t{}", id.0),
207 }
208}
209
210fn format_literal(v: f64, suffix: &str) -> String {
212 let base = if v == 0.0 {
213 "0.0".to_string()
214 } else if v == 1.0 {
215 "1.0".to_string()
216 } else if v == -1.0 {
217 "-1.0".to_string()
218 } else if v == 2.0 {
219 "2.0".to_string()
220 } else {
221 let s = format!("{v}");
222 if s.contains('.') || s.contains('e') || s.contains('E') {
223 s
224 } else {
225 format!("{s}.0")
226 }
227 };
228 format!("{base}{suffix}")
229}
230
231fn emit_wgsl(
236 src: &mut String,
237 graph: &ExprGraph,
238 outputs: &[ExprId],
239 n_inputs: usize,
240 n_outputs: usize,
241 live: &std::collections::HashSet<usize>,
242 max_id: usize,
243 workgroup_size: u32,
244) {
245 writeln!(src, "// Auto-generated by tang-expr").unwrap();
246 writeln!(src).unwrap();
247
248 writeln!(src, "struct Params {{").unwrap();
250 writeln!(src, " count: u32,").unwrap();
251 writeln!(src, " _pad1: u32,").unwrap();
252 writeln!(src, " _pad2: u32,").unwrap();
253 writeln!(src, " _pad3: u32,").unwrap();
254 writeln!(src, "}}").unwrap();
255 writeln!(src).unwrap();
256
257 writeln!(src, "@group(0) @binding(0) var<storage, read> inputs: array<f32>;").unwrap();
259 writeln!(src, "@group(0) @binding(1) var<storage, read_write> outputs: array<f32>;").unwrap();
260 writeln!(src, "@group(0) @binding(2) var<uniform> params: Params;").unwrap();
261 writeln!(src).unwrap();
262
263 writeln!(src, "@compute @workgroup_size({workgroup_size})").unwrap();
265 writeln!(src, "fn k0(@builtin(global_invocation_id) gid: vec3<u32>) {{").unwrap();
266 writeln!(src, " let idx = gid.x;").unwrap();
267 writeln!(src, " if (idx >= params.count) {{ return; }}").unwrap();
268 writeln!(src).unwrap();
269
270 emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, " ", "let", "", "idx", Dialect::Wgsl);
271
272 writeln!(src, "}}").unwrap();
273}
274
275fn emit_msl(
280 src: &mut String,
281 graph: &ExprGraph,
282 outputs: &[ExprId],
283 n_inputs: usize,
284 n_outputs: usize,
285 live: &std::collections::HashSet<usize>,
286 max_id: usize,
287) {
288 writeln!(src, "// Auto-generated by tang-expr").unwrap();
289 writeln!(src, "#include <metal_stdlib>").unwrap();
290 writeln!(src, "using namespace metal;").unwrap();
291 writeln!(src).unwrap();
292
293 write!(src, "kernel void k0(").unwrap();
294 writeln!(src, "device const float* inputs [[buffer(0)]],").unwrap();
295 writeln!(src, " device float* outputs [[buffer(1)]],").unwrap();
296 writeln!(src, " device const uint& count [[buffer(2)]],").unwrap();
297 writeln!(src, " uint gid [[thread_position_in_grid]]) {{").unwrap();
298 writeln!(src, " if (gid >= count) {{ return; }}").unwrap();
299 writeln!(src).unwrap();
300
301 emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, " ", "float", "f", "gid", Dialect::Msl);
302
303 writeln!(src, "}}").unwrap();
304}
305
306fn emit_cuda(
311 src: &mut String,
312 graph: &ExprGraph,
313 outputs: &[ExprId],
314 n_inputs: usize,
315 n_outputs: usize,
316 live: &std::collections::HashSet<usize>,
317 max_id: usize,
318) {
319 writeln!(src, "// Auto-generated by tang-expr").unwrap();
320 writeln!(src).unwrap();
322
323 write!(src, "extern \"C\" __global__ void k0(").unwrap();
325 for i in 0..n_inputs {
326 writeln!(src, "const float* __restrict__ in{i},").unwrap();
327 write!(src, " ").unwrap();
328 }
329 writeln!(src, "float* __restrict__ outputs,").unwrap();
330 writeln!(src, " const unsigned int count) {{").unwrap();
331 writeln!(src, " unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;").unwrap();
332 writeln!(src, " if (gid >= count) {{ return; }}").unwrap();
333 writeln!(src).unwrap();
334
335 emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, " ", "float", "f", "gid", Dialect::Cuda);
336
337 writeln!(src, "}}").unwrap();
338}
339
340fn emit_c(
345 src: &mut String,
346 graph: &ExprGraph,
347 outputs: &[ExprId],
348 n_inputs: usize,
349 n_outputs: usize,
350 live: &std::collections::HashSet<usize>,
351 max_id: usize,
352) {
353 writeln!(src, "// Auto-generated by tang-expr").unwrap();
354 writeln!(src, "#include <math.h>").unwrap();
355 writeln!(src).unwrap();
356
357 write!(src, "void k0(").unwrap();
358 writeln!(src, "const float* inputs,").unwrap();
359 writeln!(src, " float* outputs,").unwrap();
360 writeln!(src, " int count) {{").unwrap();
361 writeln!(src, " for (int i = 0; i < count; i++) {{").unwrap();
362
363 emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, " ", "float", "f", "i", Dialect::C);
364
365 writeln!(src, " }}").unwrap();
366 writeln!(src, "}}").unwrap();
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::graph::ExprGraph;
373
374 #[test]
375 fn wgsl_matches_original() {
376 let mut g = ExprGraph::new();
378 let x = g.var(0);
379 let y = g.var(1);
380 let xx = g.mul(x, x);
381 let yy = g.mul(y, y);
382 let sum = g.add(xx, yy);
383 let dist = g.sqrt(sum);
384
385 let old = g.to_wgsl(&[dist], 2);
386 let new = g.to_kernel(&[dist], 2, Dialect::Wgsl);
387
388 assert_eq!(new.n_inputs, old.n_inputs);
389 assert_eq!(new.n_outputs, old.n_outputs);
390 assert_eq!(new.workgroup_size, old.workgroup_size);
391 assert!(new.source.contains("@compute"));
393 assert!(new.source.contains("sqrt("));
394 }
395
396 #[test]
397 fn msl_entry_point() {
398 let mut g = ExprGraph::new();
399 let x = g.var(0);
400 let y = g.var(1);
401 let sum = g.add(x, y);
402 let kernel = g.to_kernel(&[sum], 2, Dialect::Msl);
403 assert!(kernel.source.contains("kernel void k0("));
404 assert!(kernel.source.contains("thread_position_in_grid"));
405 assert!(kernel.source.contains("#include <metal_stdlib>"));
406 assert_eq!(kernel.entry_point, "k0");
407 }
408
409 #[test]
410 fn cuda_entry_point() {
411 let mut g = ExprGraph::new();
412 let x = g.var(0);
413 let y = g.var(1);
414 let prod = g.mul(x, y);
415 let kernel = g.to_kernel(&[prod], 2, Dialect::Cuda);
416 assert!(kernel.source.contains("extern \"C\" __global__ void k0("));
417 assert!(kernel.source.contains("blockIdx.x * blockDim.x + threadIdx.x"));
418 assert!(kernel.source.contains("const float* __restrict__ in0,"));
420 assert!(kernel.source.contains("const float* __restrict__ in1,"));
421 assert!(!kernel.source.contains("inputs["));
422 assert!(kernel.source.contains("in0[gid]"));
424 assert!(kernel.source.contains("in1[gid]"));
425 }
426
427 #[test]
428 fn c_loop() {
429 let mut g = ExprGraph::new();
430 let x = g.var(0);
431 let s = g.sin(x);
432 let kernel = g.to_kernel(&[s], 1, Dialect::C);
433 assert!(kernel.source.contains("for (int i = 0; i < count; i++)"));
434 assert!(kernel.source.contains("sin("));
435 }
436
437 #[test]
438 fn msl_select_ternary() {
439 let mut g = ExprGraph::new();
440 let x = g.var(0);
441 let a = g.lit(3.0);
442 let b = g.lit(7.0);
443 let s = g.select(x, a, b);
444 let kernel = g.to_kernel(&[s], 1, Dialect::Msl);
445 assert!(kernel.source.contains("?"));
447 assert!(!kernel.source.contains("select("));
448 }
449
450 #[test]
451 fn wgsl_select_builtin() {
452 let mut g = ExprGraph::new();
453 let x = g.var(0);
454 let a = g.lit(3.0);
455 let b = g.lit(7.0);
456 let s = g.select(x, a, b);
457 let kernel = g.to_kernel(&[s], 1, Dialect::Wgsl);
458 assert!(kernel.source.contains("select("));
459 }
460
461 #[test]
462 fn msl_literal_suffix() {
463 let mut g = ExprGraph::new();
464 let x = g.var(0);
465 let c = g.lit(3.14);
466 let prod = g.mul(x, c);
467 let kernel = g.to_kernel(&[prod], 1, Dialect::Msl);
468 assert!(kernel.source.contains("3.14f"));
469 }
470
471 #[test]
472 fn multiple_outputs_all_dialects() {
473 let mut g = ExprGraph::new();
474 let x = g.var(0);
475 let y = g.var(1);
476 let sum = g.add(x, y);
477 let prod = g.mul(x, y);
478
479 for dialect in [Dialect::Wgsl, Dialect::Msl, Dialect::Cuda, Dialect::C] {
480 let kernel = g.to_kernel(&[sum, prod], 2, dialect);
481 assert_eq!(kernel.n_outputs, 2);
482 assert_eq!(kernel.n_inputs, 2);
483 }
484 }
485
486 #[test]
487 fn full_pipeline_all_dialects() {
488 let mut g = ExprGraph::new();
489 let x = g.var(0);
490 let xx = g.mul(x, x);
491 let dx = g.diff(xx, 0);
492 let dx = g.simplify(dx);
493
494 for dialect in [Dialect::Wgsl, Dialect::Msl, Dialect::Cuda, Dialect::C] {
495 let kernel = g.to_kernel(&[xx, dx], 1, dialect);
496 assert_eq!(kernel.n_outputs, 2);
497 assert!(!kernel.source.is_empty());
498 }
499 }
500}