1use std::path::{Path, PathBuf};
2
3use naga::{
4 front::wgsl,
5 valid::{Capabilities, ValidationFlags, Validator},
6};
7use proc_macro::TokenStream;
8use syn::{parse_macro_input, LitStr};
9
10fn resolve_shader_path<T: Into<PathBuf>>(shader_path: T) -> PathBuf {
11 let dir = std::env::current_dir().unwrap();
12 dir.join(shader_path.into())
13}
14
15fn parse_shader_includes_recursive(name: &str, includes: &mut Vec<String>) -> String {
16 let file_path = Path::new(&resolve_shader_path(name))
17 .to_str()
18 .unwrap()
19 .to_owned();
20
21 if includes.contains(&file_path) {
22 return String::new();
23 }
24 includes.push(file_path.clone());
25
26 let mut contents = std::fs::read_to_string(&file_path).unwrap_or_else(|_| {
27 panic!(
28 "Failed to include shader \"{}\".",
29 file_path.replace("\\", "/")
30 )
31 });
32
33 let mut include_indices: Vec<usize> = contents.match_indices("@include").map(|i| i.0).collect();
34 include_indices.reverse();
35 for include_index in include_indices {
36 let end_of_line = contents[include_index..].find('\n').unwrap() + include_index - 1;
37 let include_name = contents[(include_index + 9)..end_of_line].to_owned();
38
39 for i in (include_index..end_of_line).rev() {
40 contents.remove(i);
41 }
42 contents.insert_str(
43 include_index,
44 &parse_shader_includes_recursive(&include_name, includes),
45 );
46 }
47
48 contents
49}
50
51fn parse_shader_includes(mut contents: String) -> String {
52 let mut includes = vec![];
53
54 let mut include_indices: Vec<usize> = contents.match_indices("@include").map(|i| i.0).collect();
55 include_indices.reverse();
56 for include_index in include_indices {
57 let end_of_line = contents[include_index..].find('\n').unwrap() + include_index - 1;
58 let include_name = contents[(include_index + 9)..end_of_line].to_owned();
59
60 for i in (include_index..end_of_line).rev() {
61 contents.remove(i);
62 }
63 contents.insert_str(
64 include_index,
65 &parse_shader_includes_recursive(&include_name, &mut includes),
66 );
67 }
68
69 contents.replace("::", "_")
70}
71
72#[proc_macro]
85pub fn include_wgsl(input: TokenStream) -> TokenStream {
86 let input = parse_macro_input!(input as LitStr);
87 let file_path = input.value();
88
89 let resolved_file_path = resolve_shader_path(&file_path);
90
91 match std::fs::read_to_string(resolved_file_path.clone()) {
92 Ok(wgsl_str) => {
93 let wgsl_str = parse_shader_includes(wgsl_str);
95
96 match wgsl::parse_str(&wgsl_str) {
98 Ok(module) => {
99 match Validator::new(ValidationFlags::all(), Capabilities::all())
101 .validate(&module)
102 {
103 Ok(_) => {}
104 Err(e) => {
105 return syn::Error::new(input.span(), format!("{}: {}", file_path, e))
106 .to_compile_error()
107 .into();
108 }
109 }
110 }
111 Err(e) => {
112 return syn::Error::new(input.span(), format!("{}: {}", file_path, e))
113 .to_compile_error()
114 .into();
115 }
116 }
117
118 let resolved_file_path = resolved_file_path.to_str().unwrap().replace("\\", "/");
119
120 format!(
121 "{{ std::hint::black_box(include_str!(\"{}\")); r#\"{}\"# }}",
122 resolved_file_path, wgsl_str
123 )
124 .parse()
125 .unwrap()
126 }
127 Err(e) => syn::Error::new(input.span(), format!("{}: {}", file_path, e))
128 .to_compile_error()
129 .into(),
130 }
131}