Skip to main content

tui_dispatch_macros/
lib.rs

1//! Procedural macros for tui-dispatch
2
3use darling::{FromDeriveInput, FromField, FromVariant};
4use proc_macro::TokenStream;
5use proc_macro2::Ident;
6use quote::{format_ident, quote};
7use std::collections::HashMap;
8use syn::{parse_macro_input, DeriveInput};
9use tui_dispatch_shared::{infer_action_category, pascal_to_snake_case};
10
11/// Container-level attributes for #[derive(Action)]
12#[derive(Debug, FromDeriveInput)]
13#[darling(attributes(action), supports(enum_any))]
14struct ActionOpts {
15    ident: syn::Ident,
16    data: darling::ast::Data<ActionVariant, ()>,
17
18    /// Enable automatic category inference from variant name prefixes
19    #[darling(default)]
20    infer_categories: bool,
21
22    /// Generate dispatcher trait
23    #[darling(default)]
24    generate_dispatcher: bool,
25}
26
27/// Variant-level attributes
28#[derive(Debug, FromVariant)]
29#[darling(attributes(action))]
30struct ActionVariant {
31    ident: syn::Ident,
32    fields: darling::ast::Fields<()>,
33
34    /// Explicit category override
35    #[darling(default)]
36    category: Option<String>,
37
38    /// Exclude from category inference
39    #[darling(default)]
40    skip_category: bool,
41}
42
43/// Convert PascalCase to snake_case
44fn to_snake_case(s: &str) -> String {
45    pascal_to_snake_case(s)
46}
47
48/// Convert snake_case to PascalCase
49fn to_pascal_case(s: &str) -> String {
50    s.split('_')
51        .map(|part| {
52            let mut chars = part.chars();
53            match chars.next() {
54                None => String::new(),
55                Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
56            }
57        })
58        .collect()
59}
60
61/// Infer category from a variant name using naming patterns
62fn infer_category(name: &str) -> Option<String> {
63    infer_action_category(name)
64}
65
66/// Derive macro for the Action trait
67///
68/// Generates a `name()` method that returns the variant name as a static string.
69///
70/// With `#[action(infer_categories)]`, also generates:
71/// - `category() -> Option<&'static str>` - Get action's category
72/// - `category_enum() -> {Name}Category` - Get category as enum
73/// - `is_{category}()` predicates for each category
74/// - `{Name}Category` enum with all discovered categories
75///
76/// With `#[action(generate_dispatcher)]`, also generates:
77/// - `{Name}Dispatcher` trait with category-based dispatch methods
78///
79/// # Example
80/// ```ignore
81/// #[derive(Action, Clone, Debug)]
82/// #[action(infer_categories, generate_dispatcher)]
83/// enum MyAction {
84///     SearchStart,
85///     SearchClear,
86///     ConnectionFormOpen,
87///     ConnectionFormSubmit,
88///     DidConnect,
89///     Tick,  // uncategorized
90/// }
91///
92/// let action = MyAction::SearchStart;
93/// assert_eq!(action.name(), "SearchStart");
94/// assert_eq!(action.category(), Some("search"));
95/// assert!(action.is_search());
96/// ```
97#[proc_macro_derive(Action, attributes(action))]
98pub fn derive_action(input: TokenStream) -> TokenStream {
99    let input = parse_macro_input!(input as DeriveInput);
100
101    // Try to parse with darling for attributes
102    let opts = match ActionOpts::from_derive_input(&input) {
103        Ok(opts) => opts,
104        Err(e) => return e.write_errors().into(),
105    };
106
107    let name = &opts.ident;
108
109    let variants = match &opts.data {
110        darling::ast::Data::Enum(variants) => variants,
111        _ => {
112            return syn::Error::new_spanned(&input, "Action can only be derived for enums")
113                .to_compile_error()
114                .into();
115        }
116    };
117
118    // Get the original syn variants for field info (darling loses field names)
119    let syn_variants = match &input.data {
120        syn::Data::Enum(data) => &data.variants,
121        _ => unreachable!(), // Already checked above
122    };
123
124    // Generate basic name() implementation
125    let name_arms = variants.iter().map(|v| {
126        let variant_name = &v.ident;
127        let variant_str = variant_name.to_string();
128
129        match &v.fields.style {
130            darling::ast::Style::Unit => quote! {
131                #name::#variant_name => #variant_str
132            },
133            darling::ast::Style::Tuple => quote! {
134                #name::#variant_name(..) => #variant_str
135            },
136            darling::ast::Style::Struct => quote! {
137                #name::#variant_name { .. } => #variant_str
138            },
139        }
140    });
141
142    // Generate params() implementation - outputs field values without variant name
143    let params_arms = syn_variants.iter().map(|v| {
144        let variant_name = &v.ident;
145
146        match &v.fields {
147            syn::Fields::Unit => quote! {
148                #name::#variant_name => ::std::string::String::new()
149            },
150            syn::Fields::Unnamed(fields) => {
151                let field_count = fields.unnamed.len();
152                let field_names: Vec<_> =
153                    (0..field_count).map(|i| format_ident!("_{}", i)).collect();
154                if field_count == 1 {
155                    quote! {
156                        #name::#variant_name(#(#field_names),*) => {
157                            tui_dispatch::debug::debug_string(&#(#field_names),*)
158                        }
159                    }
160                } else {
161                    let parts = field_names.iter().map(|field| {
162                        quote! { tui_dispatch::debug::debug_string(&#field) }
163                    });
164                    quote! {
165                        #name::#variant_name(#(#field_names),*) => {
166                            let values = ::std::vec![#(#parts),*];
167                            format!("({})", values.join(", "))
168                        }
169                    }
170                }
171            }
172            syn::Fields::Named(fields) => {
173                let field_names: Vec<_> = fields
174                    .named
175                    .iter()
176                    .filter_map(|f| f.ident.as_ref())
177                    .collect();
178                if field_names.is_empty() {
179                    quote! {
180                        #name::#variant_name { .. } => ::std::string::String::new()
181                    }
182                } else {
183                    let parts = field_names.iter().map(|field| {
184                        let label = field.to_string();
185                        quote! {
186                            format!("{}: {}", #label, tui_dispatch::debug::debug_string(&#field))
187                        }
188                    });
189                    quote! {
190                        #name::#variant_name { #(#field_names),*, .. } => {
191                            let values = ::std::vec![#(#parts),*];
192                            format!("{{{}}}", values.join(", "))
193                        }
194                    }
195                }
196            }
197        }
198    });
199
200    let params_pretty_arms = syn_variants.iter().map(|v| {
201        let variant_name = &v.ident;
202
203        match &v.fields {
204            syn::Fields::Unit => quote! {
205                #name::#variant_name => ::std::string::String::new()
206            },
207            syn::Fields::Unnamed(fields) => {
208                let field_count = fields.unnamed.len();
209                let field_names: Vec<_> =
210                    (0..field_count).map(|i| format_ident!("_{}", i)).collect();
211                if field_count == 1 {
212                    quote! {
213                        #name::#variant_name(#(#field_names),*) => {
214                            tui_dispatch::debug::debug_string_pretty(&#(#field_names),*)
215                        }
216                    }
217                } else {
218                    let parts = field_names.iter().map(|field| {
219                        quote! { tui_dispatch::debug::debug_string_pretty(&#field) }
220                    });
221                    quote! {
222                        #name::#variant_name(#(#field_names),*) => {
223                            let values = ::std::vec![#(#parts),*];
224                            format!("({})", values.join(", "))
225                        }
226                    }
227                }
228            }
229            syn::Fields::Named(fields) => {
230                let field_names: Vec<_> = fields
231                    .named
232                    .iter()
233                    .filter_map(|f| f.ident.as_ref())
234                    .collect();
235                if field_names.is_empty() {
236                    quote! {
237                        #name::#variant_name { .. } => ::std::string::String::new()
238                    }
239                } else {
240                    let parts = field_names.iter().map(|field| {
241                        let label = field.to_string();
242                        quote! {
243                            format!("{}: {}", #label, tui_dispatch::debug::debug_string_pretty(&#field))
244                        }
245                    });
246                    quote! {
247                        #name::#variant_name { #(#field_names),*, .. } => {
248                            let values = ::std::vec![#(#parts),*];
249                            format!("{{{}}}", values.join(", "))
250                        }
251                    }
252                }
253            }
254        }
255    });
256
257    let mut expanded = quote! {
258        impl tui_dispatch::Action for #name {
259            fn name(&self) -> &'static str {
260                match self {
261                    #(#name_arms),*
262                }
263            }
264        }
265
266        impl tui_dispatch::ActionParams for #name {
267            fn params(&self) -> ::std::string::String {
268                match self {
269                    #(#params_arms),*
270                }
271            }
272
273            fn params_pretty(&self) -> ::std::string::String {
274                match self {
275                    #(#params_pretty_arms),*
276                }
277            }
278        }
279    };
280
281    // If category inference is enabled, generate category-related code
282    if opts.infer_categories {
283        // Collect categories and their variants
284        let mut categories: HashMap<String, Vec<&Ident>> = HashMap::new();
285        let mut variant_categories: Vec<(&Ident, Option<String>)> = Vec::new();
286
287        for v in variants.iter() {
288            let cat = if v.skip_category {
289                None
290            } else if let Some(ref explicit_cat) = v.category {
291                Some(explicit_cat.clone())
292            } else {
293                infer_category(&v.ident.to_string())
294            };
295
296            variant_categories.push((&v.ident, cat.clone()));
297
298            if let Some(ref category) = cat {
299                categories
300                    .entry(category.clone())
301                    .or_default()
302                    .push(&v.ident);
303            }
304        }
305
306        // Sort categories for deterministic output
307        let mut sorted_categories: Vec<_> = categories.keys().cloned().collect();
308        sorted_categories.sort();
309
310        // Create deduplicated category match arms
311        let category_arms_dedup: Vec<_> = variant_categories
312            .iter()
313            .map(|(variant, cat)| {
314                let cat_expr = match cat {
315                    Some(c) => quote! { ::core::option::Option::Some(#c) },
316                    None => quote! { ::core::option::Option::None },
317                };
318                // Use wildcard pattern to handle all field types
319                quote! { #name::#variant { .. } => #cat_expr }
320            })
321            .collect();
322
323        // Generate category enum
324        let category_enum_name = format_ident!("{}Category", name);
325        let category_variants: Vec<_> = sorted_categories
326            .iter()
327            .map(|c| format_ident!("{}", to_pascal_case(c)))
328            .collect();
329        let category_variant_names: Vec<_> = sorted_categories.clone();
330
331        // Generate category_enum() method arms
332        let category_enum_arms: Vec<_> = variant_categories
333            .iter()
334            .map(|(variant, cat)| {
335                let cat_variant = match cat {
336                    Some(c) => format_ident!("{}", to_pascal_case(c)),
337                    None => format_ident!("Uncategorized"),
338                };
339                quote! { #name::#variant { .. } => #category_enum_name::#cat_variant }
340            })
341            .collect();
342
343        // Generate is_* predicates
344        let predicates: Vec<_> = sorted_categories
345            .iter()
346            .map(|cat| {
347                let predicate_name = format_ident!("is_{}", cat);
348                let cat_variants = categories.get(cat).unwrap();
349                let patterns: Vec<_> = cat_variants
350                    .iter()
351                    .map(|v| quote! { #name::#v { .. } })
352                    .collect();
353                let doc = format!(
354                    "Returns true if this action belongs to the `{}` category.",
355                    cat
356                );
357
358                quote! {
359                    #[doc = #doc]
360                    pub fn #predicate_name(&self) -> bool {
361                        matches!(self, #(#patterns)|*)
362                    }
363                }
364            })
365            .collect();
366
367        // Add category-related implementations
368        let category_enum_doc = format!(
369            "Action categories for [`{}`].\n\n\
370             Use [`{}::category_enum()`] to get the category of an action.",
371            name, name
372        );
373
374        expanded = quote! {
375            #expanded
376
377            #[doc = #category_enum_doc]
378            #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
379            pub enum #category_enum_name {
380                #(#category_variants,)*
381                /// Actions that don't belong to any specific category.
382                Uncategorized,
383            }
384
385            impl #category_enum_name {
386                /// Get all category values
387                pub fn all() -> &'static [Self] {
388                    &[#(Self::#category_variants,)* Self::Uncategorized]
389                }
390
391                /// Get category name as string
392                pub fn name(&self) -> &'static str {
393                    match self {
394                        #(Self::#category_variants => #category_variant_names,)*
395                        Self::Uncategorized => "uncategorized",
396                    }
397                }
398            }
399
400            impl #name {
401                /// Get the action's category (if categorized)
402                pub fn category(&self) -> ::core::option::Option<&'static str> {
403                    match self {
404                        #(#category_arms_dedup,)*
405                    }
406                }
407
408                /// Get the category as an enum value
409                pub fn category_enum(&self) -> #category_enum_name {
410                    match self {
411                        #(#category_enum_arms,)*
412                    }
413                }
414
415                #(#predicates)*
416            }
417
418            impl tui_dispatch::ActionCategory for #name {
419                type Category = #category_enum_name;
420
421                fn category(&self) -> ::core::option::Option<&'static str> {
422                    #name::category(self)
423                }
424
425                fn category_enum(&self) -> Self::Category {
426                    #name::category_enum(self)
427                }
428            }
429        };
430
431        // Generate dispatcher trait if requested
432        if opts.generate_dispatcher {
433            let dispatcher_trait_name = format_ident!("{}Dispatcher", name);
434
435            let dispatch_methods: Vec<_> = sorted_categories
436                .iter()
437                .map(|cat| {
438                    let method_name = format_ident!("dispatch_{}", cat);
439                    let doc = format!("Handle actions in the `{}` category.", cat);
440                    quote! {
441                        #[doc = #doc]
442                        fn #method_name(&mut self, action: &#name) -> bool {
443                            false
444                        }
445                    }
446                })
447                .collect();
448
449            let dispatch_arms: Vec<_> = sorted_categories
450                .iter()
451                .map(|cat| {
452                    let method_name = format_ident!("dispatch_{}", cat);
453                    let cat_variant = format_ident!("{}", to_pascal_case(cat));
454                    quote! {
455                        #category_enum_name::#cat_variant => self.#method_name(action)
456                    }
457                })
458                .collect();
459
460            let dispatcher_doc = format!(
461                "Dispatcher trait for [`{}`].\n\n\
462                 Implement the `dispatch_*` methods for each category you want to handle.\n\
463                 The [`dispatch()`](Self::dispatch) method automatically routes to the correct handler.",
464                name
465            );
466
467            expanded = quote! {
468                #expanded
469
470                #[doc = #dispatcher_doc]
471                pub trait #dispatcher_trait_name {
472                    #(#dispatch_methods)*
473
474                    /// Handle uncategorized actions.
475                    fn dispatch_uncategorized(&mut self, action: &#name) -> bool {
476                        false
477                    }
478
479                    /// Main dispatch entry point - routes to category-specific handlers.
480                    fn dispatch(&mut self, action: &#name) -> bool {
481                        match action.category_enum() {
482                            #(#dispatch_arms,)*
483                            #category_enum_name::Uncategorized => self.dispatch_uncategorized(action),
484                        }
485                    }
486                }
487            };
488        }
489    }
490
491    TokenStream::from(expanded)
492}
493
494/// Derive macro for the BindingContext trait
495///
496/// Generates implementations for `name()`, `from_name()`, and `all()` methods.
497/// The context name is derived from the variant name converted to snake_case.
498///
499/// # Example
500/// ```ignore
501/// #[derive(BindingContext, Clone, Copy, PartialEq, Eq, Hash)]
502/// enum MyContext {
503///     Default,
504///     Search,
505///     ConnectionForm,
506/// }
507///
508/// // Generated names: "default", "search", "connection_form"
509/// assert_eq!(MyContext::Default.name(), "default");
510/// assert_eq!(MyContext::from_name("search"), Some(MyContext::Search));
511/// ```
512#[proc_macro_derive(BindingContext)]
513pub fn derive_binding_context(input: TokenStream) -> TokenStream {
514    let input = parse_macro_input!(input as DeriveInput);
515    let name = &input.ident;
516
517    let expanded = match &input.data {
518        syn::Data::Enum(data) => {
519            // Check that all variants are unit variants
520            for variant in &data.variants {
521                if !matches!(variant.fields, syn::Fields::Unit) {
522                    return syn::Error::new_spanned(
523                        variant,
524                        "BindingContext can only be derived for enums with unit variants",
525                    )
526                    .to_compile_error()
527                    .into();
528                }
529            }
530
531            let variant_names: Vec<_> = data.variants.iter().map(|v| &v.ident).collect();
532            let variant_strings: Vec<_> = variant_names
533                .iter()
534                .map(|v| to_snake_case(&v.to_string()))
535                .collect();
536
537            let name_arms = variant_names
538                .iter()
539                .zip(variant_strings.iter())
540                .map(|(v, s)| {
541                    quote! { #name::#v => #s }
542                });
543
544            let from_name_arms = variant_names
545                .iter()
546                .zip(variant_strings.iter())
547                .map(|(v, s)| {
548                    quote! { #s => ::core::option::Option::Some(#name::#v) }
549                });
550
551            let all_variants = variant_names.iter().map(|v| quote! { #name::#v });
552
553            quote! {
554                impl tui_dispatch::BindingContext for #name {
555                    fn name(&self) -> &'static str {
556                        match self {
557                            #(#name_arms),*
558                        }
559                    }
560
561                    fn from_name(name: &str) -> ::core::option::Option<Self> {
562                        match name {
563                            #(#from_name_arms,)*
564                            _ => ::core::option::Option::None,
565                        }
566                    }
567
568                    fn all() -> &'static [Self] {
569                        static ALL: &[#name] = &[#(#all_variants),*];
570                        ALL
571                    }
572                }
573            }
574        }
575        _ => {
576            return syn::Error::new_spanned(input, "BindingContext can only be derived for enums")
577                .to_compile_error()
578                .into();
579        }
580    };
581
582    TokenStream::from(expanded)
583}
584
585/// Derive macro for the ComponentId trait
586///
587/// Generates implementations for `name()` method that returns the variant name.
588///
589/// # Example
590/// ```ignore
591/// #[derive(ComponentId, Clone, Copy, PartialEq, Eq, Hash, Debug)]
592/// enum MyComponentId {
593///     Sidebar,
594///     MainContent,
595///     StatusBar,
596/// }
597///
598/// assert_eq!(MyComponentId::Sidebar.name(), "Sidebar");
599/// ```
600#[proc_macro_derive(ComponentId)]
601pub fn derive_component_id(input: TokenStream) -> TokenStream {
602    let input = parse_macro_input!(input as DeriveInput);
603    let name = &input.ident;
604
605    let expanded = match &input.data {
606        syn::Data::Enum(data) => {
607            // Check that all variants are unit variants
608            for variant in &data.variants {
609                if !matches!(variant.fields, syn::Fields::Unit) {
610                    return syn::Error::new_spanned(
611                        variant,
612                        "ComponentId can only be derived for enums with unit variants",
613                    )
614                    .to_compile_error()
615                    .into();
616                }
617            }
618
619            let variant_names: Vec<_> = data.variants.iter().map(|v| &v.ident).collect();
620            let variant_strings: Vec<_> = variant_names.iter().map(|v| v.to_string()).collect();
621
622            let name_arms = variant_names
623                .iter()
624                .zip(variant_strings.iter())
625                .map(|(v, s)| {
626                    quote! { #name::#v => #s }
627                });
628
629            quote! {
630                impl tui_dispatch::ComponentId for #name {
631                    fn name(&self) -> &'static str {
632                        match self {
633                            #(#name_arms),*
634                        }
635                    }
636                }
637            }
638        }
639        _ => {
640            return syn::Error::new_spanned(input, "ComponentId can only be derived for enums")
641                .to_compile_error()
642                .into();
643        }
644    };
645
646    TokenStream::from(expanded)
647}
648
649// ============================================================================
650// DebugState derive macro
651// ============================================================================
652
653/// Container-level attributes for #[derive(DebugState)]
654#[derive(Debug, FromDeriveInput)]
655#[darling(attributes(debug_state), supports(struct_named))]
656struct DebugStateOpts {
657    ident: syn::Ident,
658    data: darling::ast::Data<(), DebugStateField>,
659}
660
661/// Field-level attributes for DebugState
662#[derive(Debug, FromField)]
663#[darling(attributes(debug))]
664struct DebugStateField {
665    ident: Option<syn::Ident>,
666
667    /// Section name for this field (groups fields together)
668    #[darling(default)]
669    section: Option<String>,
670
671    /// Skip this field in debug output
672    #[darling(default)]
673    skip: bool,
674
675    /// Custom display format (e.g., "{:?}" for Debug, "{:#?}" for pretty Debug)
676    #[darling(default)]
677    format: Option<String>,
678
679    /// Custom label for this field (defaults to field name)
680    #[darling(default)]
681    label: Option<String>,
682
683    /// Use Debug trait instead of Display
684    #[darling(default)]
685    debug_fmt: bool,
686}
687
688/// Derive macro for the DebugState trait
689///
690/// Automatically generates `debug_sections()` implementation from struct fields.
691///
692/// # Attributes
693///
694/// - `#[debug(section = "Name")]` - Group field under a section
695/// - `#[debug(skip)]` - Exclude field from debug output
696/// - `#[debug(label = "Custom Label")]` - Use custom label instead of field name
697/// - `#[debug(debug_fmt)]` - Use `{:?}` format instead of `Display`
698/// - `#[debug(format = "{:#?}")]` - Use custom format string
699///
700/// # Example
701///
702/// ```ignore
703/// use tui_dispatch::DebugState;
704///
705/// #[derive(DebugState)]
706/// struct AppState {
707///     #[debug(section = "Connection")]
708///     host: String,
709///     #[debug(section = "Connection")]
710///     port: u16,
711///
712///     #[debug(section = "UI")]
713///     scroll_offset: usize,
714///
715///     #[debug(skip)]
716///     internal_cache: HashMap<String, Data>,
717///
718///     #[debug(section = "Stats", debug_fmt)]
719///     status: ConnectionStatus,
720/// }
721/// ```
722///
723/// Fields without a section attribute are grouped under a section named after
724/// the struct (e.g., "AppState").
725#[proc_macro_derive(DebugState, attributes(debug, debug_state))]
726pub fn derive_debug_state(input: TokenStream) -> TokenStream {
727    let input = parse_macro_input!(input as DeriveInput);
728
729    let opts = match DebugStateOpts::from_derive_input(&input) {
730        Ok(opts) => opts,
731        Err(e) => return e.write_errors().into(),
732    };
733
734    let name = &opts.ident;
735    let default_section = name.to_string();
736
737    let fields = match &opts.data {
738        darling::ast::Data::Struct(fields) => fields,
739        _ => {
740            return syn::Error::new_spanned(&input, "DebugState can only be derived for structs")
741                .to_compile_error()
742                .into();
743        }
744    };
745
746    // Group fields by section
747    let mut sections: HashMap<String, Vec<&DebugStateField>> = HashMap::new();
748    let mut section_order: Vec<String> = Vec::new();
749
750    for field in fields.iter() {
751        if field.skip {
752            continue;
753        }
754
755        let section_name = field
756            .section
757            .clone()
758            .unwrap_or_else(|| default_section.clone());
759
760        if !section_order.contains(&section_name) {
761            section_order.push(section_name.clone());
762        }
763
764        sections.entry(section_name).or_default().push(field);
765    }
766
767    // Generate code for each section
768    let section_code: Vec<_> = section_order
769        .iter()
770        .map(|section_name| {
771            let fields_in_section = sections.get(section_name).unwrap();
772
773            let entry_calls: Vec<_> = fields_in_section
774                .iter()
775                .filter_map(|field| {
776                    let field_ident = field.ident.as_ref()?;
777                    let label = field
778                        .label
779                        .clone()
780                        .unwrap_or_else(|| field_ident.to_string());
781
782                    let value_expr = if let Some(ref fmt) = field.format {
783                        quote! { format!(#fmt, self.#field_ident) }
784                    } else if field.debug_fmt {
785                        quote! { format!("{:?}", self.#field_ident) }
786                    } else {
787                        quote! { tui_dispatch::debug::debug_string(&self.#field_ident) }
788                    };
789
790                    Some(quote! {
791                        .entry(#label, #value_expr)
792                    })
793                })
794                .collect();
795
796            quote! {
797                tui_dispatch::debug::DebugSection::new(#section_name)
798                    #(#entry_calls)*
799            }
800        })
801        .collect();
802
803    let expanded = quote! {
804        impl tui_dispatch::debug::DebugState for #name {
805            fn debug_sections(&self) -> ::std::vec::Vec<tui_dispatch::debug::DebugSection> {
806                ::std::vec![
807                    #(#section_code),*
808                ]
809            }
810        }
811    };
812
813    TokenStream::from(expanded)
814}
815
816// ============================================================================
817// FeatureFlags derive macro
818// ============================================================================
819
820/// Field-level attributes for FeatureFlags
821#[derive(Debug, FromField)]
822#[darling(attributes(flag))]
823struct FeatureFlagsField {
824    ident: Option<syn::Ident>,
825    ty: syn::Type,
826
827    /// Default value for this feature (defaults to false)
828    #[darling(default)]
829    default: Option<bool>,
830}
831
832/// Container-level attributes for #[derive(FeatureFlags)]
833#[derive(Debug, FromDeriveInput)]
834#[darling(attributes(feature_flags), supports(struct_named))]
835struct FeatureFlagsOpts {
836    ident: syn::Ident,
837    data: darling::ast::Data<(), FeatureFlagsField>,
838}
839
840/// Derive macro for the FeatureFlags trait
841///
842/// Generates implementations for `is_enabled()`, `set()`, and `all_flags()` methods.
843/// Also generates a `Default` implementation using the specified defaults.
844///
845/// # Attributes
846///
847/// - `#[flag(default = true)]` - Set default value (defaults to false)
848///
849/// # Example
850///
851/// ```ignore
852/// use tui_dispatch::FeatureFlags;
853///
854/// #[derive(FeatureFlags)]
855/// struct Features {
856///     #[flag(default = false)]
857///     new_search_ui: bool,
858///
859///     #[flag(default = true)]
860///     vim_bindings: bool,
861/// }
862///
863/// let mut features = Features::default();
864/// assert!(!features.new_search_ui);
865/// assert!(features.vim_bindings);
866///
867/// features.enable("new_search_ui");
868/// assert!(features.new_search_ui);
869/// ```
870#[proc_macro_derive(FeatureFlags, attributes(flag, feature_flags))]
871pub fn derive_feature_flags(input: TokenStream) -> TokenStream {
872    let input = parse_macro_input!(input as DeriveInput);
873
874    let opts = match FeatureFlagsOpts::from_derive_input(&input) {
875        Ok(opts) => opts,
876        Err(e) => return e.write_errors().into(),
877    };
878
879    let name = &opts.ident;
880
881    let fields = match &opts.data {
882        darling::ast::Data::Struct(fields) => fields,
883        _ => {
884            return syn::Error::new_spanned(
885                &input,
886                "FeatureFlags can only be derived for structs with named fields",
887            )
888            .to_compile_error()
889            .into();
890        }
891    };
892
893    // Collect bool fields only
894    let bool_fields: Vec<_> = fields
895        .iter()
896        .filter_map(|f| {
897            let ident = f.ident.as_ref()?;
898            // Check if type is bool
899            if let syn::Type::Path(type_path) = &f.ty {
900                if type_path.path.is_ident("bool") {
901                    return Some((ident.clone(), f.default.unwrap_or(false)));
902                }
903            }
904            None
905        })
906        .collect();
907
908    if bool_fields.is_empty() {
909        return syn::Error::new_spanned(
910            &input,
911            "FeatureFlags struct must have at least one bool field",
912        )
913        .to_compile_error()
914        .into();
915    }
916
917    // Generate is_enabled match arms
918    let is_enabled_arms: Vec<_> = bool_fields
919        .iter()
920        .map(|(ident, _)| {
921            let name_str = ident.to_string();
922            quote! { #name_str => ::core::option::Option::Some(self.#ident) }
923        })
924        .collect();
925
926    // Generate set match arms
927    let set_arms: Vec<_> = bool_fields
928        .iter()
929        .map(|(ident, _)| {
930            let name_str = ident.to_string();
931            quote! {
932                #name_str => {
933                    self.#ident = enabled;
934                    true
935                }
936            }
937        })
938        .collect();
939
940    // Generate all_flags array
941    let flag_names: Vec<_> = bool_fields
942        .iter()
943        .map(|(ident, _)| ident.to_string())
944        .collect();
945
946    // Generate Default impl with proper defaults
947    let default_fields: Vec<_> = bool_fields
948        .iter()
949        .map(|(ident, default)| {
950            quote! { #ident: #default }
951        })
952        .collect();
953
954    let expanded = quote! {
955        impl tui_dispatch::FeatureFlags for #name {
956            fn is_enabled(&self, name: &str) -> ::core::option::Option<bool> {
957                match name {
958                    #(#is_enabled_arms,)*
959                    _ => ::core::option::Option::None,
960                }
961            }
962
963            fn set(&mut self, name: &str, enabled: bool) -> bool {
964                match name {
965                    #(#set_arms)*
966                    _ => false,
967                }
968            }
969
970            fn all_flags() -> &'static [&'static str] {
971                &[#(#flag_names),*]
972            }
973        }
974
975        impl ::core::default::Default for #name {
976            fn default() -> Self {
977                Self {
978                    #(#default_fields,)*
979                }
980            }
981        }
982    };
983
984    TokenStream::from(expanded)
985}
986
987#[cfg(test)]
988mod tests {
989    use super::*;
990
991    #[test]
992    fn test_to_snake_case_handles_acronyms() {
993        assert_eq!(to_snake_case("APIFetch"), "api_fetch");
994        assert_eq!(to_snake_case("HTTPResult"), "http_result");
995    }
996
997    #[test]
998    fn test_infer_category_handles_acronyms() {
999        assert_eq!(infer_category("APIFetchStart"), Some("api".to_string()));
1000        assert_eq!(
1001            infer_category("SearchHTTPStart"),
1002            Some("search_http".to_string())
1003        );
1004    }
1005}