1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::quote;
4use std::collections::HashMap;
5use syn::{
6 parse_macro_input, punctuated::Punctuated, token::Comma, ItemEnum, ItemStruct, Meta, Variant,
7};
8
9fn generate_target_enum(input: TokenStream) -> TokenStream {
13 let item = parse_macro_input!(input as ItemEnum);
14 let enum_name = &item.ident;
15 let original_variants = &item.variants;
16
17 let output = quote! {
18 #[derive(strum::EnumString, strum::EnumIter, Default, strum::Display, Clone, PartialEq, clap::ValueEnum)]
19 #[strum(serialize_all = "lowercase")]
20 pub enum #enum_name {
21 #[doc = r"Targets all crates and examples using cargo --package."]
22 AllPackages,
23 #[doc = r"Targets all binary and library crates."]
24 Crates,
25 #[doc = r"Targets all example crates."]
26 Examples,
27 #[default]
28 #[doc = r"Targets the whole workspace using cargo --workspace."]
29 Workspace,
30 #original_variants
31 }
32 };
33 TokenStream::from(output)
34}
35
36fn generate_target_tryinto(_args: TokenStream, input: TokenStream) -> TokenStream {
37 let item = parse_macro_input!(input as ItemEnum);
38 let item_ident = &item.ident;
39 let tryinto = quote! {
40 impl std::convert::TryInto<tracel_xtask::commands::Target> for #item_ident {
41 type Error = anyhow::Error;
42 fn try_into(self) -> Result<tracel_xtask::commands::Target, Self::Error> {
43 match self {
44 #item_ident::AllPackages => Ok(tracel_xtask::commands::Target::AllPackages),
45 #item_ident::Crates => Ok(tracel_xtask::commands::Target::Crates),
46 #item_ident::Examples => Ok(tracel_xtask::commands::Target::Examples),
47 #item_ident::Workspace => Ok(tracel_xtask::commands::Target::Workspace),
48 _ => Err(anyhow::anyhow!("{} target is not supported.", self))
49 }
50 }
51 }
52 };
53 TokenStream::from(tryinto)
54}
55
56#[proc_macro_attribute]
57pub fn declare_targets(_args: TokenStream, input: TokenStream) -> TokenStream {
58 generate_target_enum(input)
59}
60
61#[proc_macro_attribute]
62pub fn extend_targets(args: TokenStream, input: TokenStream) -> TokenStream {
63 let mut output = generate_target_enum(input);
64 output.extend(generate_target_tryinto(args, output.clone()));
65 output
66}
67
68fn generate_dispatch_function(
72 enum_ident: &syn::Ident,
73 args: &Punctuated<Meta, Comma>,
74) -> TokenStream {
75 let arms: Vec<proc_macro2::TokenStream> = args.iter().map(|meta| {
76 let cmd_ident = meta.path().get_ident().unwrap();
77 let cmd_ident_string = cmd_ident.to_string();
78 let module_ident = syn::Ident::new(cmd_ident_string.to_lowercase().as_str(), cmd_ident.span());
79 match cmd_ident_string.as_str() {
80 "Fix" => quote! {
81 #enum_ident::#cmd_ident(cmd_args) => base_commands::#module_ident::handle_command(cmd_args, args.environment, args.context, None),
82 },
83 _ => quote! {
84 #enum_ident::#cmd_ident(cmd_args) => base_commands::#module_ident::handle_command(cmd_args, args.environment, args.context),
85 }
86 }
87 }).collect();
88 let func = quote! {
89 fn dispatch_base_commands(args: XtaskArgs<Command>) -> anyhow::Result<()> {
90 match args.command {
91 #(#arms)*
92 _ => Err(anyhow::anyhow!("Unknown command")),
93 }
94 }
95 };
96 TokenStream::from(func)
97}
98
99#[proc_macro_attribute]
100pub fn base_commands(args: TokenStream, input: TokenStream) -> TokenStream {
101 let item = parse_macro_input!(input as ItemEnum);
103 let args = parse_macro_input!(args with Punctuated::<Meta, Comma>::parse_terminated);
104
105 let mut variant_map: HashMap<&str, proc_macro2::TokenStream> = HashMap::new();
107 variant_map.insert(
108 "Build",
109 quote! {
110 #[doc = r"Build the code."]
111 Build(tracel_xtask::commands::build::BuildCmdArgs)
112 },
113 );
114 variant_map.insert(
115 "Bump",
116 quote! {
117 #[doc = r"Bump the version of all crates to be published."]
118 Bump(tracel_xtask::commands::bump::BumpCmdArgs)
119 },
120 );
121 variant_map.insert(
122 "Check",
123 quote! {
124 #[doc = r"Run checks like formatting, linting etc... This command only reports issues, use the 'fix' command to auto-fix issues."]
125 Check(tracel_xtask::commands::check::CheckCmdArgs)
126 },
127 );
128 variant_map.insert(
129 "Compile",
130 quote! {
131 #[doc = r"Compile check the code (does not write binaries to disk)."]
132 Compile(tracel_xtask::commands::compile::CompileCmdArgs)
133 },
134 );
135 variant_map.insert(
136 "Coverage",
137 quote! {
138 #[doc = r"Install and run coverage tools."]
139 Coverage(tracel_xtask::commands::coverage::CoverageCmdArgs)
140 },
141 );
142 variant_map.insert(
143 "Dependencies",
144 quote! {
145 #[doc = r"Run the specified dependencies check locally."]
146 Dependencies(tracel_xtask::commands::dependencies::DependenciesCmdArgs)
147 },
148 );
149 variant_map.insert(
150 "Doc",
151 quote! {
152 #[doc = r"Build documentation."]
153 Doc(tracel_xtask::commands::doc::DocCmdArgs)
154 },
155 );
156 variant_map.insert(
157 "Docker",
158 quote! {
159 #[doc = r"Manage docker compose stacks."]
160 Docker(tracel_xtask::commands::docker::DockerCmdArgs)
161 },
162 );
163 variant_map.insert(
164 "Fix",
165 quote! {
166 #[doc = r"Fix issues found with the 'check' command."]
167 Fix(tracel_xtask::commands::fix::FixCmdArgs)
168 },
169 );
170 variant_map.insert(
171 "Publish",
172 quote! {
173 #[doc = r"Publish a crate to crates.io."]
174 Publish(tracel_xtask::commands::publish::PublishCmdArgs)
175 },
176 );
177 variant_map.insert(
178 "Test",
179 quote! {
180 #[doc = r"Runs tests."]
181 Test(tracel_xtask::commands::test::TestCmdArgs)
182 },
183 );
184 variant_map.insert(
185 "Validate",
186 quote! {
187 #[doc = r"Validate the code base by running all the relevant checks and tests. Use this command before creating a new pull-request."]
188 Validate(tracel_xtask::commands::validate::ValidateCmdArgs)
189 },
190 );
191 variant_map.insert("Vulnerabilities", quote! {
192 #[doc = r"Run the specified vulnerability check locally. These commands must be called with 'cargo +nightly'."]
193 Vulnerabilities(tracel_xtask::commands::vulnerabilities::VulnerabilitiesCmdArgs)
194 });
195
196 let mut variants = vec![];
198 for arg in &args {
199 if let Meta::Path(path) = arg {
200 if let Some(ident) = path.get_ident() {
201 let ident_string = ident.to_string();
202 if let Some(variant) = variant_map.get(ident_string.as_str()) {
203 variants.push(variant.clone());
204 } else {
205 let err_msg = format!(
206 "Unknown command: {}\nPossible commands are:\n {}",
207 ident_string,
208 variant_map
209 .keys()
210 .cloned()
211 .collect::<Vec<&str>>()
212 .join("\n "),
213 );
214 return TokenStream::from(quote! {
215 compile_error!(#err_msg);
216 });
217 }
218 }
219 }
220 }
221
222 let enum_name = &item.ident;
224 let other_variants = &item.variants;
225 let mut output = TokenStream::from(quote! {
226 #[derive(clap::Subcommand)]
227 pub enum #enum_name {
228 #(#variants,)*
229 #other_variants
230 }
231 });
232 output.extend(generate_dispatch_function(enum_name, &args));
233 output
234}
235
236fn get_additional_cmd_args_map() -> HashMap<&'static str, proc_macro2::TokenStream> {
240 HashMap::from([
241 (
242 "BuildCmdArgs",
243 quote! {
244 #[doc = r"Build artifacs in release mode."]
245 #[arg(short, long, required = false)]
246 pub release: bool,
247 },
248 ),
249 (
250 "CheckCmdArgs",
251 quote! {
252 #[doc = r"Ignore audit errors."]
253 #[arg(long = "ignore-audit", required = false)]
254 pub ignore_audit: bool,
255 },
256 ),
257 (
258 "DockerCmdArgs",
259 quote! {
260 #[doc = r"Build images before starting containers."]
261 #[arg(short, long, required = false)]
262 pub build: bool,
263 #[doc = r"Project name."]
264 #[arg(short, long, default_value = "xtask")]
265 pub project: String,
266 #[doc = r"Space separated list of service subset to start. If empty then launch all the services in the stack."]
267 #[arg(short, long, num_args(1..), required = false)]
268 pub services: Vec<String>,
269 },
270 ),
271 (
272 "TestCmdArgs",
273 quote! {
274 #[doc = r"Execute only the test whose name matches the passed string."]
275 #[arg(
276 long = "test",
277 value_name = "TEST",
278 required = false
279 )]
280 pub test: Option<String>,
281 #[doc = r"Maximum number of parallel test crate compilations."]
282 #[arg(
283 long = "compilation-jobs",
284 value_name = "NUMBER OF THREADS",
285 required = false
286 )]
287 pub jobs: Option<u16>,
288 #[doc = r"Maximum number of parallel test within a test crate execution."]
289 #[arg(
290 long = "test-threads",
291 value_name = "NUMBER OF THREADS",
292 required = false
293 )]
294 pub threads: Option<u16>,
295 #[doc = r"Comma-separated list of features to enable during tests."]
296 #[arg(
297 long,
298 value_name = "FEATURE,FEATURE,...",
299 value_delimiter = ',',
300 required = false
301 )]
302 pub features: Option<Vec<String>>,
303 #[doc = r"If set, ignore default features."]
304 #[arg(
305 long = "no-default-features",
306 required = false
307 )]
308 pub no_default_features: bool,
309 #[doc = r"Force execution of tests no matter the environment (i.e. authorize to execute tests in prod)."]
310 #[arg(
311 short = 'f',
312 long = "force",
313 required = false
314 )]
315 pub force: bool,
316 #[doc = r"If set, test logs are sent to output."]
317 #[arg(long = "nocapture", required = false)]
318 pub no_capture: bool,
319 #[doc = r"Build test in release mode."]
320 #[arg(short = 'r', long = "release", required = false)]
321 pub release: bool,
322 },
323 ),
324 (
325 "ValidateCmdArgs",
326 quote! {
327 #[doc = r"Ignore audit errors."]
328 #[arg(long = "ignore-audit", required = false)]
329 pub ignore_audit: bool,
330 #[doc = r"Build in release mode."]
331 #[arg(short = 'r', long = "release", required = false)]
332 pub release: bool,
333 },
334 ),
335 ])
336}
337
338fn generate_command_args_struct(
340 args: TokenStream,
341 input: TokenStream,
342) -> (TokenStream, TokenStream) {
343 let item = match syn::parse::<ItemStruct>(input) {
344 Ok(data) => data,
345 Err(e) => return (TokenStream::from(e.to_compile_error()), TokenStream::new()),
346 };
347 let args = match syn::parse::Parser::parse(Punctuated::<Meta, Comma>::parse_terminated, args) {
348 Ok(data) => data,
349 Err(e) => return (TokenStream::from(e.to_compile_error()), TokenStream::new()),
350 };
351 let struct_name = &item.ident;
352 let original_fields = item.fields.iter().map(|f| {
353 let attrs = &f.attrs;
354 let vis = &f.vis;
355 let ident = &f.ident;
356 let ty = &f.ty;
357 quote! {
358 #(#attrs)*
359 #vis #ident: #ty
360 }
361 });
362
363 if args.is_empty() {
364 let struct_output = TokenStream::from(quote! {
365 #[derive(clap::Args, Clone)]
366 pub struct #struct_name {
367 #(#original_fields,)*
368 }
369 });
370 (struct_output, TokenStream::new())
371 } else {
372 let mut target_type: Option<Meta> = None;
373 let mut subcommand_type: Option<Meta> = None;
374 if args.len() == 2 {
375 let ty = args.get(0).unwrap();
377 if ty.path().get_ident().unwrap().to_string().as_str() != "None" {
378 target_type = Some(ty.clone());
379 }
380 let ty = args.get(1).unwrap();
381 if ty.path().get_ident().unwrap().to_string().as_str() != "None" {
382 subcommand_type = Some(ty.clone());
383 }
384 } else if args.len() == 3 {
385 let ty = args.get(1).unwrap();
387 if ty.path().get_ident().unwrap().to_string().as_str() != "None" {
388 target_type = Some(ty.clone());
389 }
390 let ty = args.get(2).unwrap();
391 if ty.path().get_ident().unwrap().to_string().as_str() != "None" {
392 subcommand_type = Some(ty.clone());
393 }
394 } else {
395 return (
396 TokenStream::from(quote! {
397 compile_error!("Error expanding macro.");
398 }),
399 TokenStream::new(),
400 );
401 };
402
403 let target_fields = if let Some(target) = target_type {
404 quote! {
405 #[doc = r"The target on which executing the command."]
406 #[arg(short, long, value_enum, default_value_t = #target::default())]
407 pub target: #target,
408 #[doc = r"Comma-separated list of excluded crates."]
409 #[arg(
410 short = 'x',
411 long,
412 value_name = "CRATE,CRATE,...",
413 value_delimiter = ',',
414 required = false
415 )]
416 pub exclude: Vec<String>,
417 #[doc = r"Comma-separated list of crates to include exclusively."]
418 #[arg(
419 short = 'n',
420 long,
421 value_name = "CRATE,CRATE,...",
422 value_delimiter = ',',
423 required = false
424 )]
425 pub only: Vec<String>,
426 }
427 } else {
428 quote! {}
429 };
430
431 let additional_cmd_args_map = get_additional_cmd_args_map();
432 let mut base_command_type = struct_name.to_string();
433 if args.len() == 3 {
434 base_command_type = args.get(0).unwrap().path().get_ident().unwrap().to_string();
435 }
436 let additional_fields = match additional_cmd_args_map.get(base_command_type.as_str()) {
437 Some(fields) => fields.clone(),
438 None => quote! {},
439 };
440
441 let (subcommand_field, subcommand_impl) = if let Some(subcommand) = subcommand_type.clone()
442 {
443 (
444 quote! {
445 #[command(subcommand)]
446 pub command: Option<#subcommand>,
447 },
448 quote! {
449 impl #struct_name {
450 pub fn get_command(&self) -> #subcommand {
451 self.command.clone().unwrap_or_default()
452 }
453 }
454 },
455 )
456 } else {
457 (quote! {}, quote! {})
458 };
459
460 let struct_output = TokenStream::from(quote! {
461 #[derive(clap::Args, Clone)]
462 pub struct #struct_name {
463 #target_fields
464 #additional_fields
465 #subcommand_field
466 #(#original_fields,)*
467 }
468 });
469 let mut additional_output = TokenStream::from(quote! {
470 #subcommand_impl
471 });
472 if args.len() == 2 {
474 if let Some(subcommand) = subcommand_type {
475 let subcommand_ident = subcommand.path().get_ident().unwrap();
476 let subcommand_string = subcommand_ident.to_string();
477 let original_variants = Punctuated::<Variant, Comma>::new();
478 additional_output.extend(generate_subcommand_enum(
479 subcommand_string,
480 subcommand_ident,
481 &original_variants,
482 ));
483 }
484 }
485 (struct_output, additional_output)
486 }
487}
488
489fn generate_command_args_tryinto(args: TokenStream, input: TokenStream) -> TokenStream {
490 let args = parse_macro_input!(args with Punctuated::<Meta, Comma>::parse_terminated);
491 let base_type = args.get(0).unwrap();
492 let base_type_string = base_type.path().get_ident().unwrap().to_string();
493 let item = parse_macro_input!(input as ItemStruct);
494 let item_ident = &item.ident;
495 let has_target = item.fields.iter().any(|f| {
496 if let Some(ident) = &f.ident {
497 *ident == "target"
498 } else {
499 false
500 }
501 });
502 let subcommand_variant_map = get_subcommand_variant_map();
504 let base_subcommand_type_string = base_type_string.replace("CmdArgs", "SubCommand");
505 let has_subcommand = subcommand_variant_map.contains_key(base_subcommand_type_string.as_str())
506 && item.fields.iter().any(|f| {
507 if let Some(ident) = &f.ident {
508 *ident == "command"
509 } else {
510 false
511 }
512 });
513
514 let target = if has_target {
516 quote! {
517 target: self.target.try_into()?,
518 }
519 } else {
520 quote! {}
521 };
522 let (subcommand_let, subcommand_assign) = if has_subcommand {
523 (
524 quote! {
525 let cmd = self.get_command().try_into()?;
526 },
527 quote! {
528 command: Some(cmd),
529 },
530 )
531 } else {
532 (quote! {}, quote! {})
533 };
534 let fields: Vec<_> = item
535 .fields
536 .iter()
537 .filter_map(|f| {
538 f.ident.as_ref().map(|ident| {
539 let ident_str = ident.to_string();
540 if ident_str != "target"
542 && (ident_str == "exclude"
543 || ident_str == "features"
544 || ident_str == "force"
545 || ident_str == "ignore_audit"
546 || ident_str == "jobs"
547 || ident_str == "no_default_features"
548 || ident_str == "no_capture"
549 || ident_str == "only"
550 || ident_str == "release"
551 || ident_str == "test"
552 || ident_str == "threads")
553 {
554 quote! { #ident: self.#ident, }
555 } else {
556 quote! {}
557 }
558 })
559 })
560 .collect();
561
562 let tryinto = quote! {
563 impl std::convert::TryInto<#base_type> for #item_ident {
564 type Error = anyhow::Error;
565 fn try_into(self) -> Result<#base_type, Self::Error> {
566 #subcommand_let
567 Ok(#base_type {
568 #target
569 #subcommand_assign
570 #(#fields)*
571 })
572 }
573 }
574 };
575 TokenStream::from(tryinto)
576}
577
578#[proc_macro_attribute]
579pub fn declare_command_args(args: TokenStream, input: TokenStream) -> TokenStream {
580 let args_clone = args.clone();
581 let parsed_args =
582 parse_macro_input!(args_clone with Punctuated::<Meta, Comma>::parse_terminated);
583 if parsed_args.len() == 2 {
584 let mut output: TokenStream = quote! {}.into();
585 let (struct_output, additional_output) = generate_command_args_struct(args, input);
586 output.extend(struct_output);
587 output.extend(additional_output);
588 output
589 } else {
590 let error_msg = r#"declare_commands_args macro takes 2 arguments.
591 First argument is the target type (None if there is no target).
592 Second argument is the subcommand type (None if there is no subcommand)."#;
593 TokenStream::from(quote! {compile_error!(#error_msg)})
594 }
595}
596
597#[proc_macro_attribute]
598pub fn extend_command_args(args: TokenStream, input: TokenStream) -> TokenStream {
599 let args_clone = args.clone();
600 let parsed_args =
601 parse_macro_input!(args_clone with Punctuated::<Meta, Comma>::parse_terminated);
602 if parsed_args.len() != 3 {
603 let error_msg = r#"extend_command_args takes three arguments.
604 First argument is the type of the base command arguments struct to extend.
605 Second argument is the target type (None if there is no target).
606 Third argument is the subcommand type (None if there is no subcommand)"#;
607 return TokenStream::from(quote! {compile_error!(#error_msg);});
608 }
609 let mut output: TokenStream = quote! {}.into();
610 let (struct_output, additional_output) = generate_command_args_struct(args.clone(), input);
611 let tryinto = generate_command_args_tryinto(args, struct_output.clone());
612 output.extend(struct_output);
613 output.extend(additional_output);
614 output.extend(tryinto);
615 output
616}
617
618fn get_subcommand_variant_map() -> HashMap<&'static str, proc_macro2::TokenStream> {
622 HashMap::from([
623 (
624 "BumpSubCommand",
625 quote! {
626 #[doc = r"Bump the major version (x.0.0)."]
627 Major,
628 #[doc = r"Bump the minor version (0.x.0)."]
629 Minor,
630 #[default]
631 #[doc = r"Bump the patch version (0.0.x)."]
632 Patch,
633 },
634 ),
635 (
636 "CheckSubCommand",
637 quote! {
638 #[default]
639 #[doc = r"Run all the checks."]
640 All,
641 #[doc = r"Run audit command."]
642 Audit,
643 #[doc = r"Run format command."]
644 Format,
645 #[doc = r"Run lint command."]
646 Lint,
647 #[doc = r"Report typos in source code."]
648 Typos,
649 },
650 ),
651 (
652 "CoverageSubCommand",
654 quote! {
655 #[doc = r"Install grcov and its dependencies."]
656 Install,
657 #[doc = r"Generate lcov.info file. [default with default debug profile]"]
658 Generate(GenerateCmdArgs),
659 },
660 ),
661 (
662 "DependenciesSubCommand",
663 quote! {
664 #[doc = r"Run all dependency checks."]
665 #[default]
666 All,
667 #[doc = r"Run cargo-deny Lint dependency graph to ensure all dependencies meet requirements `<https://crates.io/crates/cargo-deny>`. [default]"]
668 Deny,
669 #[doc = r"Run cargo-machete to find unused dependencies `<https://crates.io/crates/cargo-machete>`"]
670 Unused,
671 },
672 ),
673 (
674 "DocSubCommand",
675 quote! {
676 #[default]
677 #[doc = r"Build documentation."]
678 Build,
679 #[doc = r"Run documentation tests."]
680 Tests,
681 },
682 ),
683 (
684 "DockerSubCommand",
685 quote! {
686 #[default]
687 #[doc = r"Start docker compose stack."]
688 Up,
689 #[doc = r"Stop docker compose stack."]
690 Down,
691 },
692 ),
693 (
694 "FixSubCommand",
695 quote! {
696 #[default]
697 #[doc = r"Run all the checks."]
698 All,
699 #[doc = r"Run audit command."]
700 Audit,
701 #[doc = r"Run format command and fix formatting."]
702 Format,
703 #[doc = r"Run lint command and fix issues."]
704 Lint,
705 #[doc = r"Find typos in source code and fix them."]
706 Typos,
707 },
708 ),
709 (
710 "TestSubCommand",
711 quote! {
712 #[default]
713 #[doc = r"Run all the checks."]
714 All,
715 #[doc = r"Run unit tests."]
716 Unit,
717 #[doc = r"Run integration tests."]
718 Integration,
719 },
720 ),
721 (
722 "VulnerabilitiesSubCommand",
723 quote! {
724 #[default]
725 #[doc = r"Run all most useful vulnerability checks. [default]"]
726 All,
727 #[doc = r"Run Address sanitizer (memory error detector)"]
728 AddressSanitizer,
729 #[doc = r"Run LLVM Control Flow Integrity (CFI) (provides forward-edge control flow protection)"]
730 ControlFlowIntegrity,
731 #[doc = r"Run newer variant of Address sanitizer (memory error detector similar to AddressSanitizer, but based on partial hardware assistance)"]
732 HWAddressSanitizer,
733 #[doc = r"Run Kernel LLVM Control Flow Integrity (KCFI) (provides forward-edge control flow protection for operating systems kernels)"]
734 KernelControlFlowIntegrity,
735 #[doc = r"Run Leak sanitizer (run-time memory leak detector)"]
736 LeakSanitizer,
737 #[doc = r"Run memory sanitizer (detector of uninitialized reads)"]
738 MemorySanitizer,
739 #[doc = r"Run another address sanitizer (like AddressSanitizer and HardwareAddressSanitizer but with lower overhead suitable for use as hardening for production binaries)"]
740 MemTagSanitizer,
741 #[doc = r"Run nightly-only checks through cargo-careful `<https://crates.io/crates/cargo-careful>`"]
742 NightlyChecks,
743 #[doc = r"Run SafeStack check (provides backward-edge control flow protection by separating stack into safe and unsafe regions"]
744 SafeStack,
745 #[doc = r"Run ShadowCall check (provides backward-edge control flow protection - aarch64 only)"]
746 ShadowCallStack,
747 #[doc = r"Run Thread sanitizer (data race detector)"]
748 ThreadSanitizer,
749 },
750 ),
751 ])
752}
753
754fn generate_subcommand_enum(
755 subcommand: String,
756 enum_name: &syn::Ident,
757 original_variants: &Punctuated<Variant, Comma>,
758) -> TokenStream {
759 let variant_map = get_subcommand_variant_map();
760 let output = if let Some(variants) = variant_map.get(subcommand.as_str()) {
761 let variants_tokens = TokenStream::from(variants.clone());
763 let parsed_variants =
764 parse_macro_input!(variants_tokens with Punctuated::<Variant, Comma>::parse_terminated);
765 let default = if parsed_variants
766 .iter()
767 .any(|v| v.attrs.iter().any(|a| a.path().is_ident("default")))
768 {
769 quote! { Default }
770 } else {
771 quote! {}
772 };
773 quote! {
774 #[derive(strum::EnumString, strum::EnumIter, strum::Display, Clone, PartialEq, clap::Subcommand, #default)]
775 #[strum(serialize_all = "lowercase")]
776 pub enum #enum_name {
777 #variants
778 #original_variants
779 }
780 }
781 } else {
782 quote! {}
784 };
785 TokenStream::from(output)
786}
787
788fn generate_subcomand_tryinto(
789 base_subcommand: &syn::Ident,
790 subcommand: &syn::Ident,
791) -> TokenStream {
792 let variant_map = get_subcommand_variant_map();
793 let variants = variant_map
795 .get(base_subcommand.to_string().as_str())
796 .unwrap();
797 let variants_tokens = TokenStream::from(variants.clone());
799 let parsed_variants =
800 parse_macro_input!(variants_tokens with Punctuated::<Variant, Comma>::parse_terminated);
801 let arms = parsed_variants.iter().map(|v| {
802 let variant_ident = &v.ident;
803 quote! {
804 #subcommand::#variant_ident => Ok(#base_subcommand::#variant_ident),
805 }
806 });
807 let tryinto = quote! {
808 impl std::convert::TryInto<#base_subcommand> for #subcommand {
809 type Error = anyhow::Error;
810 fn try_into(self) -> Result<#base_subcommand, Self::Error> {
811 match self {
812 #(#arms)*
813 _ => Err(anyhow::anyhow!("{} target is not supported.", self))
814 }
815 }
816 }
817 };
818 TokenStream::from(tryinto)
819}
820
821#[proc_macro_attribute]
822pub fn extend_subcommands(args: TokenStream, input: TokenStream) -> TokenStream {
823 let item = parse_macro_input!(input as ItemEnum);
824 let args_clone = args.clone();
825 let parsed_args =
826 parse_macro_input!(args_clone with Punctuated::<Meta, Comma>::parse_terminated);
827 if parsed_args.len() != 1 {
828 return TokenStream::from(quote! {
829 compile_error!("extend_subcommand takes one argument which is the type of the subcommand enum.");
830 });
831 }
832 let base_subcommand = parsed_args.get(0).unwrap();
833 let base_subcommand_ident = base_subcommand.path().get_ident().unwrap();
834 let base_subcommand_string = base_subcommand_ident.to_string();
835 let subcommand_ident = &item.ident;
836 let original_variants = &item.variants;
837
838 let variant_map = get_subcommand_variant_map();
839 if !variant_map.contains_key(base_subcommand_string.as_str()) {
840 let err_msg = format!(
841 "Unknown command: {}\nPossible commands are:\n {}",
842 base_subcommand_string,
843 variant_map
844 .keys()
845 .cloned()
846 .collect::<Vec<&str>>()
847 .join("\n "),
848 );
849 return TokenStream::from(quote! { compile_error!(#err_msg); });
850 }
851 let mut output = generate_subcommand_enum(
852 base_subcommand_string.clone(),
853 subcommand_ident,
854 original_variants,
855 );
856 output.extend(generate_subcomand_tryinto(
857 base_subcommand_ident,
858 subcommand_ident,
859 ));
860 output
861}