tauri_ipc_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::{format_ident, quote, ToTokens, TokenStreamExt};
4use syn::{
5    self, braced,
6    parse::Parse,
7    parse_macro_input, parse_quote,
8    punctuated::{Pair, Punctuated},
9    token::{self, Comma},
10    Field, FieldMutability, Fields, FnArg, Generics, Ident, ItemEnum, ItemFn, ItemTrait, LitStr,
11    Pat, Signature, Token, TraitItem, Type, Variant, Visibility,
12};
13
14#[derive(Default)]
15struct InvokeBindingAttrs {
16    cmd_prefix: Option<String>,
17}
18
19impl Parse for InvokeBindingAttrs {
20    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
21        let mut attrs: Self = Default::default();
22        while !input.is_empty() {
23            let kv: KeyValuePair = input.parse()?;
24            if kv.key.as_str() == "cmd_prefix" {
25                attrs.cmd_prefix = Some(kv.value)
26            }
27        }
28        Ok(attrs)
29    }
30}
31
32struct KeyValuePair {
33    key: String,
34    value: String,
35}
36
37impl Parse for KeyValuePair {
38    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
39        let key: Ident = input.parse()?;
40        let _: Token![=] = input.parse()?;
41        let value: LitStr = input.parse()?;
42        Ok(Self {
43            key: key.to_string(),
44            value: value.value(),
45        })
46    }
47}
48
49/// Apply this to a trait, and generate an implementation for it's fns in the
50/// same scope that call `invoke` using the fn name as the command
51///
52/// # Examples
53///
54/// ```ignore
55/// #[allow(async_fn_in_trait)]
56/// #[tauri_bindgen_rs_macros::invoke_bindings]
57/// pub trait Commands {
58///     async hello(name: String) -> Result<String, String>;
59/// }
60///
61/// async fn hello_world() -> Result<String, String> {
62///     hello("world".into())
63/// }
64/// ```
65#[proc_macro_attribute]
66pub fn invoke_bindings(attrs: TokenStream, tokens: TokenStream) -> TokenStream {
67    let attrs = parse_macro_input!(attrs as InvokeBindingAttrs);
68    let trait_item = parse_macro_input!(tokens as ItemTrait);
69    let fn_items = trait_item.items.iter().fold(Vec::new(), |mut m, item| {
70        if let TraitItem::Fn(fn_item) = item {
71            let fields: Punctuated<Field, Token![,]> =
72                Punctuated::from_iter(fn_item.sig.inputs.iter().fold(Vec::new(), |mut m, arg| {
73                    let pt = match arg {
74                        FnArg::Typed(pt) => pt,
75                        FnArg::Receiver(_) => {
76                            panic!("receiver arguments not supported");
77                        }
78                    };
79                    let ident = match pt.pat.as_ref() {
80                        Pat::Ident(pi) => Some(pi.ident.clone()),
81                        _ => panic!("argument not supported"),
82                    };
83                    let colon_token = Some(pt.colon_token);
84                    let ty = pt.ty.as_ref().clone();
85                    m.push(Field {
86                        attrs: Vec::new(),
87                        vis: Visibility::Inherited,
88                        mutability: FieldMutability::None,
89                        ident,
90                        colon_token,
91                        ty,
92                    });
93                    m
94                }));
95            let field_names: Punctuated<Ident, Token![,]> =
96                Punctuated::from_iter(fields.iter().map(|field| field.ident.clone().unwrap()));
97            let fn_name = fn_item.sig.ident.to_string();
98            let fn_name = attrs
99                .cmd_prefix
100                .clone()
101                .map_or(fn_name.clone(), |prefix| prefix + fn_name.as_str());
102            m.push(ItemFn {
103                attrs: Vec::new(),
104                vis: trait_item.vis.clone(),
105                sig: fn_item.sig.clone(),
106                block: parse_quote!({
107                    #[derive(::serde::Serialize)]
108                    #[serde(rename_all = "camelCase")]
109                    struct Args {
110                        #fields
111                    }
112                    let args = Args { #field_names };
113                    let args: JsValue = ::serde_wasm_bindgen::to_value(&args).unwrap();
114                    match invoke(#fn_name, args).await {
115                        Ok(value) => Ok(::serde_wasm_bindgen::from_value(value).unwrap()),
116                        Err(err) => Err(::serde_wasm_bindgen::from_value(err).unwrap()),
117                    }
118                }),
119            });
120        }
121        m
122    });
123    let fn_items = ItemList { list: fn_items };
124    let ret = quote! {
125        #trait_item
126
127        use wasm_bindgen::prelude::*;
128
129        #[wasm_bindgen]
130        extern "C" {
131            #[wasm_bindgen(js_namespace = ["window", "__TAURI__", "core"], catch)]
132            async fn invoke(cmd: &str, args: JsValue) -> Result<JsValue, JsValue>;
133        }
134
135        #fn_items
136    };
137
138    TokenStream::from(ret)
139}
140
141/// # Examples
142///
143/// ```ignore
144/// #[derive(Events, Debug, Clone, ::serde::Serialize, ::serde::Deserialize)]
145/// enum Event {
146///     SomethingHappened { payload: Vec<u8> },
147///     SomeoneSaidHello(String),
148///     NoPayload,
149/// }
150///
151/// fn emit_event(app_handle: tauri::AppHandle, event: Event) -> anyhow::Result<()> {
152///     Ok(app_handle.emit(event.event_name(), event)?)
153/// }
154///
155/// // ...
156///
157/// let listener = EventBinding::SomethingHappened.listen(|event: Event| {
158///     // ...
159/// }).await;
160/// drop(listener); // unlisten
161/// ```
162#[proc_macro_derive(Events)]
163pub fn derive_event(tokens: TokenStream) -> TokenStream {
164    let item_enum = parse_macro_input!(tokens as ItemEnum);
165    let ItemEnum {
166        attrs: _,
167        vis,
168        enum_token: _,
169        ident,
170        generics,
171        brace_token: _,
172        variants,
173    } = item_enum;
174
175    fn derive_impl_display(
176        vis: Visibility,
177        _generics: Generics, // TODO: support generics
178        ident: Ident,
179        variants: Punctuated<Variant, Comma>,
180    ) -> TokenStream2 {
181        let match_arms: Punctuated<TokenStream2, Comma> = variants
182            .iter()
183            .map(|v| -> TokenStream2 {
184                let ident = ident.clone();
185                let v_ident = &v.ident;
186                let v_ident_str = v_ident.to_string();
187                let fields: TokenStream2 = match &v.fields {
188                    Fields::Unit => quote! {},
189                    Fields::Unnamed(fields) => {
190                        let placeholders: Punctuated<TokenStream2, Comma> = fields
191                            .unnamed
192                            .iter()
193                            .map(|_| -> TokenStream2 {
194                                quote! { _ }
195                            })
196                            .collect();
197                        quote! { (#placeholders) }
198                    }
199                    Fields::Named(fields) => {
200                        let placeholders: Punctuated<TokenStream2, Comma> = fields
201                            .named
202                            .iter()
203                            .map(|f| -> TokenStream2 {
204                                let ident = f.ident.as_ref().unwrap();
205                                quote! { #ident: _ }
206                            })
207                            .collect();
208                        quote! { {#placeholders} }
209                    }
210                };
211                quote! {
212                    #ident::#v_ident #fields => #v_ident_str
213                }
214            })
215            .collect();
216        let ret = quote! {
217            impl #ident {
218                #vis fn event_name(&self) -> &'static str {
219                    match self {
220                        #match_arms
221                    }
222                }
223            }
224        };
225        ret
226    }
227
228    fn derive_event_binding(
229        _generics: Generics, // TODO: support generics
230        ident: Ident,
231        variants: Punctuated<Variant, Comma>,
232    ) -> TokenStream2 {
233        let event_binding_ident = Ident::new(&format!("{}Binding", ident), Span::call_site());
234        let variant_names: Punctuated<Ident, Comma> =
235            variants.iter().map(|v| v.ident.clone()).collect();
236        let variant_to_str_match_arms: Punctuated<TokenStream2, Comma> = variants
237            .iter()
238            .map(|v| -> TokenStream2 {
239                let ident = &v.ident;
240                let ident_str = ident.to_string();
241                quote! {
242                    #event_binding_ident::#ident => #ident_str
243                }
244            })
245            .collect();
246        let ret = quote! {
247            pub enum #event_binding_ident {
248                #variant_names
249            }
250
251            impl #event_binding_ident {
252                pub async fn listen<F>(&self, handler: F) -> Result<EventListener, JsValue>
253                where
254                    F: Fn(#ident) + 'static,
255                {
256                    let event_name = self.as_str();
257                    EventListener::new(event_name, move |event| {
258                        let event: TauriEvent<#ident> = ::serde_wasm_bindgen::from_value(event).unwrap();
259                        handler(event.payload);
260                    })
261                    .await
262                }
263
264                fn as_str(&self) -> &str {
265                    match self {
266                        #variant_to_str_match_arms
267                    }
268                }
269            }
270        };
271        ret
272    }
273
274    // TODO: break this out into another crate (it doesn't need to be in a macro)
275    fn events_mod(vis: Visibility) -> TokenStream2 {
276        quote! {
277            use wasm_bindgen::prelude::*;
278
279            #[wasm_bindgen]
280            extern "C" {
281                #[wasm_bindgen(js_namespace = ["window", "__TAURI__", "event"], catch)]
282                async fn listen(
283                    event_name: &str,
284                    handler: &Closure<dyn FnMut(JsValue)>,
285                ) -> Result<JsValue, JsValue>;
286            }
287
288            #vis struct EventListener {
289                event_name: String,
290                _closure: Closure<dyn FnMut(JsValue)>,
291                unlisten: js_sys::Function,
292            }
293
294            impl EventListener {
295                pub async fn new<F>(event_name: &str, handler: F) -> Result<Self, JsValue>
296                where
297                    F: Fn(JsValue) + 'static,
298                {
299                    let closure = Closure::new(handler);
300                    let unlisten = listen(event_name, &closure).await?;
301                    let unlisten = js_sys::Function::from(unlisten);
302
303                    tracing::trace!("EventListener created for {event_name}");
304
305                    Ok(Self {
306                        event_name: event_name.to_string(),
307                        _closure: closure,
308                        unlisten,
309                    })
310                }
311            }
312
313            impl Drop for EventListener {
314                fn drop(&mut self) {
315                    tracing::trace!("EventListener dropped for {}", self.event_name);
316                    let context = JsValue::null();
317                    self.unlisten.call0(&context).unwrap();
318                }
319            }
320
321            #[derive(::serde::Deserialize)]
322            struct TauriEvent<T> {
323                pub payload: T,
324            }
325        }
326    }
327
328    let impl_display = derive_impl_display(
329        vis.clone(),
330        generics.clone(),
331        ident.clone(),
332        variants.clone(),
333    );
334    let event_binding = derive_event_binding(generics, ident, variants);
335    let events_mod = events_mod(vis);
336
337    let ret = quote! {
338        #impl_display
339
340        #event_binding
341
342        #events_mod
343    };
344    TokenStream::from(ret)
345}
346
347struct ImplTrait {
348    trait_ident: Ident,
349    fns: ItemList<ItemFn>,
350}
351
352impl Parse for ImplTrait {
353    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
354        let fns;
355        let trait_ident = input.parse()?;
356        let _: Token![,] = input.parse()?;
357        let _: token::Brace = braced!(fns in input);
358        let fns = fns.parse()?;
359        Ok(ImplTrait { trait_ident, fns })
360    }
361}
362
363struct ItemList<I: ToTokens> {
364    list: Vec<I>,
365}
366
367impl<I: Parse + ToTokens> Parse for ItemList<I> {
368    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
369        let mut list = Vec::new();
370
371        while !input.is_empty() {
372            let item: I = input.parse()?;
373            list.push(item);
374        }
375
376        Ok(ItemList { list })
377    }
378}
379
380impl<I: ToTokens> ToTokens for ItemList<I> {
381    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
382        tokens.append_all(self.list.iter());
383    }
384}
385
386/// Takes the name of a trait and an impl block, and emits a ghost struct that
387/// implements that trait using the provided fn signatures—stripping away any
388/// generics and arguments with `tauri` as the first path segment.
389///
390/// TODO: accept a list of arguments to ignore vs relying on the `tauri::` prefix.
391///
392/// # Examples
393///
394/// ```ignore
395/// trait Commands {
396///     async foo(bar: String) -> Result<(), String>;
397///     async bar(foo: String) -> Result<(), String>;
398/// }
399///
400/// tauri_bindgen_rs_macros::impl_trait!(Commands, {
401///     #[tauri::command]
402///     async foo(state: tauri::State, bar: String) -> Result<(), String> {
403///         Ok(())
404///     }
405///
406///     #[tauri::command]
407///     async bar(state: tauri::State, foo: String) -> Result<(), String> {
408///         Ok(())
409///     }
410/// });
411/// ```
412#[proc_macro]
413pub fn impl_trait(tokens: TokenStream) -> TokenStream {
414    let ImplTrait { trait_ident, fns } = parse_macro_input!(tokens as ImplTrait);
415
416    let mut fn_idents = Vec::new();
417    let mut trait_fns = Vec::new();
418
419    fn map_fn_input(mut item: Pair<FnArg, Comma>) -> Pair<FnArg, Comma> {
420        let value = item.value_mut();
421        if let FnArg::Typed(pt) = value {
422            if let Pat::Ident(pi) = pt.pat.as_mut() {
423                pi.ident = Ident::new(
424                    // add an _ prefix to all fn arguments so we don't trigger unused variable warnings
425                    { "_".to_string() + pi.ident.to_string().as_str() }.as_str(),
426                    pi.ident.span(),
427                );
428            }
429        }
430        item
431    }
432
433    fn filter_map_fn_inputs(inputs: Punctuated<FnArg, Comma>) -> Punctuated<FnArg, Comma> {
434        let tauri_ident = Ident::new("tauri", Span::call_site());
435        Punctuated::from_iter(inputs.into_pairs().fold(Vec::new(), |mut m, item| {
436            if let Some(tp) = match item.value() {
437                FnArg::Typed(pt) => match pt.ty.as_ref() {
438                    Type::Path(path) => Some(path),
439                    _ => None,
440                },
441                _ => None,
442            } {
443                if let Some(s) = tp.path.segments.first() {
444                    if s.ident == tauri_ident {
445                        return m;
446                    }
447                }
448            }
449            m.push(map_fn_input(item));
450            m
451        }))
452    }
453
454    fns.list.iter().for_each(|func| {
455        let sig = &func.sig;
456
457        fn_idents.push(sig.ident.clone());
458
459        trait_fns.push(ItemFn {
460            attrs: Vec::new(),
461            vis: func.vis.clone(),
462            sig: Signature {
463                constness: None,
464                asyncness: sig.asyncness,
465                unsafety: None,
466                abi: None,
467                fn_token: sig.fn_token,
468                generics: Default::default(),
469                ident: sig.ident.clone(),
470                paren_token: sig.paren_token,
471                inputs: filter_map_fn_inputs(sig.inputs.clone()),
472                variadic: None,
473                output: sig.output.clone(),
474            },
475            block: parse_quote!({ todo!() }),
476        });
477    });
478
479    let struct_name = format_ident!("__Impl{}", trait_ident);
480    let trait_fns = ItemList { list: trait_fns };
481    let generate_handler_macro_name = format_ident!(
482        "generate_{}_handler",
483        camel_to_snake_case(trait_ident.clone())
484    );
485    let generate_handler_macro_doc = format!("Expands to call [`::tauri::generate_handler`] with a list of all the fns defined in [`{}`]", trait_ident);
486
487    let ret = quote! {
488        struct #struct_name {}
489
490        impl #trait_ident for #struct_name {
491            #trait_fns
492        }
493
494        #fns
495
496        #[allow(unused)]
497        #[doc = #generate_handler_macro_doc]
498        macro_rules! #generate_handler_macro_name {
499            () => {
500                ::tauri::generate_handler![#(#fn_idents),*]
501            };
502        }
503    };
504
505    TokenStream::from(ret)
506}
507
508fn camel_to_snake_case(ident: Ident) -> Ident {
509    let snake_case: String = ident
510        .to_string()
511        .chars()
512        .enumerate()
513        .flat_map(|(i, c)| {
514            if c.is_uppercase() && i > 0 {
515                let mut ret = Vec::with_capacity(c.len_utf8() + 1);
516                ret.push('_');
517                ret.extend(c.to_lowercase());
518                ret
519            } else {
520                Vec::from_iter(c.to_lowercase())
521            }
522        })
523        .collect();
524    Ident::new(snake_case.as_str(), Span::call_site())
525}