wasm_bridge_macros/
lib.rs

1use std::{ops::Deref, str::FromStr};
2
3use original::{Style, VariantStyle};
4use quote::ToTokens;
5use regex::{Captures, Regex};
6use syn::{Attribute, ImplItem, ItemImpl};
7
8mod direct_impl;
9mod original;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12enum CompilationTarget {
13    Sys,
14    Js,
15}
16
17#[proc_macro_derive(Lift, attributes(component))]
18pub fn lift(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
19    replace_namespace(original::lift(input))
20}
21
22#[proc_macro_derive(Lower, attributes(component))]
23pub fn lower(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24    replace_namespace(original::lower(input))
25}
26
27#[proc_macro_derive(ComponentType, attributes(component))]
28pub fn component_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29    replace_namespace(original::component_type(input))
30}
31
32#[proc_macro]
33pub fn flags_sys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
34    replace_namespace(original::flags(input, CompilationTarget::Sys))
35}
36
37#[proc_macro]
38pub fn flags_js(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
39    replace_namespace(original::flags(input, CompilationTarget::Js))
40}
41
42fn bindgen(input: proc_macro::TokenStream) -> String {
43    let as_string = replace_namespace_str(original::bindgen(input));
44
45    // Add PartialEq derive, so that testing isn't so miserably painful
46    let regex = Regex::new("derive\\(([^\\)]*Clone[^\\)]*)\\)").unwrap();
47    let as_string = regex.replace_all(&as_string, |caps: &Captures| {
48        if caps[0].contains("PartialEq") {
49            caps[0].to_string()
50        } else {
51            format!("derive({}, PartialEq)", &caps[1])
52        }
53    });
54
55    as_string.to_string()
56}
57
58#[proc_macro]
59pub fn bindgen_sys(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
60    let as_string = bindgen(input);
61
62    let as_string = add_safe_instantiation(&as_string);
63
64    // eprintln!("bindgen SYS IMPL: {}", as_string.deref());
65    proc_macro::TokenStream::from_str(&as_string).unwrap()
66}
67
68#[proc_macro]
69pub fn bindgen_js(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
70    let as_string = bindgen(input);
71
72    // Clone exported function
73    let regex = Regex::new("\\*\\s*__exports\\.typed_func([^?]*)\\?\\.func\\(\\)").unwrap();
74    let as_string = regex.replace_all(&as_string, "__exports.typed_func$1?.func().clone()");
75
76    // Clone "inner" function
77    let regex = Regex::new("new_unchecked\\(self\\.([^)]*)\\)").unwrap();
78    let as_string = regex.replace_all(&as_string, "new_unchecked(self.$1.clone())");
79
80    let regex = Regex::new("add_to_linker\\s*<\\s*T").unwrap();
81    let as_string = regex.replace_all(&as_string, "add_to_linker<T: 'static");
82
83    let regex = Regex::new("add_root_to_linker\\s*<\\s*T").unwrap();
84    let as_string = regex.replace_all(&as_string, "add_root_to_linker<T: 'static");
85
86    // Remove the "ComponentType" trait, it's about memory and type safety, we don't need to care about it as much
87    let regex = Regex::new("#\\[derive[^C]*ComponentType\\s*\\)\\s*\\]").unwrap();
88    let as_string = regex.replace_all(&as_string, "");
89
90    let regex =
91        Regex::new("const\\s*_\\s*:\\s*\\(\\)\\s*=[^}]*ComponentType[^}]*\\}\\s*;").unwrap();
92    let as_string = regex.replace_all(&as_string, "");
93
94    // Replace the "Lift" trait with our Lift trait and SizeDescription
95    let regex = Regex::new("#\\[derive\\([^)]*Lift\\)\\]").unwrap();
96    let as_string = regex.replace_all(&as_string, "#[derive(wasm_bridge::component::SizeDescription)]\n#[derive(wasm_bridge::component::LiftJs)]");
97
98    // Replace the "Lower" trait with out Lower trait
99    let regex = Regex::new("#\\[derive\\([^)]*Lower\\)\\]").unwrap();
100    let as_string = regex.replace_all(&as_string, "#[derive(wasm_bridge::component::LowerJs)]");
101
102    let as_string = add_safe_instantiation(&as_string);
103
104    // eprintln!("bindgen JS IMPL: {}", as_string.deref());
105    proc_macro::TokenStream::from_str(&as_string).unwrap()
106}
107
108fn add_safe_instantiation(as_string: &str) -> impl Deref<Target = str> + '_ {
109    let regex = Regex::new("pub\\s+fn\\s+instantiate\\s*<([^{]*)\\{").unwrap();
110
111    regex.replace_all(as_string, r#"
112    pub async fn instantiate_safe<T>(
113        mut store: impl wasm_bridge::AsContextMut<Data = T>,
114        component: &wasm_bridge::component::Component,
115        linker: &wasm_bridge::component::Linker<T>,
116    ) -> wasm_bridge::Result<(Self, wasm_bridge::component::Instance)> {
117        let instance = linker.instantiate_safe(&mut store, component).await?;
118        Ok((Self::new(store, &instance)?, instance))
119    }
120    
121    #[deprecated(
122        since = "0.4.0",
123        note = "Instantiating a component synchronously can panic on the web, please use `instantiate_safe` instead."
124    )]
125    pub fn instantiate< $1 {
126        #[allow(deprecated)]
127        "#)
128}
129
130#[proc_macro_derive(SizeDescription, attributes(component))]
131pub fn derive_size_description(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
132    let derive_input: syn::DeriveInput = syn::parse(input).unwrap();
133
134    let name = derive_input.ident;
135    let struct_style = style_from_attributes(&derive_input.attrs);
136
137    let tokens = match derive_input.data {
138        syn::Data::Struct(data) => direct_impl::size_description_struct(name, data),
139        syn::Data::Enum(data) => match struct_style.expect("cannot find attribute style") {
140            Style::Record => unreachable!("enum is not a record"),
141            Style::Variant(VariantStyle::Enum) => direct_impl::size_description_enum(name, data),
142            Style::Variant(VariantStyle::Variant) => {
143                direct_impl::size_description_variant(name, data)
144            }
145        },
146        syn::Data::Union(_) => unimplemented!("Union type should not be generated by wit bindgen"),
147    };
148
149    // eprintln!("derive_size_description IMPL: {}", tokens);
150    proc_macro::TokenStream::from(tokens)
151}
152
153#[proc_macro_derive(LiftJs, attributes(component))]
154pub fn derive_lift(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
155    let derive_input: syn::DeriveInput = syn::parse(input).unwrap();
156
157    let name = derive_input.ident;
158    let struct_style = style_from_attributes(&derive_input.attrs);
159
160    let tokens = match derive_input.data {
161        syn::Data::Struct(data) => direct_impl::lift_struct(name, data),
162        syn::Data::Enum(data) => match struct_style.expect("cannot find attribute style") {
163            Style::Record => unreachable!("enum is not a record"),
164            Style::Variant(VariantStyle::Enum) => direct_impl::lift_enum(name, data),
165            Style::Variant(VariantStyle::Variant) => direct_impl::lift_variant(name, data),
166        },
167        syn::Data::Union(_) => unimplemented!("Union type should not be generated by wit bindgen"),
168    };
169
170    // eprintln!("derive_lift IMPL: {}", tokens);
171    proc_macro::TokenStream::from(tokens)
172}
173
174#[proc_macro_derive(LowerJs, attributes(component))]
175pub fn derive_lower(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
176    let derive_input: syn::DeriveInput = syn::parse(input).unwrap();
177
178    let name = derive_input.ident;
179    let struct_style = style_from_attributes(&derive_input.attrs);
180
181    let tokens = match derive_input.data {
182        syn::Data::Struct(data) => direct_impl::lower_struct(name, data),
183        syn::Data::Enum(data) => match struct_style.expect("cannot find attribute style") {
184            Style::Record => unreachable!("enum is not a record"),
185            Style::Variant(VariantStyle::Enum) => direct_impl::lower_enum(name, data),
186            Style::Variant(VariantStyle::Variant) => direct_impl::lower_variant(name, data),
187        },
188        syn::Data::Union(_) => unimplemented!("Union type should not be generated by wit bindgen"),
189    };
190
191    // eprintln!("derive_lower IMPL: {}", tokens);
192    proc_macro::TokenStream::from(tokens)
193}
194
195#[proc_macro_attribute]
196pub fn async_trait(
197    _attr: proc_macro::TokenStream,
198    input: proc_macro::TokenStream,
199) -> proc_macro::TokenStream {
200    let mut item_impl: ItemImpl = syn::parse(input).unwrap();
201    for item in item_impl.items.iter_mut() {
202        if let ImplItem::Fn(method) = item {
203            method.sig.asyncness = None;
204        }
205    }
206    item_impl.into_token_stream().into()
207}
208
209fn replace_namespace_str(stream: proc_macro::TokenStream) -> String {
210    let as_string = stream.to_string();
211
212    // Replace wasmtime:: package path with wasm_bridge::
213    let regex = Regex::new("wasmtime[^:]*::").unwrap();
214    let as_string = regex.replace_all(&as_string, "wasm_bridge::");
215
216    as_string.to_string()
217}
218
219fn replace_namespace(stream: proc_macro::TokenStream) -> proc_macro::TokenStream {
220    let as_string = replace_namespace_str(stream);
221
222    proc_macro::TokenStream::from_str(&as_string).unwrap()
223}
224
225fn style_from_attributes(attributes: &[Attribute]) -> Option<Style> {
226    attributes
227        .iter()
228        .find(|attr| attr.path().is_ident("component"))
229        .map(|attr| {
230            attr.parse_args()
231                .expect("Failed to parse Style from Attribute")
232        })
233}
234
235#[proc_macro]
236pub fn size_description_tuple(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
237    direct_impl::size_description_tuple(tokens)
238}