wasm_minimal_protocol/
lib.rs

1//! Minimal protocol for sending/receiving messages from and to a wasm host.
2//!
3//! If you define a function accepting `n` arguments of type `&[u8]`, it will
4//! internally be exported as a function accepting `n` integers.
5//!
6//! # Example
7//!
8//! ```
9//! use wasm_minimal_protocol::wasm_func;
10//!
11//! #[cfg(target_arch = "wasm32")]
12//! wasm_minimal_protocol::initiate_protocol!();
13//!
14//! #[cfg_attr(target_arch = "wasm32", wasm_func)]
15//! fn concatenate(arg1: &[u8], arg2: &[u8]) -> Vec<u8> {
16//!     [arg1, arg2].concat()
17//! }
18//! ```
19//!
20//! # Protocol
21//!
22//! The specification of the protocol can be found in the typst documentation:
23//! <https://typst.app/docs/reference/foundations/plugin/#protocol>
24
25use proc_macro::TokenStream;
26use quote::{format_ident, quote, ToTokens};
27use venial::*;
28
29/// Macro that sets up the correct imports and traits to be used by [`macro@wasm_func`].
30///
31/// This macro should be called only once, preferably at the root of the crate. It does
32/// not take any arguments.
33#[proc_macro]
34pub fn initiate_protocol(stream: TokenStream) -> TokenStream {
35    if !stream.is_empty() {
36        return quote!(
37            compile_error!("This macro does not take any arguments");
38        )
39        .into();
40    }
41    quote!(
42        // #[cfg(not(target_arch = "wasm32"))]
43        // compile_error!("Error: this protocol may only be used when compiling to wasm architectures");
44
45        #[link(wasm_import_module = "typst_env")]
46        extern "C" {
47            #[link_name = "wasm_minimal_protocol_send_result_to_host"]
48            fn __send_result_to_host(ptr: *const u8, len: usize);
49            #[link_name = "wasm_minimal_protocol_write_args_to_buffer"]
50            fn __write_args_to_buffer(ptr: *mut u8);
51        }
52
53        trait __BytesOrResultBytes {
54            type Err;
55            fn convert(self) -> ::std::result::Result<Vec<u8>, Self::Err>;
56        }
57        impl __BytesOrResultBytes for Vec<u8> {
58            type Err = i32;
59            fn convert(self) -> ::std::result::Result<Vec<u8>, <Self as __BytesOrResultBytes>::Err> {
60                Ok(self)
61            }
62        }
63        impl<E> __BytesOrResultBytes for ::std::result::Result<Vec<u8>, E> {
64            type Err = E;
65            fn convert(self) -> ::std::result::Result<Vec<u8>, <Self as __BytesOrResultBytes>::Err> {
66                self
67            }
68        }
69    ).into()
70}
71
72/// Wrap the function to be used with the [protocol](https://typst.app/docs/reference/foundations/plugin/#protocol).
73///
74/// # Arguments
75///
76/// All the arguments of the function should be `&[u8]`, no lifetime needed.
77///
78/// # Return type
79///
80/// The return type of the function should be `Vec<u8>` or `Result<Vec<u8>, E>` where
81/// `E: ToString`.
82///
83/// If the function return `Vec<u8>`, it will be implicitely wrapped in `Ok`.
84///
85/// # Example
86///
87/// ```
88/// use wasm_minimal_protocol::wasm_func;
89///
90/// #[cfg(target_arch = "wasm32")]
91/// wasm_minimal_protocol::initiate_protocol!();
92///
93/// #[cfg_attr(target_arch = "wasm32", wasm_func)]
94/// fn function_one() -> Vec<u8> {
95///     Vec::new()
96/// }
97///
98/// #[cfg_attr(target_arch = "wasm32", wasm_func)]
99/// fn function_two(arg1: &[u8], arg2: &[u8]) -> Result<Vec<u8>, i32> {
100///     Ok(b"Normal message".to_vec())
101/// }
102///
103/// #[cfg_attr(target_arch = "wasm32", wasm_func)]
104/// fn function_three(arg1: &[u8]) -> Result<Vec<u8>, String> {
105///     Err(String::from("Error message"))
106/// }
107/// ```
108#[proc_macro_attribute]
109pub fn wasm_func(_: TokenStream, item: TokenStream) -> TokenStream {
110    let mut item = proc_macro2::TokenStream::from(item);
111    let decl = parse_declaration(item.clone()).expect("invalid declaration");
112    let func = match decl.as_function() {
113        Some(func) => func.clone(),
114        None => {
115            let error = venial::Error::new_at_tokens(
116                &item,
117                "#[wasm_func] can only be applied to a function",
118            );
119            item.extend(error.to_compile_error());
120            return item.into();
121        }
122    };
123    let Function {
124        name,
125        params,
126        vis_marker,
127        ..
128    } = func.clone();
129
130    let mut error = None;
131
132    let p = params
133        .items()
134        .filter_map(|x| match x {
135            FnParam::Receiver(_p) => {
136                let x = x.to_token_stream();
137                error = Some(venial::Error::new_at_tokens(
138                    &x,
139                    format!("the {x} argument is not allowed by the protocol"),
140                ));
141                None
142            }
143            FnParam::Typed(p) => {
144                if p.ty.tokens.len() != 2
145                    || p.ty.tokens[0].to_string() != "&"
146                    || p.ty.tokens[1].to_string() != "[u8]"
147                {
148                    let p_to_string = p.ty.to_token_stream();
149                    error = Some(venial::Error::new_at_tokens(
150                        &p_to_string,
151                        format!("only parameters of type &[u8] are allowed, not {p_to_string}"),
152                    ));
153                    None
154                } else {
155                    Some(p.name.clone())
156                }
157            }
158        })
159        .collect::<Vec<_>>();
160    let p_idx = p
161        .iter()
162        .map(|name| format_ident!("__{}_idx", name))
163        .collect::<Vec<_>>();
164
165    let mut get_unsplit_params = quote!(
166        let __total_len = #(#p_idx + )* 0;
167        let mut __unsplit_params = vec![0u8; __total_len];
168        unsafe { __write_args_to_buffer(__unsplit_params.as_mut_ptr()); }
169    );
170    let mut set_args = quote!(
171        let start: usize = 0;
172    );
173    match p.len() {
174        0 => get_unsplit_params = quote!(),
175        1 => {
176            let arg = p.first().unwrap();
177            set_args = quote!(
178                let #arg: &[u8] = &__unsplit_params;
179            )
180        }
181        _ => {
182            // ignore last arg, rest used to split unsplit_param
183            let args = &p;
184            let mut args_idx = p
185                .iter()
186                .map(|name| format_ident!("__{}_idx", &name))
187                .collect::<Vec<_>>();
188            args_idx.pop();
189            let mut sets = vec![];
190            let mut start = quote!(0usize);
191            let mut end = quote!(0usize);
192            for (idx, arg_idx) in args_idx.iter().enumerate() {
193                end = quote!(#end + #arg_idx);
194                let arg_name = &args[idx];
195                sets.push(quote!(
196                    let #arg_name: &[u8] = &__unsplit_params[#start..#end];
197                ));
198                start = quote!(#start + #arg_idx)
199            }
200            let last = args.last().unwrap();
201            sets.push(quote!(
202                let #last = &__unsplit_params[#end..];
203            ));
204            set_args = quote!(
205                #(
206                    #sets
207                )*
208            );
209        }
210    }
211
212    let inner_name = format_ident!("__wasm_minimal_protocol_internal_function_{}", name);
213    let export_name = proc_macro2::Literal::string(&name.to_string());
214
215    let mut result = quote!(#func);
216    if let Some(error) = error {
217        result.extend(error.to_compile_error());
218    } else {
219        result.extend(quote!(
220            #[export_name = #export_name]
221            #vis_marker extern "C" fn #inner_name(#(#p_idx: usize),*) -> i32 {
222                #get_unsplit_params
223                #set_args
224
225                let result = __BytesOrResultBytes::convert(#name(#(#p),*));
226                let (message, code) = match result {
227                    Ok(s) => (s.into_boxed_slice(), 0),
228                    Err(err) => (err.to_string().into_bytes().into_boxed_slice(), 1),
229                };
230                unsafe { __send_result_to_host(message.as_ptr(), message.len()); }
231                code
232            }
233        ))
234    }
235    result.into()
236}