1#![doc = include_str!("../README.md")]
2use std::collections::HashMap;
3
4use proc_macro::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{
7 braced,
8 parse::{Parse, ParseStream},
9 parse_macro_input,
10 punctuated::Punctuated,
11 token, GenericArgument, PathArguments, Token,
12};
13
14struct TypeMatch {
15 #[allow(unused)]
16 match_token: Token![match],
17 match_type: syn::Type,
18 #[allow(unused)]
19 brace_token: token::Brace,
20 arms: Punctuated<TypeMatchArm, Token![,]>,
21}
22
23struct TypeMatchArm {
24 pattern: syn::Type,
25 #[allow(unused)]
26 fat_arrow: Token![=>],
27 result: syn::Type,
28}
29
30impl Parse for TypeMatch {
31 fn parse(input: ParseStream) -> syn::Result<Self> {
32 let content;
33 Ok(TypeMatch {
34 match_token: input.parse()?,
35 match_type: input.parse()?,
36 brace_token: braced!(content in input),
37 arms: content.parse_terminated(TypeMatchArm::parse, Token![,])?,
38 })
39 }
40}
41
42impl Parse for TypeMatchArm {
43 fn parse(input: ParseStream) -> syn::Result<Self> {
44 Ok(TypeMatchArm {
45 pattern: input.parse()?,
46 fat_arrow: input.parse()?,
47 result: input.parse()?,
48 })
49 }
50}
51
52#[derive(Default, Clone)]
53struct Wildcards {
54 wildcards: HashMap<String, syn::Type>,
55 lifetimes: HashMap<String, syn::Lifetime>,
56}
57
58impl Wildcards {
59 fn track_wildcard(&mut self, arg: &str, input: &impl ToTokens) {
60 self.wildcards.insert(
61 arg.to_string(),
62 syn::parse2(input.to_token_stream()).expect("Failed to parse a type"),
63 );
64 }
65
66 fn track_lifetime(&mut self, arg: &str, input: &impl ToTokens) {
67 self.wildcards.insert(
68 arg.to_string(),
69 syn::parse2(input.to_token_stream()).expect("Failed to parse a lifetime"),
70 );
71 }
72}
73fn match_type(input: &syn::Type, pattern: &syn::Type) -> Result<Wildcards, &'static str> {
79 match_type_recursive(input, pattern, &mut Wildcards::default())
80}
81
82fn match_type_recursive(
88 mut input: &syn::Type,
89 mut pattern: &syn::Type,
90 wildcards: &mut Wildcards,
91) -> Result<Wildcards, &'static str> {
92 #![allow(unknown_lints)]
93 #![cfg_attr(test, deny(non_exhaustive_omitted_patterns))]
95
96 while let Group(grouped_input) = input {
97 input = &grouped_input.elem;
98 }
99
100 while let Group(grouped_pattern) = pattern {
101 pattern = &grouped_pattern.elem;
102 }
103
104 use syn::Type::*;
105 match (input, pattern) {
106 (input, Infer(_)) => {
107 wildcards.track_wildcard("_", input);
108 Ok(wildcards.clone())
109 }
110
111 (Path(input_path), Path(pattern_path)) => {
112 match_type_path(&input_path.path, &pattern_path.path, wildcards)
113 }
114 (input, Path(pattern_path)) => {
115 if pattern_path.path.segments.len() == 1 {
116 if let Some(first) = pattern_path.path.segments.first() {
117 if first.ident.to_string().starts_with('_') {
118 let last_args = pattern_path.path.segments.last().unwrap();
119 if matches!(&last_args.arguments, PathArguments::None)
120 || matches!(&last_args.arguments, PathArguments::AngleBracketed(args) if args.args.len() == 0)
121 {
122 wildcards.track_wildcard(&first.ident.to_string(), &input);
123 return Ok(wildcards.clone());
124 }
125 }
126 }
127 }
128 Err("Type shapes are not the same")
129 }
130
131 (Array(input), Array(pattern)) => {
132 if input.len.to_token_stream().to_string() != pattern.len.to_token_stream().to_string()
133 {
134 Err("Array length mismatch")
135 } else {
136 match_type_recursive(&input.elem, &pattern.elem, wildcards)
137 }
138 }
139 (BareFn(input), BareFn(pattern)) => {
140 if input.inputs.len() != pattern.inputs.len() {
141 Err("Function argument length mismatch")
142 } else {
143 for (input_arg, pattern_arg) in input.inputs.iter().zip(pattern.inputs.iter()) {
144 match_type_recursive(&input_arg.ty, &pattern_arg.ty, wildcards)?;
145 }
146 Ok(wildcards.clone())
147 }
148 }
149 (Group(_), Group(_)) => {
150 panic!("Groups should not exist at this point");
151 }
152 (ImplTrait(input), ImplTrait(pattern)) => {
153 panic!(
154 "ImplTrait types not supported: {:?} {:?}",
155 input.to_token_stream().to_string(),
156 pattern.to_token_stream().to_string()
157 );
158 }
159 (Macro(input), Macro(pattern)) => {
160 panic!(
161 "Macro types not supported: {:?} {:?}",
162 input.to_token_stream().to_string(),
163 pattern.to_token_stream().to_string()
164 );
165 }
166 (Never(_), Never(_)) => Ok(wildcards.clone()),
167 (Paren(input), Paren(pattern)) => {
168 match_type_recursive(&input.elem, &pattern.elem, wildcards)
169 }
170 (Ptr(input), Ptr(pattern)) => match_type_recursive(&input.elem, &pattern.elem, wildcards),
171 (Reference(input), Reference(pattern)) => {
172 match_type_recursive(&input.elem, &pattern.elem, wildcards)
173 }
174 (Slice(input), Slice(pattern)) => {
175 match_type_recursive(&input.elem, &pattern.elem, wildcards)
176 }
177 (TraitObject(input), TraitObject(pattern)) => {
178 panic!(
179 "TraitObject types not supported: {:?} {:?}",
180 input.to_token_stream().to_string(),
181 pattern.to_token_stream().to_string()
182 );
183 }
184 (Tuple(input), Tuple(pattern)) => {
185 if input.elems.len() != pattern.elems.len() {
186 Err("Tuple length mismatch")
187 } else {
188 for (input_arg, pattern_arg) in input.elems.iter().zip(pattern.elems.iter()) {
189 match_type_recursive(&input_arg, &pattern_arg, wildcards)?;
190 }
191 Ok(wildcards.clone())
192 }
193 }
194 (Verbatim(input), Verbatim(pattern)) => {
195 panic!(
196 "Verbatim types not supported: {:?} {:?}",
197 input.to_token_stream().to_string(),
198 pattern.to_token_stream().to_string()
199 );
200 }
201 _ => Err("Type shapes are not the same"),
202 }
203}
204
205fn match_type_path(
216 input: &syn::Path,
217 pattern: &syn::Path,
218 wildcards: &mut Wildcards,
219) -> Result<Wildcards, &'static str> {
220 let mut is_wildcard = false;
221 if pattern.segments.len() == 1 {
222 if let Some(first) = pattern.segments.first() {
223 if first.ident.to_string().starts_with('_') {
224 is_wildcard = true;
225 }
226 }
227 }
228
229 if is_wildcard {
230 let input_args = &input.segments.last().as_ref().unwrap().arguments;
232 let pattern_args = &pattern.segments.last().as_ref().unwrap().arguments;
233
234 let mut input = input.clone();
235
236 if !matches!(pattern_args, PathArguments::None) {
237 input.segments.last_mut().unwrap().arguments = PathArguments::None;
238 }
239
240 wildcards.track_wildcard(&pattern.segments.first().unwrap().ident.to_string(), &input);
241
242 match_type_path_args(input_args, pattern_args, wildcards)
243 } else {
244 if input.segments.len() != pattern.segments.len() {
245 Err("Path segment lengths are not the same")
246 } else {
247 for (input_segment, pattern_segment) in
248 input.segments.iter().zip(pattern.segments.iter())
249 {
250 if input_segment.ident.to_string() != pattern_segment.ident.to_string() {
251 return Err("Path segment identifiers are not the same");
252 }
253 match_type_path_args(
254 &input_segment.arguments,
255 &pattern_segment.arguments,
256 wildcards,
257 )?;
258 }
259 Ok(wildcards.clone())
260 }
261 }
262}
263
264fn match_type_path_args(
266 input: &PathArguments,
267 pattern: &PathArguments,
268 wildcards: &mut Wildcards,
269) -> Result<Wildcards, &'static str> {
270 match (&input, &pattern) {
271 (_, PathArguments::None) => {}
273 (PathArguments::None, PathArguments::AngleBracketed(args)) if args.args.len() == 0 => {}
275
276 (
277 PathArguments::AngleBracketed(input_args),
278 PathArguments::AngleBracketed(pattern_args),
279 ) => {
280 if input_args.args.len() != pattern_args.args.len() {
281 return Err("Path argument lengths are not the same");
282 }
283 for (input_arg, pattern_arg) in input_args.args.iter().zip(pattern_args.args.iter()) {
284 match (input_arg, pattern_arg) {
285 (GenericArgument::Type(input_arg), GenericArgument::Type(pattern_arg)) => {
286 match_type_recursive(&input_arg, &pattern_arg, wildcards)?;
287 }
288 (
289 GenericArgument::Lifetime(input_arg),
290 GenericArgument::Lifetime(pattern_arg),
291 ) => {
292 if pattern_arg.ident.to_string() != "_" {
293 if input_arg.ident.to_string() != pattern_arg.ident.to_string() {
294 return Err("Lifetime mismatch");
295 }
296 } else {
297 wildcards
298 .track_lifetime(&pattern_arg.ident.to_string(), &pattern_arg.ident);
299 }
300 }
301 _ => {
302 if input_arg.to_token_stream().to_string()
303 != pattern_arg.to_token_stream().to_string()
304 {
305 return Err("Path argument types are not the same");
306 }
307 }
308 }
309 }
310 }
311 (_, PathArguments::Parenthesized(..)) => panic!(
312 "Unsupported parenthesized arguments: {:?}",
313 input.to_token_stream().to_string()
314 ),
315 _ => {
316 return Err("Path arguments are not the same");
317 }
318 }
319 Ok(wildcards.clone())
320}
321
322fn render(mut result: syn::Type, matched: &Wildcards, input: &TypeMatch) -> syn::Type {
324 use syn::Type::*;
326 match &mut result {
327 Path(path) => {
332 if path.path.segments.len() == 1 {
334 if let Some(first) = path.path.segments.first() {
335 if first.ident.to_string().starts_with('_') {
336 let Some(wildcard) = matched.wildcards.get(&first.ident.to_string()) else {
337 panic!("Unknown wildcard: {:?}", first.ident.to_string());
338 };
339
340 if let Some(args) = path.path.segments.last_mut() {
342 if !matches!(args.arguments, PathArguments::None) {
343 match wildcard {
344 Path(wildcard_path) => {
345 let last_segment =
346 path.path.segments.last_mut().unwrap().clone();
347
348 *path = wildcard_path.clone();
349
350 path.path.segments.last_mut().unwrap().arguments =
351 last_segment.arguments;
352
353 for segment in &mut path.path.segments {
354 segment.arguments = render_path_args(
355 segment.arguments.clone(),
356 matched,
357 input,
358 );
359 }
360 return result;
361 }
362 _ => {
363 panic!(
364 "Wildcard is not a Path type,: {:?}",
365 wildcard.to_token_stream().to_string()
366 );
367 }
368 }
369 }
370 }
371
372 return wildcard.clone();
373 }
374 }
375 }
376
377 for segment in &mut path.path.segments {
378 segment.arguments = render_path_args(segment.arguments.clone(), matched, input);
379 }
380
381 return result;
382 }
383 Reference(reference) => {
384 if let Some(lifetime) = &mut reference.lifetime {
385 if lifetime.ident.to_string().starts_with("_") && lifetime.ident != "_" {
386 *lifetime = matched
387 .lifetimes
388 .get(&lifetime.ident.to_string())
389 .expect("Unknown lifetime")
390 .clone();
391 }
392 }
393 reference.elem = Box::new(render(*reference.elem.clone(), matched, input));
394 return result;
395 }
396 Slice(slice) => {
397 slice.elem = Box::new(render(*slice.elem.clone(), matched, input));
398 return result;
399 }
400 Macro(macro_type) => {
401 if macro_type.mac.path.segments.len() == 1 {
402 if let Some(first) = macro_type.mac.path.segments.first() {
403 if first.ident == "recurse" {
404 let recurse_input_type =
405 syn::parse2::<syn::Type>(macro_type.mac.tokens.clone())
406 .expect("Recursive call failed");
407 let recurse_type = render(recurse_input_type, &matched, input);
408
409 for arm in &input.arms {
410 if let Ok(matched) = match_type(&recurse_type, &arm.pattern) {
411 return render(arm.result.clone(), &matched, input);
412 }
413 }
414 panic!(
415 "No recursive match found for {:?}",
416 recurse_type.to_token_stream().to_string()
417 );
418 }
419 }
420 }
421 panic!(
422 "Unhandled macro: {:?}",
423 macro_type.mac.path.to_token_stream().to_string()
424 );
425 }
426 _ => {
427 panic!("Unhandled type: {:?}", result.to_token_stream().to_string());
428 }
429 }
430}
431
432fn render_path_args(
433 mut args: PathArguments,
434 matched: &Wildcards,
435 input: &TypeMatch,
436) -> PathArguments {
437 match &mut args {
438 PathArguments::None => {}
439 PathArguments::AngleBracketed(args) => {
440 for arg in &mut args.args {
441 match arg {
442 GenericArgument::Type(arg) => {
443 *arg = render(arg.clone(), matched, input);
444 }
445 GenericArgument::Lifetime(arg) => {
446 if arg.ident.to_string().starts_with("_") && arg.ident != "_" {
447 *arg = matched
448 .lifetimes
449 .get(&arg.ident.to_string())
450 .expect("Unknown lifetime")
451 .clone();
452 }
453 }
454 _ => {}
455 }
456 }
457 }
458 _ => {
459 panic!(
460 "Unhandled path arguments: {:?}",
461 args.to_token_stream().to_string()
462 );
463 }
464 }
465 args
466}
467
468#[proc_macro]
481pub fn map_types(input: TokenStream) -> TokenStream {
482 let input = parse_macro_input!(input as TypeMatch);
483
484 let mut out = String::new();
485 for arm in &input.arms {
486 out.push_str(&arm.pattern.to_token_stream().to_string());
487
488 match match_type(&input.match_type, &arm.pattern) {
489 Ok(matched) => {
490 let result: proc_macro2::TokenStream =
491 render(arm.result.clone(), &matched, &input).into_token_stream();
492 return TokenStream::from(quote! { #result });
493 }
494 Err(e) => {
495 out.push_str(&format!(": No match: {e}\n"));
496 }
497 }
498 }
499
500 panic!(
501 "No match found for {:?}\n{}",
502 input.match_type.to_token_stream().to_string(),
503 out
504 );
505}
506
507struct AssertTypeMatches {
508 input_type: syn::Type,
509 #[allow(unused)]
510 comma: Token![,],
511 expected_type: syn::Type,
512 message: Option<syn::LitStr>,
513}
514
515impl Parse for AssertTypeMatches {
516 fn parse(input: ParseStream) -> syn::Result<Self> {
517 Ok(AssertTypeMatches {
518 input_type: input.parse()?,
519 comma: input.parse()?,
520 expected_type: input.parse()?,
521 message: input.parse()?,
522 })
523 }
524}
525
526#[proc_macro]
527pub fn assert_type_matches(input: TokenStream) -> TokenStream {
528 let input = parse_macro_input!(input as AssertTypeMatches);
529
530 match match_type(&input.input_type, &input.expected_type) {
531 Err(e) => {
532 if let Some(message) = input.message {
533 panic!("{}", message.value());
534 } else {
535 panic!(
536 "Type mismatch: {:?} !~ {:?}: {e}",
537 input.input_type.to_token_stream().to_string(),
538 input.expected_type.to_token_stream().to_string()
539 );
540 }
541 }
542 Ok(_) => TokenStream::new(),
543 }
544}
545
546#[proc_macro]
547pub fn assert_type_not_matches(input: TokenStream) -> TokenStream {
548 let input = parse_macro_input!(input as AssertTypeMatches);
549
550 match match_type(&input.input_type, &input.expected_type) {
551 Err(_) => TokenStream::new(),
552 Ok(_) => {
553 panic!(
554 "Type matches when it should not: {:?} ~ {:?}",
555 input.input_type.to_token_stream().to_string(),
556 input.expected_type.to_token_stream().to_string()
557 );
558 }
559 }
560}
561
562#[proc_macro]
563pub fn recurse(_: TokenStream) -> TokenStream {
564 panic!("Don't use this macro directly, use `map_types!` instead");
565}