Skip to main content

wasm_split_macro/
lib.rs

1use proc_macro::TokenStream;
2
3use digest::Digest;
4use quote::{format_ident, quote};
5use syn::{FnArg, Ident, ItemFn, ReturnType, Signature, parse_macro_input, parse_quote};
6
7#[proc_macro_attribute]
8pub fn wasm_split(args: TokenStream, input: TokenStream) -> TokenStream {
9    let module_ident = parse_macro_input!(args as Ident);
10    let item_fn = parse_macro_input!(input as ItemFn);
11
12    if item_fn.sig.asyncness.is_none() {
13        panic!(
14            "wasm_split functions must be async. Use a LazyLoader with synchronous functions instead."
15        );
16    }
17
18    let LoaderNames {
19        split_loader_ident,
20        impl_import_ident,
21        impl_export_ident,
22        load_module_ident,
23        ..
24    } = LoaderNames::new(item_fn.sig.ident.clone(), module_ident.to_string());
25
26    let mut desugard_async_sig = item_fn.sig.clone();
27    desugard_async_sig.asyncness = None;
28    desugard_async_sig.output = match &desugard_async_sig.output {
29        ReturnType::Default => {
30            parse_quote! { -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ()>>> }
31        }
32        ReturnType::Type(_, ty) => {
33            parse_quote! { -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = #ty>>> }
34        }
35    };
36
37    let import_sig = Signature {
38        ident: impl_import_ident.clone(),
39        ..desugard_async_sig.clone()
40    };
41
42    let export_sig = Signature {
43        ident: impl_export_ident.clone(),
44        ..desugard_async_sig.clone()
45    };
46
47    let default_item = item_fn.clone();
48
49    let mut wrapper_sig = item_fn.sig;
50    wrapper_sig.asyncness = Some(Default::default());
51
52    let mut args = Vec::new();
53    for (i, param) in wrapper_sig.inputs.iter_mut().enumerate() {
54        match param {
55            syn::FnArg::Receiver(_) => args.push(format_ident!("self")),
56            syn::FnArg::Typed(pat_type) => {
57                let param_ident = format_ident!("__wasm_split_arg_{i}");
58                args.push(param_ident.clone());
59                *pat_type.pat = syn::Pat::Ident(syn::PatIdent {
60                    attrs: vec![],
61                    by_ref: None,
62                    mutability: None,
63                    ident: param_ident,
64                    subpat: None,
65                });
66            }
67        }
68    }
69
70    let attrs = &item_fn.attrs;
71    let stmts = &item_fn.block.stmts;
72
73    quote! {
74        #[cfg(target_arch = "wasm32")]
75        #wrapper_sig {
76            #(#attrs)*
77            #[allow(improper_ctypes_definitions)]
78            #[unsafe(no_mangle)]
79            pub unsafe extern "C" #export_sig {
80                Box::pin(async move { #(#stmts)* })
81            }
82
83            #[link(wasm_import_module = "./__wasm_split.js")]
84            unsafe extern "C" {
85                #[unsafe(no_mangle)]
86                fn #load_module_ident (
87                    callback: unsafe extern "C" fn(*const ::std::ffi::c_void, bool),
88                    data: *const ::std::ffi::c_void
89                );
90
91                #[allow(improper_ctypes)]
92                #[unsafe(no_mangle)]
93                #import_sig;
94            }
95
96            thread_local! {
97                static #split_loader_ident: wasm_split::LazySplitLoader = unsafe {
98                    wasm_split::LazySplitLoader::new(#load_module_ident)
99                };
100            }
101
102            // Initiate the download by calling the load_module_ident function which will kick-off the loader
103            if !wasm_split::LazySplitLoader::ensure_loaded(&#split_loader_ident).await {
104                panic!("Failed to load wasm-split module");
105            }
106
107            unsafe { #impl_import_ident( #(#args),* ) }.await
108        }
109
110        #[cfg(not(target_arch = "wasm32"))]
111        #default_item
112    }
113    .into()
114}
115
116/// Create a lazy loader for a given function. Meant to be used in statics. Designed for libraries to
117/// integrate with.
118///
119/// ```rust, ignore
120/// fn SomeFunction(args: Args) -> Ret {}
121///
122/// static LOADER: wasm_split::LazyLoader<Args, Ret> = lazy_loader!(SomeFunction);
123///
124/// LOADER.load().await.call(args)
125/// ```
126#[proc_macro]
127pub fn lazy_loader(input: TokenStream) -> TokenStream {
128    // We can only accept idents/paths that will be the source function
129    let sig = parse_macro_input!(input as Signature);
130    let params = sig.inputs.clone();
131    let outputs = sig.output.clone();
132    let Some(FnArg::Typed(arg)) = params.first().cloned() else {
133        panic!(
134            "Lazy Loader must define a single input argument to satisfy the LazyLoader signature"
135        )
136    };
137    let arg_ty = arg.ty.clone();
138    let LoaderNames {
139        name,
140        split_loader_ident,
141        impl_import_ident,
142        impl_export_ident,
143        load_module_ident,
144        ..
145    } = LoaderNames::new(
146        sig.ident.clone(),
147        sig.abi
148            .as_ref()
149            .and_then(|abi| abi.name.as_ref().map(|f| f.value()))
150            .expect("abi to be module name")
151            .to_string(),
152    );
153
154    quote! {
155        {
156            #[cfg(target_arch = "wasm32")]
157            {
158                #[link(wasm_import_module = "./__wasm_split.js")]
159                unsafe extern "C" {
160                    // The function we'll use to initiate the download of the module
161                    #[unsafe(no_mangle)]
162                    fn #load_module_ident(
163                        callback: unsafe extern "C" fn(*const ::std::ffi::c_void, bool),
164                        data: *const ::std::ffi::c_void,
165                    );
166
167                    #[allow(improper_ctypes)]
168                    #[unsafe(no_mangle)]
169                    fn #impl_import_ident(arg: #arg_ty) #outputs;
170                }
171
172
173                #[allow(improper_ctypes_definitions)]
174                #[unsafe(no_mangle)]
175                pub unsafe extern "C" fn #impl_export_ident(arg: #arg_ty) #outputs {
176                    #name(arg)
177                }
178
179                thread_local! {
180                    static #split_loader_ident: wasm_split::LazySplitLoader = unsafe {
181                        wasm_split::LazySplitLoader::new(#load_module_ident)
182                    };
183                };
184
185                unsafe {
186                    wasm_split::LazyLoader::new(#impl_import_ident, &#split_loader_ident)
187                }
188            }
189
190            #[cfg(not(target_arch = "wasm32"))]
191            {
192                wasm_split::LazyLoader::preloaded(#name)
193            }
194        }
195    }
196    .into()
197}
198
199struct LoaderNames {
200    name: Ident,
201    split_loader_ident: Ident,
202    impl_import_ident: Ident,
203    impl_export_ident: Ident,
204    load_module_ident: Ident,
205}
206
207impl LoaderNames {
208    fn new(name: Ident, module: String) -> Self {
209        let unique_identifier = base16::encode_lower(
210            &sha2::Sha256::digest(format!("{name} {span:?}", name = name, span = name.span()))
211                [..16],
212        );
213
214        Self {
215            split_loader_ident: format_ident!("__wasm_split_loader_{module}"),
216            impl_export_ident: format_ident!(
217                "__wasm_split_00___{module}___00_export_{unique_identifier}_{name}"
218            ),
219            impl_import_ident: format_ident!(
220                "__wasm_split_00___{module}___00_import_{unique_identifier}_{name}"
221            ),
222            load_module_ident: format_ident!(
223                "__wasm_split_load_{module}_{unique_identifier}_{name}"
224            ),
225            name,
226        }
227    }
228}