wit_bindgen_rust_macro/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::ToTokens;
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
6use syn::parse::{Error, Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::{LitStr, Token, braced, token};
10use wit_bindgen_core::AsyncFilterSet;
11use wit_bindgen_core::WorldGenerator;
12use wit_bindgen_core::wit_parser::{PackageId, Resolve, UnresolvedPackageGroup, WorldId};
13use wit_bindgen_rust::{Opts, Ownership, WithOption};
14
15#[proc_macro]
16pub fn generate(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
17    syn::parse_macro_input!(input as Config)
18        .expand()
19        .unwrap_or_else(Error::into_compile_error)
20        .into()
21}
22
23fn anyhow_to_syn(span: Span, err: anyhow::Error) -> Error {
24    let err = attach_with_context(err);
25    let mut msg = err.to_string();
26    for cause in err.chain().skip(1) {
27        msg.push_str(&format!("\n\nCaused by:\n  {cause}"));
28    }
29    Error::new(span, msg)
30}
31
32fn attach_with_context(err: anyhow::Error) -> anyhow::Error {
33    if let Some(e) = err.downcast_ref::<wit_bindgen_rust::MissingWith>() {
34        let option = e.0.clone();
35        return err.context(format!(
36            "missing one of:\n\
37            * `generate_all` option\n\
38            * `with: {{ \"{option}\": path::to::bindings, }}`\n\
39            * `with: {{ \"{option}\": generate, }}`\
40            "
41        ));
42    }
43    err
44}
45
46struct Config {
47    opts: Opts,
48    resolve: Resolve,
49    world: WorldId,
50    files: Vec<PathBuf>,
51    debug: bool,
52}
53
54/// The source of the wit package definition
55enum Source {
56    /// A list of paths to wit directories
57    Paths(Vec<PathBuf>),
58    /// Inline sources have an optional path to a directory of their dependencies
59    Inline(String, Option<Vec<PathBuf>>),
60}
61
62impl Parse for Config {
63    fn parse(input: ParseStream<'_>) -> Result<Self> {
64        let call_site = Span::call_site();
65        let mut opts = Opts::default();
66        let mut world = None;
67        let mut source = None;
68        let mut features = Vec::new();
69        let mut async_configured = false;
70        let mut debug = false;
71
72        if input.peek(token::Brace) {
73            let content;
74            syn::braced!(content in input);
75            let fields = Punctuated::<Opt, Token![,]>::parse_terminated(&content)?;
76            for field in fields.into_pairs() {
77                match field.into_value() {
78                    Opt::Path(span, p) => {
79                        let paths = p.into_iter().map(|f| PathBuf::from(f.value())).collect();
80
81                        source = Some(match source {
82                            Some(Source::Paths(_)) | Some(Source::Inline(_, Some(_))) => {
83                                return Err(Error::new(span, "cannot specify second source"));
84                            }
85                            Some(Source::Inline(i, None)) => Source::Inline(i, Some(paths)),
86                            None => Source::Paths(paths),
87                        })
88                    }
89                    Opt::World(s) => {
90                        if world.is_some() {
91                            return Err(Error::new(s.span(), "cannot specify second world"));
92                        }
93                        world = Some(s.value());
94                    }
95                    Opt::Inline(s) => {
96                        source = Some(match source {
97                            Some(Source::Inline(_, _)) => {
98                                return Err(Error::new(s.span(), "cannot specify second source"));
99                            }
100                            Some(Source::Paths(p)) => Source::Inline(s.value(), Some(p)),
101                            None => Source::Inline(s.value(), None),
102                        })
103                    }
104                    Opt::UseStdFeature => opts.std_feature = true,
105                    Opt::RawStrings => opts.raw_strings = true,
106                    Opt::Ownership(ownership) => opts.ownership = ownership,
107                    Opt::Skip(list) => opts.skip.extend(list.iter().map(|i| i.value())),
108                    Opt::RuntimePath(path) => opts.runtime_path = Some(path.value()),
109                    Opt::BitflagsPath(path) => opts.bitflags_path = Some(path.value()),
110                    Opt::Stubs => {
111                        opts.stubs = true;
112                    }
113                    Opt::ExportPrefix(prefix) => opts.export_prefix = Some(prefix.value()),
114                    Opt::AdditionalDerives(paths) => {
115                        opts.additional_derive_attributes = paths
116                            .into_iter()
117                            .map(|p| p.into_token_stream().to_string())
118                            .collect()
119                    }
120                    Opt::AdditionalDerivesIgnore(list) => {
121                        opts.additional_derive_ignore =
122                            list.into_iter().map(|i| i.value()).collect()
123                    }
124                    Opt::With(with) => opts.with.extend(with),
125                    Opt::GenerateAll => {
126                        opts.generate_all = true;
127                    }
128                    Opt::TypeSectionSuffix(suffix) => {
129                        opts.type_section_suffix = Some(suffix.value());
130                    }
131                    Opt::DisableRunCtorsOnceWorkaround(enable) => {
132                        opts.disable_run_ctors_once_workaround = enable.value();
133                    }
134                    Opt::DefaultBindingsModule(enable) => {
135                        opts.default_bindings_module = Some(enable.value());
136                    }
137                    Opt::ExportMacroName(name) => {
138                        opts.export_macro_name = Some(name.value());
139                    }
140                    Opt::PubExportMacro(enable) => {
141                        opts.pub_export_macro = enable.value();
142                    }
143                    Opt::GenerateUnusedTypes(enable) => {
144                        opts.generate_unused_types = enable.value();
145                    }
146                    Opt::Features(f) => {
147                        features.extend(f.into_iter().map(|f| f.value()));
148                    }
149                    Opt::DisableCustomSectionLinkHelpers(disable) => {
150                        opts.disable_custom_section_link_helpers = disable.value();
151                    }
152                    Opt::Debug(enable) => {
153                        debug = enable.value();
154                    }
155                    Opt::Async(val, span) => {
156                        if async_configured {
157                            return Err(Error::new(span, "cannot specify second async config"));
158                        }
159                        async_configured = true;
160                        if val.any_enabled() && !cfg!(feature = "async") {
161                            return Err(Error::new(
162                                span,
163                                "must enable `async` feature to enable async imports and/or exports",
164                            ));
165                        }
166                        opts.async_ = val;
167                    }
168                }
169            }
170        } else {
171            world = input.parse::<Option<syn::LitStr>>()?.map(|s| s.value());
172            if input.parse::<Option<syn::token::In>>()?.is_some() {
173                source = Some(Source::Paths(vec![PathBuf::from(
174                    input.parse::<syn::LitStr>()?.value(),
175                )]));
176            }
177        }
178        let (resolve, main_packages, files) =
179            parse_source(&source, &features).map_err(|err| anyhow_to_syn(call_site, err))?;
180        let world = resolve
181            .select_world(&main_packages, world.as_deref())
182            .map_err(|e| anyhow_to_syn(call_site, e))?;
183        Ok(Config {
184            opts,
185            resolve,
186            world,
187            files,
188            debug,
189        })
190    }
191}
192
193/// Parse the source
194fn parse_source(
195    source: &Option<Source>,
196    features: &[String],
197) -> anyhow::Result<(Resolve, Vec<PackageId>, Vec<PathBuf>)> {
198    let mut resolve = Resolve::default();
199    resolve.features.extend(features.iter().cloned());
200    let mut files = Vec::new();
201    let mut pkgs = Vec::new();
202    let root = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap());
203    let mut parse = |paths: &[PathBuf]| -> anyhow::Result<()> {
204        for path in paths {
205            let p = root.join(path);
206            // Try to normalize the path to make the error message more understandable when
207            // the path is not correct. Fallback to the original path if normalization fails
208            // (probably return an error somewhere else).
209            let normalized_path = match std::fs::canonicalize(&p) {
210                Ok(p) => p,
211                Err(_) => p.to_path_buf(),
212            };
213            let (pkg, sources) = resolve.push_path(normalized_path)?;
214            pkgs.push(pkg);
215            files.extend(sources.paths().map(|p| p.to_owned()));
216        }
217        Ok(())
218    };
219    let default = root.join("wit");
220    match source {
221        Some(Source::Inline(s, path)) => {
222            match path {
223                Some(p) => parse(p)?,
224                // If no `path` is explicitly specified still parse the default
225                // `wit` directory if it exists. Don't require its existence,
226                // however, as `inline` can be used in lieu of a folder. Test
227                // whether it exists and only if there is it parsed.
228                None => {
229                    if default.exists() {
230                        parse(&[default])?;
231                    }
232                }
233            }
234            pkgs.truncate(0);
235            pkgs.push(resolve.push_group(UnresolvedPackageGroup::parse("macro-input", s)?)?);
236        }
237        Some(Source::Paths(p)) => parse(p)?,
238        None => parse(&[default])?,
239    };
240
241    Ok((resolve, pkgs, files))
242}
243
244impl Config {
245    fn expand(self) -> Result<TokenStream> {
246        let mut files = Default::default();
247        let mut generator = self.opts.build();
248        generator
249            .generate(&self.resolve, self.world, &mut files)
250            .map_err(|e| anyhow_to_syn(Span::call_site(), e))?;
251        let (_, src) = files.iter().next().unwrap();
252        let mut src = std::str::from_utf8(src).unwrap().to_string();
253
254        // If a magical `WIT_BINDGEN_DEBUG` environment variable is set then
255        // place a formatted version of the expanded code into a file. This file
256        // will then show up in rustc error messages for any codegen issues and can
257        // be inspected manually.
258        if std::env::var("WIT_BINDGEN_DEBUG").is_ok() || self.debug {
259            static INVOCATION: AtomicUsize = AtomicUsize::new(0);
260            let root = Path::new(env!("DEBUG_OUTPUT_DIR"));
261            let world_name = &self.resolve.worlds[self.world].name;
262            let n = INVOCATION.fetch_add(1, Relaxed);
263            let path = root.join(format!("{world_name}{n}.rs"));
264
265            // optimistically format the code but don't require success
266            let contents = match fmt(&src) {
267                Ok(formatted) => formatted,
268                Err(_) => src.clone(),
269            };
270            std::fs::write(&path, contents.as_bytes()).unwrap();
271
272            src = format!("include!({path:?});");
273        }
274        let mut contents = src.parse::<TokenStream>().unwrap();
275
276        // Include a dummy `include_bytes!` for any files we read so rustc knows that
277        // we depend on the contents of those files.
278        for file in self.files.iter() {
279            contents.extend(
280                format!(
281                    "const _: &[u8] = include_bytes!(r#\"{}\"#);\n",
282                    file.display()
283                )
284                .parse::<TokenStream>()
285                .unwrap(),
286            );
287        }
288
289        Ok(contents)
290    }
291}
292
293mod kw {
294    syn::custom_keyword!(std_feature);
295    syn::custom_keyword!(raw_strings);
296    syn::custom_keyword!(skip);
297    syn::custom_keyword!(world);
298    syn::custom_keyword!(path);
299    syn::custom_keyword!(inline);
300    syn::custom_keyword!(ownership);
301    syn::custom_keyword!(runtime_path);
302    syn::custom_keyword!(bitflags_path);
303    syn::custom_keyword!(exports);
304    syn::custom_keyword!(stubs);
305    syn::custom_keyword!(export_prefix);
306    syn::custom_keyword!(additional_derives);
307    syn::custom_keyword!(additional_derives_ignore);
308    syn::custom_keyword!(with);
309    syn::custom_keyword!(generate_all);
310    syn::custom_keyword!(type_section_suffix);
311    syn::custom_keyword!(disable_run_ctors_once_workaround);
312    syn::custom_keyword!(default_bindings_module);
313    syn::custom_keyword!(export_macro_name);
314    syn::custom_keyword!(pub_export_macro);
315    syn::custom_keyword!(generate_unused_types);
316    syn::custom_keyword!(features);
317    syn::custom_keyword!(disable_custom_section_link_helpers);
318    syn::custom_keyword!(imports);
319    syn::custom_keyword!(debug);
320}
321
322#[derive(Clone)]
323enum ExportKey {
324    World,
325    Name(syn::LitStr),
326}
327
328impl Parse for ExportKey {
329    fn parse(input: ParseStream<'_>) -> Result<Self> {
330        let l = input.lookahead1();
331        Ok(if l.peek(kw::world) {
332            input.parse::<kw::world>()?;
333            Self::World
334        } else {
335            Self::Name(input.parse()?)
336        })
337    }
338}
339
340impl From<ExportKey> for wit_bindgen_rust::ExportKey {
341    fn from(key: ExportKey) -> Self {
342        match key {
343            ExportKey::World => Self::World,
344            ExportKey::Name(s) => Self::Name(s.value()),
345        }
346    }
347}
348
349enum Opt {
350    World(syn::LitStr),
351    Path(Span, Vec<syn::LitStr>),
352    Inline(syn::LitStr),
353    UseStdFeature,
354    RawStrings,
355    Skip(Vec<syn::LitStr>),
356    Ownership(Ownership),
357    RuntimePath(syn::LitStr),
358    BitflagsPath(syn::LitStr),
359    Stubs,
360    ExportPrefix(syn::LitStr),
361    // Parse as paths so we can take the concrete types/macro names rather than raw strings
362    AdditionalDerives(Vec<syn::Path>),
363    AdditionalDerivesIgnore(Vec<syn::LitStr>),
364    With(HashMap<String, WithOption>),
365    GenerateAll,
366    TypeSectionSuffix(syn::LitStr),
367    DisableRunCtorsOnceWorkaround(syn::LitBool),
368    DefaultBindingsModule(syn::LitStr),
369    ExportMacroName(syn::LitStr),
370    PubExportMacro(syn::LitBool),
371    GenerateUnusedTypes(syn::LitBool),
372    Features(Vec<syn::LitStr>),
373    DisableCustomSectionLinkHelpers(syn::LitBool),
374    Async(AsyncFilterSet, Span),
375    Debug(syn::LitBool),
376}
377
378impl Parse for Opt {
379    fn parse(input: ParseStream<'_>) -> Result<Self> {
380        let l = input.lookahead1();
381        if l.peek(kw::path) {
382            input.parse::<kw::path>()?;
383            input.parse::<Token![:]>()?;
384            // the `path` supports two forms:
385            // * path: "xxx"
386            // * path: ["aaa", "bbb"]
387            if input.peek(token::Bracket) {
388                let contents;
389                syn::bracketed!(contents in input);
390                let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
391                Ok(Opt::Path(list.span(), list.into_iter().collect()))
392            } else {
393                let path: LitStr = input.parse()?;
394                Ok(Opt::Path(path.span(), vec![path]))
395            }
396        } else if l.peek(kw::inline) {
397            input.parse::<kw::inline>()?;
398            input.parse::<Token![:]>()?;
399            Ok(Opt::Inline(input.parse()?))
400        } else if l.peek(kw::world) {
401            input.parse::<kw::world>()?;
402            input.parse::<Token![:]>()?;
403            Ok(Opt::World(input.parse()?))
404        } else if l.peek(kw::std_feature) {
405            input.parse::<kw::std_feature>()?;
406            Ok(Opt::UseStdFeature)
407        } else if l.peek(kw::raw_strings) {
408            input.parse::<kw::raw_strings>()?;
409            Ok(Opt::RawStrings)
410        } else if l.peek(kw::ownership) {
411            input.parse::<kw::ownership>()?;
412            input.parse::<Token![:]>()?;
413            let ownership = input.parse::<syn::Ident>()?;
414            Ok(Opt::Ownership(match ownership.to_string().as_str() {
415                "Owning" => Ownership::Owning,
416                "Borrowing" => Ownership::Borrowing {
417                    duplicate_if_necessary: {
418                        let contents;
419                        braced!(contents in input);
420                        let field = contents.parse::<syn::Ident>()?;
421                        match field.to_string().as_str() {
422                            "duplicate_if_necessary" => {
423                                contents.parse::<Token![:]>()?;
424                                contents.parse::<syn::LitBool>()?.value
425                            }
426                            name => {
427                                return Err(Error::new(
428                                    field.span(),
429                                    format!(
430                                        "unrecognized `Ownership::Borrowing` field: `{name}`; \
431                                         expected `duplicate_if_necessary`"
432                                    ),
433                                ));
434                            }
435                        }
436                    },
437                },
438                name => {
439                    return Err(Error::new(
440                        ownership.span(),
441                        format!(
442                            "unrecognized ownership: `{name}`; \
443                             expected `Owning` or `Borrowing`"
444                        ),
445                    ));
446                }
447            }))
448        } else if l.peek(kw::skip) {
449            input.parse::<kw::skip>()?;
450            input.parse::<Token![:]>()?;
451            let contents;
452            syn::bracketed!(contents in input);
453            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
454            Ok(Opt::Skip(list.iter().cloned().collect()))
455        } else if l.peek(kw::runtime_path) {
456            input.parse::<kw::runtime_path>()?;
457            input.parse::<Token![:]>()?;
458            Ok(Opt::RuntimePath(input.parse()?))
459        } else if l.peek(kw::bitflags_path) {
460            input.parse::<kw::bitflags_path>()?;
461            input.parse::<Token![:]>()?;
462            Ok(Opt::BitflagsPath(input.parse()?))
463        } else if l.peek(kw::stubs) {
464            input.parse::<kw::stubs>()?;
465            Ok(Opt::Stubs)
466        } else if l.peek(kw::export_prefix) {
467            input.parse::<kw::export_prefix>()?;
468            input.parse::<Token![:]>()?;
469            Ok(Opt::ExportPrefix(input.parse()?))
470        } else if l.peek(kw::additional_derives) {
471            input.parse::<kw::additional_derives>()?;
472            input.parse::<Token![:]>()?;
473            let contents;
474            syn::bracketed!(contents in input);
475            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
476            Ok(Opt::AdditionalDerives(list.iter().cloned().collect()))
477        } else if l.peek(kw::additional_derives_ignore) {
478            input.parse::<kw::additional_derives_ignore>()?;
479            input.parse::<Token![:]>()?;
480            let contents;
481            syn::bracketed!(contents in input);
482            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
483            Ok(Opt::AdditionalDerivesIgnore(list.iter().cloned().collect()))
484        } else if l.peek(kw::with) {
485            input.parse::<kw::with>()?;
486            input.parse::<Token![:]>()?;
487            let contents;
488            let _lbrace = braced!(contents in input);
489            let fields: Punctuated<_, Token![,]> =
490                contents.parse_terminated(with_field_parse, Token![,])?;
491            Ok(Opt::With(HashMap::from_iter(fields.into_iter())))
492        } else if l.peek(kw::generate_all) {
493            input.parse::<kw::generate_all>()?;
494            Ok(Opt::GenerateAll)
495        } else if l.peek(kw::type_section_suffix) {
496            input.parse::<kw::type_section_suffix>()?;
497            input.parse::<Token![:]>()?;
498            Ok(Opt::TypeSectionSuffix(input.parse()?))
499        } else if l.peek(kw::disable_run_ctors_once_workaround) {
500            input.parse::<kw::disable_run_ctors_once_workaround>()?;
501            input.parse::<Token![:]>()?;
502            Ok(Opt::DisableRunCtorsOnceWorkaround(input.parse()?))
503        } else if l.peek(kw::default_bindings_module) {
504            input.parse::<kw::default_bindings_module>()?;
505            input.parse::<Token![:]>()?;
506            Ok(Opt::DefaultBindingsModule(input.parse()?))
507        } else if l.peek(kw::export_macro_name) {
508            input.parse::<kw::export_macro_name>()?;
509            input.parse::<Token![:]>()?;
510            Ok(Opt::ExportMacroName(input.parse()?))
511        } else if l.peek(kw::pub_export_macro) {
512            input.parse::<kw::pub_export_macro>()?;
513            input.parse::<Token![:]>()?;
514            Ok(Opt::PubExportMacro(input.parse()?))
515        } else if l.peek(kw::generate_unused_types) {
516            input.parse::<kw::generate_unused_types>()?;
517            input.parse::<Token![:]>()?;
518            Ok(Opt::GenerateUnusedTypes(input.parse()?))
519        } else if l.peek(kw::features) {
520            input.parse::<kw::features>()?;
521            input.parse::<Token![:]>()?;
522            let contents;
523            syn::bracketed!(contents in input);
524            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
525            Ok(Opt::Features(list.into_iter().collect()))
526        } else if l.peek(kw::disable_custom_section_link_helpers) {
527            input.parse::<kw::disable_custom_section_link_helpers>()?;
528            input.parse::<Token![:]>()?;
529            Ok(Opt::DisableCustomSectionLinkHelpers(input.parse()?))
530        } else if l.peek(kw::debug) {
531            input.parse::<kw::debug>()?;
532            input.parse::<Token![:]>()?;
533            Ok(Opt::Debug(input.parse()?))
534        } else if l.peek(Token![async]) {
535            let span = input.parse::<Token![async]>()?.span;
536            input.parse::<Token![:]>()?;
537            if input.peek(syn::LitBool) {
538                let enabled = input.parse::<syn::LitBool>()?.value;
539                Ok(Opt::Async(AsyncFilterSet::all(enabled), span))
540            } else {
541                let mut set = AsyncFilterSet::default();
542                let contents;
543                syn::bracketed!(contents in input);
544                for val in contents.parse_terminated(|p| p.parse::<syn::LitStr>(), Token![,])? {
545                    set.push(&val.value());
546                }
547                Ok(Opt::Async(set, span))
548            }
549        } else {
550            Err(l.error())
551        }
552    }
553}
554
555fn with_field_parse(input: ParseStream<'_>) -> Result<(String, WithOption)> {
556    let interface = input.parse::<syn::LitStr>()?.value();
557    input.parse::<Token![:]>()?;
558    let start = input.span();
559    let path = input.parse::<syn::Path>()?;
560
561    // It's not possible for the segments of a path to be empty
562    let span = start
563        .join(path.segments.last().unwrap().ident.span())
564        .unwrap_or(start);
565
566    if path.is_ident("generate") {
567        return Ok((interface, WithOption::Generate));
568    }
569
570    let mut buf = String::new();
571    let append = |buf: &mut String, segment: syn::PathSegment| -> Result<()> {
572        if !segment.arguments.is_none() {
573            return Err(Error::new(
574                span,
575                "Module path must not contain angles or parens",
576            ));
577        }
578
579        buf.push_str(&segment.ident.to_string());
580
581        Ok(())
582    };
583
584    if path.leading_colon.is_some() {
585        buf.push_str("::");
586    }
587
588    let mut segments = path.segments.into_iter();
589
590    if let Some(segment) = segments.next() {
591        append(&mut buf, segment)?;
592    }
593
594    for segment in segments {
595        buf.push_str("::");
596        append(&mut buf, segment)?;
597    }
598
599    Ok((interface, WithOption::Path(buf)))
600}
601
602/// Format a valid Rust string
603fn fmt(input: &str) -> Result<String> {
604    let syntax_tree = syn::parse_file(&input)?;
605    Ok(prettyplease::unparse(&syntax_tree))
606}