Skip to main content

tusks_lib/codegen/handle_matches/arms/
function.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3
4use crate::codegen::module_path::ModulePath;
5use crate::codegen::util::enum_util::to_variant_ident;
6use crate::codegen::util::field_util::is_generated_field;
7use crate::{TusksModule, models::Tusk};
8
9impl TusksModule {
10    /// Generates a match arm for a command function.
11    ///
12    /// For `fn my_func(arg1: String, arg2: i32)` this produces:
13    /// ```ignore
14    /// Some(cli::Commands::my_func { arg1: p1, arg2: p2 }) => {
15    ///     super::my_func(p1.clone(), p2.clone());
16    /// }
17    /// ```
18    pub fn build_function_match_arm(
19        &self,
20        tusk: &Tusk,
21        cli_path: &TokenStream,
22        path: &ModulePath,
23    ) -> TokenStream {
24        let variant_ident = to_variant_ident(&tusk.func.sig.ident);
25        let pattern_bindings = self.build_pattern_bindings(tusk);
26        let pattern_fields = self.build_pattern_fields(&pattern_bindings);
27        let function_call = self.build_function_call(tusk, &pattern_bindings, path, false, false);
28
29        quote! {
30            Some(#cli_path::Commands::#variant_ident { #(#pattern_fields),* }) => {
31                #function_call
32            }
33        }
34    }
35
36    pub fn build_default_function_match_arm(
37        &self,
38        tusk: &Tusk,
39        path: &ModulePath,
40        is_external_subcommand_case: bool,
41    ) -> TokenStream {
42        let pattern_bindings = self.build_pattern_bindings(tusk);
43        let function_call = self.build_function_call(
44            tusk, &pattern_bindings, path, true, is_external_subcommand_case,
45        );
46
47        quote! {
48            None => {
49                #function_call
50            }
51        }
52    }
53
54    pub fn build_external_subcommand_match_arm(
55        &self,
56        tusk: &Tusk,
57        path: &ModulePath,
58    ) -> TokenStream {
59        let pattern_bindings = self.build_pattern_bindings(tusk);
60        let function_call = self.build_function_call(
61            tusk, &pattern_bindings, path, false, true,
62        );
63
64        let cli_path = path.cli_path();
65
66        quote! {
67            Some(#cli_path::Commands::ClapExternalSubcommand(external_subcommand_args)) => {
68                #function_call
69            }
70        }
71    }
72
73    fn build_function_call(
74        &self,
75        tusk: &Tusk,
76        pattern_bindings: &[(syn::Ident, syn::Ident)],
77        path: &ModulePath,
78        is_default_case: bool,
79        is_external_subcommand_case: bool,
80    ) -> TokenStream {
81        let func_args = self.build_function_arguments(
82            tusk, pattern_bindings, is_default_case, is_external_subcommand_case,
83        );
84        let func_name = &tusk.func.sig.ident;
85        let func_path = path.super_path_to(func_name);
86        let maybe_await = if tusk.is_async { quote! { .await } } else { quote! {} };
87
88        match &tusk.func.sig.output {
89            syn::ReturnType::Default => {
90                quote! { #func_path(#(#func_args),*)#maybe_await; None }
91            }
92            syn::ReturnType::Type(_, ty) => {
93                if Tusk::is_u8_type(ty) {
94                    quote! { Some(#func_path(#(#func_args),*)#maybe_await) }
95                } else if Tusk::is_option_u8_type(ty) {
96                    quote! { #func_path(#(#func_args),*)#maybe_await }
97                } else {
98                    quote! { None }
99                }
100            }
101        }
102    }
103
104    /// Creates bindings `[(field_name, p1), (field_name, p2), ...]`
105    /// for function parameters, skipping the first if it's `&Parameters`.
106    fn build_pattern_bindings(&self, tusk: &Tusk) -> Vec<(syn::Ident, syn::Ident)> {
107        let skip = if self.tusk_has_parameters_arg(tusk) { 1 } else { 0 };
108        let mut bindings = Vec::new();
109        let mut counter = 1;
110
111        for param in tusk.func.sig.inputs.iter().skip(skip) {
112            if let syn::FnArg::Typed(pat_type) = param {
113                if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
114                    let binding = syn::Ident::new(&format!("p{}", counter), Span::call_site());
115                    bindings.push((pat_ident.ident.clone(), binding));
116                    counter += 1;
117                }
118            }
119        }
120
121        bindings
122    }
123
124    /// Converts bindings to pattern fields `[field: p1, field: p2, ...]`,
125    /// filtering out generated fields.
126    pub fn build_pattern_fields(
127        &self,
128        pattern_bindings: &[(syn::Ident, syn::Ident)],
129    ) -> Vec<TokenStream> {
130        pattern_bindings
131            .iter()
132            .filter(|(name, _)| !is_generated_field(&name.to_string()))
133            .map(|(name, binding)| quote! { #name: #binding })
134            .collect()
135    }
136
137    fn build_function_arguments(
138        &self,
139        tusk: &Tusk,
140        pattern_bindings: &[(syn::Ident, syn::Ident)],
141        is_default_case: bool,
142        is_external_subcommand_case: bool,
143    ) -> Vec<TokenStream> {
144        let has_params_arg = self.tusk_has_parameters_arg(tusk);
145        let mut func_args = Vec::new();
146
147        let mut number_of_non_params_args = tusk.func.sig.inputs.len();
148        if has_params_arg {
149            func_args.push(quote! { &parameters });
150            number_of_non_params_args -= 1;
151        }
152
153        if is_default_case {
154            if is_external_subcommand_case {
155                func_args.push(quote! { Vec::new() });
156            }
157            return func_args;
158        }
159
160        if is_external_subcommand_case && number_of_non_params_args > 0 {
161            func_args.push(quote! { external_subcommand_args.clone() });
162            return func_args;
163        }
164
165        for (_, binding_name) in pattern_bindings {
166            func_args.push(quote! { #binding_name.clone() });
167        }
168
169        func_args
170    }
171}