wgsl_includes/
lib.rs

1use std::path::{Path, PathBuf};
2
3use naga::{
4    front::wgsl,
5    valid::{Capabilities, ValidationFlags, Validator},
6};
7use proc_macro::TokenStream;
8use syn::{parse_macro_input, LitStr};
9
10fn resolve_shader_path<T: Into<PathBuf>>(shader_path: T) -> PathBuf {
11    let dir = std::env::current_dir().unwrap();
12    dir.join(shader_path.into())
13}
14
15fn parse_shader_includes_recursive(name: &str, includes: &mut Vec<String>) -> String {
16    let file_path = Path::new(&resolve_shader_path(name))
17        .to_str()
18        .unwrap()
19        .to_owned();
20
21    if includes.contains(&file_path) {
22        return String::new();
23    }
24    includes.push(file_path.clone());
25
26    let mut contents = std::fs::read_to_string(&file_path).unwrap_or_else(|_| {
27        panic!(
28            "Failed to include shader \"{}\".",
29            file_path.replace("\\", "/")
30        )
31    });
32
33    let mut include_indices: Vec<usize> = contents.match_indices("@include").map(|i| i.0).collect();
34    include_indices.reverse();
35    for include_index in include_indices {
36        let end_of_line = contents[include_index..].find('\n').unwrap() + include_index - 1;
37        let include_name = contents[(include_index + 9)..end_of_line].to_owned();
38
39        for i in (include_index..end_of_line).rev() {
40            contents.remove(i);
41        }
42        contents.insert_str(
43            include_index,
44            &parse_shader_includes_recursive(&include_name, includes),
45        );
46    }
47
48    contents
49}
50
51fn parse_shader_includes(mut contents: String) -> String {
52    let mut includes = vec![];
53
54    let mut include_indices: Vec<usize> = contents.match_indices("@include").map(|i| i.0).collect();
55    include_indices.reverse();
56    for include_index in include_indices {
57        let end_of_line = contents[include_index..].find('\n').unwrap() + include_index - 1;
58        let include_name = contents[(include_index + 9)..end_of_line].to_owned();
59
60        for i in (include_index..end_of_line).rev() {
61            contents.remove(i);
62        }
63        contents.insert_str(
64            include_index,
65            &parse_shader_includes_recursive(&include_name, &mut includes),
66        );
67    }
68
69    contents.replace("::", "_")
70}
71
72/// Include shader source code & validate at compile time. Path must be relative to the crate root.
73///
74/// # Example
75///
76/// ```
77/// let shader_str = include_wgsl!("src/shader.wgsl");
78/// device.create_shader_module(&ShaderModuleDescriptor {
79///     source: ShaderSource::Wgsl(Cow::Borrowed(&shader_str)),
80///     flags: ShaderFlags::default(),
81///     label: None,
82/// })
83/// ```
84#[proc_macro]
85pub fn include_wgsl(input: TokenStream) -> TokenStream {
86    let input = parse_macro_input!(input as LitStr);
87    let file_path = input.value();
88
89    let resolved_file_path = resolve_shader_path(&file_path);
90
91    match std::fs::read_to_string(resolved_file_path.clone()) {
92        Ok(wgsl_str) => {
93            // Resolve shader includes
94            let wgsl_str = parse_shader_includes(wgsl_str);
95
96            // Attempt to parse WGSL
97            match wgsl::parse_str(&wgsl_str) {
98                Ok(module) => {
99                    // Attempt to validate WGSL
100                    match Validator::new(ValidationFlags::all(), Capabilities::all())
101                        .validate(&module)
102                    {
103                        Ok(_) => {}
104                        Err(e) => {
105                            return syn::Error::new(input.span(), format!("{}: {}", file_path, e))
106                                .to_compile_error()
107                                .into();
108                        }
109                    }
110                }
111                Err(e) => {
112                    return syn::Error::new(input.span(), format!("{}: {}", file_path, e))
113                        .to_compile_error()
114                        .into();
115                }
116            }
117
118            let resolved_file_path = resolved_file_path.to_str().unwrap().replace("\\", "/");
119
120            format!(
121                "{{ std::hint::black_box(include_str!(\"{}\")); r#\"{}\"# }}",
122                resolved_file_path, wgsl_str
123            )
124            .parse()
125            .unwrap()
126        }
127        Err(e) => syn::Error::new(input.span(), format!("{}: {}", file_path, e))
128            .to_compile_error()
129            .into(),
130    }
131}