tracel_xtask_macros/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::quote;
4use std::collections::HashMap;
5use syn::{
6    parse_macro_input, punctuated::Punctuated, token::Comma, ItemEnum, ItemStruct, Meta, Variant,
7};
8
9// Targets
10// =======
11
12fn generate_target_enum(input: TokenStream) -> TokenStream {
13    let item = parse_macro_input!(input as ItemEnum);
14    let enum_name = &item.ident;
15    let original_variants = &item.variants;
16
17    let output = quote! {
18        #[derive(strum::EnumString, strum::EnumIter, Default, strum::Display, Clone, PartialEq, clap::ValueEnum)]
19        #[strum(serialize_all = "lowercase")]
20        pub enum #enum_name {
21            #[doc = r"Targets all crates and examples using cargo --package."]
22            AllPackages,
23            #[doc = r"Targets all binary and library crates."]
24            Crates,
25            #[doc = r"Targets all example crates."]
26            Examples,
27            #[default]
28            #[doc = r"Targets the whole workspace using cargo --workspace."]
29            Workspace,
30            #original_variants
31        }
32    };
33    TokenStream::from(output)
34}
35
36fn generate_target_tryinto(_args: TokenStream, input: TokenStream) -> TokenStream {
37    let item = parse_macro_input!(input as ItemEnum);
38    let item_ident = &item.ident;
39    let tryinto = quote! {
40        impl std::convert::TryInto<tracel_xtask::commands::Target> for #item_ident {
41            type Error = anyhow::Error;
42            fn try_into(self) -> Result<tracel_xtask::commands::Target, Self::Error> {
43                match self {
44                    #item_ident::AllPackages => Ok(tracel_xtask::commands::Target::AllPackages),
45                    #item_ident::Crates => Ok(tracel_xtask::commands::Target::Crates),
46                    #item_ident::Examples => Ok(tracel_xtask::commands::Target::Examples),
47                    #item_ident::Workspace => Ok(tracel_xtask::commands::Target::Workspace),
48                    _ => Err(anyhow::anyhow!("{} target is not supported.", self))
49                }
50            }
51        }
52    };
53    TokenStream::from(tryinto)
54}
55
56#[proc_macro_attribute]
57pub fn declare_targets(_args: TokenStream, input: TokenStream) -> TokenStream {
58    generate_target_enum(input)
59}
60
61#[proc_macro_attribute]
62pub fn extend_targets(args: TokenStream, input: TokenStream) -> TokenStream {
63    let mut output = generate_target_enum(input);
64    output.extend(generate_target_tryinto(args, output.clone()));
65    output
66}
67
68// Commands
69// ========
70
71fn generate_dispatch_function(
72    enum_ident: &syn::Ident,
73    args: &Punctuated<Meta, Comma>,
74) -> TokenStream {
75    let arms: Vec<proc_macro2::TokenStream> = args.iter().map(|meta| {
76        let cmd_ident = meta.path().get_ident().unwrap();
77        let cmd_ident_string = cmd_ident.to_string();
78        let module_ident = syn::Ident::new(cmd_ident_string.to_lowercase().as_str(), cmd_ident.span());
79        match cmd_ident_string.as_str() {
80            "Fix" => quote! {
81                #enum_ident::#cmd_ident(cmd_args) => base_commands::#module_ident::handle_command(cmd_args, args.environment, args.context, None),
82            },
83            _ => quote! {
84                #enum_ident::#cmd_ident(cmd_args) => base_commands::#module_ident::handle_command(cmd_args, args.environment, args.context),
85            }
86        }
87    }).collect();
88    let func = quote! {
89        fn dispatch_base_commands(args: XtaskArgs<Command>) -> anyhow::Result<()> {
90            match args.command {
91                #(#arms)*
92                _ => Err(anyhow::anyhow!("Unknown command")),
93            }
94        }
95    };
96    TokenStream::from(func)
97}
98
99#[proc_macro_attribute]
100pub fn base_commands(args: TokenStream, input: TokenStream) -> TokenStream {
101    // Parse the input tokens into a syntax tree
102    let item = parse_macro_input!(input as ItemEnum);
103    let args = parse_macro_input!(args with Punctuated::<Meta, Comma>::parse_terminated);
104
105    // Supported commands and their quoted expansions
106    let mut variant_map: HashMap<&str, proc_macro2::TokenStream> = HashMap::new();
107    variant_map.insert(
108        "Build",
109        quote! {
110            #[doc = r"Build the code."]
111            Build(tracel_xtask::commands::build::BuildCmdArgs)
112        },
113    );
114    variant_map.insert(
115        "Bump",
116        quote! {
117            #[doc = r"Bump the version of all crates to be published."]
118            Bump(tracel_xtask::commands::bump::BumpCmdArgs)
119        },
120    );
121    variant_map.insert(
122        "Check",
123        quote! {
124            #[doc = r"Run checks like formatting, linting etc... This command only reports issues, use the 'fix' command to auto-fix issues."]
125            Check(tracel_xtask::commands::check::CheckCmdArgs)
126        },
127    );
128    variant_map.insert(
129        "Compile",
130        quote! {
131            #[doc = r"Compile check the code (does not write binaries to disk)."]
132            Compile(tracel_xtask::commands::compile::CompileCmdArgs)
133        },
134    );
135    variant_map.insert(
136        "Coverage",
137        quote! {
138            #[doc = r"Install and run coverage tools."]
139            Coverage(tracel_xtask::commands::coverage::CoverageCmdArgs)
140        },
141    );
142    variant_map.insert(
143        "Dependencies",
144        quote! {
145            #[doc = r"Run the specified dependencies check locally."]
146            Dependencies(tracel_xtask::commands::dependencies::DependenciesCmdArgs)
147        },
148    );
149    variant_map.insert(
150        "Doc",
151        quote! {
152            #[doc = r"Build documentation."]
153            Doc(tracel_xtask::commands::doc::DocCmdArgs)
154        },
155    );
156    variant_map.insert(
157        "Docker",
158        quote! {
159            #[doc = r"Manage docker compose stacks."]
160            Docker(tracel_xtask::commands::docker::DockerCmdArgs)
161        },
162    );
163    variant_map.insert(
164        "Fix",
165        quote! {
166            #[doc = r"Fix issues found with the 'check' command."]
167            Fix(tracel_xtask::commands::fix::FixCmdArgs)
168        },
169    );
170    variant_map.insert(
171        "Publish",
172        quote! {
173            #[doc = r"Publish a crate to crates.io."]
174            Publish(tracel_xtask::commands::publish::PublishCmdArgs)
175        },
176    );
177    variant_map.insert(
178        "Test",
179        quote! {
180            #[doc = r"Runs tests."]
181            Test(tracel_xtask::commands::test::TestCmdArgs)
182        },
183    );
184    variant_map.insert(
185        "Validate",
186        quote! {
187            #[doc = r"Validate the code base by running all the relevant checks and tests. Use this command before creating a new pull-request."]
188            Validate(tracel_xtask::commands::validate::ValidateCmdArgs)
189        },
190    );
191    variant_map.insert("Vulnerabilities", quote! {
192        #[doc = r"Run the specified vulnerability check locally. These commands must be called with 'cargo +nightly'."]
193        Vulnerabilities(tracel_xtask::commands::vulnerabilities::VulnerabilitiesCmdArgs)
194    });
195
196    // Generate the corresponding enum variant
197    let mut variants = vec![];
198    for arg in &args {
199        if let Meta::Path(path) = arg {
200            if let Some(ident) = path.get_ident() {
201                let ident_string = ident.to_string();
202                if let Some(variant) = variant_map.get(ident_string.as_str()) {
203                    variants.push(variant.clone());
204                } else {
205                    let err_msg = format!(
206                        "Unknown command: {}\nPossible commands are:\n  {}",
207                        ident_string,
208                        variant_map
209                            .keys()
210                            .cloned()
211                            .collect::<Vec<&str>>()
212                            .join("\n  "),
213                    );
214                    return TokenStream::from(quote! {
215                        compile_error!(#err_msg);
216                    });
217                }
218            }
219        }
220    }
221
222    // Generate the xtask commands enum
223    let enum_name = &item.ident;
224    let other_variants = &item.variants;
225    let mut output = TokenStream::from(quote! {
226        #[derive(clap::Subcommand)]
227        pub enum #enum_name {
228            #(#variants,)*
229            #other_variants
230        }
231    });
232    output.extend(generate_dispatch_function(enum_name, &args));
233    output
234}
235
236// Command arguments
237// =================
238
239fn get_additional_cmd_args_map() -> HashMap<&'static str, proc_macro2::TokenStream> {
240    HashMap::from([
241        (
242            "BuildCmdArgs",
243            quote! {
244                #[doc = r"Build artifacs in release mode."]
245                #[arg(short, long, required = false)]
246                pub release: bool,
247            },
248        ),
249        (
250            "CheckCmdArgs",
251            quote! {
252                #[doc = r"Ignore audit errors."]
253                #[arg(long = "ignore-audit", required = false)]
254                pub ignore_audit: bool,
255            },
256        ),
257        (
258            "DockerCmdArgs",
259            quote! {
260                #[doc = r"Build images before starting containers."]
261                #[arg(short, long, required = false)]
262                pub build: bool,
263                #[doc = r"Project name."]
264                #[arg(short, long, default_value = "xtask")]
265                pub project: String,
266                #[doc = r"Space separated list of service subset to start. If empty then launch all the services in the stack."]
267                #[arg(short, long, num_args(1..), required = false)]
268                pub services: Vec<String>,
269            },
270        ),
271        (
272            "TestCmdArgs",
273            quote! {
274                #[doc = r"Execute only the test whose name matches the passed string."]
275                #[arg(
276                    long = "test",
277                    value_name = "TEST",
278                    required = false
279                )]
280                pub test: Option<String>,
281                #[doc = r"Maximum number of parallel test crate compilations."]
282                #[arg(
283                    long = "compilation-jobs",
284                    value_name = "NUMBER OF THREADS",
285                    required = false
286                )]
287                pub jobs: Option<u16>,
288                #[doc = r"Maximum number of parallel test within a test crate execution."]
289                #[arg(
290                    long = "test-threads",
291                    value_name = "NUMBER OF THREADS",
292                    required = false
293                )]
294                pub threads: Option<u16>,
295                #[doc = r"Comma-separated list of features to enable during tests."]
296                #[arg(
297                    long,
298                    value_name = "FEATURE,FEATURE,...",
299                    value_delimiter = ',',
300                    required = false
301                )]
302                pub features: Option<Vec<String>>,
303                #[doc = r"If set, ignore default features."]
304                #[arg(
305                    long = "no-default-features",
306                    required = false
307                )]
308                pub no_default_features: bool,
309                #[doc = r"Force execution of tests no matter the environment (i.e. authorize to execute tests in prod)."]
310                #[arg(
311                    short = 'f',
312                    long = "force",
313                    required = false
314                )]
315                pub force: bool,
316                #[doc = r"If set, test logs are sent to output."]
317                #[arg(long = "nocapture", required = false)]
318                pub no_capture: bool,
319                #[doc = r"Build test in release mode."]
320                #[arg(short = 'r', long = "release", required = false)]
321                pub release: bool,
322            },
323        ),
324        (
325            "ValidateCmdArgs",
326            quote! {
327                #[doc = r"Ignore audit errors."]
328                #[arg(long = "ignore-audit", required = false)]
329                pub ignore_audit: bool,
330                #[doc = r"Build in release mode."]
331                #[arg(short = 'r', long = "release", required = false)]
332                pub release: bool,
333            },
334        ),
335    ])
336}
337
338// Returns a tuple where 0 is the actual struct and 1 is additional implementations
339fn generate_command_args_struct(
340    args: TokenStream,
341    input: TokenStream,
342) -> (TokenStream, TokenStream) {
343    let item = match syn::parse::<ItemStruct>(input) {
344        Ok(data) => data,
345        Err(e) => return (TokenStream::from(e.to_compile_error()), TokenStream::new()),
346    };
347    let args = match syn::parse::Parser::parse(Punctuated::<Meta, Comma>::parse_terminated, args) {
348        Ok(data) => data,
349        Err(e) => return (TokenStream::from(e.to_compile_error()), TokenStream::new()),
350    };
351    let struct_name = &item.ident;
352    let original_fields = item.fields.iter().map(|f| {
353        let attrs = &f.attrs;
354        let vis = &f.vis;
355        let ident = &f.ident;
356        let ty = &f.ty;
357        quote! {
358            #(#attrs)*
359            #vis #ident: #ty
360        }
361    });
362
363    if args.is_empty() {
364        let struct_output = TokenStream::from(quote! {
365            #[derive(clap::Args, Clone)]
366            pub struct #struct_name {
367                #(#original_fields,)*
368            }
369        });
370        (struct_output, TokenStream::new())
371    } else {
372        let mut target_type: Option<Meta> = None;
373        let mut subcommand_type: Option<Meta> = None;
374        if args.len() == 2 {
375            // from declare_command_args
376            let ty = args.get(0).unwrap();
377            if ty.path().get_ident().unwrap().to_string().as_str() != "None" {
378                target_type = Some(ty.clone());
379            }
380            let ty = args.get(1).unwrap();
381            if ty.path().get_ident().unwrap().to_string().as_str() != "None" {
382                subcommand_type = Some(ty.clone());
383            }
384        } else if args.len() == 3 {
385            // from extend_command_args
386            let ty = args.get(1).unwrap();
387            if ty.path().get_ident().unwrap().to_string().as_str() != "None" {
388                target_type = Some(ty.clone());
389            }
390            let ty = args.get(2).unwrap();
391            if ty.path().get_ident().unwrap().to_string().as_str() != "None" {
392                subcommand_type = Some(ty.clone());
393            }
394        } else {
395            return (
396                TokenStream::from(quote! {
397                    compile_error!("Error expanding macro.");
398                }),
399                TokenStream::new(),
400            );
401        };
402
403        let target_fields = if let Some(target) = target_type {
404            quote! {
405                #[doc = r"The target on which executing the command."]
406                #[arg(short, long, value_enum, default_value_t = #target::default())]
407                pub target: #target,
408                #[doc = r"Comma-separated list of excluded crates."]
409                #[arg(
410                    short = 'x',
411                    long,
412                    value_name = "CRATE,CRATE,...",
413                    value_delimiter = ',',
414                    required = false
415                )]
416                pub exclude: Vec<String>,
417                #[doc = r"Comma-separated list of crates to include exclusively."]
418                #[arg(
419                    short = 'n',
420                    long,
421                    value_name = "CRATE,CRATE,...",
422                    value_delimiter = ',',
423                    required = false
424                )]
425                pub only: Vec<String>,
426            }
427        } else {
428            quote! {}
429        };
430
431        let additional_cmd_args_map = get_additional_cmd_args_map();
432        let mut base_command_type = struct_name.to_string();
433        if args.len() == 3 {
434            base_command_type = args.get(0).unwrap().path().get_ident().unwrap().to_string();
435        }
436        let additional_fields = match additional_cmd_args_map.get(base_command_type.as_str()) {
437            Some(fields) => fields.clone(),
438            None => quote! {},
439        };
440
441        let (subcommand_field, subcommand_impl) = if let Some(subcommand) = subcommand_type.clone()
442        {
443            (
444                quote! {
445                    #[command(subcommand)]
446                    pub command: Option<#subcommand>,
447                },
448                quote! {
449                    impl #struct_name {
450                        pub fn get_command(&self) -> #subcommand {
451                            self.command.clone().unwrap_or_default()
452                        }
453                    }
454                },
455            )
456        } else {
457            (quote! {}, quote! {})
458        };
459
460        let struct_output = TokenStream::from(quote! {
461            #[derive(clap::Args, Clone)]
462            pub struct #struct_name {
463                #target_fields
464                #additional_fields
465                #subcommand_field
466                #(#original_fields,)*
467            }
468        });
469        let mut additional_output = TokenStream::from(quote! {
470            #subcommand_impl
471        });
472        // generate the subcommand enum only when it is declared
473        if args.len() == 2 {
474            if let Some(subcommand) = subcommand_type {
475                let subcommand_ident = subcommand.path().get_ident().unwrap();
476                let subcommand_string = subcommand_ident.to_string();
477                let original_variants = Punctuated::<Variant, Comma>::new();
478                additional_output.extend(generate_subcommand_enum(
479                    subcommand_string,
480                    subcommand_ident,
481                    &original_variants,
482                ));
483            }
484        }
485        (struct_output, additional_output)
486    }
487}
488
489fn generate_command_args_tryinto(args: TokenStream, input: TokenStream) -> TokenStream {
490    let args = parse_macro_input!(args with Punctuated::<Meta, Comma>::parse_terminated);
491    let base_type = args.get(0).unwrap();
492    let base_type_string = base_type.path().get_ident().unwrap().to_string();
493    let item = parse_macro_input!(input as ItemStruct);
494    let item_ident = &item.ident;
495    let has_target = item.fields.iter().any(|f| {
496        if let Some(ident) = &f.ident {
497            *ident == "target"
498        } else {
499            false
500        }
501    });
502    // check if the base command has subcommands
503    let subcommand_variant_map = get_subcommand_variant_map();
504    let base_subcommand_type_string = base_type_string.replace("CmdArgs", "SubCommand");
505    let has_subcommand = subcommand_variant_map.contains_key(base_subcommand_type_string.as_str())
506        && item.fields.iter().any(|f| {
507            if let Some(ident) = &f.ident {
508                *ident == "command"
509            } else {
510                false
511            }
512        });
513
514    // expand
515    let target = if has_target {
516        quote! {
517            target: self.target.try_into()?,
518        }
519    } else {
520        quote! {}
521    };
522    let (subcommand_let, subcommand_assign) = if has_subcommand {
523        (
524            quote! {
525                let cmd = self.get_command().try_into()?;
526            },
527            quote! {
528                command: Some(cmd),
529            },
530        )
531    } else {
532        (quote! {}, quote! {})
533    };
534    let fields: Vec<_> = item
535        .fields
536        .iter()
537        .filter_map(|f| {
538            f.ident.as_ref().map(|ident| {
539                let ident_str = ident.to_string();
540                // TODO this hardcoded predicate is awful, find a way to make this better
541                if ident_str != "target"
542                    && (ident_str == "exclude"
543                        || ident_str == "features"
544                        || ident_str == "force"
545                        || ident_str == "ignore_audit"
546                        || ident_str == "jobs"
547                        || ident_str == "no_default_features"
548                        || ident_str == "no_capture"
549                        || ident_str == "only"
550                        || ident_str == "release"
551                        || ident_str == "test"
552                        || ident_str == "threads")
553                {
554                    quote! { #ident: self.#ident, }
555                } else {
556                    quote! {}
557                }
558            })
559        })
560        .collect();
561
562    let tryinto = quote! {
563        impl std::convert::TryInto<#base_type> for #item_ident {
564            type Error = anyhow::Error;
565            fn try_into(self) -> Result<#base_type, Self::Error> {
566                #subcommand_let
567                Ok(#base_type {
568                    #target
569                    #subcommand_assign
570                    #(#fields)*
571                })
572            }
573        }
574    };
575    TokenStream::from(tryinto)
576}
577
578#[proc_macro_attribute]
579pub fn declare_command_args(args: TokenStream, input: TokenStream) -> TokenStream {
580    let args_clone = args.clone();
581    let parsed_args =
582        parse_macro_input!(args_clone with Punctuated::<Meta, Comma>::parse_terminated);
583    if parsed_args.len() == 2 {
584        let mut output: TokenStream = quote! {}.into();
585        let (struct_output, additional_output) = generate_command_args_struct(args, input);
586        output.extend(struct_output);
587        output.extend(additional_output);
588        output
589    } else {
590        let error_msg = r#"declare_commands_args macro takes 2 arguments.
591 First argument is the target type (None if there is no target).
592 Second argument is the subcommand type (None if there is no subcommand)."#;
593        TokenStream::from(quote! {compile_error!(#error_msg)})
594    }
595}
596
597#[proc_macro_attribute]
598pub fn extend_command_args(args: TokenStream, input: TokenStream) -> TokenStream {
599    let args_clone = args.clone();
600    let parsed_args =
601        parse_macro_input!(args_clone with Punctuated::<Meta, Comma>::parse_terminated);
602    if parsed_args.len() != 3 {
603        let error_msg = r#"extend_command_args takes three arguments.
604 First argument is the type of the base command arguments struct to extend.
605 Second argument is the target type (None if there is no target).
606 Third argument is the subcommand type (None if there is no subcommand)"#;
607        return TokenStream::from(quote! {compile_error!(#error_msg);});
608    }
609    let mut output: TokenStream = quote! {}.into();
610    let (struct_output, additional_output) = generate_command_args_struct(args.clone(), input);
611    let tryinto = generate_command_args_tryinto(args, struct_output.clone());
612    output.extend(struct_output);
613    output.extend(additional_output);
614    output.extend(tryinto);
615    output
616}
617
618// Subcommands
619// ===========
620
621fn get_subcommand_variant_map() -> HashMap<&'static str, proc_macro2::TokenStream> {
622    HashMap::from([
623        (
624            "BumpSubCommand",
625            quote! {
626                #[doc = r"Bump the major version (x.0.0)."]
627                Major,
628                #[doc = r"Bump the minor version (0.x.0)."]
629                Minor,
630                #[default]
631                #[doc = r"Bump the patch version (0.0.x)."]
632                Patch,
633            },
634        ),
635        (
636            "CheckSubCommand",
637            quote! {
638                #[default]
639                #[doc = r"Run all the checks."]
640                All,
641                #[doc = r"Run audit command."]
642                Audit,
643                #[doc = r"Run format command."]
644                Format,
645                #[doc = r"Run lint command."]
646                Lint,
647                #[doc = r"Report typos in source code."]
648                Typos,
649            },
650        ),
651        (
652            // note: default is manually implemented for this subcommand as the default variant is not a unit variant.
653            "CoverageSubCommand",
654            quote! {
655                #[doc = r"Install grcov and its dependencies."]
656                Install,
657                #[doc = r"Generate lcov.info file. [default with default debug profile]"]
658                Generate(GenerateCmdArgs),
659            },
660        ),
661        (
662            "DependenciesSubCommand",
663            quote! {
664                #[doc = r"Run all dependency checks."]
665                #[default]
666                All,
667                #[doc = r"Run cargo-deny Lint dependency graph to ensure all dependencies meet requirements `<https://crates.io/crates/cargo-deny>`. [default]"]
668                Deny,
669                #[doc = r"Run cargo-machete to find unused dependencies `<https://crates.io/crates/cargo-machete>`"]
670                Unused,
671            },
672        ),
673        (
674            "DocSubCommand",
675            quote! {
676                #[default]
677                #[doc = r"Build documentation."]
678                Build,
679                #[doc = r"Run documentation tests."]
680                Tests,
681            },
682        ),
683        (
684            "DockerSubCommand",
685            quote! {
686                #[default]
687                #[doc = r"Start docker compose stack."]
688                Up,
689                #[doc = r"Stop docker compose stack."]
690                Down,
691            },
692        ),
693        (
694            "FixSubCommand",
695            quote! {
696                #[default]
697                #[doc = r"Run all the checks."]
698                All,
699                #[doc = r"Run audit command."]
700                Audit,
701                #[doc = r"Run format command and fix formatting."]
702                Format,
703                #[doc = r"Run lint command and fix issues."]
704                Lint,
705                #[doc = r"Find typos in source code and fix them."]
706                Typos,
707            },
708        ),
709        (
710            "TestSubCommand",
711            quote! {
712                #[default]
713                #[doc = r"Run all the checks."]
714                All,
715                #[doc = r"Run unit tests."]
716                Unit,
717                #[doc = r"Run integration tests."]
718                Integration,
719            },
720        ),
721        (
722            "VulnerabilitiesSubCommand",
723            quote! {
724                #[default]
725                #[doc = r"Run all most useful vulnerability checks. [default]"]
726                All,
727                #[doc = r"Run Address sanitizer (memory error detector)"]
728                AddressSanitizer,
729                #[doc = r"Run LLVM Control Flow Integrity (CFI) (provides forward-edge control flow protection)"]
730                ControlFlowIntegrity,
731                #[doc = r"Run newer variant of Address sanitizer (memory error detector similar to AddressSanitizer, but based on partial hardware assistance)"]
732                HWAddressSanitizer,
733                #[doc = r"Run Kernel LLVM Control Flow Integrity (KCFI) (provides forward-edge control flow protection for operating systems kernels)"]
734                KernelControlFlowIntegrity,
735                #[doc = r"Run Leak sanitizer (run-time memory leak detector)"]
736                LeakSanitizer,
737                #[doc = r"Run memory sanitizer (detector of uninitialized reads)"]
738                MemorySanitizer,
739                #[doc = r"Run another address sanitizer (like AddressSanitizer and HardwareAddressSanitizer but with lower overhead suitable for use as hardening for production binaries)"]
740                MemTagSanitizer,
741                #[doc = r"Run nightly-only checks through cargo-careful `<https://crates.io/crates/cargo-careful>`"]
742                NightlyChecks,
743                #[doc = r"Run SafeStack check (provides backward-edge control flow protection by separating stack into safe and unsafe regions"]
744                SafeStack,
745                #[doc = r"Run ShadowCall check (provides backward-edge control flow protection - aarch64 only)"]
746                ShadowCallStack,
747                #[doc = r"Run Thread sanitizer (data race detector)"]
748                ThreadSanitizer,
749            },
750        ),
751    ])
752}
753
754fn generate_subcommand_enum(
755    subcommand: String,
756    enum_name: &syn::Ident,
757    original_variants: &Punctuated<Variant, Comma>,
758) -> TokenStream {
759    let variant_map = get_subcommand_variant_map();
760    let output = if let Some(variants) = variant_map.get(subcommand.as_str()) {
761        // parse the variant and look for a default attribute so that we add the default derive if required
762        let variants_tokens = TokenStream::from(variants.clone());
763        let parsed_variants =
764            parse_macro_input!(variants_tokens with Punctuated::<Variant, Comma>::parse_terminated);
765        let default = if parsed_variants
766            .iter()
767            .any(|v| v.attrs.iter().any(|a| a.path().is_ident("default")))
768        {
769            quote! { Default }
770        } else {
771            quote! {}
772        };
773        quote! {
774            #[derive(strum::EnumString, strum::EnumIter, strum::Display, Clone, PartialEq, clap::Subcommand, #default)]
775            #[strum(serialize_all = "lowercase")]
776            pub enum #enum_name {
777                #variants
778                #original_variants
779            }
780        }
781    } else {
782        // Subcommand not found return no tokens
783        quote! {}
784    };
785    TokenStream::from(output)
786}
787
788fn generate_subcomand_tryinto(
789    base_subcommand: &syn::Ident,
790    subcommand: &syn::Ident,
791) -> TokenStream {
792    let variant_map = get_subcommand_variant_map();
793    // check if variants exist is done by the caller
794    let variants = variant_map
795        .get(base_subcommand.to_string().as_str())
796        .unwrap();
797    // parse the variant and look for a default attribute so that we add the default derive if required
798    let variants_tokens = TokenStream::from(variants.clone());
799    let parsed_variants =
800        parse_macro_input!(variants_tokens with Punctuated::<Variant, Comma>::parse_terminated);
801    let arms = parsed_variants.iter().map(|v| {
802        let variant_ident = &v.ident;
803        quote! {
804            #subcommand::#variant_ident => Ok(#base_subcommand::#variant_ident),
805        }
806    });
807    let tryinto = quote! {
808        impl std::convert::TryInto<#base_subcommand> for #subcommand {
809            type Error = anyhow::Error;
810            fn try_into(self) -> Result<#base_subcommand, Self::Error> {
811                match self {
812                    #(#arms)*
813                    _ => Err(anyhow::anyhow!("{} target is not supported.", self))
814                }
815            }
816        }
817    };
818    TokenStream::from(tryinto)
819}
820
821#[proc_macro_attribute]
822pub fn extend_subcommands(args: TokenStream, input: TokenStream) -> TokenStream {
823    let item = parse_macro_input!(input as ItemEnum);
824    let args_clone = args.clone();
825    let parsed_args =
826        parse_macro_input!(args_clone with Punctuated::<Meta, Comma>::parse_terminated);
827    if parsed_args.len() != 1 {
828        return TokenStream::from(quote! {
829            compile_error!("extend_subcommand takes one argument which is the type of the subcommand enum.");
830        });
831    }
832    let base_subcommand = parsed_args.get(0).unwrap();
833    let base_subcommand_ident = base_subcommand.path().get_ident().unwrap();
834    let base_subcommand_string = base_subcommand_ident.to_string();
835    let subcommand_ident = &item.ident;
836    let original_variants = &item.variants;
837
838    let variant_map = get_subcommand_variant_map();
839    if !variant_map.contains_key(base_subcommand_string.as_str()) {
840        let err_msg = format!(
841            "Unknown command: {}\nPossible commands are:\n  {}",
842            base_subcommand_string,
843            variant_map
844                .keys()
845                .cloned()
846                .collect::<Vec<&str>>()
847                .join("\n  "),
848        );
849        return TokenStream::from(quote! { compile_error!(#err_msg); });
850    }
851    let mut output = generate_subcommand_enum(
852        base_subcommand_string.clone(),
853        subcommand_ident,
854        original_variants,
855    );
856    output.extend(generate_subcomand_tryinto(
857        base_subcommand_ident,
858        subcommand_ident,
859    ));
860    output
861}