whiskers_derive/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use proc_macro2::Ident;
4use quote::{format_ident, quote};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{
8    parse_macro_input, parse_quote, visit_mut::VisitMut, Attribute, Data, DataEnum, DataStruct,
9    DeriveInput, Expr, ExprPath, Field, Fields, FieldsNamed, FieldsUnnamed, Index, Variant,
10};
11
12fn label_from_ident(ident: &Ident) -> String {
13    format!("{}:", ident.to_string().to_case(Case::Lower))
14}
15
16/// Attribute macro to automatically derive some of the required traits for a sketch app.
17///
18/// This is equivalent to:
19/// ```ignore
20/// #[derive(Sketch, serde::Serialize, serde::Deserialize)]
21/// #[serde(crate = "::whiskers::prelude::serde")]
22/// ```
23#[proc_macro_attribute]
24pub fn sketch_app(_attr: TokenStream, item: TokenStream) -> TokenStream {
25    let ast = parse_macro_input!(item as DeriveInput);
26
27    let expanded = quote! {
28        #[derive(Sketch, serde::Serialize, serde::Deserialize)]
29        #[serde(crate = "::whiskers::prelude::serde")]
30        #ast
31    };
32
33    TokenStream::from(expanded)
34}
35
36/// Attribute macro to automatically derive some of the required traits for a sketch widget.
37///
38/// This is equivalent to:
39/// ```ignore
40/// #[derive(Widget, serde::Serialize, serde::Deserialize)]
41/// #[serde(crate = "::whiskers::prelude::serde")]
42/// ```
43#[proc_macro_attribute]
44pub fn sketch_widget(_attr: TokenStream, item: TokenStream) -> TokenStream {
45    let ast = parse_macro_input!(item as DeriveInput);
46
47    let expanded = quote! {
48        #[derive(Widget, serde::Serialize, serde::Deserialize)]
49        #[serde(crate = "whiskers_widgets::exports::serde")]
50        #ast
51    };
52
53    TokenStream::from(expanded)
54}
55
56#[proc_macro_derive(Sketch, attributes(param, skip))]
57pub fn sketch_derive(input: TokenStream) -> TokenStream {
58    let input: DeriveInput = parse_macro_input!(input);
59
60    let name = input.ident;
61
62    let fields_ui = match input.data {
63        Data::Struct(DataStruct { fields, .. }) => {
64            process_fields(fields, &format_ident!("Self"), &format_ident!("self"))
65        }
66        _ => panic!("The Sketch derive macro only supports structs"),
67    };
68
69    TokenStream::from(quote! {
70        impl whiskers_widgets::WidgetApp for #name {
71            fn name(&self) -> String {
72                stringify!(#name).to_string()
73            }
74
75            fn ui(&mut self, ui: &mut whiskers_widgets::exports::egui::Ui) -> bool {
76                #fields_ui
77            }
78        }
79
80        impl ::whiskers::SketchApp for #name {}
81    })
82}
83
84#[proc_macro_derive(Widget, attributes(param, skip))]
85pub fn sketch_ui_derive(input: TokenStream) -> TokenStream {
86    let input: DeriveInput = parse_macro_input!(input);
87
88    let name = input.ident;
89    let widget_name = format_ident!("{}Widget", name);
90
91    match input.data {
92        Data::Struct(DataStruct { fields, .. }) => process_struct(fields, &name, &widget_name),
93        Data::Enum(DataEnum { variants, .. }) => process_enum(variants, &name, &widget_name),
94        Data::Union(_) => {
95            unimplemented!()
96        }
97    }
98}
99
100fn process_struct(fields: Fields, name: &Ident, widget_name: &Ident) -> TokenStream {
101    let fields_ui = process_fields(fields, name, &format_ident!("value"));
102
103    TokenStream::from(quote! {
104        #[derive(Default)]
105        pub struct #widget_name;
106
107        impl whiskers_widgets::Widget<#name> for #widget_name {
108            fn ui(&self, ui: &mut whiskers_widgets::exports::egui::Ui, label: &str, value: &mut #name) -> bool {
109                ::whiskers_widgets::collapsing_header(ui, label.trim_end_matches(':'), "", true, |ui|{
110                        #fields_ui
111                    })
112                    .unwrap_or(false)
113            }
114
115            fn use_grid() -> bool {
116                false
117            }
118        }
119
120        impl whiskers_widgets::WidgetMapper<#name> for #name {
121            type Type = #widget_name;
122        }
123    })
124}
125
126fn field_defaults<'a>(fields: impl Iterator<Item = &'a Field>) -> proc_macro2::TokenStream {
127    let mut output = proc_macro2::TokenStream::new();
128    for field in fields {
129        let typ_ = &field.ty;
130        if let Some(name) = &field.ident {
131            output.extend(quote! {
132                #name: #typ_::default(),
133            });
134        } else {
135            output.extend(quote! {
136                #typ_::default(),
137            });
138        }
139    }
140
141    output
142}
143
144fn default_function_name_for_variant(variant_ident: &Ident) -> Ident {
145    format_ident!("__default_{}", variant_ident)
146}
147
148fn process_enum(
149    variants: Punctuated<Variant, Comma>,
150    name: &Ident,
151    widget_name: &Ident,
152) -> TokenStream {
153    //
154    // For each variant, create a function that returns the default value for that variant.
155    //
156
157    let mut default_functions = proc_macro2::TokenStream::new();
158    let mut simple_enum = true;
159    for Variant { ident, fields, .. } in variants.iter() {
160        let func_ident = default_function_name_for_variant(ident);
161
162        let fields_defaults = match fields {
163            Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
164                simple_enum = false;
165
166                let fields = field_defaults(unnamed.iter());
167                quote! {( #fields )}
168            }
169            Fields::Named(FieldsNamed { named, .. }) => {
170                simple_enum = false;
171
172                let fields = field_defaults(named.iter());
173                quote! {{ #fields }}
174            }
175            Fields::Unit => {
176                quote! {}
177            }
178        };
179
180        default_functions.extend(quote! {
181            #[allow(non_snake_case)]
182            fn #func_ident() -> Self {
183                #name::#ident #fields_defaults
184            }
185        });
186    }
187
188    let impl_default_functions = quote! {
189        impl #name {
190            #default_functions
191        }
192    };
193
194    //
195    // Create the UI code for the combo box menu. This is done in parts that are combined later, differently
196    // depending on whether the enum is simple or complex.
197    //
198
199    let idents = variants
200        .iter()
201        .map(|Variant { ident, .. }| ident.clone())
202        .collect::<Vec<_>>();
203
204    let field_captures_catch_all: Vec<_> = variants
205        .iter()
206        .map(|variant| match &variant.fields {
207            Fields::Named(FieldsNamed { .. }) => quote! { { .. } },
208            Fields::Unnamed(FieldsUnnamed { .. }) => quote! { ( .. ) },
209            Fields::Unit => quote! {},
210        })
211        .collect();
212
213    let ident_default_functions = idents
214        .iter()
215        .map(default_function_name_for_variant)
216        .collect::<Vec<_>>();
217    let ident_strings = idents
218        .iter()
219        .map(|ident| ident.to_string())
220        .collect::<Vec<_>>();
221
222    let name_string = name.to_string();
223
224    let pre_combo_code = quote! {
225        let mut selected_text = match value {
226            #(
227                #name::#idents #field_captures_catch_all => #ident_strings,
228            )*
229        }.to_owned();
230        let initial_selected_text = selected_text.clone();
231    };
232
233    let combo_code = quote! {
234        whiskers_widgets::exports::egui::ComboBox::from_id_source(#name_string).selected_text(&selected_text).show_ui(ui, |ui| {
235            #(
236                ui.selectable_value(&mut selected_text, #ident_strings.to_owned(), #ident_strings);
237            )*
238        });
239    };
240
241    let post_combo_code = quote! {
242        let mut changed = initial_selected_text != selected_text;
243
244        if changed {
245            *value = match selected_text.as_str() {
246                #(
247                    #ident_strings => #name::#ident_default_functions(),
248                )*
249                _ => unreachable!(),
250            };
251        }
252    };
253
254    //
255    // Simple enum case: build a simple UI and return.
256    //
257
258    if simple_enum {
259        let simple_enum_full_code = quote! {
260            #impl_default_functions
261
262            #[derive(Default)]
263            pub struct #widget_name;
264
265            impl whiskers_widgets::Widget<#name> for #widget_name {
266                fn ui(&self, ui: &mut whiskers_widgets::exports::egui::Ui, label: &str, value: &mut #name) -> bool {
267                    #pre_combo_code
268
269                    ui.label(label);
270                    #combo_code
271
272                    #post_combo_code
273
274                    changed
275                }
276
277                fn use_grid() -> bool {
278                    true
279                }
280            }
281
282            impl whiskers_widgets::WidgetMapper<#name> for #name {
283                type Type = #widget_name;
284            }
285        };
286
287        return TokenStream::from(simple_enum_full_code);
288    }
289
290    //
291    // Complex enum case: use a collapsing header whose body display the variant's UI.
292    //
293
294    // collect things like:
295    // - tuple variant => (field_0, field_1)
296    // - struct variant => { some_field, another_field }
297    // - unit variant => <empty>
298    let field_captures: Vec<_> = variants
299        .iter()
300        .map(|variant| match &variant.fields {
301            Fields::Named(FieldsNamed { named, .. }) => {
302                let fields = named
303                    .iter()
304                    .map(|field| field.ident.clone().unwrap())
305                    .collect::<Vec<_>>();
306
307                quote! {
308                    { #( #fields, )* }
309                }
310            }
311            Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
312                let fields = (0..unnamed.len())
313                    .map(|idx| format_ident!("field_{}", Index::from(idx)))
314                    .collect::<Vec<_>>();
315
316                quote! {
317                    ( #( #fields, )* )
318                }
319            }
320            Fields::Unit => quote! {},
321        })
322        .collect();
323
324    // collect (UI function, grid predicate) list of tuples for each variants
325    let field_tuples: Vec<_> = variants
326        .iter()
327        .map(|variant| match &variant.fields {
328            Fields::Named(FieldsNamed { named: field_list, .. })
329                | Fields::Unnamed(FieldsUnnamed { unnamed: field_list, .. }) => {
330                let field_names = field_list
331                    .iter()
332                    .filter(|variant| !has_skip_attr(&variant.attrs))
333                    .enumerate()
334                    .map(|(idx, field)| field
335                        .ident
336                        .clone()
337                        .unwrap_or(format_ident!("field_{}", Index::from(idx))))
338                    .collect::<Vec<_>>();
339                let field_types = field_list
340                    .iter()
341                    .filter(|variant| !has_skip_attr(&variant.attrs))
342                    .map(|field| field.ty.clone())
343                    .collect::<Vec<_>>();
344                let field_labels = field_names
345                    .iter()
346                    .map(label_from_ident)
347                    .collect::<Vec<_>>();
348                let chained_calls = field_list
349                    .iter()
350                    .filter(|variant| !has_skip_attr(&variant.attrs))
351                    .map(|field| chained_call_for_attrs(&field.attrs))
352                    .collect::<Vec<_>>();
353
354                quote! {
355                    #(
356                        (
357                            &mut |ui| {
358                                <#field_types as whiskers_widgets::WidgetMapper<#field_types>>::Type::default()
359                                    #chained_calls
360                                    .ui(
361                                        ui,
362                                        #field_labels,
363                                        #field_names,
364                                    )
365                            },
366                            &<#field_types as whiskers_widgets::WidgetMapper<#field_types>>::Type::use_grid,
367                        )
368                    ),*
369                }
370            }
371            Fields::Unit => quote!{
372                (
373                    &mut |ui| {
374                        ui.label(whiskers_widgets::exports::egui::RichText::new("no fields for this variant").weak().italics());
375                        false
376                    },
377                    &|| false,
378                )
379            }
380        })
381        .collect();
382
383    //
384    // Final assembly of the complex enum code.
385    //
386
387    TokenStream::from(quote! {
388        #impl_default_functions
389
390        #[derive(Default)]
391        pub struct #widget_name;
392
393        impl whiskers_widgets::Widget<#name> for #widget_name {
394            fn ui(&self, ui: &mut whiskers_widgets::exports::egui::Ui, label: &str, value: &mut #name) -> bool {
395
396                // draw the UI for a bunch of fields, swapping the grid on and off based on grid support
397                fn draw_ui(
398                    ui: &mut whiskers_widgets::exports::egui::Ui,
399                    changed: &mut bool,
400                    array: &mut [(&mut dyn FnMut(&mut egui::Ui) -> bool, &dyn Fn() -> bool)],
401                ) {
402                    let mut cur_index = 0;
403                    while cur_index < array.len() {
404                        if array[cur_index].1() {
405                            whiskers_widgets::exports::egui::Grid::new(cur_index).num_columns(2).show(ui, |ui| {
406                                while cur_index < array.len() && array[cur_index].1() {
407                                    *changed = (array[cur_index].0)(ui) || *changed;
408                                    ui.end_row();
409                                    cur_index += 1;
410                                }
411                            });
412                        }
413                        while cur_index < array.len() && !array[cur_index].1() {
414                            *changed = (array[cur_index].0)(ui) || *changed;
415                            cur_index += 1;
416                        }
417                    }
418                }
419
420                let (header_changed, body_changed) = ::whiskers_widgets::enum_collapsing_header(
421                    ui,
422                    label,
423                    value,
424                    |ui, value| {
425                        #pre_combo_code
426                        #combo_code
427                        #post_combo_code
428
429                        changed
430                    },
431                    true,
432                    |ui, value| {
433
434                        let mut changed = false;
435
436                        match value {
437                            #(
438                                #[allow(unused_variables)]
439                                #name::#idents #field_captures => {
440                                    draw_ui(
441                                        ui,
442                                        &mut changed,
443                                        &mut [ #field_tuples ],
444                                    );
445                                }
446                            )*
447                        };
448
449                        changed
450
451                    },
452                );
453
454                header_changed || body_changed.unwrap_or(false)
455            }
456
457            fn use_grid() -> bool {
458                false
459            }
460        }
461
462        impl whiskers_widgets::WidgetMapper<#name> for #name {
463            type Type = #widget_name;
464        }
465    })
466}
467
468fn has_skip_attr(attrs: &[Attribute]) -> bool {
469    attrs.iter().any(|attr| attr.path().is_ident("skip"))
470}
471fn chained_call_for_attrs(attrs: &[Attribute]) -> proc_macro2::TokenStream {
472    let param_attr = attrs.iter().find(|attr| attr.path().is_ident("param"));
473
474    let mut chained_calls = proc_macro2::TokenStream::new();
475
476    let mut add_chained_call = |meta: syn::meta::ParseNestedMeta, inner: bool| -> syn::Result<()> {
477        let ident = meta.path.get_ident().expect("expected ident");
478        let value = meta.value();
479
480        if value.is_ok() {
481            let mut expr: Expr = meta.input.parse()?;
482
483            // replaces occurrences of self with obj
484            ReplaceSelf.visit_expr_mut(&mut expr);
485
486            if inner {
487                chained_calls.extend(quote! {
488                    .inner(|obj| obj.#ident(#expr))
489                })
490            } else {
491                chained_calls.extend(quote! {
492                    .#ident(#expr)
493                });
494            }
495        } else if inner {
496            chained_calls.extend(quote! {
497                .inner(|obj| obj.#ident(true))
498
499            });
500        } else {
501            chained_calls.extend(quote! {
502                .#ident(true)
503            });
504        }
505
506        Ok(())
507    };
508
509    if let Some(param_attr) = param_attr {
510        let res = param_attr.parse_nested_meta(|meta| {
511            if meta.path.is_ident("inner") {
512                meta.parse_nested_meta(|meta| add_chained_call(meta, true))
513            } else {
514                add_chained_call(meta, false)
515            }
516        });
517
518        match res {
519            Ok(_) => {}
520            Err(err) => {
521                panic!("failed to parse param attribute {err}");
522            }
523        }
524    }
525
526    chained_calls
527}
528
529fn process_fields(
530    fields: Fields,
531    parent_type: &Ident,
532    parent_var: &Ident,
533) -> proc_macro2::TokenStream {
534    let mut output = proc_macro2::TokenStream::new();
535
536    let fields = match fields {
537        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => unnamed,
538        Fields::Named(FieldsNamed { named, .. }) => named,
539        Fields::Unit => {
540            return quote! { false };
541        }
542    };
543
544    for (idx, field) in fields.into_iter().enumerate() {
545        let (field_name, field_access) = match field.ident {
546            Some(ident) => (ident.clone(), quote!(#ident)),
547            None => {
548                let i = Index::from(idx);
549                (format_ident!("field_{}", idx), quote!(#i))
550            }
551        };
552
553        let field_type = field.ty;
554
555        if has_skip_attr(&field.attrs) {
556            continue;
557        }
558
559        let chained_call = chained_call_for_attrs(&field.attrs);
560        let formatted_label = label_from_ident(&field_name);
561
562        output.extend(quote! {
563            (
564                &|ui, obj| {
565                    <#field_type as whiskers_widgets::WidgetMapper<#field_type>>::Type::default()
566                        #chained_call
567                        .ui(ui, #formatted_label, &mut obj.#field_access)
568
569                },
570                &<#field_type as whiskers_widgets::WidgetMapper<#field_type>>::Type::use_grid,
571            ),
572        });
573    }
574
575    // This is the magic UI code that handles `whiskers_widgets::Widget::use_grid()`. It works as
576    // follows:
577    // - An array of closure tuple are created for all fields. The first closure is the actual UI
578    //   code, the second is a predicate that returns whether the grid should be used.
579    // - The array is then walked, and contiguous stretches of tuple for which the predicate returns
580    //   `true` grouped together and rendered in a grid.
581    quote! {
582        {
583            let array: &[(
584                &dyn Fn(&mut whiskers_widgets::exports::egui::Ui, &mut #parent_type) -> bool, // ui code
585                &dyn Fn() -> bool                                  // use grid predicate
586            )] = &[
587                #output
588            ];
589
590            let mut cur_index = 0;
591            let mut changed = false;
592
593            while cur_index < array.len() {
594                if array[cur_index].1() {
595                    whiskers_widgets::exports::egui::Grid::new(cur_index)
596                        .num_columns(2)
597                        .show(ui, |ui| {
598                            while cur_index < array.len() && array[cur_index].1() {
599                                changed = (array[cur_index].0)(ui, #parent_var) || changed;
600                                ui.end_row();
601                                cur_index += 1;
602                            }
603                        });
604                }
605
606                while cur_index < array.len() && !array[cur_index].1() {
607                    changed = (array[cur_index].0)(ui, #parent_var) || changed;
608                    cur_index += 1;
609                }
610            }
611
612            changed
613        }
614    }
615}
616
617/// Expression visitor to replace `self` with `obj`.
618struct ReplaceSelf;
619
620impl VisitMut for ReplaceSelf {
621    fn visit_expr_path_mut(&mut self, node: &mut ExprPath) {
622        if node.path.is_ident("self") {
623            *node = parse_quote!(obj);
624        }
625    }
626}