1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use syn::Expr;
7use syn::Ident;
8use syn::Result;
9use syn::Token;
10use syn::parse::Parse;
11use syn::parse::ParseStream;
12use syn::parse_macro_input;
13use syn::punctuated::Punctuated;
14
15enum PathPattern {
16 Path {
17 segments: Vec<SegmentPattern>,
18 fully_qualified: bool,
19 },
20 Wildcard,
21}
22
23enum SegmentPattern {
24 Required(SegmentMatcher),
25 Optional(SegmentMatcher),
26 Binding {
27 span: Span,
28 name: Ident,
29 multi: bool,
30 },
31}
32
33struct SegmentMatcher {
34 ident: Ident,
35 args: Option<ArgumentPattern>,
36}
37
38enum ArgumentPattern {
39 AngleBracketed(Vec<GenericArgumentPattern>),
40}
41
42enum GenericArgumentPattern {
43 Argument(Box<SegmentMatcher>),
44 Wildcard,
45 Binding(Span, Ident),
46}
47
48struct PathMatchArms {
49 path: Expr,
50 arms: Vec<MatchArm>,
51}
52
53struct MatchArm {
54 patterns: Punctuated<PathPattern, Token![|]>,
55 _fat_arrow: Token![=>],
56 body: Expr,
57 _comma: Option<Token![,]>,
58}
59
60impl Parse for PathMatchArms {
61 fn parse(input: ParseStream) -> Result<Self> {
62 let path = input.parse()?;
63 input.parse::<Token![,]>()?;
64
65 let mut arms = Vec::new();
66 while !input.is_empty() {
67 arms.push(input.parse()?);
68 }
69
70 Ok(PathMatchArms { path, arms })
71 }
72}
73
74impl Parse for MatchArm {
75 fn parse(input: ParseStream) -> Result<Self> {
76 let patterns = Punctuated::parse_separated_nonempty(input)?;
77 let _fat_arrow = input.parse()?;
78 let body = input.parse()?;
79 let _comma = input.parse().ok();
80
81 Ok(MatchArm {
82 patterns,
83 _fat_arrow,
84 body,
85 _comma,
86 })
87 }
88}
89
90impl Parse for PathPattern {
91 fn parse(input: ParseStream) -> Result<Self> {
92 if input.peek(Token![_]) {
93 input.parse::<Token![_]>()?;
94 return Ok(PathPattern::Wildcard);
95 }
96
97 let mut segments = Vec::new();
98 let mut fully_qualified = false;
99
100 if input.peek(Token![::]) {
101 input.parse::<Token![::]>()?;
102 fully_qualified = true;
103 }
104
105 loop {
106 if input.peek(Token![$]) {
107 input.parse::<Token![$]>()?;
108 let mut name: Ident = input.parse()?;
109 let multi = if input.peek(Token![+]) {
110 input.parse::<Token![+]>()?;
111 true
112 } else {
113 false
114 };
115
116 let span = name.span();
117 name.set_span(Span::call_site());
118
119 segments.push(SegmentPattern::Binding { span, name, multi });
120
121 if input.peek(Token![::]) {
122 input.parse::<Token![::]>()?;
123 } else {
124 break;
125 }
126 continue;
127 }
128
129 let ident: Ident = input.parse()?;
130
131 let args = if input.peek(Token![<]) {
132 Some(input.parse()?)
133 } else {
134 None
135 };
136
137 let optional = if input.peek(Token![?]) {
138 input.parse::<Token![?]>()?;
139 true
140 } else {
141 false
142 };
143
144 let matcher = SegmentMatcher { ident, args };
145
146 segments.push(if optional {
147 SegmentPattern::Optional(matcher)
148 } else {
149 SegmentPattern::Required(matcher)
150 });
151
152 if input.peek(Token![::]) {
153 input.parse::<Token![::]>()?;
154 } else {
155 break;
156 }
157 }
158
159 Ok(PathPattern::Path {
160 segments,
161 fully_qualified,
162 })
163 }
164}
165
166impl Parse for ArgumentPattern {
167 fn parse(input: ParseStream) -> Result<Self> {
168 input.parse::<Token![<]>()?;
169 let mut args = Vec::new();
170
171 loop {
172 if input.peek(Token![_]) {
173 input.parse::<Token![_]>()?;
174 args.push(GenericArgumentPattern::Wildcard);
175 } else if input.peek(Token![$]) {
176 input.parse::<Token![$]>()?;
177 let mut name: Ident = input.parse()?;
178 let span = name.span();
179 name.set_span(Span::call_site());
180 args.push(GenericArgumentPattern::Binding(span, name));
181 } else {
182 let ident: Ident = input.parse()?;
183
184 let matcher = SegmentMatcher {
185 ident,
186 args: if input.peek(Token![<]) {
187 Some(input.parse()?)
188 } else {
189 None
190 },
191 };
192 args.push(GenericArgumentPattern::Argument(Box::new(matcher)));
193 }
194
195 if input.peek(Token![,]) {
196 input.parse::<Token![,]>()?;
197 } else {
198 break;
199 }
200 }
201
202 input.parse::<Token![>]>()?;
203 Ok(ArgumentPattern::AngleBracketed(args))
204 }
205}
206
207struct PathMatcher {
208 path_expr: Expr,
209 fully_qualified_check: proc_macro2::TokenStream,
210 length_check: proc_macro2::TokenStream,
211 segment_checks: Vec<proc_macro2::TokenStream>,
212 binding_names: Vec<(Span, Ident)>,
213}
214
215fn generate_path_matcher(
216 path_expr: &Expr,
217 patterns: &Punctuated<PathPattern, Token![|]>,
218) -> Vec<PathMatcher> {
219 let mut out = Vec::new();
220
221 for pattern in patterns {
222 match pattern {
223 PathPattern::Wildcard => {}
224 PathPattern::Path {
225 segments,
226 fully_qualified,
227 } => {
228 let mut binding_names = Vec::new();
229 for seg in segments {
230 match seg {
231 SegmentPattern::Binding { span, name, .. } => {
232 binding_names.push((span.clone(), name.clone()));
233 }
234 SegmentPattern::Required(matcher)
235 | SegmentPattern::Optional(matcher) => {
236 fn handle_segment_matcher(
237 binding_names: &mut Vec<(Span, Ident)>,
238 matcher: &SegmentMatcher,
239 ) {
240 if let Some(ArgumentPattern::AngleBracketed(args)) =
241 &matcher.args
242 {
243 for arg in args {
244 match arg {
245 GenericArgumentPattern::Binding(span, name) => {
246 binding_names.push((span.clone(), name.clone()));
247 }
248 GenericArgumentPattern::Argument(arg) => {
249 handle_segment_matcher(binding_names, &*arg);
250 }
251 GenericArgumentPattern::Wildcard => {}
252 }
253 }
254 }
255 }
256
257 handle_segment_matcher(&mut binding_names, matcher);
258 }
259 }
260 }
261
262 let mut required_segments = Vec::new();
263 let mut optional_segments = Vec::new();
264 let mut has_multi_binding = false;
265
266 for seg in segments {
267 match seg {
268 SegmentPattern::Required(matcher) => {
269 required_segments.push(matcher)
270 }
271 SegmentPattern::Optional(matcher) => {
272 optional_segments.push(matcher)
273 }
274 SegmentPattern::Binding { multi, .. } => {
275 if *multi {
276 has_multi_binding = true;
277 }
278 }
279 }
280 }
281
282 let min_len = segments
283 .iter()
284 .filter(|s| {
285 matches!(
286 s,
287 SegmentPattern::Required(_)
288 | SegmentPattern::Binding { multi: false, .. }
289 )
290 })
291 .count();
292
293 let max_len = if has_multi_binding {
294 None
295 } else {
296 Some(segments.len())
297 };
298
299 let mut segment_checks = Vec::new();
300
301 for (seg_idx, seg) in segments.iter().enumerate() {
302 match seg {
303 SegmentPattern::Binding { name, multi, .. } => {
304 if *multi {
305 let required_after = segments[seg_idx + 1..]
306 .iter()
307 .filter(|s| {
308 matches!(
309 s,
310 SegmentPattern::Required(_)
311 | SegmentPattern::Binding { multi: false, .. }
312 )
313 })
314 .count();
315
316 let check = quote! {
317 let __end_idx = __segments.len() - #required_after;
318 if __idx > __end_idx {
319 break false;
320 }
321 #name = Some(__segments.iter().skip(__idx).take(__end_idx - __idx).cloned().collect::<syn::punctuated::Punctuated<_, syn::Token![::]>>());
322 __idx = __end_idx;
323 };
324 segment_checks.push(check);
325 } else {
326 let check = quote! {
327 if __idx >= __segments.len() {
328 break false;
329 }
330 #name = Some(&__segments[__idx]);
331 __idx += 1;
332 };
333 segment_checks.push(check);
334 }
335 continue;
336 }
337 _ => {}
338 }
339
340 let (matcher, is_optional) = match seg {
341 SegmentPattern::Required(m) => (m, false),
342 SegmentPattern::Optional(m) => (m, true),
343 SegmentPattern::Binding { .. } => unreachable!(),
344 };
345
346 fn handle_segment_matcher(
347 matcher: &SegmentMatcher,
348 is_optional: bool,
349 ) -> proc_macro2::TokenStream {
350 let seg_ident = &matcher.ident;
351 let seg_ident_str = seg_ident.to_string();
352
353 let name_check = quote! {
354 __seg.ident == #seg_ident_str
355 };
356
357 let args_check = if let Some(args) = &matcher.args {
358 match args {
359 ArgumentPattern::AngleBracketed(arg_patterns) => {
360 let mut arg_checks = Vec::new();
361 for (arg_idx, arg_pattern) in arg_patterns.iter().enumerate()
362 {
363 match arg_pattern {
364 GenericArgumentPattern::Wildcard => {}
365 GenericArgumentPattern::Argument(segment_matcher) => {
366 fn generate_nested_arg_check(
367 segment_matcher: &SegmentMatcher,
368 arg_var: &str,
369 depth: usize,
370 ) -> proc_macro2::TokenStream {
371 let arg_ident = &segment_matcher.ident;
372 let arg_ident_str = arg_ident.to_string();
373 let arg_var_ident =
374 Ident::new(arg_var, Span::call_site());
375
376 if let Some(ArgumentPattern::AngleBracketed(
377 nested_arg_patterns,
378 )) = &segment_matcher.args
379 {
380 let mut nested_arg_checks = Vec::new();
381 for (nested_arg_idx, nested_arg_pattern) in
382 nested_arg_patterns.iter().enumerate()
383 {
384 match nested_arg_pattern {
385 GenericArgumentPattern::Binding(
386 _span,
387 name,
388 ) => {
389 let nested_arg_var =
390 format!("__nested_arg_{}", depth);
391 let nested_arg_var_ident = Ident::new(
392 &nested_arg_var,
393 Span::call_site(),
394 );
395 nested_arg_checks.push(quote! {
396 if let Some(#nested_arg_var_ident) = __nested_args.get(#nested_arg_idx) {
397 #name = Some(#nested_arg_var_ident);
398 } else {
399 break false;
400 }
401 });
402 }
403 GenericArgumentPattern::Wildcard => {}
404 GenericArgumentPattern::Argument(
405 inner_segment_matcher,
406 ) => {
407 let nested_arg_var =
408 format!("__nested_arg_{}", depth);
409 let inner_check = generate_nested_arg_check(
410 inner_segment_matcher,
411 &nested_arg_var,
412 depth + 1,
413 );
414 let nested_arg_var_ident = Ident::new(
415 &nested_arg_var,
416 Span::call_site(),
417 );
418 nested_arg_checks.push(quote! {
419 if let Some(#nested_arg_var_ident) = __nested_args.get(#nested_arg_idx) {
420 #inner_check
421 } else {
422 break false;
423 }
424 });
425 }
426 }
427 }
428
429 let nested_arg_count = nested_arg_patterns.len();
430 quote! {
431 if
432 let syn::GenericArgument::Type(syn::Type::Path(__nested_type_path)) = #arg_var_ident
433 && let Some(__nested_seg) = __nested_type_path.path.segments.last()
434 && __nested_seg.ident == #arg_ident_str
435 && let syn::PathArguments::AngleBracketed(__nested_angle_args) = &__nested_seg.arguments
436 {
437 let __nested_args = &__nested_angle_args.args;
438 if __nested_args.len() != #nested_arg_count {
439 break false;
440 }
441 #(#nested_arg_checks)*
442 } else {
443 break false;
444 }
445 }
446 } else {
447 quote! {
448 if
449 let syn::GenericArgument::Type(syn::Type::Path(__nested_type_path)) = #arg_var_ident
450 && let Some(__nested_seg) = __nested_type_path.path.segments.last()
451 && __nested_seg.ident == #arg_ident_str
452 {
453 } else {
455 break false;
456 }
457 }
458 }
459 }
460
461 let nested_check = generate_nested_arg_check(
462 segment_matcher,
463 "__arg",
464 0,
465 );
466
467 arg_checks.push(quote! {
468 if let Some(__arg) = __args.get(#arg_idx) {
469 #nested_check
470 } else {
471 break false;
472 }
473 });
474 }
475 GenericArgumentPattern::Binding(_span, name) => {
476 arg_checks.push(quote! {
477 if let Some(__arg) = __args.get(#arg_idx) {
478 #name = Some(__arg);
479 } else {
480 break false;
481 }
482 });
483 }
484 }
485 }
486
487 let arg_count = arg_patterns.len();
488 quote! {
489 if let syn::PathArguments::AngleBracketed(__angle_args) = &__seg.arguments {
490 let __args = &__angle_args.args;
491 if __args.len() != #arg_count {
492 break false;
493 }
494 #(#arg_checks)*
495 } else {
496 break false;
497 }
498 }
499 }
500 }
501 } else {
502 quote!()
503 };
504
505 let check = if is_optional {
506 quote! {
507 if __idx < __segments.len() {
508 let __seg = &__segments[__idx];
509 if #name_check {
510 #args_check
511 __idx += 1;
512 }
513 }
514 }
515 } else {
516 quote! {
517 if __idx >= __segments.len() {
518 break false;
519 }
520 let __seg = &__segments[__idx];
521 if !(#name_check) {
522 break false;
523 }
524 #args_check
525 __idx += 1;
526 }
527 };
528
529 check
530 }
531
532 segment_checks.push(handle_segment_matcher(matcher, is_optional));
533 }
534
535 let fully_qualified_check = if *fully_qualified {
536 quote! {
537 if !__path.leading_colon.is_some() {
538 __matched = false;
539 }
540 }
541 } else {
542 quote!()
543 };
544
545 let length_check = match max_len {
546 Some(max) if min_len == max => {
547 quote! {
548 if __segments.len() != #min_len {
549 __matched = false;
550 }
551 }
552 }
553 Some(max) => {
554 quote! {
555 if __segments.len() < #min_len || __segments.len() > #max {
556 __matched = false;
557 }
558 }
559 }
560 None => {
561 quote! {
562 if __segments.len() < #min_len {
563 __matched = false;
564 }
565 }
566 }
567 };
568
569 out.push(PathMatcher {
570 path_expr: path_expr.clone(),
571 fully_qualified_check,
572 length_check,
573 segment_checks,
574 binding_names,
575 })
576 }
577 }
578 }
579
580 out
581}
582
583#[proc_macro]
584pub fn path_match(input: TokenStream) -> TokenStream {
585 let PathMatchArms { path, arms } = parse_macro_input!(input as PathMatchArms);
586
587 let wildcard_arm = arms.last().filter(|arm| {
588 arm
589 .patterns
590 .first()
591 .is_some_and(|pattern| matches!(pattern, PathPattern::Wildcard))
592 });
593
594 if wildcard_arm.is_none() {
595 return syn::Error::new(
596 Span::call_site(),
597 "path_match! requires a wildcard arm `_ => ...` as the last arm",
598 )
599 .to_compile_error()
600 .into();
601 }
602
603 let wildcard_body = &wildcard_arm.unwrap().body;
604 let non_wildcard_arms = &arms[..arms.len() - 1];
605
606 for arm in non_wildcard_arms {
607 if arm
608 .patterns
609 .iter()
610 .any(|pattern| matches!(pattern, PathPattern::Wildcard))
611 {
612 return syn::Error::new(
613 Span::call_site(),
614 "wildcard pattern `_` must be the last arm",
615 )
616 .to_compile_error()
617 .into();
618 }
619 }
620
621 let mut match_checks = Vec::new();
622
623 for arm in non_wildcard_arms {
624 let path_matchers = generate_path_matcher(&path, &arm.patterns);
625 for PathMatcher {
626 path_expr,
627 fully_qualified_check,
628 length_check,
629 segment_checks,
630 binding_names,
631 } in path_matchers
632 {
633 match_checks.push((
634 path_expr,
635 fully_qualified_check,
636 length_check,
637 segment_checks,
638 binding_names,
639 &arm.body,
640 ));
641 }
642 }
643
644 let arms_code = match_checks.into_iter().map(
645 |(path_expr, fq_check, len_check, seg_checks, binding_names, body)| {
646 let spanless_binding_names = binding_names
647 .iter()
648 .map(|(_, name)| name.clone())
649 .collect::<Vec<_>>();
650 let binding_extractions =
651 binding_names.into_iter().map(|(span, name)| {
652 let mut name_in_some = name.clone();
653 name_in_some.set_span(span);
654
655 quote! {
656 let #name_in_some = #name.unwrap();
657 }
658 });
659
660 quote! {
661 {
662 let __path = #path_expr;
663 let __segments = &__path.segments;
664 let mut __idx = 0;
665 let mut __matched = true;
666
667 #(let mut #spanless_binding_names: Option<_> = None;)*
668
669 #fq_check
670
671 if __matched {
672 #len_check
673 }
674
675 if __matched {
676 __matched = loop {
677 #(#seg_checks)*
678 break __matched;
679 }
680 }
681
682 if __matched && __idx != __segments.len() {
683 __matched = false;
684 }
685
686 if __matched {
687 #(#binding_extractions)*
688 return #body;
689 }
690 }
691 }
692 },
693 );
694
695 let expanded = quote! {
696 (|| {
697 #(#arms_code)*
698
699 #wildcard_body
700 })()
701 };
702
703 TokenStream::from(expanded)
704}