wgpu_pp/
lib.rs

1#![feature(proc_macro_span)]
2
3extern crate proc_macro;
4
5mod preprocessor;
6
7use naga::front::wgsl::Frontend;
8use naga::valid::{Capabilities, ValidationFlags, Validator};
9
10use litrs::Literal;
11use preprocessor::{preprocess, PreprocessorError};
12use proc_macro::{Span, TokenStream, TokenTree};
13
14fn validate_wgsl(wgsl_source: &str) -> Result<(), TokenStream> {
15    let mut frontend = Frontend::new();
16    let module = frontend.parse(wgsl_source).map_err(|e| {
17        let msg = format!("failed to parse WGSL: {}", e.emit_to_string(wgsl_source));
18        format!("compile_error!(\"{}\")", msg)
19            .parse::<TokenStream>()
20            .unwrap()
21    })?;
22
23    let mut validator = Validator::new(ValidationFlags::all(), Capabilities::default());
24    validator.validate(&module).map_err(|e| {
25        let msg = format!("failed to validate WGSL: {}", e.emit_to_string(wgsl_source));
26        format!("compile_error!(\"{}\")", msg)
27            .parse::<TokenStream>()
28            .unwrap()
29    })?;
30
31    Ok(())
32}
33
34#[proc_macro]
35pub fn include_wgsl(input: TokenStream) -> TokenStream {
36    let input = input.into_iter().collect::<Vec<_>>();
37    if input.len() != 1 {
38        let msg = format!("expected exactly one input token, got {}", input.len());
39        return format!("compile_error!(\"{}\")", msg).parse().unwrap();
40    }
41
42    let call_site = Span::call_site();
43    let source_path = match call_site.local_file() {
44        Some(p) => p,
45        None => {
46            // This happens in the Rust Analyzer, just let it go...
47            return "\"\"".parse().unwrap();
48        }
49    };
50    let basepath = match source_path.parent() {
51        Some(p) => p,
52        _ => {
53            // This happens in the Rust Analyzer, just let it go...
54            return "\"\"".parse().unwrap();
55        }
56    };
57
58    let filename = match Literal::try_from(&input[0]) {
59        Ok(Literal::String(str)) => str.value().to_string(),
60        // Error if the token is not a string literal
61        Err(e) => return e.to_compile_error(),
62        _ => {
63            return "compile_error!(\"expected a string literal\")"
64                .to_string()
65                .parse()
66                .unwrap();
67        }
68    };
69
70    let shader = match preprocess(&filename, basepath) {
71        Ok(value) => value,
72        Err(e) => match e {
73            PreprocessorError::FileNotFound(filename) => {
74                let msg = format!(
75                    "file not found: {}",
76                    basepath.join(filename).to_string_lossy()
77                );
78                return format!("compile_error!(\"{}\")", msg).parse().unwrap();
79            }
80            PreprocessorError::FileNotValidUtf8(filename) => {
81                let msg = format!("file not valid utf-8: {}", filename);
82                return format!("compile_error!(\"{}\")", msg).parse().unwrap();
83            }
84            PreprocessorError::UnknownDirective(directive) => {
85                let msg = format!("unknown directive: {}", directive);
86                return format!("compile_error!(\"{}\")", msg).parse().unwrap();
87            }
88            PreprocessorError::IncludeIncorrectArgs => {
89                return "compile_error!(\"incorrect arguments to #include\")"
90                    .to_string()
91                    .parse()
92                    .unwrap();
93            }
94            PreprocessorError::MacroNoParenthesis => {
95                return "compile_error!(\"macro must have parenthesis\")"
96                    .to_string()
97                    .parse()
98                    .unwrap();
99            }
100            PreprocessorError::MacroIncorrectArgs(expected, got) => {
101                let msg = format!("macro expected {} arguments, got {}", expected, got);
102                return format!("compile_error!(\"{}\")", msg).parse().unwrap();
103            }
104        },
105    };
106
107    match validate_wgsl(&shader) {
108        Ok(_) => {}
109        Err(e) => return e,
110    }
111
112    TokenTree::Literal(proc_macro::Literal::string(&shader)).into()
113}