wit_bindgen_wrpc_rust_macro/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::ToTokens;
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
6use syn::parse::{Error, Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::{braced, token, LitStr, Token};
10use wit_bindgen_core::wit_parser::{PackageId, Resolve, UnresolvedPackageGroup, WorldId};
11use wit_bindgen_wrpc_rust::{Opts, WithOption};
12
13#[proc_macro]
14pub fn generate(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
15    syn::parse_macro_input!(input as Config)
16        .expand()
17        .unwrap_or_else(Error::into_compile_error)
18        .into()
19}
20
21fn anyhow_to_syn(span: Span, err: anyhow::Error) -> Error {
22    let err = attach_with_context(err);
23    let mut msg = err.to_string();
24    for cause in err.chain().skip(1) {
25        msg.push_str(&format!("\n\nCaused by:\n  {cause}"));
26    }
27    Error::new(span, msg)
28}
29
30fn attach_with_context(err: anyhow::Error) -> anyhow::Error {
31    if let Some(e) = err.downcast_ref::<wit_bindgen_wrpc_rust::MissingWith>() {
32        let option = e.0.clone();
33        return err.context(format!(
34            "missing one of:\n\
35            * `generate_all` option\n\
36            * `with: {{ \"{option}\": path::to::bindings, }}`\n\
37            * `with: {{ \"{option}\": generate, }}`\
38            "
39        ));
40    }
41    err
42}
43
44struct Config {
45    opts: Opts,
46    resolve: Resolve,
47    world: WorldId,
48    files: Vec<PathBuf>,
49}
50
51/// The source of the wit package definition
52enum Source {
53    /// A path to a wit directory
54    Paths(Vec<PathBuf>),
55    /// Inline sources have an optional path to a directory of their dependencies
56    Inline(String, Option<Vec<PathBuf>>),
57}
58
59impl Parse for Config {
60    fn parse(input: ParseStream<'_>) -> Result<Self> {
61        let call_site = Span::call_site();
62        let mut opts = Opts::default();
63        let mut world = None;
64        let mut source = None;
65        let mut features = Vec::new();
66
67        if input.peek(token::Brace) {
68            let content;
69            syn::braced!(content in input);
70            let fields = Punctuated::<Opt, Token![,]>::parse_terminated(&content)?;
71            for field in fields.into_pairs() {
72                match field.into_value() {
73                    Opt::Path(span, p) => {
74                        let paths = p.into_iter().map(|f| PathBuf::from(f.value())).collect();
75
76                        source = Some(match source {
77                            Some(Source::Paths(_) | Source::Inline(_, Some(_))) => {
78                                return Err(Error::new(span, "cannot specify second source"));
79                            }
80                            Some(Source::Inline(i, None)) => Source::Inline(i, Some(paths)),
81                            None => Source::Paths(paths),
82                        });
83                    }
84                    Opt::World(s) => {
85                        if world.is_some() {
86                            return Err(Error::new(s.span(), "cannot specify second world"));
87                        }
88                        world = Some(s.value());
89                    }
90                    Opt::Inline(s) => {
91                        source = Some(match source {
92                            Some(Source::Inline(_, _)) => {
93                                return Err(Error::new(s.span(), "cannot specify second source"));
94                            }
95                            Some(Source::Paths(p)) => Source::Inline(s.value(), Some(p)),
96                            None => Source::Inline(s.value(), None),
97                        });
98                    }
99                    Opt::Skip(list) => opts.skip.extend(list.iter().map(syn::LitStr::value)),
100                    Opt::BitflagsPath(path) => opts.bitflags_path = Some(path.value()),
101                    Opt::AdditionalDerives(paths) => {
102                        opts.additional_derive_attributes = paths
103                            .into_iter()
104                            .map(|p| p.into_token_stream().to_string())
105                            .collect();
106                    }
107                    Opt::With(with) => opts.with.extend(with),
108                    Opt::GenerateAll => {
109                        opts.generate_all = true;
110                    }
111                    Opt::GenerateUnusedTypes(enable) => {
112                        opts.generate_unused_types = enable.value();
113                    }
114                    Opt::Features(f) => {
115                        features.extend(f.into_iter().map(|f| f.value()));
116                    }
117                    Opt::AnyhowPath(path) => {
118                        opts.anyhow_path = Some(path.value());
119                    }
120                    Opt::BytesPath(path) => {
121                        opts.bytes_path = Some(path.value());
122                    }
123                    Opt::FuturesPath(path) => {
124                        opts.futures_path = Some(path.value());
125                    }
126                    Opt::TokioPath(path) => {
127                        opts.tokio_path = Some(path.value());
128                    }
129                    Opt::TokioUtilPath(path) => {
130                        opts.tokio_util_path = Some(path.value());
131                    }
132                    Opt::TracingPath(path) => {
133                        opts.tracing_path = Some(path.value());
134                    }
135                    Opt::WasmTokioPath(path) => {
136                        opts.wasm_tokio_path = Some(path.value());
137                    }
138                    Opt::WrpcTransportPath(path) => {
139                        opts.wrpc_transport_path = Some(path.value());
140                    }
141                }
142            }
143        } else {
144            world = input.parse::<Option<syn::LitStr>>()?.map(|s| s.value());
145            if input.parse::<Option<syn::token::In>>()?.is_some() {
146                source = Some(Source::Paths(vec![PathBuf::from(
147                    input.parse::<syn::LitStr>()?.value(),
148                )]));
149            }
150        }
151        let (resolve, pkgs, files) =
152            parse_source(&source, &features).map_err(|err| anyhow_to_syn(call_site, err))?;
153        let world = select_world(&resolve, &pkgs, world.as_deref())
154            .map_err(|e| anyhow_to_syn(call_site, e))?;
155        Ok(Config {
156            opts,
157            resolve,
158            world,
159            files,
160        })
161    }
162}
163
164fn select_world(
165    resolve: &Resolve,
166    pkgs: &[PackageId],
167    world: Option<&str>,
168) -> anyhow::Result<WorldId> {
169    if pkgs.len() == 1 {
170        resolve.select_world(pkgs[0], world)
171    } else {
172        assert!(!pkgs.is_empty());
173        if let Some(name) = world {
174            if !name.contains(':') {
175                anyhow::bail!(
176                    "with multiple packages a fully qualified \
177                     world name must be specified"
178                )
179            }
180
181            // This will ignore the package argument due to the fully
182            // qualified name being used.
183            resolve.select_world(pkgs[0], world)
184        } else {
185            let worlds = pkgs
186                .iter()
187                .filter_map(|p| resolve.select_world(*p, None).ok())
188                .collect::<Vec<_>>();
189            match &worlds[..] {
190                [] => anyhow::bail!("no packages have a world"),
191                [world] => Ok(*world),
192                _ => anyhow::bail!("multiple packages have a world, must specify which to use"),
193            }
194        }
195    }
196}
197
198/// Parse the source
199fn parse_source(
200    source: &Option<Source>,
201    features: &[String],
202) -> anyhow::Result<(Resolve, Vec<PackageId>, Vec<PathBuf>)> {
203    let mut resolve = Resolve::default();
204    resolve.features.extend(features.iter().cloned());
205    let mut files = Vec::new();
206    let mut pkgs = Vec::new();
207    let root = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap());
208    let mut parse = |paths: &[PathBuf]| -> anyhow::Result<()> {
209        for path in paths {
210            let p = root.join(path);
211            // Try to normalize the path to make the error message more understandable when
212            // the path is not correct. Fallback to the original path if normalization fails
213            // (probably return an error somewhere else).
214            let normalized_path = match std::fs::canonicalize(&p) {
215                Ok(p) => p,
216                Err(_) => p.clone(),
217            };
218            let (pkg, sources) = resolve.push_path(normalized_path)?;
219            pkgs.push(pkg);
220            files.extend(sources.paths().map(|p| p.to_owned()));
221        }
222        Ok(())
223    };
224    match source {
225        Some(Source::Inline(s, path)) => {
226            if let Some(p) = path {
227                parse(p)?;
228            }
229            pkgs.push(resolve.push_group(UnresolvedPackageGroup::parse("macro-input", s)?)?);
230        }
231        Some(Source::Paths(p)) => parse(p)?,
232        None => parse(&[root.join("wit")])?,
233    };
234
235    Ok((resolve, pkgs, files))
236}
237
238impl Config {
239    fn expand(self) -> Result<TokenStream> {
240        let mut files = Default::default();
241        let mut generator = self.opts.build();
242        generator
243            .generate(&self.resolve, self.world, &mut files)
244            .map_err(|e| anyhow_to_syn(Span::call_site(), e))?;
245        let (_, src) = files.iter().next().unwrap();
246        let mut src = std::str::from_utf8(src).unwrap().to_string();
247
248        // If a magical `WIT_BINDGEN_DEBUG` environment variable is set then
249        // place a formatted version of the expanded code into a file. This file
250        // will then show up in rustc error messages for any codegen issues and can
251        // be inspected manually.
252        if std::env::var("WIT_BINDGEN_DEBUG").is_ok() {
253            static INVOCATION: AtomicUsize = AtomicUsize::new(0);
254            let root = Path::new(env!("DEBUG_OUTPUT_DIR"));
255            let world_name = &self.resolve.worlds[self.world].name;
256            let n = INVOCATION.fetch_add(1, Relaxed);
257            let path = root.join(format!("{world_name}{n}.rs"));
258
259            // optimistically format the code but don't require success
260            let contents = match fmt(&src) {
261                Ok(formatted) => formatted,
262                Err(_) => src.clone(),
263            };
264            std::fs::write(&path, contents.as_bytes()).unwrap();
265
266            src = format!("include!({path:?});");
267        }
268        let mut contents = src.parse::<TokenStream>().unwrap();
269
270        // Include a dummy `include_bytes!` for any files we read so rustc knows that
271        // we depend on the contents of those files.
272        for file in &self.files {
273            contents.extend(
274                format!(
275                    "const _: &[u8] = include_bytes!(r#\"{}\"#);\n",
276                    file.display()
277                )
278                .parse::<TokenStream>()
279                .unwrap(),
280            );
281        }
282
283        Ok(contents)
284    }
285}
286
287mod kw {
288    syn::custom_keyword!(skip);
289    syn::custom_keyword!(world);
290    syn::custom_keyword!(path);
291    syn::custom_keyword!(inline);
292    syn::custom_keyword!(bitflags_path);
293    syn::custom_keyword!(exports);
294    syn::custom_keyword!(additional_derives);
295    syn::custom_keyword!(with);
296    syn::custom_keyword!(generate_all);
297    syn::custom_keyword!(generate_unused_types);
298    syn::custom_keyword!(features);
299    syn::custom_keyword!(anyhow_path);
300    syn::custom_keyword!(bytes_path);
301    syn::custom_keyword!(futures_path);
302    syn::custom_keyword!(tokio_path);
303    syn::custom_keyword!(tokio_util_path);
304    syn::custom_keyword!(tracing_path);
305    syn::custom_keyword!(wasm_tokio_path);
306    syn::custom_keyword!(wrpc_transport_path);
307}
308
309enum Opt {
310    World(syn::LitStr),
311    Path(Span, Vec<syn::LitStr>),
312    Inline(syn::LitStr),
313    Skip(Vec<syn::LitStr>),
314    BitflagsPath(syn::LitStr),
315    // Parse as paths so we can take the concrete types/macro names rather than raw strings
316    AdditionalDerives(Vec<syn::Path>),
317    With(HashMap<String, WithOption>),
318    GenerateAll,
319    GenerateUnusedTypes(syn::LitBool),
320    Features(Vec<syn::LitStr>),
321    AnyhowPath(syn::LitStr),
322    BytesPath(syn::LitStr),
323    FuturesPath(syn::LitStr),
324    TokioPath(syn::LitStr),
325    TokioUtilPath(syn::LitStr),
326    TracingPath(syn::LitStr),
327    WasmTokioPath(syn::LitStr),
328    WrpcTransportPath(syn::LitStr),
329}
330
331impl Parse for Opt {
332    fn parse(input: ParseStream<'_>) -> Result<Self> {
333        let l = input.lookahead1();
334        if l.peek(kw::path) {
335            input.parse::<kw::path>()?;
336            input.parse::<Token![:]>()?;
337            // the `path` supports two forms:
338            // * path: "xxx"
339            // * path: ["aaa", "bbb"]
340            if input.peek(token::Bracket) {
341                let contents;
342                syn::bracketed!(contents in input);
343                let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
344                Ok(Opt::Path(list.span(), list.into_iter().collect()))
345            } else {
346                let path: LitStr = input.parse()?;
347                Ok(Opt::Path(path.span(), vec![path]))
348            }
349        } else if l.peek(kw::inline) {
350            input.parse::<kw::inline>()?;
351            input.parse::<Token![:]>()?;
352            Ok(Opt::Inline(input.parse()?))
353        } else if l.peek(kw::world) {
354            input.parse::<kw::world>()?;
355            input.parse::<Token![:]>()?;
356            Ok(Opt::World(input.parse()?))
357        } else if l.peek(kw::skip) {
358            input.parse::<kw::skip>()?;
359            input.parse::<Token![:]>()?;
360            let contents;
361            syn::bracketed!(contents in input);
362            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
363            Ok(Opt::Skip(list.iter().cloned().collect()))
364        } else if l.peek(kw::bitflags_path) {
365            input.parse::<kw::bitflags_path>()?;
366            input.parse::<Token![:]>()?;
367            Ok(Opt::BitflagsPath(input.parse()?))
368        } else if l.peek(kw::additional_derives) {
369            input.parse::<kw::additional_derives>()?;
370            input.parse::<Token![:]>()?;
371            let contents;
372            syn::bracketed!(contents in input);
373            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
374            Ok(Opt::AdditionalDerives(list.iter().cloned().collect()))
375        } else if l.peek(kw::with) {
376            input.parse::<kw::with>()?;
377            input.parse::<Token![:]>()?;
378            let contents;
379            let _lbrace = braced!(contents in input);
380            let fields: Punctuated<_, Token![,]> =
381                contents.parse_terminated(with_field_parse, Token![,])?;
382            Ok(Opt::With(HashMap::from_iter(fields)))
383        } else if l.peek(kw::generate_all) {
384            input.parse::<kw::generate_all>()?;
385            Ok(Opt::GenerateAll)
386        } else if l.peek(kw::generate_unused_types) {
387            input.parse::<kw::generate_unused_types>()?;
388            input.parse::<Token![:]>()?;
389            Ok(Opt::GenerateUnusedTypes(input.parse()?))
390        } else if l.peek(kw::features) {
391            input.parse::<kw::features>()?;
392            input.parse::<Token![:]>()?;
393            let contents;
394            syn::bracketed!(contents in input);
395            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
396            Ok(Opt::Features(list.into_iter().collect()))
397        } else if l.peek(kw::anyhow_path) {
398            input.parse::<kw::anyhow_path>()?;
399            input.parse::<Token![:]>()?;
400            Ok(Opt::AnyhowPath(input.parse()?))
401        } else if l.peek(kw::bytes_path) {
402            input.parse::<kw::bytes_path>()?;
403            input.parse::<Token![:]>()?;
404            Ok(Opt::BytesPath(input.parse()?))
405        } else if l.peek(kw::futures_path) {
406            input.parse::<kw::futures_path>()?;
407            input.parse::<Token![:]>()?;
408            Ok(Opt::FuturesPath(input.parse()?))
409        } else if l.peek(kw::tokio_path) {
410            input.parse::<kw::tokio_path>()?;
411            input.parse::<Token![:]>()?;
412            Ok(Opt::TokioPath(input.parse()?))
413        } else if l.peek(kw::tokio_util_path) {
414            input.parse::<kw::tokio_util_path>()?;
415            input.parse::<Token![:]>()?;
416            Ok(Opt::TokioUtilPath(input.parse()?))
417        } else if l.peek(kw::tracing_path) {
418            input.parse::<kw::tracing_path>()?;
419            input.parse::<Token![:]>()?;
420            Ok(Opt::TracingPath(input.parse()?))
421        } else if l.peek(kw::wasm_tokio_path) {
422            input.parse::<kw::wasm_tokio_path>()?;
423            input.parse::<Token![:]>()?;
424            Ok(Opt::WasmTokioPath(input.parse()?))
425        } else if l.peek(kw::wrpc_transport_path) {
426            input.parse::<kw::wrpc_transport_path>()?;
427            input.parse::<Token![:]>()?;
428            Ok(Opt::WrpcTransportPath(input.parse()?))
429        } else {
430            Err(l.error())
431        }
432    }
433}
434
435fn with_field_parse(input: ParseStream<'_>) -> Result<(String, WithOption)> {
436    let interface = input.parse::<syn::LitStr>()?.value();
437    input.parse::<Token![:]>()?;
438    let start = input.span();
439    let path = input.parse::<syn::Path>()?;
440
441    // It's not possible for the segments of a path to be empty
442    let span = start
443        .join(path.segments.last().unwrap().ident.span())
444        .unwrap_or(start);
445
446    if path.is_ident("generate") {
447        return Ok((interface, WithOption::Generate));
448    }
449
450    let mut buf = String::new();
451    let append = |buf: &mut String, segment: syn::PathSegment| -> Result<()> {
452        if !segment.arguments.is_none() {
453            return Err(Error::new(
454                span,
455                "Module path must not contain angles or parens",
456            ));
457        }
458
459        buf.push_str(&segment.ident.to_string());
460
461        Ok(())
462    };
463
464    if path.leading_colon.is_some() {
465        buf.push_str("::");
466    }
467
468    let mut segments = path.segments.into_iter();
469
470    if let Some(segment) = segments.next() {
471        append(&mut buf, segment)?;
472    }
473
474    for segment in segments {
475        buf.push_str("::");
476        append(&mut buf, segment)?;
477    }
478
479    Ok((interface, WithOption::Path(buf)))
480}
481
482/// Format a valid Rust string
483fn fmt(input: &str) -> Result<String> {
484    let syntax_tree = syn::parse_file(input)?;
485    Ok(prettyplease::unparse(&syntax_tree))
486}