1#![deny(missing_docs)]
2
3use std::collections::HashMap;
9use std::fmt::Write;
10
11use t_ree::{
12 declaration::{Declaration, FunctionDefinition, Module, Newtype},
13 expression::{Binding, Block, Expression, ExpressionKind, Literal, Statement},
14 operator::{
15 ArithmeticOperator, BinaryOperator, BitwiseOperator, ComparisonOperator, LogicalOperator,
16 UnaryOperator,
17 },
18 types::{FloatWidth, IntWidth, Signedness, Type},
19};
20
21pub fn compile(module: &Module) -> Result<String, String> {
23 let mut emitter = Emitter::new();
24 emitter.emit_module(module)?;
25 Ok(emitter.output)
26}
27
28struct Emitter {
29 output: String,
30 indent_level: usize,
31 newtypes: HashMap<String, Type>,
32}
33
34impl Emitter {
35 fn new() -> Self {
36 Self {
37 output: String::new(),
38 indent_level: 0,
39 newtypes: HashMap::new(),
40 }
41 }
42
43 fn push(&mut self, string: &str) {
44 self.output.push_str(string);
45 }
46
47 fn indent(&mut self) {
48 for _ in 0..self.indent_level {
49 self.push(" ");
50 }
51 }
52
53 fn emit_module(&mut self, module: &Module) -> Result<(), String> {
54 for declaration in module {
55 if let Declaration::Type(newtype) = declaration {
56 self.newtypes
57 .insert(newtype.name.clone(), newtype.inner_type.clone());
58 }
59 }
60
61 for declaration in module {
62 self.emit_declaration(declaration)?;
63 }
64 Ok(())
65 }
66
67 fn emit_declaration(&mut self, declaration: &Declaration) -> Result<(), String> {
68 match declaration {
69 Declaration::Type(newtype) => self.emit_newtype(newtype),
70 Declaration::Function(function) => self.emit_function(function),
71 Declaration::Constant(constant) => {
72 self.push("const ");
73 self.push(&constant.name);
74 self.push(": ");
75 self.emit_type(&constant.constant_type)?;
76 self.push(" = ");
77 self.emit_expression(&constant.value)?;
78 self.push(";\n");
79 Ok(())
80 }
81 Declaration::Extern(_) => Ok(()),
82 Declaration::Import(_) => Ok(()),
83 }
84 }
85
86 fn emit_newtype(&mut self, newtype: &Newtype) -> Result<(), String> {
87 match &newtype.inner_type {
88 Type::Tuple(fields) if !fields.is_empty() => {
89 writeln!(self.output, "struct {} {{", newtype.name).unwrap();
90 self.indent_level += 1;
91 for field in fields {
92 if let Type::Named(field_name) = field {
93 self.indent();
94 self.push(field_name);
95 self.push(": ");
96 if let Some(inner) = self.newtypes.get(field_name).cloned() {
97 self.emit_type(&inner)?;
98 } else {
99 self.push(field_name);
100 }
101 self.push(",\n");
102 }
103 }
104 self.indent_level -= 1;
105 self.push("}\n\n");
106 }
107 _ => {
108 write!(self.output, "alias {} = ", newtype.name).unwrap();
109 self.emit_type(&newtype.inner_type)?;
110 self.push(";\n");
111 }
112 }
113 Ok(())
114 }
115
116 fn emit_type(&mut self, ty: &Type) -> Result<(), String> {
117 match ty {
118 Type::Bool => self.push("bool"),
119 Type::Int(IntWidth::W32, Signedness::Signed) => self.push("i32"),
120 Type::Int(IntWidth::W32, Signedness::Unsigned) => self.push("u32"),
121 Type::Int(_, Signedness::Signed) => self.push("i32"),
122 Type::Int(_, Signedness::Unsigned) => self.push("u32"),
123 Type::Float(FloatWidth::W32) => self.push("f32"),
124 Type::Float(FloatWidth::W64) => {
125 return Err("WGSL does not support f64 (use f32 instead)".into());
126 }
127 Type::Array(element, length) => {
128 self.push("array<");
129 self.emit_type(element)?;
130 write!(self.output, ", {length}>").unwrap();
131 }
132 Type::Vector(element, count) => {
133 write!(self.output, "vec{count}<").unwrap();
134 self.emit_type(element)?;
135 self.push(">");
136 }
137 Type::Named(name) => {
138 if let Some(inner) = self.newtypes.get(name).cloned() {
139 if matches!(inner, Type::Tuple(ref fields) if !fields.is_empty()) {
140 self.push(name);
141 } else {
142 self.emit_type(&inner)?;
143 }
144 } else {
145 self.push(name);
146 }
147 }
148 Type::Never => self.push("void"),
149 Type::Pointer(..) => {
150 return Err("WGSL does not support pointers".into());
151 }
152 _ => {
153 return Err(format!("unsupported WGSL type: {ty:?}"));
154 }
155 }
156 Ok(())
157 }
158
159 fn emit_function(&mut self, function: &FunctionDefinition) -> Result<(), String> {
160 if function.name == "main" {
161 return Ok(());
162 }
163 self.push("fn ");
164 self.push(&function.name);
165 self.push("(");
166 for (index, parameter) in function.parameters.iter().enumerate() {
167 if index > 0 {
168 self.push(", ");
169 }
170 self.push(¶meter.name);
171 self.push(": ");
172 if let Some(ref ty) = parameter.parameter_type {
173 self.emit_type(ty)?;
174 }
175 }
176 self.push(")");
177 if !function.return_type.is_unit() {
178 self.push(" -> ");
179 self.emit_type(&function.return_type)?;
180 }
181 self.push(" {\n");
182 self.indent_level += 1;
183 self.emit_block(&function.body)?;
184 self.indent_level -= 1;
185 self.push("}\n\n");
186 Ok(())
187 }
188
189 fn emit_block(&mut self, block: &Block) -> Result<(), String> {
190 let mut index = 0;
191 while index < block.statements.len() {
192 if let Statement::Label {
193 name,
194 parameters,
195 initial_arguments,
196 } = &block.statements[index]
197 {
198 self.emit_loop(
199 name,
200 parameters,
201 initial_arguments,
202 &block.statements[index + 1..],
203 )?;
204 break;
205 }
206 self.emit_statement(&block.statements[index])?;
207 index += 1;
208 }
209 if let Some(result) = &block.result {
210 self.indent();
211 self.push("return ");
212 self.emit_expression(result)?;
213 self.push(";\n");
214 }
215 Ok(())
216 }
217
218 fn emit_loop(
219 &mut self,
220 label_name: &str,
221 parameters: &[t_ree::declaration::Parameter],
222 initial_arguments: &[Expression],
223 body_statements: &[Statement],
224 ) -> Result<(), String> {
225 for (parameter, init) in parameters.iter().zip(initial_arguments) {
226 self.indent();
227 self.push("var ");
228 self.push(¶meter.name);
229 if let Some(ref ty) = parameter.parameter_type {
230 self.push(": ");
231 self.emit_type(ty)?;
232 }
233 self.push(" = ");
234 self.emit_expression(init)?;
235 self.push(";\n");
236 }
237 let loop_end = body_statements
238 .iter()
239 .rposition(|s| Self::contains_jump(s, label_name))
240 .map_or(body_statements.len(), |i| i + 1);
241 self.indent();
242 self.push("loop {\n");
243 self.indent_level += 1;
244 for statement in &body_statements[..loop_end] {
245 self.emit_loop_statement(statement, label_name, parameters)?;
246 }
247 self.indent_level -= 1;
248 self.indent();
249 self.push("}\n");
250 for statement in &body_statements[loop_end..] {
251 self.emit_statement(statement)?;
252 }
253 Ok(())
254 }
255
256 fn contains_jump(statement: &Statement, label: &str) -> bool {
257 match statement {
258 Statement::Jump { label: target, .. } => target == label,
259 Statement::Expression(expression) => match &expression.kind {
260 ExpressionKind::If {
261 then_branch,
262 else_branch,
263 ..
264 } => {
265 then_branch
266 .statements
267 .iter()
268 .any(|s| Self::contains_jump(s, label))
269 || else_branch.as_ref().is_some_and(|b| {
270 b.statements.iter().any(|s| Self::contains_jump(s, label))
271 })
272 }
273 _ => false,
274 },
275 _ => false,
276 }
277 }
278
279 fn emit_loop_statement(
280 &mut self,
281 statement: &Statement,
282 label_name: &str,
283 parameters: &[t_ree::declaration::Parameter],
284 ) -> Result<(), String> {
285 match statement {
286 Statement::Jump { label, arguments } if label == label_name => {
287 if !arguments.is_empty() {
288 for (index, (_parameter, argument)) in
289 parameters.iter().zip(arguments).enumerate()
290 {
291 self.indent();
292 let temp = format!("_t{index}");
293 self.push("let ");
294 self.push(&temp);
295 self.push(" = ");
296 self.emit_expression(argument)?;
297 self.push(";\n");
298 }
299 for (index, parameter) in parameters.iter().enumerate() {
300 self.indent();
301 self.push(¶meter.name);
302 writeln!(self.output, " = _t{index};").unwrap();
303 }
304 }
305 self.indent();
306 self.push("continue;\n");
307 Ok(())
308 }
309 Statement::Expression(expression) => {
310 if let ExpressionKind::If {
311 condition,
312 then_branch,
313 else_branch,
314 } = &expression.kind
315 {
316 if Self::block_ends_with_jump(then_branch, label_name)
317 && else_branch.is_none()
318 && then_branch.result.is_none()
319 {
320 self.indent();
321 self.push("if !(");
322 self.emit_expression(condition)?;
323 self.push(") { break; }\n");
324 for inner in &then_branch.statements {
325 self.emit_loop_statement(inner, label_name, parameters)?;
326 }
327 return Ok(());
328 }
329 self.indent();
330 self.push("if (");
331 self.emit_expression(condition)?;
332 self.push(") {\n");
333 self.indent_level += 1;
334 for inner in &then_branch.statements {
335 self.emit_loop_statement(inner, label_name, parameters)?;
336 }
337 self.indent_level -= 1;
338 self.indent();
339 if let Some(else_block) = else_branch {
340 self.push("} else {\n");
341 self.indent_level += 1;
342 for inner in &else_block.statements {
343 self.emit_loop_statement(inner, label_name, parameters)?;
344 }
345 self.indent_level -= 1;
346 self.indent();
347 }
348 self.push("}\n");
349 return Ok(());
350 }
351 self.emit_statement(statement)
352 }
353 _ => self.emit_statement(statement),
354 }
355 }
356
357 fn block_ends_with_jump(block: &Block, label_name: &str) -> bool {
358 block
359 .statements
360 .last()
361 .is_some_and(|s| matches!(s, Statement::Jump { label, .. } if label == label_name))
362 }
363
364 fn emit_statement(&mut self, statement: &Statement) -> Result<(), String> {
365 self.indent();
366 match statement {
367 Statement::Expression(expression) => {
368 self.emit_expression(expression)?;
369 self.push(";\n");
370 }
371 Statement::Let {
372 name,
373 binding,
374 value,
375 ..
376 } => {
377 match binding {
378 Binding::Value | Binding::Reference => self.push("let "),
379 Binding::Variable => self.push("var "),
380 }
381 self.push(name);
382 if let Some(ref ty) = value.resolved_type {
383 self.push(": ");
384 self.emit_type(ty)?;
385 }
386 self.push(" = ");
387 self.emit_expression(value)?;
388 self.push(";\n");
389 }
390 Statement::Assign(target, value) => {
391 self.emit_expression(target)?;
392 self.push(" = ");
393 self.emit_expression(value)?;
394 self.push(";\n");
395 }
396 Statement::Return(Some(value)) => {
397 self.push("return ");
398 self.emit_expression(value)?;
399 self.push(";\n");
400 }
401 Statement::Return(None) => {
402 self.push("return;\n");
403 }
404 Statement::Label { .. } => {}
405 Statement::Jump { label, .. } => {
406 self.push("break; // jump ");
407 self.push(label);
408 self.push("\n");
409 }
410 Statement::MultiReplace {
411 targets, values, ..
412 } => {
413 for (index, (target, value)) in targets.iter().zip(values).enumerate() {
414 if index > 0 {
415 self.indent();
416 }
417 self.emit_expression(target)?;
418 self.push(" = ");
419 self.emit_expression(value)?;
420 self.push(";\n");
421 }
422 }
423 Statement::Defer(_) => {
424 return Err("defer not supported in WGSL".into());
425 }
426 }
427 Ok(())
428 }
429
430 fn emit_expression(&mut self, expression: &Expression) -> Result<(), String> {
431 match &expression.kind {
432 ExpressionKind::Literal(literal) => self.emit_literal(literal),
433 ExpressionKind::Variable(name) => {
434 self.push(name);
435 Ok(())
436 }
437 ExpressionKind::BinaryOperation(operator, left, right) => {
438 self.push("(");
439 self.emit_expression(left)?;
440 self.push(" ");
441 self.emit_binary_operator(operator);
442 self.push(" ");
443 self.emit_expression(right)?;
444 self.push(")");
445 Ok(())
446 }
447 ExpressionKind::UnaryOperation(operator, operand) => {
448 match operator {
449 UnaryOperator::Negate => self.push("-("),
450 UnaryOperator::LogicalNot => self.push("!("),
451 UnaryOperator::BitwiseNot => self.push("~("),
452 }
453 self.emit_expression(operand)?;
454 self.push(")");
455 Ok(())
456 }
457 ExpressionKind::Call(callee, arguments) => {
458 self.emit_expression(callee)?;
459 self.push("(");
460 for (index, argument) in arguments.iter().enumerate() {
461 if index > 0 {
462 self.push(", ");
463 }
464 self.emit_expression(argument)?;
465 }
466 self.push(")");
467 Ok(())
468 }
469 ExpressionKind::Field(object, field) => {
470 self.emit_expression(object)?;
471 self.push(".");
472 self.push(field);
473 Ok(())
474 }
475 ExpressionKind::Index(array, index) => {
476 self.emit_expression(array)?;
477 self.push("[");
478 self.emit_expression(index)?;
479 self.push("]");
480 Ok(())
481 }
482 ExpressionKind::If {
483 condition,
484 then_branch,
485 else_branch,
486 } => {
487 self.push("if (");
488 self.emit_expression(condition)?;
489 self.push(") {\n");
490 self.indent_level += 1;
491 self.emit_block(then_branch)?;
492 self.indent_level -= 1;
493 self.indent();
494 if let Some(else_block) = else_branch {
495 self.push("} else {\n");
496 self.indent_level += 1;
497 self.emit_block(else_block)?;
498 self.indent_level -= 1;
499 self.indent();
500 }
501 self.push("}");
502 Ok(())
503 }
504 ExpressionKind::Convert(operand, target_type)
505 | ExpressionKind::Transmute(operand, target_type) => {
506 self.emit_type(target_type)?;
507 self.push("(");
508 self.emit_expression(operand)?;
509 self.push(")");
510 Ok(())
511 }
512 ExpressionKind::TypeConstruction(name, fields) => {
513 self.push(name);
514 self.push("(");
515 for (index, (_, value)) in fields.iter().enumerate() {
516 if index > 0 {
517 self.push(", ");
518 }
519 self.emit_expression(value)?;
520 }
521 self.push(")");
522 Ok(())
523 }
524 ExpressionKind::ArrayLiteral(elements) => {
525 if let Some(ref ty) = expression.resolved_type {
526 self.emit_type(ty)?;
527 }
528 self.push("(");
529 for (index, element) in elements.iter().enumerate() {
530 if index > 0 {
531 self.push(", ");
532 }
533 self.emit_expression(element)?;
534 }
535 self.push(")");
536 Ok(())
537 }
538 ExpressionKind::Block(block) => {
539 self.push("{\n");
540 self.indent_level += 1;
541 self.emit_block(block)?;
542 self.indent_level -= 1;
543 self.indent();
544 self.push("}");
545 Ok(())
546 }
547 ExpressionKind::Replace(target, value) => {
548 self.emit_expression(target)?;
549 self.push(" = ");
550 self.emit_expression(value)?;
551 Ok(())
552 }
553 ExpressionKind::OpAssign(operator, target, value) => {
554 self.emit_expression(target)?;
555 let symbol = match operator {
556 ArithmeticOperator::Add => " += ",
557 ArithmeticOperator::Subtract => " -= ",
558 ArithmeticOperator::Multiply => " *= ",
559 ArithmeticOperator::Divide => " /= ",
560 ArithmeticOperator::Remainder => " %= ",
561 };
562 self.push(symbol);
563 self.emit_expression(value)?;
564 Ok(())
565 }
566 ExpressionKind::Dereference(inner) => self.emit_expression(inner),
567 _ => Err(format!(
568 "unsupported WGSL expression: {:?}",
569 expression.kind
570 )),
571 }
572 }
573
574 fn emit_literal(&mut self, literal: &Literal) -> Result<(), String> {
575 match literal {
576 Literal::Integer(value) => write!(self.output, "{value}").unwrap(),
577 Literal::Float(value) => {
578 let string = format!("{value}");
579 if string.contains('.') {
580 self.push(&string);
581 } else {
582 write!(self.output, "{value}.0").unwrap();
583 }
584 }
585 Literal::Bool(value) => write!(self.output, "{value}").unwrap(),
586 Literal::String(_) => return Err("strings not supported in WGSL".into()),
587 Literal::Null => return Err("null pointers not supported in WGSL".into()),
588 }
589 Ok(())
590 }
591
592 fn emit_binary_operator(&mut self, operator: &BinaryOperator) {
593 let symbol = match operator {
594 BinaryOperator::Arithmetic(ArithmeticOperator::Add) => "+",
595 BinaryOperator::Arithmetic(ArithmeticOperator::Subtract) => "-",
596 BinaryOperator::Arithmetic(ArithmeticOperator::Multiply) => "*",
597 BinaryOperator::Arithmetic(ArithmeticOperator::Divide) => "/",
598 BinaryOperator::Arithmetic(ArithmeticOperator::Remainder) => "%",
599 BinaryOperator::Comparison(ComparisonOperator::Equal) => "==",
600 BinaryOperator::Comparison(ComparisonOperator::NotEqual) => "!=",
601 BinaryOperator::Comparison(ComparisonOperator::Less) => "<",
602 BinaryOperator::Comparison(ComparisonOperator::LessEqual) => "<=",
603 BinaryOperator::Comparison(ComparisonOperator::Greater) => ">",
604 BinaryOperator::Comparison(ComparisonOperator::GreaterEqual) => ">=",
605 BinaryOperator::Logical(LogicalOperator::And) => "&&",
606 BinaryOperator::Logical(LogicalOperator::Or) => "||",
607 BinaryOperator::Bitwise(BitwiseOperator::And) => "&",
608 BinaryOperator::Bitwise(BitwiseOperator::Or) => "|",
609 BinaryOperator::Bitwise(BitwiseOperator::Xor) => "^",
610 BinaryOperator::Bitwise(BitwiseOperator::ShiftLeft) => "<<",
611 BinaryOperator::Bitwise(BitwiseOperator::ShiftRight) => ">>",
612 };
613 self.push(symbol);
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620
621 fn compile_c_to_wgsl(source: &str) -> Result<String, String> {
622 let mut module = t_parser_c::parse(source).map_err(|e| e.to_string())?;
623 let errors = t_ree::resolve::resolve_module(&mut module, true);
624 if !errors.is_empty() {
625 return Err(errors.join("\n"));
626 }
627 compile(&module)
628 }
629
630 fn validate_wgsl(source: &str) -> Result<(), String> {
631 let wgsl = compile_c_to_wgsl(source)?;
632 naga::front::wgsl::parse_str(&wgsl).map_err(|error| {
633 format!("WGSL validation failed:\n{error}\n\nGenerated WGSL:\n{wgsl}")
634 })?;
635 Ok(())
636 }
637
638 #[test]
639 fn simple_function() {
640 validate_wgsl("type a = f32; type b = f32; fn add(a, b) -> f32 { (f32: a) + (f32: b) }")
641 .unwrap();
642 }
643
644 #[test]
645 fn newtype_alias() {
646 validate_wgsl("type meters = f32; fn test(meters) -> f32 { (f32: meters) }").unwrap();
647 }
648
649 #[test]
650 fn struct_type() {
651 validate_wgsl(
652 "type x = f32; type y = f32; type z = f32; type point3 = x & y & z;\n\
653 fn test() -> f32 { let v = point3 { x: 1.0, y: 2.0, z: 3.0 }; v.x }",
654 )
655 .unwrap();
656 }
657
658 #[test]
659 fn constant() {
660 validate_wgsl("pub const PI: f32 = 3.14159; fn test() -> f32 { PI }").unwrap();
661 }
662
663 #[test]
664 fn if_else() {
665 validate_wgsl(
666 "fn test(x: f32) -> f32 { if x > 0.0 { return x; } else { return 0.0 - x; } 0.0 }",
667 )
668 .unwrap();
669 }
670
671 #[test]
672 fn variables() {
673 validate_wgsl(
674 "type a = f32; type b = f32;\n\
675 fn test(a, b) -> f32 {\n\
676 var result: f32 = (f32: a);\n\
677 result := result + (f32: b);\n\
678 return result;\n\
679 }",
680 )
681 .unwrap();
682 }
683
684 #[test]
685 fn array_type() {
686 validate_wgsl(
687 "fn sum(arr: [f32]3) -> f32 {\n\
688 return arr[0] + arr[1] + arr[2];\n\
689 }",
690 )
691 .unwrap();
692 }
693
694 #[test]
695 fn arithmetic() {
696 validate_wgsl("type a = f32; type b = f32; fn test(a, b) -> f32 { ((f32: a) + (f32: b)) * ((f32: a) - (f32: b)) / ((f32: a) + 1.0) }").unwrap();
697 }
698
699 #[test]
700 fn comparison() {
701 validate_wgsl("type a = f32; type b = f32; fn test(a, b) -> bool { (f32: a) < (f32: b) }")
702 .unwrap();
703 }
704
705 #[test]
706 fn integer_ops() {
707 validate_wgsl(
708 "type a = i32; type b = i32; fn test(a, b) -> i32 { ((i32: a) + (i32: b)) * 2 }",
709 )
710 .unwrap();
711 }
712
713 #[test]
714 fn sdf_functions() {
715 validate_wgsl(
716 "extern fn sqrt(x: f32) -> f32;\n\
717 extern fn abs(x: f32) -> f32;\n\
718 extern fn min(a: f32, b: f32) -> f32;\n\
719 extern fn max(a: f32, b: f32) -> f32;\n\
720 extern fn clamp(x: f32, lo: f32, hi: f32) -> f32;\n\
721 extern fn mix(a: f32, b: f32, t: f32) -> f32;\n\
722 type px = f32; type py = f32; type pz = f32;\n\
723 type point = px & py & pz;\n\
724 type radius = f32;\n\
725 fn sphere_sdf(point, radius) -> f32 {\n\
726 let d = sqrt(point.px * point.px + point.py * point.py + point.pz * point.pz);\n\
727 return d - (f32: radius);\n\
728 }\n\
729 type distance_a = f32; type distance_b = f32; type smoothness = f32;\n\
730 fn smooth_union(distance_a, distance_b, smoothness) -> f32 {\n\
731 let h = clamp(0.5 + 0.5 * ((f32: distance_b) - (f32: distance_a)) / (f32: smoothness), 0.0, 1.0);\n\
732 return mix((f32: distance_b), (f32: distance_a), h) - (f32: smoothness) * h * (1.0 - h);\n\
733 }",
734 )
735 .unwrap();
736 }
737
738 #[test]
739 fn label_jump_loop() {
740 validate_wgsl(
741 "fn sum_to(n: i32) -> i32 {\n\
742 label loop(i: i32 = 0, total: i32 = 0);\n\
743 if i < n {\n\
744 jump loop(i + 1, total + i);\n\
745 }\n\
746 return total;\n\
747 }",
748 )
749 .unwrap();
750 }
751
752 #[test]
753 fn factorial_loop() {
754 validate_wgsl(
755 "fn factorial(n: i32) -> i32 {\n\
756 label loop(i: i32 = 1, result: i32 = 1);\n\
757 if i <= n {\n\
758 jump loop(i + 1, result * i);\n\
759 }\n\
760 return result;\n\
761 }",
762 )
763 .unwrap();
764 }
765
766 #[test]
767 fn f64_rejected() {
768 assert!(compile_c_to_wgsl("fn test(x: f64) -> f64 { x }").is_err());
769 }
770
771 #[test]
772 fn pointer_rejected() {
773 assert!(compile_c_to_wgsl("fn test(x: |i32) -> i32 { x }").is_err());
774 }
775}