1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::{format_ident, quote, ToTokens, TokenStreamExt};
4use syn::{
5 self, braced,
6 parse::Parse,
7 parse_macro_input, parse_quote,
8 punctuated::{Pair, Punctuated},
9 token::{self, Comma},
10 Field, FieldMutability, Fields, FnArg, Generics, Ident, ItemEnum, ItemFn, ItemTrait, LitStr,
11 Pat, Signature, Token, TraitItem, Type, Variant, Visibility,
12};
13
14#[derive(Default)]
15struct InvokeBindingAttrs {
16 cmd_prefix: Option<String>,
17}
18
19impl Parse for InvokeBindingAttrs {
20 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
21 let mut attrs: Self = Default::default();
22 while !input.is_empty() {
23 let kv: KeyValuePair = input.parse()?;
24 if kv.key.as_str() == "cmd_prefix" {
25 attrs.cmd_prefix = Some(kv.value)
26 }
27 }
28 Ok(attrs)
29 }
30}
31
32struct KeyValuePair {
33 key: String,
34 value: String,
35}
36
37impl Parse for KeyValuePair {
38 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
39 let key: Ident = input.parse()?;
40 let _: Token![=] = input.parse()?;
41 let value: LitStr = input.parse()?;
42 Ok(Self {
43 key: key.to_string(),
44 value: value.value(),
45 })
46 }
47}
48
49#[proc_macro_attribute]
66pub fn invoke_bindings(attrs: TokenStream, tokens: TokenStream) -> TokenStream {
67 let attrs = parse_macro_input!(attrs as InvokeBindingAttrs);
68 let trait_item = parse_macro_input!(tokens as ItemTrait);
69 let fn_items = trait_item.items.iter().fold(Vec::new(), |mut m, item| {
70 if let TraitItem::Fn(fn_item) = item {
71 let fields: Punctuated<Field, Token![,]> =
72 Punctuated::from_iter(fn_item.sig.inputs.iter().fold(Vec::new(), |mut m, arg| {
73 let pt = match arg {
74 FnArg::Typed(pt) => pt,
75 FnArg::Receiver(_) => {
76 panic!("receiver arguments not supported");
77 }
78 };
79 let ident = match pt.pat.as_ref() {
80 Pat::Ident(pi) => Some(pi.ident.clone()),
81 _ => panic!("argument not supported"),
82 };
83 let colon_token = Some(pt.colon_token);
84 let ty = pt.ty.as_ref().clone();
85 m.push(Field {
86 attrs: Vec::new(),
87 vis: Visibility::Inherited,
88 mutability: FieldMutability::None,
89 ident,
90 colon_token,
91 ty,
92 });
93 m
94 }));
95 let field_names: Punctuated<Ident, Token![,]> =
96 Punctuated::from_iter(fields.iter().map(|field| field.ident.clone().unwrap()));
97 let fn_name = fn_item.sig.ident.to_string();
98 let fn_name = attrs
99 .cmd_prefix
100 .clone()
101 .map_or(fn_name.clone(), |prefix| prefix + fn_name.as_str());
102 m.push(ItemFn {
103 attrs: Vec::new(),
104 vis: trait_item.vis.clone(),
105 sig: fn_item.sig.clone(),
106 block: parse_quote!({
107 #[derive(::serde::Serialize)]
108 #[serde(rename_all = "camelCase")]
109 struct Args {
110 #fields
111 }
112 let args = Args { #field_names };
113 let args: JsValue = ::serde_wasm_bindgen::to_value(&args).unwrap();
114 match invoke(#fn_name, args).await {
115 Ok(value) => Ok(::serde_wasm_bindgen::from_value(value).unwrap()),
116 Err(err) => Err(::serde_wasm_bindgen::from_value(err).unwrap()),
117 }
118 }),
119 });
120 }
121 m
122 });
123 let fn_items = ItemList { list: fn_items };
124 let ret = quote! {
125 #trait_item
126
127 use wasm_bindgen::prelude::*;
128
129 #[wasm_bindgen]
130 extern "C" {
131 #[wasm_bindgen(js_namespace = ["window", "__TAURI__", "core"], catch)]
132 async fn invoke(cmd: &str, args: JsValue) -> Result<JsValue, JsValue>;
133 }
134
135 #fn_items
136 };
137
138 TokenStream::from(ret)
139}
140
141#[proc_macro_derive(Events)]
163pub fn derive_event(tokens: TokenStream) -> TokenStream {
164 let item_enum = parse_macro_input!(tokens as ItemEnum);
165 let ItemEnum {
166 attrs: _,
167 vis,
168 enum_token: _,
169 ident,
170 generics,
171 brace_token: _,
172 variants,
173 } = item_enum;
174
175 fn derive_impl_display(
176 vis: Visibility,
177 _generics: Generics, ident: Ident,
179 variants: Punctuated<Variant, Comma>,
180 ) -> TokenStream2 {
181 let match_arms: Punctuated<TokenStream2, Comma> = variants
182 .iter()
183 .map(|v| -> TokenStream2 {
184 let ident = ident.clone();
185 let v_ident = &v.ident;
186 let v_ident_str = v_ident.to_string();
187 let fields: TokenStream2 = match &v.fields {
188 Fields::Unit => quote! {},
189 Fields::Unnamed(fields) => {
190 let placeholders: Punctuated<TokenStream2, Comma> = fields
191 .unnamed
192 .iter()
193 .map(|_| -> TokenStream2 {
194 quote! { _ }
195 })
196 .collect();
197 quote! { (#placeholders) }
198 }
199 Fields::Named(fields) => {
200 let placeholders: Punctuated<TokenStream2, Comma> = fields
201 .named
202 .iter()
203 .map(|f| -> TokenStream2 {
204 let ident = f.ident.as_ref().unwrap();
205 quote! { #ident: _ }
206 })
207 .collect();
208 quote! { {#placeholders} }
209 }
210 };
211 quote! {
212 #ident::#v_ident #fields => #v_ident_str
213 }
214 })
215 .collect();
216 let ret = quote! {
217 impl #ident {
218 #vis fn event_name(&self) -> &'static str {
219 match self {
220 #match_arms
221 }
222 }
223 }
224 };
225 ret
226 }
227
228 fn derive_event_binding(
229 _generics: Generics, ident: Ident,
231 variants: Punctuated<Variant, Comma>,
232 ) -> TokenStream2 {
233 let event_binding_ident = Ident::new(&format!("{}Binding", ident), Span::call_site());
234 let variant_names: Punctuated<Ident, Comma> =
235 variants.iter().map(|v| v.ident.clone()).collect();
236 let variant_to_str_match_arms: Punctuated<TokenStream2, Comma> = variants
237 .iter()
238 .map(|v| -> TokenStream2 {
239 let ident = &v.ident;
240 let ident_str = ident.to_string();
241 quote! {
242 #event_binding_ident::#ident => #ident_str
243 }
244 })
245 .collect();
246 let ret = quote! {
247 pub enum #event_binding_ident {
248 #variant_names
249 }
250
251 impl #event_binding_ident {
252 pub async fn listen<F>(&self, handler: F) -> Result<EventListener, JsValue>
253 where
254 F: Fn(#ident) + 'static,
255 {
256 let event_name = self.as_str();
257 EventListener::new(event_name, move |event| {
258 let event: TauriEvent<#ident> = ::serde_wasm_bindgen::from_value(event).unwrap();
259 handler(event.payload);
260 })
261 .await
262 }
263
264 fn as_str(&self) -> &str {
265 match self {
266 #variant_to_str_match_arms
267 }
268 }
269 }
270 };
271 ret
272 }
273
274 fn events_mod(vis: Visibility) -> TokenStream2 {
276 quote! {
277 use wasm_bindgen::prelude::*;
278
279 #[wasm_bindgen]
280 extern "C" {
281 #[wasm_bindgen(js_namespace = ["window", "__TAURI__", "event"], catch)]
282 async fn listen(
283 event_name: &str,
284 handler: &Closure<dyn FnMut(JsValue)>,
285 ) -> Result<JsValue, JsValue>;
286 }
287
288 #vis struct EventListener {
289 event_name: String,
290 _closure: Closure<dyn FnMut(JsValue)>,
291 unlisten: js_sys::Function,
292 }
293
294 impl EventListener {
295 pub async fn new<F>(event_name: &str, handler: F) -> Result<Self, JsValue>
296 where
297 F: Fn(JsValue) + 'static,
298 {
299 let closure = Closure::new(handler);
300 let unlisten = listen(event_name, &closure).await?;
301 let unlisten = js_sys::Function::from(unlisten);
302
303 tracing::trace!("EventListener created for {event_name}");
304
305 Ok(Self {
306 event_name: event_name.to_string(),
307 _closure: closure,
308 unlisten,
309 })
310 }
311 }
312
313 impl Drop for EventListener {
314 fn drop(&mut self) {
315 tracing::trace!("EventListener dropped for {}", self.event_name);
316 let context = JsValue::null();
317 self.unlisten.call0(&context).unwrap();
318 }
319 }
320
321 #[derive(::serde::Deserialize)]
322 struct TauriEvent<T> {
323 pub payload: T,
324 }
325 }
326 }
327
328 let impl_display = derive_impl_display(
329 vis.clone(),
330 generics.clone(),
331 ident.clone(),
332 variants.clone(),
333 );
334 let event_binding = derive_event_binding(generics, ident, variants);
335 let events_mod = events_mod(vis);
336
337 let ret = quote! {
338 #impl_display
339
340 #event_binding
341
342 #events_mod
343 };
344 TokenStream::from(ret)
345}
346
347struct ImplTrait {
348 trait_ident: Ident,
349 fns: ItemList<ItemFn>,
350}
351
352impl Parse for ImplTrait {
353 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
354 let fns;
355 let trait_ident = input.parse()?;
356 let _: Token![,] = input.parse()?;
357 let _: token::Brace = braced!(fns in input);
358 let fns = fns.parse()?;
359 Ok(ImplTrait { trait_ident, fns })
360 }
361}
362
363struct ItemList<I: ToTokens> {
364 list: Vec<I>,
365}
366
367impl<I: Parse + ToTokens> Parse for ItemList<I> {
368 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
369 let mut list = Vec::new();
370
371 while !input.is_empty() {
372 let item: I = input.parse()?;
373 list.push(item);
374 }
375
376 Ok(ItemList { list })
377 }
378}
379
380impl<I: ToTokens> ToTokens for ItemList<I> {
381 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
382 tokens.append_all(self.list.iter());
383 }
384}
385
386#[proc_macro]
413pub fn impl_trait(tokens: TokenStream) -> TokenStream {
414 let ImplTrait { trait_ident, fns } = parse_macro_input!(tokens as ImplTrait);
415
416 let mut fn_idents = Vec::new();
417 let mut trait_fns = Vec::new();
418
419 fn map_fn_input(mut item: Pair<FnArg, Comma>) -> Pair<FnArg, Comma> {
420 let value = item.value_mut();
421 if let FnArg::Typed(pt) = value {
422 if let Pat::Ident(pi) = pt.pat.as_mut() {
423 pi.ident = Ident::new(
424 { "_".to_string() + pi.ident.to_string().as_str() }.as_str(),
426 pi.ident.span(),
427 );
428 }
429 }
430 item
431 }
432
433 fn filter_map_fn_inputs(inputs: Punctuated<FnArg, Comma>) -> Punctuated<FnArg, Comma> {
434 let tauri_ident = Ident::new("tauri", Span::call_site());
435 Punctuated::from_iter(inputs.into_pairs().fold(Vec::new(), |mut m, item| {
436 if let Some(tp) = match item.value() {
437 FnArg::Typed(pt) => match pt.ty.as_ref() {
438 Type::Path(path) => Some(path),
439 _ => None,
440 },
441 _ => None,
442 } {
443 if let Some(s) = tp.path.segments.first() {
444 if s.ident == tauri_ident {
445 return m;
446 }
447 }
448 }
449 m.push(map_fn_input(item));
450 m
451 }))
452 }
453
454 fns.list.iter().for_each(|func| {
455 let sig = &func.sig;
456
457 fn_idents.push(sig.ident.clone());
458
459 trait_fns.push(ItemFn {
460 attrs: Vec::new(),
461 vis: func.vis.clone(),
462 sig: Signature {
463 constness: None,
464 asyncness: sig.asyncness,
465 unsafety: None,
466 abi: None,
467 fn_token: sig.fn_token,
468 generics: Default::default(),
469 ident: sig.ident.clone(),
470 paren_token: sig.paren_token,
471 inputs: filter_map_fn_inputs(sig.inputs.clone()),
472 variadic: None,
473 output: sig.output.clone(),
474 },
475 block: parse_quote!({ todo!() }),
476 });
477 });
478
479 let struct_name = format_ident!("__Impl{}", trait_ident);
480 let trait_fns = ItemList { list: trait_fns };
481 let generate_handler_macro_name = format_ident!(
482 "generate_{}_handler",
483 camel_to_snake_case(trait_ident.clone())
484 );
485 let generate_handler_macro_doc = format!("Expands to call [`::tauri::generate_handler`] with a list of all the fns defined in [`{}`]", trait_ident);
486
487 let ret = quote! {
488 struct #struct_name {}
489
490 impl #trait_ident for #struct_name {
491 #trait_fns
492 }
493
494 #fns
495
496 #[allow(unused)]
497 #[doc = #generate_handler_macro_doc]
498 macro_rules! #generate_handler_macro_name {
499 () => {
500 ::tauri::generate_handler![#(#fn_idents),*]
501 };
502 }
503 };
504
505 TokenStream::from(ret)
506}
507
508fn camel_to_snake_case(ident: Ident) -> Ident {
509 let snake_case: String = ident
510 .to_string()
511 .chars()
512 .enumerate()
513 .flat_map(|(i, c)| {
514 if c.is_uppercase() && i > 0 {
515 let mut ret = Vec::with_capacity(c.len_utf8() + 1);
516 ret.push('_');
517 ret.extend(c.to_lowercase());
518 ret
519 } else {
520 Vec::from_iter(c.to_lowercase())
521 }
522 })
523 .collect();
524 Ident::new(snake_case.as_str(), Span::call_site())
525}