1use darling::{FromDeriveInput, FromField, FromVariant};
4use proc_macro::TokenStream;
5use proc_macro2::Ident;
6use quote::{format_ident, quote};
7use std::collections::HashMap;
8use syn::{parse_macro_input, DeriveInput};
9
10#[derive(Debug, FromDeriveInput)]
12#[darling(attributes(action), supports(enum_any))]
13struct ActionOpts {
14 ident: syn::Ident,
15 data: darling::ast::Data<ActionVariant, ()>,
16
17 #[darling(default)]
19 infer_categories: bool,
20
21 #[darling(default)]
23 generate_dispatcher: bool,
24}
25
26#[derive(Debug, FromVariant)]
28#[darling(attributes(action))]
29struct ActionVariant {
30 ident: syn::Ident,
31 fields: darling::ast::Fields<()>,
32
33 #[darling(default)]
35 category: Option<String>,
36
37 #[darling(default)]
39 skip_category: bool,
40}
41
42const ACTION_VERBS: &[&str] = &[
46 "Start", "End", "Open", "Close", "Submit", "Confirm", "Cancel", "Next", "Prev", "Up", "Down", "Left", "Right", "Enter", "Exit", "Escape",
49 "Add", "Remove", "Clear", "Update", "Set", "Get", "Load", "Save", "Delete", "Create",
51 "Show", "Hide", "Enable", "Disable", "Toggle", "Focus", "Blur", "Select", "Move", "Copy", "Cycle", "Reset", "Scroll",
55];
56
57fn split_pascal_case(s: &str) -> Vec<String> {
59 let mut parts = Vec::new();
60 let mut current = String::new();
61
62 for ch in s.chars() {
63 if ch.is_uppercase() && !current.is_empty() {
64 parts.push(current);
65 current = String::new();
66 }
67 current.push(ch);
68 }
69 if !current.is_empty() {
70 parts.push(current);
71 }
72 parts
73}
74
75fn to_snake_case(s: &str) -> String {
77 let mut result = String::new();
78 for (i, ch) in s.chars().enumerate() {
79 if ch.is_uppercase() {
80 if i > 0 {
81 result.push('_');
82 }
83 result.push(ch.to_lowercase().next().unwrap());
84 } else {
85 result.push(ch);
86 }
87 }
88 result
89}
90
91fn to_pascal_case(s: &str) -> String {
93 s.split('_')
94 .map(|part| {
95 let mut chars = part.chars();
96 match chars.next() {
97 None => String::new(),
98 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
99 }
100 })
101 .collect()
102}
103
104fn infer_category(name: &str) -> Option<String> {
106 let parts = split_pascal_case(name);
107 if parts.is_empty() {
108 return None;
109 }
110
111 if parts[0] == "Did" {
113 return Some("async_result".to_string());
114 }
115
116 if parts.len() < 2 {
118 return None;
119 }
120
121 let first_is_verb = ACTION_VERBS.contains(&parts[0].as_str());
127
128 let mut prefix_end = parts.len();
129 let mut found_verb = false;
130 for (i, part) in parts.iter().enumerate().skip(1) {
131 if ACTION_VERBS.contains(&part.as_str()) {
132 prefix_end = i;
133 found_verb = true;
134 break;
135 }
136 }
137
138 if first_is_verb {
142 return None;
143 }
144
145 if !found_verb {
147 return None;
148 }
149
150 if prefix_end == 0 {
151 return None;
152 }
153
154 let prefix_parts: Vec<&str> = parts[..prefix_end].iter().map(|s| s.as_str()).collect();
155 let prefix = prefix_parts.join("");
156
157 Some(to_snake_case(&prefix))
158}
159
160#[proc_macro_derive(Action, attributes(action))]
192pub fn derive_action(input: TokenStream) -> TokenStream {
193 let input = parse_macro_input!(input as DeriveInput);
194
195 let opts = match ActionOpts::from_derive_input(&input) {
197 Ok(opts) => opts,
198 Err(e) => return e.write_errors().into(),
199 };
200
201 let name = &opts.ident;
202
203 let variants = match &opts.data {
204 darling::ast::Data::Enum(variants) => variants,
205 _ => {
206 return syn::Error::new_spanned(&input, "Action can only be derived for enums")
207 .to_compile_error()
208 .into();
209 }
210 };
211
212 let name_arms = variants.iter().map(|v| {
214 let variant_name = &v.ident;
215 let variant_str = variant_name.to_string();
216
217 match &v.fields.style {
218 darling::ast::Style::Unit => quote! {
219 #name::#variant_name => #variant_str
220 },
221 darling::ast::Style::Tuple => quote! {
222 #name::#variant_name(..) => #variant_str
223 },
224 darling::ast::Style::Struct => quote! {
225 #name::#variant_name { .. } => #variant_str
226 },
227 }
228 });
229
230 let mut expanded = quote! {
231 impl tui_dispatch::Action for #name {
232 fn name(&self) -> &'static str {
233 match self {
234 #(#name_arms),*
235 }
236 }
237 }
238 };
239
240 if opts.infer_categories {
242 let mut categories: HashMap<String, Vec<&Ident>> = HashMap::new();
244 let mut variant_categories: Vec<(&Ident, Option<String>)> = Vec::new();
245
246 for v in variants.iter() {
247 let cat = if v.skip_category {
248 None
249 } else if let Some(ref explicit_cat) = v.category {
250 Some(explicit_cat.clone())
251 } else {
252 infer_category(&v.ident.to_string())
253 };
254
255 variant_categories.push((&v.ident, cat.clone()));
256
257 if let Some(ref category) = cat {
258 categories
259 .entry(category.clone())
260 .or_default()
261 .push(&v.ident);
262 }
263 }
264
265 let mut sorted_categories: Vec<_> = categories.keys().cloned().collect();
267 sorted_categories.sort();
268
269 let category_arms_dedup: Vec<_> = variant_categories
271 .iter()
272 .map(|(variant, cat)| {
273 let cat_expr = match cat {
274 Some(c) => quote! { ::core::option::Option::Some(#c) },
275 None => quote! { ::core::option::Option::None },
276 };
277 quote! { #name::#variant { .. } => #cat_expr }
279 })
280 .collect();
281
282 let category_enum_name = format_ident!("{}Category", name);
284 let category_variants: Vec<_> = sorted_categories
285 .iter()
286 .map(|c| format_ident!("{}", to_pascal_case(c)))
287 .collect();
288 let category_variant_names: Vec<_> = sorted_categories.clone();
289
290 let category_enum_arms: Vec<_> = variant_categories
292 .iter()
293 .map(|(variant, cat)| {
294 let cat_variant = match cat {
295 Some(c) => format_ident!("{}", to_pascal_case(c)),
296 None => format_ident!("Uncategorized"),
297 };
298 quote! { #name::#variant { .. } => #category_enum_name::#cat_variant }
299 })
300 .collect();
301
302 let predicates: Vec<_> = sorted_categories
304 .iter()
305 .map(|cat| {
306 let predicate_name = format_ident!("is_{}", cat);
307 let cat_variants = categories.get(cat).unwrap();
308 let patterns: Vec<_> = cat_variants
309 .iter()
310 .map(|v| quote! { #name::#v { .. } })
311 .collect();
312 let doc = format!(
313 "Returns true if this action belongs to the `{}` category.",
314 cat
315 );
316
317 quote! {
318 #[doc = #doc]
319 pub fn #predicate_name(&self) -> bool {
320 matches!(self, #(#patterns)|*)
321 }
322 }
323 })
324 .collect();
325
326 let category_enum_doc = format!(
328 "Action categories for [`{}`].\n\n\
329 Use [`{}::category_enum()`] to get the category of an action.",
330 name, name
331 );
332
333 expanded = quote! {
334 #expanded
335
336 #[doc = #category_enum_doc]
337 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
338 pub enum #category_enum_name {
339 #(#category_variants,)*
340 Uncategorized,
342 }
343
344 impl #category_enum_name {
345 pub fn all() -> &'static [Self] {
347 &[#(Self::#category_variants,)* Self::Uncategorized]
348 }
349
350 pub fn name(&self) -> &'static str {
352 match self {
353 #(Self::#category_variants => #category_variant_names,)*
354 Self::Uncategorized => "uncategorized",
355 }
356 }
357 }
358
359 impl #name {
360 pub fn category(&self) -> ::core::option::Option<&'static str> {
362 match self {
363 #(#category_arms_dedup,)*
364 }
365 }
366
367 pub fn category_enum(&self) -> #category_enum_name {
369 match self {
370 #(#category_enum_arms,)*
371 }
372 }
373
374 #(#predicates)*
375 }
376
377 impl tui_dispatch::ActionCategory for #name {
378 type Category = #category_enum_name;
379
380 fn category(&self) -> ::core::option::Option<&'static str> {
381 #name::category(self)
382 }
383
384 fn category_enum(&self) -> Self::Category {
385 #name::category_enum(self)
386 }
387 }
388 };
389
390 if opts.generate_dispatcher {
392 let dispatcher_trait_name = format_ident!("{}Dispatcher", name);
393
394 let dispatch_methods: Vec<_> = sorted_categories
395 .iter()
396 .map(|cat| {
397 let method_name = format_ident!("dispatch_{}", cat);
398 let doc = format!("Handle actions in the `{}` category.", cat);
399 quote! {
400 #[doc = #doc]
401 fn #method_name(&mut self, action: &#name) -> bool {
402 false
403 }
404 }
405 })
406 .collect();
407
408 let dispatch_arms: Vec<_> = sorted_categories
409 .iter()
410 .map(|cat| {
411 let method_name = format_ident!("dispatch_{}", cat);
412 let cat_variant = format_ident!("{}", to_pascal_case(cat));
413 quote! {
414 #category_enum_name::#cat_variant => self.#method_name(action)
415 }
416 })
417 .collect();
418
419 let dispatcher_doc = format!(
420 "Dispatcher trait for [`{}`].\n\n\
421 Implement the `dispatch_*` methods for each category you want to handle.\n\
422 The [`dispatch()`](Self::dispatch) method automatically routes to the correct handler.",
423 name
424 );
425
426 expanded = quote! {
427 #expanded
428
429 #[doc = #dispatcher_doc]
430 pub trait #dispatcher_trait_name {
431 #(#dispatch_methods)*
432
433 fn dispatch_uncategorized(&mut self, action: &#name) -> bool {
435 false
436 }
437
438 fn dispatch(&mut self, action: &#name) -> bool {
440 match action.category_enum() {
441 #(#dispatch_arms,)*
442 #category_enum_name::Uncategorized => self.dispatch_uncategorized(action),
443 }
444 }
445 }
446 };
447 }
448 }
449
450 TokenStream::from(expanded)
451}
452
453#[proc_macro_derive(BindingContext)]
472pub fn derive_binding_context(input: TokenStream) -> TokenStream {
473 let input = parse_macro_input!(input as DeriveInput);
474 let name = &input.ident;
475
476 let expanded = match &input.data {
477 syn::Data::Enum(data) => {
478 for variant in &data.variants {
480 if !matches!(variant.fields, syn::Fields::Unit) {
481 return syn::Error::new_spanned(
482 variant,
483 "BindingContext can only be derived for enums with unit variants",
484 )
485 .to_compile_error()
486 .into();
487 }
488 }
489
490 let variant_names: Vec<_> = data.variants.iter().map(|v| &v.ident).collect();
491 let variant_strings: Vec<_> = variant_names
492 .iter()
493 .map(|v| to_snake_case(&v.to_string()))
494 .collect();
495
496 let name_arms = variant_names
497 .iter()
498 .zip(variant_strings.iter())
499 .map(|(v, s)| {
500 quote! { #name::#v => #s }
501 });
502
503 let from_name_arms = variant_names
504 .iter()
505 .zip(variant_strings.iter())
506 .map(|(v, s)| {
507 quote! { #s => ::core::option::Option::Some(#name::#v) }
508 });
509
510 let all_variants = variant_names.iter().map(|v| quote! { #name::#v });
511
512 quote! {
513 impl tui_dispatch::BindingContext for #name {
514 fn name(&self) -> &'static str {
515 match self {
516 #(#name_arms),*
517 }
518 }
519
520 fn from_name(name: &str) -> ::core::option::Option<Self> {
521 match name {
522 #(#from_name_arms,)*
523 _ => ::core::option::Option::None,
524 }
525 }
526
527 fn all() -> &'static [Self] {
528 static ALL: &[#name] = &[#(#all_variants),*];
529 ALL
530 }
531 }
532 }
533 }
534 _ => {
535 return syn::Error::new_spanned(input, "BindingContext can only be derived for enums")
536 .to_compile_error()
537 .into();
538 }
539 };
540
541 TokenStream::from(expanded)
542}
543
544#[proc_macro_derive(ComponentId)]
560pub fn derive_component_id(input: TokenStream) -> TokenStream {
561 let input = parse_macro_input!(input as DeriveInput);
562 let name = &input.ident;
563
564 let expanded = match &input.data {
565 syn::Data::Enum(data) => {
566 for variant in &data.variants {
568 if !matches!(variant.fields, syn::Fields::Unit) {
569 return syn::Error::new_spanned(
570 variant,
571 "ComponentId can only be derived for enums with unit variants",
572 )
573 .to_compile_error()
574 .into();
575 }
576 }
577
578 let variant_names: Vec<_> = data.variants.iter().map(|v| &v.ident).collect();
579 let variant_strings: Vec<_> = variant_names.iter().map(|v| v.to_string()).collect();
580
581 let name_arms = variant_names
582 .iter()
583 .zip(variant_strings.iter())
584 .map(|(v, s)| {
585 quote! { #name::#v => #s }
586 });
587
588 quote! {
589 impl tui_dispatch::ComponentId for #name {
590 fn name(&self) -> &'static str {
591 match self {
592 #(#name_arms),*
593 }
594 }
595 }
596 }
597 }
598 _ => {
599 return syn::Error::new_spanned(input, "ComponentId can only be derived for enums")
600 .to_compile_error()
601 .into();
602 }
603 };
604
605 TokenStream::from(expanded)
606}
607
608#[derive(Debug, FromDeriveInput)]
614#[darling(attributes(debug_state), supports(struct_named))]
615struct DebugStateOpts {
616 ident: syn::Ident,
617 data: darling::ast::Data<(), DebugStateField>,
618}
619
620#[derive(Debug, FromField)]
622#[darling(attributes(debug))]
623struct DebugStateField {
624 ident: Option<syn::Ident>,
625
626 #[darling(default)]
628 section: Option<String>,
629
630 #[darling(default)]
632 skip: bool,
633
634 #[darling(default)]
636 format: Option<String>,
637
638 #[darling(default)]
640 label: Option<String>,
641
642 #[darling(default)]
644 debug_fmt: bool,
645}
646
647#[proc_macro_derive(DebugState, attributes(debug, debug_state))]
685pub fn derive_debug_state(input: TokenStream) -> TokenStream {
686 let input = parse_macro_input!(input as DeriveInput);
687
688 let opts = match DebugStateOpts::from_derive_input(&input) {
689 Ok(opts) => opts,
690 Err(e) => return e.write_errors().into(),
691 };
692
693 let name = &opts.ident;
694 let default_section = name.to_string();
695
696 let fields = match &opts.data {
697 darling::ast::Data::Struct(fields) => fields,
698 _ => {
699 return syn::Error::new_spanned(&input, "DebugState can only be derived for structs")
700 .to_compile_error()
701 .into();
702 }
703 };
704
705 let mut sections: HashMap<String, Vec<&DebugStateField>> = HashMap::new();
707 let mut section_order: Vec<String> = Vec::new();
708
709 for field in fields.iter() {
710 if field.skip {
711 continue;
712 }
713
714 let section_name = field
715 .section
716 .clone()
717 .unwrap_or_else(|| default_section.clone());
718
719 if !section_order.contains(§ion_name) {
720 section_order.push(section_name.clone());
721 }
722
723 sections.entry(section_name).or_default().push(field);
724 }
725
726 let section_code: Vec<_> = section_order
728 .iter()
729 .map(|section_name| {
730 let fields_in_section = sections.get(section_name).unwrap();
731
732 let entry_calls: Vec<_> = fields_in_section
733 .iter()
734 .filter_map(|field| {
735 let field_ident = field.ident.as_ref()?;
736 let label = field
737 .label
738 .clone()
739 .unwrap_or_else(|| field_ident.to_string());
740
741 let value_expr = if let Some(ref fmt) = field.format {
742 quote! { format!(#fmt, self.#field_ident) }
743 } else if field.debug_fmt {
744 quote! { format!("{:?}", self.#field_ident) }
745 } else {
746 quote! { self.#field_ident.to_string() }
747 };
748
749 Some(quote! {
750 .entry(#label, #value_expr)
751 })
752 })
753 .collect();
754
755 quote! {
756 tui_dispatch::debug::DebugSection::new(#section_name)
757 #(#entry_calls)*
758 }
759 })
760 .collect();
761
762 let expanded = quote! {
763 impl tui_dispatch::debug::DebugState for #name {
764 fn debug_sections(&self) -> ::std::vec::Vec<tui_dispatch::debug::DebugSection> {
765 ::std::vec![
766 #(#section_code),*
767 ]
768 }
769 }
770 };
771
772 TokenStream::from(expanded)
773}