yew_callbacks/
lib.rs

1//! Yet another crate nobody asked for.
2//!
3//! This crate provides a derive macro `Callbacks` that can be used on Yew enum messages to help
4//! managing callbacks.
5//!
6//! # But why
7//!
8//! Callbacks in Yew's components are easy to create but hard to manage. To avoid duplication you
9//! should create them preemptively in the `create()` method of your component, store them in the
10//! state of your component, then pass clones to the children. Unfortunately this creates a lot of
11//! bloat.
12//!
13//! To address this, `yew-callbacks` provides a macro that will automatically create some kind of
14//! cache for your callbacks. You create this cache once in the `create()` method of your component
15//! and then you can use the methods to get your callbacks easily.
16//!
17//! ## Example
18//!
19//! ```
20//! use yew::prelude::*;
21//! use yew_callbacks::Callbacks;
22//!
23//! #[derive(Debug, Callbacks)]
24//! enum Msg {
25//!     OnClick(MouseEvent),
26//! }
27//!
28//! #[derive(Debug)]
29//! struct App {
30//!     cb: MsgCallbacks<Self>,
31//! }
32//!
33//! impl Component for App {
34//!     type Message = Msg;
35//!     type Properties = ();
36//!
37//!     fn create(ctx: &Context<Self>) -> Self {
38//!         Self {
39//!             cb: ctx.link().into(),
40//!         }
41//!     }
42//!
43//!     fn view(&self, ctx: &Context<Self>) -> Html {
44//!         html! {
45//!             <button onclick={self.cb.on_click()}>
46//!                 { "Hello World!" }
47//!             </button>
48//!         }
49//!     }
50//! }
51//! ```
52//!
53//! # Why care
54//!
55//! Not perf.
56//!
57//! Your children components will be updated if their properties changed. If you do
58//! `onclick={ctx.link().callback(Msg::OnClick)` then the child component will think there is an
59//! update every time the parent component updates. This is because doing
60//! `ctx.link().callback(Msg::OnClick)` creates a new callback every time.
61//!
62//! # Handling multiple child components
63//!
64//! This crate also allows currying the arguments of your callback.
65//!
66//! ## Example
67//!
68//! ```
69//! use yew::prelude::*;
70//! use yew_callbacks::Callbacks;
71//!
72//! #[derive(Debug, Callbacks)]
73//! enum Msg {
74//!     OnClick(#[curry] usize, MouseEvent),
75//! }
76//!
77//! #[derive(Debug)]
78//! struct App {
79//!     games: Vec<AttrValue>,
80//!     cb: MsgCallbacks<Self>,
81//! }
82//!
83//! impl Component for App {
84//!     type Message = Msg;
85//!     type Properties = ();
86//!
87//!     fn create(ctx: &Context<Self>) -> Self {
88//!         Self {
89//!             games: vec![
90//!                 "Freedom Planet 2".into(),
91//!                 "Asterigos: Curse of the Stars".into(),
92//!                 "Fran Bow".into(),
93//!                 "Cats in Time".into(),
94//!                 "Ittle Dew 2+".into(),
95//!                 "Inscryption".into(),
96//!             ],
97//!             cb: ctx.link().into(),
98//!         }
99//!     }
100//!
101//!     fn view(&self, _ctx: &Context<Self>) -> Html {
102//!         self
103//!             .games
104//!             .iter()
105//!             .enumerate()
106//!             .map(|(i, game)| html! {
107//!                 <button onclick={self.cb.on_click(i)}>
108//!                     { format!("You should try {game}") }
109//!                 </button>
110//!             })
111//!             .collect()
112//!     }
113//! }
114//! ```
115
116use heck::ToSnakeCase;
117use proc_macro2::{Ident, Span, TokenStream};
118use proc_macro_error::abort_call_site;
119use quote::quote;
120
121#[proc_macro_derive(Callbacks, attributes(curry))]
122pub fn main(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
123    let input = syn::parse_macro_input!(input as syn::DeriveInput);
124
125    derive_callbacks(&input).into()
126}
127
128fn derive_callbacks(input: &syn::DeriveInput) -> TokenStream {
129    let enum_name = &input.ident;
130    let vis = &input.vis;
131    let e = match &input.data {
132        syn::Data::Enum(e) => e,
133        _ => abort_call_site!("`#[derive(Callbacks)]` only supports enums"),
134    };
135
136    let name = Ident::new(&format!("{enum_name}Callbacks"), Span::call_site());
137
138    let field_names = e
139        .variants
140        .iter()
141        .map(|variant| {
142            Ident::new(
143                &format!("callback_{}", variant.ident.to_string().to_snake_case()),
144                Span::call_site(),
145            )
146        })
147        .collect::<Vec<_>>();
148
149    let inits = field_names
150        .iter()
151        .map(|field_name| {
152            quote! {
153                #field_name: Default::default(),
154            }
155        })
156        .collect::<Vec<_>>();
157
158    let curried_tys = e
159        .variants
160        .iter()
161        .map(|variant| match &variant.fields {
162            syn::Fields::Unit => None,
163            syn::Fields::Unnamed(syn::FieldsUnnamed {
164                unnamed: fields, ..
165            })
166            | syn::Fields::Named(syn::FieldsNamed { named: fields, .. }) => {
167                let tys = fields
168                    .iter()
169                    .filter(|field| is_curried(field))
170                    .map(|field| &field.ty)
171                    .collect::<Vec<_>>();
172
173                if tys.is_empty() {
174                    None
175                } else {
176                    Some(quote! {
177                        (#(#tys),*)
178                    })
179                }
180            }
181        })
182        .collect::<Vec<_>>();
183
184    let tys = e
185        .variants
186        .iter()
187        .map(|variant| match &variant.fields {
188            syn::Fields::Unit => {
189                quote! {
190                    ()
191                }
192            }
193            syn::Fields::Unnamed(syn::FieldsUnnamed {
194                unnamed: fields, ..
195            })
196            | syn::Fields::Named(syn::FieldsNamed { named: fields, .. }) => {
197                let tys = fields
198                    .iter()
199                    .filter(|field| !is_curried(field))
200                    .map(|field| &field.ty)
201                    .collect::<Vec<_>>();
202
203                quote! {
204                    (#(#tys),*)
205                }
206            }
207        })
208        .collect::<Vec<_>>();
209
210    let callbacks = field_names
211        .iter()
212        .zip(tys.iter())
213        .zip(curried_tys.iter())
214        .map(|((field_name, ty), curried_ty)| {
215            if let Some(curried_ty) = curried_ty {
216                quote! {
217                    #field_name: ::std::cell::RefCell<
218                        ::std::collections::HashMap<#curried_ty, ::yew::callback::Callback<#ty>>
219                    >,
220                }
221            } else {
222                quote! {
223                    #field_name: ::std::cell::RefCell<Option<::yew::callback::Callback<#ty>>>,
224                }
225            }
226        })
227        .collect::<Vec<_>>();
228
229    let constructors = e
230        .variants
231        .iter()
232        .zip(tys.iter())
233        .zip(field_names.iter())
234        .zip(curried_tys.iter())
235        .map(|(((variant, ty), field_name), curried_ty)| {
236            let name = &variant.ident;
237            let fn_name = Ident::new(&name.to_string().to_snake_case(), Span::call_site());
238
239            match &variant.fields {
240                syn::Fields::Unit => {
241                    quote! {
242                        fn #fn_name(&self) -> ::yew::callback::Callback<#ty> {
243                            if self.#field_name.borrow().is_none() {
244                                self.#field_name.replace(
245                                    Some(self.link.callback(|_| #enum_name::#name))
246                                );
247                            }
248                            self.#field_name.borrow().clone().unwrap()
249                        }
250                    }
251                }
252                syn::Fields::Unnamed(syn::FieldsUnnamed {
253                    unnamed: fields, ..
254                })
255                | syn::Fields::Named(syn::FieldsNamed { named: fields, .. }) => {
256                    let is_named = fields.iter().any(|field| field.ident.is_some());
257                    let idents = fields
258                        .iter()
259                        .enumerate()
260                        .map(|(i, field)| {
261                            field.ident.clone().unwrap_or_else(|| {
262                                Ident::new(&format!("arg_{i}"), Span::call_site())
263                            })
264                        })
265                        .collect::<Vec<_>>();
266
267                    if curried_ty.is_some() {
268                        let args = fields
269                            .iter()
270                            .zip(idents.iter())
271                            .filter_map(|(field, ident)| is_curried(field).then_some(ident))
272                            .collect::<Vec<_>>();
273                        let args_sig = fields
274                            .iter()
275                            .zip(idents.iter())
276                            .filter(|(field, _)| is_curried(field))
277                            .map(|(field, ident)| {
278                                let ty = &field.ty;
279
280                                quote! {
281                                    #ident: #ty
282                                }
283                            })
284                            .collect::<Vec<_>>();
285                        let ins = fields
286                            .iter()
287                            .zip(idents.iter())
288                            .filter_map(|(field, ident)| (!is_curried(field)).then_some(ident))
289                            .collect::<Vec<_>>();
290                        let keys = args
291                            .iter()
292                            .map(|arg| {
293                                quote! {
294                                    let #arg = #arg.clone();
295                                }
296                            })
297                            .collect::<Vec<_>>();
298                        let constructor = if is_named {
299                            let cloned_args = fields
300                                .iter()
301                                .zip(idents.iter())
302                                .map(|(field, ident)| {
303                                    if is_curried(field) {
304                                        quote! {
305                                            #ident: #ident.clone()
306                                        }
307                                    } else {
308                                        quote! {
309                                            #ident
310                                        }
311                                    }
312                                })
313                                .collect::<Vec<_>>();
314
315                            quote! {
316                                #enum_name::#name { #(#cloned_args),* }
317                            }
318                        } else {
319                            let cloned_args = fields
320                                .iter()
321                                .zip(idents.iter())
322                                .map(|(field, ident)| {
323                                    if is_curried(field) {
324                                        quote! {
325                                            #ident.clone()
326                                        }
327                                    } else {
328                                        quote! {
329                                            #ident
330                                        }
331                                    }
332                                })
333                                .collect::<Vec<_>>();
334
335                            quote! {
336                                #enum_name::#name(#(#cloned_args),*)
337                            }
338                        };
339
340                        quote! {
341                            #vis fn #fn_name(&self #(, #args_sig )* )
342                                -> ::yew::callback::Callback<#ty>
343                            {
344                                self.#field_name
345                                    .borrow_mut()
346                                    .entry((#(#args),*))
347                                    .or_insert_with_key(|(#(#args),*)| {
348                                        #(#keys)*
349                                        self.link.callback(move |(#(#ins),*)| #constructor)
350                                    })
351                                    .clone()
352                            }
353                        }
354                    } else {
355                        let constructor = if is_named {
356                            quote! {
357                                #enum_name::#name { #(#idents),* }
358                            }
359                        } else {
360                            quote! {
361                                #enum_name::#name(#(#idents),*)
362                            }
363                        };
364
365                        quote! {
366                            #vis fn #fn_name(&self) -> ::yew::callback::Callback<#ty> {
367                                if self.#field_name.borrow().is_none() {
368                                    self.#field_name.replace(Some(self
369                                        .link
370                                        .callback(|(#(#idents),*)| #constructor)
371                                    ));
372                                }
373                                self.#field_name.borrow().clone().unwrap()
374                            }
375                        }
376                    }
377                }
378            }
379        })
380        .collect::<Vec<_>>();
381
382    quote! {
383        #[derive(Debug)]
384        #vis struct #name<C: ::yew::html::BaseComponent> {
385            link: ::yew::html::Scope<C>,
386            #(#callbacks)*
387        }
388
389        impl<C: ::yew::html::BaseComponent<Message = #enum_name>> #name<C> {
390            #vis fn new(link: ::yew::html::Scope<C>) -> Self {
391                Self {
392                    link,
393                    #(#inits)*
394                }
395            }
396
397            #(#constructors)*
398        }
399
400        impl<C: ::yew::html::BaseComponent<Message = #enum_name>> From<::yew::html::Scope<C>>
401            for #name<C>
402        {
403            fn from(link: ::yew::html::Scope<C>) -> Self {
404                Self::new(link)
405            }
406        }
407
408        impl<C: ::yew::html::BaseComponent<Message = #enum_name>> From<&::yew::html::Scope<C>>
409            for #name<C>
410        {
411            fn from(link: &::yew::html::Scope<C>) -> Self {
412                Self::new(link.to_owned())
413            }
414        }
415    }
416}
417
418fn is_curried(field: &syn::Field) -> bool {
419    field
420        .attrs
421        .iter()
422        .any(|x| x.path.get_ident().map(|x| x == "curry").unwrap_or(false))
423}