1use darling::{FromDeriveInput, FromField, FromVariant};
4use proc_macro::TokenStream;
5use proc_macro2::Ident;
6use quote::{format_ident, quote};
7use std::collections::HashMap;
8use syn::{parse_macro_input, DeriveInput};
9use tui_dispatch_shared::{infer_action_category, pascal_to_snake_case};
10
11#[derive(Debug, FromDeriveInput)]
13#[darling(attributes(action), supports(enum_any))]
14struct ActionOpts {
15 ident: syn::Ident,
16 data: darling::ast::Data<ActionVariant, ()>,
17
18 #[darling(default)]
20 infer_categories: bool,
21
22 #[darling(default)]
24 generate_dispatcher: bool,
25}
26
27#[derive(Debug, FromVariant)]
29#[darling(attributes(action))]
30struct ActionVariant {
31 ident: syn::Ident,
32 fields: darling::ast::Fields<()>,
33
34 #[darling(default)]
36 category: Option<String>,
37
38 #[darling(default)]
40 skip_category: bool,
41}
42
43fn to_snake_case(s: &str) -> String {
45 pascal_to_snake_case(s)
46}
47
48fn to_pascal_case(s: &str) -> String {
50 s.split('_')
51 .map(|part| {
52 let mut chars = part.chars();
53 match chars.next() {
54 None => String::new(),
55 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
56 }
57 })
58 .collect()
59}
60
61fn infer_category(name: &str) -> Option<String> {
63 infer_action_category(name)
64}
65
66#[proc_macro_derive(Action, attributes(action))]
98pub fn derive_action(input: TokenStream) -> TokenStream {
99 let input = parse_macro_input!(input as DeriveInput);
100
101 let opts = match ActionOpts::from_derive_input(&input) {
103 Ok(opts) => opts,
104 Err(e) => return e.write_errors().into(),
105 };
106
107 let name = &opts.ident;
108
109 let variants = match &opts.data {
110 darling::ast::Data::Enum(variants) => variants,
111 _ => {
112 return syn::Error::new_spanned(&input, "Action can only be derived for enums")
113 .to_compile_error()
114 .into();
115 }
116 };
117
118 let syn_variants = match &input.data {
120 syn::Data::Enum(data) => &data.variants,
121 _ => unreachable!(), };
123
124 let name_arms = variants.iter().map(|v| {
126 let variant_name = &v.ident;
127 let variant_str = variant_name.to_string();
128
129 match &v.fields.style {
130 darling::ast::Style::Unit => quote! {
131 #name::#variant_name => #variant_str
132 },
133 darling::ast::Style::Tuple => quote! {
134 #name::#variant_name(..) => #variant_str
135 },
136 darling::ast::Style::Struct => quote! {
137 #name::#variant_name { .. } => #variant_str
138 },
139 }
140 });
141
142 let params_arms = syn_variants.iter().map(|v| {
144 let variant_name = &v.ident;
145
146 match &v.fields {
147 syn::Fields::Unit => quote! {
148 #name::#variant_name => ::std::string::String::new()
149 },
150 syn::Fields::Unnamed(fields) => {
151 let field_count = fields.unnamed.len();
152 let field_names: Vec<_> =
153 (0..field_count).map(|i| format_ident!("_{}", i)).collect();
154 if field_count == 1 {
155 quote! {
156 #name::#variant_name(#(#field_names),*) => {
157 tui_dispatch::debug::debug_string(&#(#field_names),*)
158 }
159 }
160 } else {
161 let parts = field_names.iter().map(|field| {
162 quote! { tui_dispatch::debug::debug_string(&#field) }
163 });
164 quote! {
165 #name::#variant_name(#(#field_names),*) => {
166 let values = ::std::vec![#(#parts),*];
167 format!("({})", values.join(", "))
168 }
169 }
170 }
171 }
172 syn::Fields::Named(fields) => {
173 let field_names: Vec<_> = fields
174 .named
175 .iter()
176 .filter_map(|f| f.ident.as_ref())
177 .collect();
178 if field_names.is_empty() {
179 quote! {
180 #name::#variant_name { .. } => ::std::string::String::new()
181 }
182 } else {
183 let parts = field_names.iter().map(|field| {
184 let label = field.to_string();
185 quote! {
186 format!("{}: {}", #label, tui_dispatch::debug::debug_string(&#field))
187 }
188 });
189 quote! {
190 #name::#variant_name { #(#field_names),*, .. } => {
191 let values = ::std::vec![#(#parts),*];
192 format!("{{{}}}", values.join(", "))
193 }
194 }
195 }
196 }
197 }
198 });
199
200 let params_pretty_arms = syn_variants.iter().map(|v| {
201 let variant_name = &v.ident;
202
203 match &v.fields {
204 syn::Fields::Unit => quote! {
205 #name::#variant_name => ::std::string::String::new()
206 },
207 syn::Fields::Unnamed(fields) => {
208 let field_count = fields.unnamed.len();
209 let field_names: Vec<_> =
210 (0..field_count).map(|i| format_ident!("_{}", i)).collect();
211 if field_count == 1 {
212 quote! {
213 #name::#variant_name(#(#field_names),*) => {
214 tui_dispatch::debug::debug_string_pretty(&#(#field_names),*)
215 }
216 }
217 } else {
218 let parts = field_names.iter().map(|field| {
219 quote! { tui_dispatch::debug::debug_string_pretty(&#field) }
220 });
221 quote! {
222 #name::#variant_name(#(#field_names),*) => {
223 let values = ::std::vec![#(#parts),*];
224 format!("({})", values.join(", "))
225 }
226 }
227 }
228 }
229 syn::Fields::Named(fields) => {
230 let field_names: Vec<_> = fields
231 .named
232 .iter()
233 .filter_map(|f| f.ident.as_ref())
234 .collect();
235 if field_names.is_empty() {
236 quote! {
237 #name::#variant_name { .. } => ::std::string::String::new()
238 }
239 } else {
240 let parts = field_names.iter().map(|field| {
241 let label = field.to_string();
242 quote! {
243 format!("{}: {}", #label, tui_dispatch::debug::debug_string_pretty(&#field))
244 }
245 });
246 quote! {
247 #name::#variant_name { #(#field_names),*, .. } => {
248 let values = ::std::vec![#(#parts),*];
249 format!("{{{}}}", values.join(", "))
250 }
251 }
252 }
253 }
254 }
255 });
256
257 let mut expanded = quote! {
258 impl tui_dispatch::Action for #name {
259 fn name(&self) -> &'static str {
260 match self {
261 #(#name_arms),*
262 }
263 }
264 }
265
266 impl tui_dispatch::ActionParams for #name {
267 fn params(&self) -> ::std::string::String {
268 match self {
269 #(#params_arms),*
270 }
271 }
272
273 fn params_pretty(&self) -> ::std::string::String {
274 match self {
275 #(#params_pretty_arms),*
276 }
277 }
278 }
279 };
280
281 if opts.infer_categories {
283 let mut categories: HashMap<String, Vec<&Ident>> = HashMap::new();
285 let mut variant_categories: Vec<(&Ident, Option<String>)> = Vec::new();
286
287 for v in variants.iter() {
288 let cat = if v.skip_category {
289 None
290 } else if let Some(ref explicit_cat) = v.category {
291 Some(explicit_cat.clone())
292 } else {
293 infer_category(&v.ident.to_string())
294 };
295
296 variant_categories.push((&v.ident, cat.clone()));
297
298 if let Some(ref category) = cat {
299 categories
300 .entry(category.clone())
301 .or_default()
302 .push(&v.ident);
303 }
304 }
305
306 let mut sorted_categories: Vec<_> = categories.keys().cloned().collect();
308 sorted_categories.sort();
309
310 let category_arms_dedup: Vec<_> = variant_categories
312 .iter()
313 .map(|(variant, cat)| {
314 let cat_expr = match cat {
315 Some(c) => quote! { ::core::option::Option::Some(#c) },
316 None => quote! { ::core::option::Option::None },
317 };
318 quote! { #name::#variant { .. } => #cat_expr }
320 })
321 .collect();
322
323 let category_enum_name = format_ident!("{}Category", name);
325 let category_variants: Vec<_> = sorted_categories
326 .iter()
327 .map(|c| format_ident!("{}", to_pascal_case(c)))
328 .collect();
329 let category_variant_names: Vec<_> = sorted_categories.clone();
330
331 let category_enum_arms: Vec<_> = variant_categories
333 .iter()
334 .map(|(variant, cat)| {
335 let cat_variant = match cat {
336 Some(c) => format_ident!("{}", to_pascal_case(c)),
337 None => format_ident!("Uncategorized"),
338 };
339 quote! { #name::#variant { .. } => #category_enum_name::#cat_variant }
340 })
341 .collect();
342
343 let predicates: Vec<_> = sorted_categories
345 .iter()
346 .map(|cat| {
347 let predicate_name = format_ident!("is_{}", cat);
348 let cat_variants = categories.get(cat).unwrap();
349 let patterns: Vec<_> = cat_variants
350 .iter()
351 .map(|v| quote! { #name::#v { .. } })
352 .collect();
353 let doc = format!(
354 "Returns true if this action belongs to the `{}` category.",
355 cat
356 );
357
358 quote! {
359 #[doc = #doc]
360 pub fn #predicate_name(&self) -> bool {
361 matches!(self, #(#patterns)|*)
362 }
363 }
364 })
365 .collect();
366
367 let category_enum_doc = format!(
369 "Action categories for [`{}`].\n\n\
370 Use [`{}::category_enum()`] to get the category of an action.",
371 name, name
372 );
373
374 expanded = quote! {
375 #expanded
376
377 #[doc = #category_enum_doc]
378 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
379 pub enum #category_enum_name {
380 #(#category_variants,)*
381 Uncategorized,
383 }
384
385 impl #category_enum_name {
386 pub fn all() -> &'static [Self] {
388 &[#(Self::#category_variants,)* Self::Uncategorized]
389 }
390
391 pub fn name(&self) -> &'static str {
393 match self {
394 #(Self::#category_variants => #category_variant_names,)*
395 Self::Uncategorized => "uncategorized",
396 }
397 }
398 }
399
400 impl #name {
401 pub fn category(&self) -> ::core::option::Option<&'static str> {
403 match self {
404 #(#category_arms_dedup,)*
405 }
406 }
407
408 pub fn category_enum(&self) -> #category_enum_name {
410 match self {
411 #(#category_enum_arms,)*
412 }
413 }
414
415 #(#predicates)*
416 }
417
418 impl tui_dispatch::ActionCategory for #name {
419 type Category = #category_enum_name;
420
421 fn category(&self) -> ::core::option::Option<&'static str> {
422 #name::category(self)
423 }
424
425 fn category_enum(&self) -> Self::Category {
426 #name::category_enum(self)
427 }
428 }
429 };
430
431 if opts.generate_dispatcher {
433 let dispatcher_trait_name = format_ident!("{}Dispatcher", name);
434
435 let dispatch_methods: Vec<_> = sorted_categories
436 .iter()
437 .map(|cat| {
438 let method_name = format_ident!("dispatch_{}", cat);
439 let doc = format!("Handle actions in the `{}` category.", cat);
440 quote! {
441 #[doc = #doc]
442 fn #method_name(&mut self, action: &#name) -> bool {
443 false
444 }
445 }
446 })
447 .collect();
448
449 let dispatch_arms: Vec<_> = sorted_categories
450 .iter()
451 .map(|cat| {
452 let method_name = format_ident!("dispatch_{}", cat);
453 let cat_variant = format_ident!("{}", to_pascal_case(cat));
454 quote! {
455 #category_enum_name::#cat_variant => self.#method_name(action)
456 }
457 })
458 .collect();
459
460 let dispatcher_doc = format!(
461 "Dispatcher trait for [`{}`].\n\n\
462 Implement the `dispatch_*` methods for each category you want to handle.\n\
463 The [`dispatch()`](Self::dispatch) method automatically routes to the correct handler.",
464 name
465 );
466
467 expanded = quote! {
468 #expanded
469
470 #[doc = #dispatcher_doc]
471 pub trait #dispatcher_trait_name {
472 #(#dispatch_methods)*
473
474 fn dispatch_uncategorized(&mut self, action: &#name) -> bool {
476 false
477 }
478
479 fn dispatch(&mut self, action: &#name) -> bool {
481 match action.category_enum() {
482 #(#dispatch_arms,)*
483 #category_enum_name::Uncategorized => self.dispatch_uncategorized(action),
484 }
485 }
486 }
487 };
488 }
489 }
490
491 TokenStream::from(expanded)
492}
493
494#[proc_macro_derive(BindingContext)]
513pub fn derive_binding_context(input: TokenStream) -> TokenStream {
514 let input = parse_macro_input!(input as DeriveInput);
515 let name = &input.ident;
516
517 let expanded = match &input.data {
518 syn::Data::Enum(data) => {
519 for variant in &data.variants {
521 if !matches!(variant.fields, syn::Fields::Unit) {
522 return syn::Error::new_spanned(
523 variant,
524 "BindingContext can only be derived for enums with unit variants",
525 )
526 .to_compile_error()
527 .into();
528 }
529 }
530
531 let variant_names: Vec<_> = data.variants.iter().map(|v| &v.ident).collect();
532 let variant_strings: Vec<_> = variant_names
533 .iter()
534 .map(|v| to_snake_case(&v.to_string()))
535 .collect();
536
537 let name_arms = variant_names
538 .iter()
539 .zip(variant_strings.iter())
540 .map(|(v, s)| {
541 quote! { #name::#v => #s }
542 });
543
544 let from_name_arms = variant_names
545 .iter()
546 .zip(variant_strings.iter())
547 .map(|(v, s)| {
548 quote! { #s => ::core::option::Option::Some(#name::#v) }
549 });
550
551 let all_variants = variant_names.iter().map(|v| quote! { #name::#v });
552
553 quote! {
554 impl tui_dispatch::BindingContext for #name {
555 fn name(&self) -> &'static str {
556 match self {
557 #(#name_arms),*
558 }
559 }
560
561 fn from_name(name: &str) -> ::core::option::Option<Self> {
562 match name {
563 #(#from_name_arms,)*
564 _ => ::core::option::Option::None,
565 }
566 }
567
568 fn all() -> &'static [Self] {
569 static ALL: &[#name] = &[#(#all_variants),*];
570 ALL
571 }
572 }
573 }
574 }
575 _ => {
576 return syn::Error::new_spanned(input, "BindingContext can only be derived for enums")
577 .to_compile_error()
578 .into();
579 }
580 };
581
582 TokenStream::from(expanded)
583}
584
585#[proc_macro_derive(ComponentId)]
601pub fn derive_component_id(input: TokenStream) -> TokenStream {
602 let input = parse_macro_input!(input as DeriveInput);
603 let name = &input.ident;
604
605 let expanded = match &input.data {
606 syn::Data::Enum(data) => {
607 for variant in &data.variants {
609 if !matches!(variant.fields, syn::Fields::Unit) {
610 return syn::Error::new_spanned(
611 variant,
612 "ComponentId can only be derived for enums with unit variants",
613 )
614 .to_compile_error()
615 .into();
616 }
617 }
618
619 let variant_names: Vec<_> = data.variants.iter().map(|v| &v.ident).collect();
620 let variant_strings: Vec<_> = variant_names.iter().map(|v| v.to_string()).collect();
621
622 let name_arms = variant_names
623 .iter()
624 .zip(variant_strings.iter())
625 .map(|(v, s)| {
626 quote! { #name::#v => #s }
627 });
628
629 quote! {
630 impl tui_dispatch::ComponentId for #name {
631 fn name(&self) -> &'static str {
632 match self {
633 #(#name_arms),*
634 }
635 }
636 }
637 }
638 }
639 _ => {
640 return syn::Error::new_spanned(input, "ComponentId can only be derived for enums")
641 .to_compile_error()
642 .into();
643 }
644 };
645
646 TokenStream::from(expanded)
647}
648
649#[derive(Debug, FromDeriveInput)]
655#[darling(attributes(debug_state), supports(struct_named))]
656struct DebugStateOpts {
657 ident: syn::Ident,
658 data: darling::ast::Data<(), DebugStateField>,
659}
660
661#[derive(Debug, FromField)]
663#[darling(attributes(debug))]
664struct DebugStateField {
665 ident: Option<syn::Ident>,
666
667 #[darling(default)]
669 section: Option<String>,
670
671 #[darling(default)]
673 skip: bool,
674
675 #[darling(default)]
677 format: Option<String>,
678
679 #[darling(default)]
681 label: Option<String>,
682
683 #[darling(default)]
685 debug_fmt: bool,
686}
687
688#[proc_macro_derive(DebugState, attributes(debug, debug_state))]
726pub fn derive_debug_state(input: TokenStream) -> TokenStream {
727 let input = parse_macro_input!(input as DeriveInput);
728
729 let opts = match DebugStateOpts::from_derive_input(&input) {
730 Ok(opts) => opts,
731 Err(e) => return e.write_errors().into(),
732 };
733
734 let name = &opts.ident;
735 let default_section = name.to_string();
736
737 let fields = match &opts.data {
738 darling::ast::Data::Struct(fields) => fields,
739 _ => {
740 return syn::Error::new_spanned(&input, "DebugState can only be derived for structs")
741 .to_compile_error()
742 .into();
743 }
744 };
745
746 let mut sections: HashMap<String, Vec<&DebugStateField>> = HashMap::new();
748 let mut section_order: Vec<String> = Vec::new();
749
750 for field in fields.iter() {
751 if field.skip {
752 continue;
753 }
754
755 let section_name = field
756 .section
757 .clone()
758 .unwrap_or_else(|| default_section.clone());
759
760 if !section_order.contains(§ion_name) {
761 section_order.push(section_name.clone());
762 }
763
764 sections.entry(section_name).or_default().push(field);
765 }
766
767 let section_code: Vec<_> = section_order
769 .iter()
770 .map(|section_name| {
771 let fields_in_section = sections.get(section_name).unwrap();
772
773 let entry_calls: Vec<_> = fields_in_section
774 .iter()
775 .filter_map(|field| {
776 let field_ident = field.ident.as_ref()?;
777 let label = field
778 .label
779 .clone()
780 .unwrap_or_else(|| field_ident.to_string());
781
782 let value_expr = if let Some(ref fmt) = field.format {
783 quote! { format!(#fmt, self.#field_ident) }
784 } else if field.debug_fmt {
785 quote! { format!("{:?}", self.#field_ident) }
786 } else {
787 quote! { tui_dispatch::debug::debug_string(&self.#field_ident) }
788 };
789
790 Some(quote! {
791 .entry(#label, #value_expr)
792 })
793 })
794 .collect();
795
796 quote! {
797 tui_dispatch::debug::DebugSection::new(#section_name)
798 #(#entry_calls)*
799 }
800 })
801 .collect();
802
803 let expanded = quote! {
804 impl tui_dispatch::debug::DebugState for #name {
805 fn debug_sections(&self) -> ::std::vec::Vec<tui_dispatch::debug::DebugSection> {
806 ::std::vec![
807 #(#section_code),*
808 ]
809 }
810 }
811 };
812
813 TokenStream::from(expanded)
814}
815
816#[derive(Debug, FromField)]
822#[darling(attributes(flag))]
823struct FeatureFlagsField {
824 ident: Option<syn::Ident>,
825 ty: syn::Type,
826
827 #[darling(default)]
829 default: Option<bool>,
830}
831
832#[derive(Debug, FromDeriveInput)]
834#[darling(attributes(feature_flags), supports(struct_named))]
835struct FeatureFlagsOpts {
836 ident: syn::Ident,
837 data: darling::ast::Data<(), FeatureFlagsField>,
838}
839
840#[proc_macro_derive(FeatureFlags, attributes(flag, feature_flags))]
871pub fn derive_feature_flags(input: TokenStream) -> TokenStream {
872 let input = parse_macro_input!(input as DeriveInput);
873
874 let opts = match FeatureFlagsOpts::from_derive_input(&input) {
875 Ok(opts) => opts,
876 Err(e) => return e.write_errors().into(),
877 };
878
879 let name = &opts.ident;
880
881 let fields = match &opts.data {
882 darling::ast::Data::Struct(fields) => fields,
883 _ => {
884 return syn::Error::new_spanned(
885 &input,
886 "FeatureFlags can only be derived for structs with named fields",
887 )
888 .to_compile_error()
889 .into();
890 }
891 };
892
893 let bool_fields: Vec<_> = fields
895 .iter()
896 .filter_map(|f| {
897 let ident = f.ident.as_ref()?;
898 if let syn::Type::Path(type_path) = &f.ty {
900 if type_path.path.is_ident("bool") {
901 return Some((ident.clone(), f.default.unwrap_or(false)));
902 }
903 }
904 None
905 })
906 .collect();
907
908 if bool_fields.is_empty() {
909 return syn::Error::new_spanned(
910 &input,
911 "FeatureFlags struct must have at least one bool field",
912 )
913 .to_compile_error()
914 .into();
915 }
916
917 let is_enabled_arms: Vec<_> = bool_fields
919 .iter()
920 .map(|(ident, _)| {
921 let name_str = ident.to_string();
922 quote! { #name_str => ::core::option::Option::Some(self.#ident) }
923 })
924 .collect();
925
926 let set_arms: Vec<_> = bool_fields
928 .iter()
929 .map(|(ident, _)| {
930 let name_str = ident.to_string();
931 quote! {
932 #name_str => {
933 self.#ident = enabled;
934 true
935 }
936 }
937 })
938 .collect();
939
940 let flag_names: Vec<_> = bool_fields
942 .iter()
943 .map(|(ident, _)| ident.to_string())
944 .collect();
945
946 let default_fields: Vec<_> = bool_fields
948 .iter()
949 .map(|(ident, default)| {
950 quote! { #ident: #default }
951 })
952 .collect();
953
954 let expanded = quote! {
955 impl tui_dispatch::FeatureFlags for #name {
956 fn is_enabled(&self, name: &str) -> ::core::option::Option<bool> {
957 match name {
958 #(#is_enabled_arms,)*
959 _ => ::core::option::Option::None,
960 }
961 }
962
963 fn set(&mut self, name: &str, enabled: bool) -> bool {
964 match name {
965 #(#set_arms)*
966 _ => false,
967 }
968 }
969
970 fn all_flags() -> &'static [&'static str] {
971 &[#(#flag_names),*]
972 }
973 }
974
975 impl ::core::default::Default for #name {
976 fn default() -> Self {
977 Self {
978 #(#default_fields,)*
979 }
980 }
981 }
982 };
983
984 TokenStream::from(expanded)
985}
986
987#[cfg(test)]
988mod tests {
989 use super::*;
990
991 #[test]
992 fn test_to_snake_case_handles_acronyms() {
993 assert_eq!(to_snake_case("APIFetch"), "api_fetch");
994 assert_eq!(to_snake_case("HTTPResult"), "http_result");
995 }
996
997 #[test]
998 fn test_infer_category_handles_acronyms() {
999 assert_eq!(infer_category("APIFetchStart"), Some("api".to_string()));
1000 assert_eq!(
1001 infer_category("SearchHTTPStart"),
1002 Some("search_http".to_string())
1003 );
1004 }
1005}