Skip to main content

wit_bindgen_rust_macro/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::ToTokens;
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
6use syn::parse::{Error, Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::{Token, braced, token};
9use wit_bindgen_core::AsyncFilterSet;
10use wit_bindgen_core::WorldGenerator;
11use wit_bindgen_core::wit_parser::{PackageId, Resolve, UnresolvedPackageGroup, WorldId};
12use wit_bindgen_rust::{Opts, Ownership, WithOption};
13
14#[proc_macro]
15pub fn generate(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
16    syn::parse_macro_input!(input as Config)
17        .expand()
18        .unwrap_or_else(Error::into_compile_error)
19        .into()
20}
21
22fn anyhow_to_syn(span: Span, err: anyhow::Error) -> Error {
23    let err = attach_with_context(err);
24    let mut msg = err.to_string();
25    for cause in err.chain().skip(1) {
26        msg.push_str(&format!("\n\nCaused by:\n  {cause}"));
27    }
28    Error::new(span, msg)
29}
30
31fn attach_with_context(err: anyhow::Error) -> anyhow::Error {
32    if let Some(e) = err.downcast_ref::<wit_bindgen_rust::MissingWith>() {
33        let option = e.0.clone();
34        return err.context(format!(
35            "missing one of:\n\
36            * `generate_all` option\n\
37            * `with: {{ \"{option}\": path::to::bindings, }}`\n\
38            * `with: {{ \"{option}\": generate, }}`\
39            "
40        ));
41    }
42    err
43}
44
45struct Config {
46    opts: Opts,
47    resolve: Resolve,
48    world: WorldId,
49    files: Vec<PathBuf>,
50    debug: bool,
51}
52
53/// The source of the wit package definition
54enum Source {
55    /// A list of paths to wit directories
56    Paths(Vec<PathBuf>),
57    /// Inline sources have an optional path to a directory of their dependencies
58    Inline(String, Option<Vec<PathBuf>>),
59}
60
61impl Parse for Config {
62    fn parse(input: ParseStream<'_>) -> Result<Self> {
63        let call_site = Span::call_site();
64        let mut opts = Opts::default();
65        let mut world = None;
66        let mut source = None;
67        let mut features = Vec::new();
68        let mut async_configured = false;
69        let mut debug = false;
70
71        if input.peek(token::Brace) {
72            let content;
73            syn::braced!(content in input);
74            let fields = Punctuated::<Opt, Token![,]>::parse_terminated(&content)?;
75            for field in fields.into_pairs() {
76                match field.into_value() {
77                    Opt::Path(span, p) => {
78                        let paths = p
79                            .into_iter()
80                            .map(|f| f.evaluate_string())
81                            .collect::<Result<Vec<_>>>()?
82                            .into_iter()
83                            .map(PathBuf::from)
84                            .collect();
85
86                        source = Some(match source {
87                            Some(Source::Paths(_)) | Some(Source::Inline(_, Some(_))) => {
88                                return Err(Error::new(span, "cannot specify second source"));
89                            }
90                            Some(Source::Inline(i, None)) => Source::Inline(i, Some(paths)),
91                            None => Source::Paths(paths),
92                        })
93                    }
94                    Opt::World(s) => {
95                        if world.is_some() {
96                            return Err(Error::new(s.span(), "cannot specify second world"));
97                        }
98                        world = Some(s.value());
99                    }
100                    Opt::Inline(s) => {
101                        source = Some(match source {
102                            Some(Source::Inline(_, _)) => {
103                                return Err(Error::new(s.span(), "cannot specify second source"));
104                            }
105                            Some(Source::Paths(p)) => Source::Inline(s.value(), Some(p)),
106                            None => Source::Inline(s.value(), None),
107                        })
108                    }
109                    Opt::UseStdFeature => opts.std_feature = true,
110                    Opt::RawStrings => opts.raw_strings = true,
111                    Opt::Ownership(ownership) => opts.ownership = ownership,
112                    Opt::Skip(list) => opts.skip.extend(list.iter().map(|i| i.value())),
113                    Opt::RuntimePath(path) => opts.runtime_path = Some(path.value()),
114                    Opt::BitflagsPath(path) => opts.bitflags_path = Some(path.value()),
115                    Opt::Stubs => {
116                        opts.stubs = true;
117                    }
118                    Opt::ExportPrefix(prefix) => opts.export_prefix = Some(prefix.value()),
119                    Opt::AdditionalDerives(paths) => {
120                        opts.additional_derive_attributes = paths
121                            .into_iter()
122                            .map(|p| p.into_token_stream().to_string())
123                            .collect()
124                    }
125                    Opt::AdditionalDerivesIgnore(list) => {
126                        opts.additional_derive_ignore =
127                            list.into_iter().map(|i| i.value()).collect()
128                    }
129                    Opt::With(with) => opts.with.extend(with),
130                    Opt::GenerateAll => {
131                        opts.generate_all = true;
132                    }
133                    Opt::TypeSectionSuffix(suffix) => {
134                        opts.type_section_suffix = Some(suffix.value());
135                    }
136                    Opt::DisableRunCtorsOnceWorkaround(enable) => {
137                        opts.disable_run_ctors_once_workaround = enable.value();
138                    }
139                    Opt::DefaultBindingsModule(enable) => {
140                        opts.default_bindings_module = Some(enable.value());
141                    }
142                    Opt::ExportMacroName(name) => {
143                        opts.export_macro_name = Some(name.value());
144                    }
145                    Opt::PubExportMacro(enable) => {
146                        opts.pub_export_macro = enable.value();
147                    }
148                    Opt::GenerateUnusedTypes(enable) => {
149                        opts.generate_unused_types = enable.value();
150                    }
151                    Opt::Features(f) => {
152                        features.extend(f.into_iter().map(|f| f.value()));
153                    }
154                    Opt::DisableCustomSectionLinkHelpers(disable) => {
155                        opts.disable_custom_section_link_helpers = disable.value();
156                    }
157                    Opt::Debug(enable) => {
158                        debug = enable.value();
159                    }
160                    Opt::Async(val, span) => {
161                        if async_configured {
162                            return Err(Error::new(span, "cannot specify second async config"));
163                        }
164                        async_configured = true;
165                        if val.any_enabled() && !cfg!(feature = "async") {
166                            return Err(Error::new(
167                                span,
168                                "must enable `async` feature to enable async imports and/or exports",
169                            ));
170                        }
171                        opts.async_ = val;
172                    }
173                }
174            }
175        } else {
176            world = input.parse::<Option<syn::LitStr>>()?.map(|s| s.value());
177            if input.parse::<Option<syn::token::In>>()?.is_some() {
178                source = Some(Source::Paths(vec![PathBuf::from(
179                    input.parse::<syn::LitStr>()?.value(),
180                )]));
181            }
182        }
183        let (resolve, main_packages, files) =
184            parse_source(&source, &features).map_err(|err| anyhow_to_syn(call_site, err))?;
185        let world = resolve
186            .select_world(&main_packages, world.as_deref())
187            .map_err(|e| anyhow_to_syn(call_site, e))?;
188        Ok(Config {
189            opts,
190            resolve,
191            world,
192            files,
193            debug,
194        })
195    }
196}
197
198/// Parse the source
199fn parse_source(
200    source: &Option<Source>,
201    features: &[String],
202) -> anyhow::Result<(Resolve, Vec<PackageId>, Vec<PathBuf>)> {
203    let mut resolve = Resolve::default();
204    resolve.features.extend(features.iter().cloned());
205    let mut files = Vec::new();
206    let mut pkgs = Vec::new();
207    let root = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap());
208    let mut parse = |paths: &[PathBuf]| -> anyhow::Result<()> {
209        for path in paths {
210            let p = root.join(path);
211            // Try to normalize the path to make the error message more understandable when
212            // the path is not correct. Fallback to the original path if normalization fails
213            // (probably return an error somewhere else).
214            let normalized_path = match std::fs::canonicalize(&p) {
215                Ok(p) => p,
216                Err(_) => p.to_path_buf(),
217            };
218            let (pkg, sources) = resolve.push_path(normalized_path)?;
219            pkgs.push(pkg);
220            files.extend(sources.paths().map(|p| p.to_owned()));
221        }
222        Ok(())
223    };
224    let default = root.join("wit");
225    match source {
226        Some(Source::Inline(s, path)) => {
227            match path {
228                Some(p) => parse(p)?,
229                // If no `path` is explicitly specified still parse the default
230                // `wit` directory if it exists. Don't require its existence,
231                // however, as `inline` can be used in lieu of a folder. Test
232                // whether it exists and only if there is it parsed.
233                None => {
234                    if default.exists() {
235                        parse(&[default])?;
236                    }
237                }
238            }
239            pkgs.truncate(0);
240            pkgs.push(resolve.push_group(UnresolvedPackageGroup::parse("macro-input", s)?)?);
241        }
242        Some(Source::Paths(p)) => parse(p)?,
243        None => parse(&[default])?,
244    };
245
246    Ok((resolve, pkgs, files))
247}
248
249impl Config {
250    fn expand(mut self) -> Result<TokenStream> {
251        let mut files = Default::default();
252        let mut generator = self.opts.build();
253        generator
254            .generate(&mut self.resolve, self.world, &mut files)
255            .map_err(|e| anyhow_to_syn(Span::call_site(), e))?;
256        let (_, src) = files.iter().next().unwrap();
257        let mut src = std::str::from_utf8(src).unwrap().to_string();
258
259        // If a magical `WIT_BINDGEN_DEBUG` environment variable is set then
260        // place a formatted version of the expanded code into a file. This file
261        // will then show up in rustc error messages for any codegen issues and can
262        // be inspected manually.
263        if std::env::var("WIT_BINDGEN_DEBUG").is_ok() || self.debug {
264            static INVOCATION: AtomicUsize = AtomicUsize::new(0);
265            let root = Path::new(env!("DEBUG_OUTPUT_DIR"));
266            let world_name = &self.resolve.worlds[self.world].name;
267            let n = INVOCATION.fetch_add(1, Relaxed);
268            let path = root.join(format!("{world_name}{n}.rs"));
269
270            // optimistically format the code but don't require success
271            let contents = match fmt(&src) {
272                Ok(formatted) => formatted,
273                Err(_) => src.clone(),
274            };
275            std::fs::write(&path, contents.as_bytes()).unwrap();
276
277            src = format!("include!({path:?});");
278        }
279        let mut contents = src.parse::<TokenStream>().unwrap();
280
281        // Include a dummy `include_bytes!` for any files we read so rustc knows that
282        // we depend on the contents of those files.
283        for file in self.files.iter() {
284            contents.extend(
285                format!(
286                    "const _: &[u8] = include_bytes!(r#\"{}\"#);\n",
287                    file.display()
288                )
289                .parse::<TokenStream>()
290                .unwrap(),
291            );
292        }
293
294        Ok(contents)
295    }
296}
297
298mod kw {
299    syn::custom_keyword!(std_feature);
300    syn::custom_keyword!(raw_strings);
301    syn::custom_keyword!(skip);
302    syn::custom_keyword!(world);
303    syn::custom_keyword!(path);
304    syn::custom_keyword!(inline);
305    syn::custom_keyword!(ownership);
306    syn::custom_keyword!(runtime_path);
307    syn::custom_keyword!(bitflags_path);
308    syn::custom_keyword!(exports);
309    syn::custom_keyword!(stubs);
310    syn::custom_keyword!(export_prefix);
311    syn::custom_keyword!(additional_derives);
312    syn::custom_keyword!(additional_derives_ignore);
313    syn::custom_keyword!(with);
314    syn::custom_keyword!(generate_all);
315    syn::custom_keyword!(type_section_suffix);
316    syn::custom_keyword!(disable_run_ctors_once_workaround);
317    syn::custom_keyword!(default_bindings_module);
318    syn::custom_keyword!(export_macro_name);
319    syn::custom_keyword!(pub_export_macro);
320    syn::custom_keyword!(generate_unused_types);
321    syn::custom_keyword!(features);
322    syn::custom_keyword!(disable_custom_section_link_helpers);
323    syn::custom_keyword!(imports);
324    syn::custom_keyword!(debug);
325}
326
327#[derive(Clone)]
328enum ExportKey {
329    World,
330    Name(syn::LitStr),
331}
332
333impl Parse for ExportKey {
334    fn parse(input: ParseStream<'_>) -> Result<Self> {
335        let l = input.lookahead1();
336        Ok(if l.peek(kw::world) {
337            input.parse::<kw::world>()?;
338            Self::World
339        } else {
340            Self::Name(input.parse()?)
341        })
342    }
343}
344
345impl From<ExportKey> for wit_bindgen_rust::ExportKey {
346    fn from(key: ExportKey) -> Self {
347        match key {
348            ExportKey::World => Self::World,
349            ExportKey::Name(s) => Self::Name(s.value()),
350        }
351    }
352}
353
354#[cfg(feature = "macro-string")]
355type PathType = macro_string::MacroString;
356#[cfg(not(feature = "macro-string"))]
357type PathType = syn::LitStr;
358
359trait EvaluateString {
360    fn evaluate_string(&self) -> Result<String>;
361}
362
363#[cfg(feature = "macro-string")]
364impl EvaluateString for macro_string::MacroString {
365    fn evaluate_string(&self) -> Result<String> {
366        self.eval()
367    }
368}
369
370#[cfg(not(feature = "macro-string"))]
371impl EvaluateString for syn::LitStr {
372    fn evaluate_string(&self) -> Result<String> {
373        Ok(self.value())
374    }
375}
376
377enum Opt {
378    World(syn::LitStr),
379    Path(Span, Vec<PathType>),
380    Inline(syn::LitStr),
381    UseStdFeature,
382    RawStrings,
383    Skip(Vec<syn::LitStr>),
384    Ownership(Ownership),
385    RuntimePath(syn::LitStr),
386    BitflagsPath(syn::LitStr),
387    Stubs,
388    ExportPrefix(syn::LitStr),
389    // Parse as paths so we can take the concrete types/macro names rather than raw strings
390    AdditionalDerives(Vec<syn::Path>),
391    AdditionalDerivesIgnore(Vec<syn::LitStr>),
392    With(HashMap<String, WithOption>),
393    GenerateAll,
394    TypeSectionSuffix(syn::LitStr),
395    DisableRunCtorsOnceWorkaround(syn::LitBool),
396    DefaultBindingsModule(syn::LitStr),
397    ExportMacroName(syn::LitStr),
398    PubExportMacro(syn::LitBool),
399    GenerateUnusedTypes(syn::LitBool),
400    Features(Vec<syn::LitStr>),
401    DisableCustomSectionLinkHelpers(syn::LitBool),
402    Async(AsyncFilterSet, Span),
403    Debug(syn::LitBool),
404}
405
406impl Parse for Opt {
407    fn parse(input: ParseStream<'_>) -> Result<Self> {
408        let l = input.lookahead1();
409        if l.peek(kw::path) {
410            input.parse::<kw::path>()?;
411            input.parse::<Token![:]>()?;
412            // the `path` supports two forms:
413            // * path: "xxx"
414            // * path: ["aaa", "bbb"]
415            if input.peek(token::Bracket) {
416                let contents;
417                syn::bracketed!(contents in input);
418                let span = input.span();
419                let list = Punctuated::<PathType, Token![,]>::parse_terminated(&contents)?;
420                Ok(Opt::Path(span, list.into_iter().collect()))
421            } else {
422                let span = input.span();
423                let path: PathType = input.parse()?;
424                Ok(Opt::Path(span, vec![path]))
425            }
426        } else if l.peek(kw::inline) {
427            input.parse::<kw::inline>()?;
428            input.parse::<Token![:]>()?;
429            Ok(Opt::Inline(input.parse()?))
430        } else if l.peek(kw::world) {
431            input.parse::<kw::world>()?;
432            input.parse::<Token![:]>()?;
433            Ok(Opt::World(input.parse()?))
434        } else if l.peek(kw::std_feature) {
435            input.parse::<kw::std_feature>()?;
436            Ok(Opt::UseStdFeature)
437        } else if l.peek(kw::raw_strings) {
438            input.parse::<kw::raw_strings>()?;
439            Ok(Opt::RawStrings)
440        } else if l.peek(kw::ownership) {
441            input.parse::<kw::ownership>()?;
442            input.parse::<Token![:]>()?;
443            let ownership = input.parse::<syn::Ident>()?;
444            Ok(Opt::Ownership(match ownership.to_string().as_str() {
445                "Owning" => Ownership::Owning,
446                "Borrowing" => Ownership::Borrowing {
447                    duplicate_if_necessary: {
448                        let contents;
449                        braced!(contents in input);
450                        let field = contents.parse::<syn::Ident>()?;
451                        match field.to_string().as_str() {
452                            "duplicate_if_necessary" => {
453                                contents.parse::<Token![:]>()?;
454                                contents.parse::<syn::LitBool>()?.value
455                            }
456                            name => {
457                                return Err(Error::new(
458                                    field.span(),
459                                    format!(
460                                        "unrecognized `Ownership::Borrowing` field: `{name}`; \
461                                         expected `duplicate_if_necessary`"
462                                    ),
463                                ));
464                            }
465                        }
466                    },
467                },
468                name => {
469                    return Err(Error::new(
470                        ownership.span(),
471                        format!(
472                            "unrecognized ownership: `{name}`; \
473                             expected `Owning` or `Borrowing`"
474                        ),
475                    ));
476                }
477            }))
478        } else if l.peek(kw::skip) {
479            input.parse::<kw::skip>()?;
480            input.parse::<Token![:]>()?;
481            let contents;
482            syn::bracketed!(contents in input);
483            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
484            Ok(Opt::Skip(list.iter().cloned().collect()))
485        } else if l.peek(kw::runtime_path) {
486            input.parse::<kw::runtime_path>()?;
487            input.parse::<Token![:]>()?;
488            Ok(Opt::RuntimePath(input.parse()?))
489        } else if l.peek(kw::bitflags_path) {
490            input.parse::<kw::bitflags_path>()?;
491            input.parse::<Token![:]>()?;
492            Ok(Opt::BitflagsPath(input.parse()?))
493        } else if l.peek(kw::stubs) {
494            input.parse::<kw::stubs>()?;
495            Ok(Opt::Stubs)
496        } else if l.peek(kw::export_prefix) {
497            input.parse::<kw::export_prefix>()?;
498            input.parse::<Token![:]>()?;
499            Ok(Opt::ExportPrefix(input.parse()?))
500        } else if l.peek(kw::additional_derives) {
501            input.parse::<kw::additional_derives>()?;
502            input.parse::<Token![:]>()?;
503            let contents;
504            syn::bracketed!(contents in input);
505            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
506            Ok(Opt::AdditionalDerives(list.iter().cloned().collect()))
507        } else if l.peek(kw::additional_derives_ignore) {
508            input.parse::<kw::additional_derives_ignore>()?;
509            input.parse::<Token![:]>()?;
510            let contents;
511            syn::bracketed!(contents in input);
512            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
513            Ok(Opt::AdditionalDerivesIgnore(list.iter().cloned().collect()))
514        } else if l.peek(kw::with) {
515            input.parse::<kw::with>()?;
516            input.parse::<Token![:]>()?;
517            let contents;
518            let _lbrace = braced!(contents in input);
519            let fields: Punctuated<_, Token![,]> =
520                contents.parse_terminated(with_field_parse, Token![,])?;
521            Ok(Opt::With(HashMap::from_iter(fields)))
522        } else if l.peek(kw::generate_all) {
523            input.parse::<kw::generate_all>()?;
524            Ok(Opt::GenerateAll)
525        } else if l.peek(kw::type_section_suffix) {
526            input.parse::<kw::type_section_suffix>()?;
527            input.parse::<Token![:]>()?;
528            Ok(Opt::TypeSectionSuffix(input.parse()?))
529        } else if l.peek(kw::disable_run_ctors_once_workaround) {
530            input.parse::<kw::disable_run_ctors_once_workaround>()?;
531            input.parse::<Token![:]>()?;
532            Ok(Opt::DisableRunCtorsOnceWorkaround(input.parse()?))
533        } else if l.peek(kw::default_bindings_module) {
534            input.parse::<kw::default_bindings_module>()?;
535            input.parse::<Token![:]>()?;
536            Ok(Opt::DefaultBindingsModule(input.parse()?))
537        } else if l.peek(kw::export_macro_name) {
538            input.parse::<kw::export_macro_name>()?;
539            input.parse::<Token![:]>()?;
540            Ok(Opt::ExportMacroName(input.parse()?))
541        } else if l.peek(kw::pub_export_macro) {
542            input.parse::<kw::pub_export_macro>()?;
543            input.parse::<Token![:]>()?;
544            Ok(Opt::PubExportMacro(input.parse()?))
545        } else if l.peek(kw::generate_unused_types) {
546            input.parse::<kw::generate_unused_types>()?;
547            input.parse::<Token![:]>()?;
548            Ok(Opt::GenerateUnusedTypes(input.parse()?))
549        } else if l.peek(kw::features) {
550            input.parse::<kw::features>()?;
551            input.parse::<Token![:]>()?;
552            let contents;
553            syn::bracketed!(contents in input);
554            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
555            Ok(Opt::Features(list.into_iter().collect()))
556        } else if l.peek(kw::disable_custom_section_link_helpers) {
557            input.parse::<kw::disable_custom_section_link_helpers>()?;
558            input.parse::<Token![:]>()?;
559            Ok(Opt::DisableCustomSectionLinkHelpers(input.parse()?))
560        } else if l.peek(kw::debug) {
561            input.parse::<kw::debug>()?;
562            input.parse::<Token![:]>()?;
563            Ok(Opt::Debug(input.parse()?))
564        } else if l.peek(Token![async]) {
565            let span = input.parse::<Token![async]>()?.span;
566            input.parse::<Token![:]>()?;
567            if input.peek(syn::LitBool) {
568                let enabled = input.parse::<syn::LitBool>()?.value;
569                Ok(Opt::Async(AsyncFilterSet::all(enabled), span))
570            } else {
571                let mut set = AsyncFilterSet::default();
572                let contents;
573                syn::bracketed!(contents in input);
574                for val in contents.parse_terminated(|p| p.parse::<syn::LitStr>(), Token![,])? {
575                    set.push(&val.value());
576                }
577                Ok(Opt::Async(set, span))
578            }
579        } else {
580            Err(l.error())
581        }
582    }
583}
584
585fn with_field_parse(input: ParseStream<'_>) -> Result<(String, WithOption)> {
586    let interface = input.parse::<syn::LitStr>()?.value();
587    input.parse::<Token![:]>()?;
588    let start = input.span();
589    let path = input.parse::<syn::Path>()?;
590
591    // It's not possible for the segments of a path to be empty
592    let span = start
593        .join(path.segments.last().unwrap().ident.span())
594        .unwrap_or(start);
595
596    if path.is_ident("generate") {
597        return Ok((interface, WithOption::Generate));
598    }
599
600    let mut buf = String::new();
601    let append = |buf: &mut String, segment: syn::PathSegment| -> Result<()> {
602        if !segment.arguments.is_none() {
603            return Err(Error::new(
604                span,
605                "Module path must not contain angles or parens",
606            ));
607        }
608
609        buf.push_str(&segment.ident.to_string());
610
611        Ok(())
612    };
613
614    if path.leading_colon.is_some() {
615        buf.push_str("::");
616    }
617
618    let mut segments = path.segments.into_iter();
619
620    if let Some(segment) = segments.next() {
621        append(&mut buf, segment)?;
622    }
623
624    for segment in segments {
625        buf.push_str("::");
626        append(&mut buf, segment)?;
627    }
628
629    Ok((interface, WithOption::Path(buf)))
630}
631
632/// Format a valid Rust string
633fn fmt(input: &str) -> Result<String> {
634    let syntax_tree = syn::parse_file(input)?;
635    Ok(prettyplease::unparse(&syntax_tree))
636}