type_mapper/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::collections::HashMap;
3
4use proc_macro::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{
7    braced,
8    parse::{Parse, ParseStream},
9    parse_macro_input,
10    punctuated::Punctuated,
11    token, GenericArgument, PathArguments, Token,
12};
13
14struct TypeMatch {
15    #[allow(unused)]
16    match_token: Token![match],
17    match_type: syn::Type,
18    #[allow(unused)]
19    brace_token: token::Brace,
20    arms: Punctuated<TypeMatchArm, Token![,]>,
21}
22
23struct TypeMatchArm {
24    pattern: syn::Type,
25    #[allow(unused)]
26    fat_arrow: Token![=>],
27    result: syn::Type,
28}
29
30impl Parse for TypeMatch {
31    fn parse(input: ParseStream) -> syn::Result<Self> {
32        let content;
33        Ok(TypeMatch {
34            match_token: input.parse()?,
35            match_type: input.parse()?,
36            brace_token: braced!(content in input),
37            arms: content.parse_terminated(TypeMatchArm::parse, Token![,])?,
38        })
39    }
40}
41
42impl Parse for TypeMatchArm {
43    fn parse(input: ParseStream) -> syn::Result<Self> {
44        Ok(TypeMatchArm {
45            pattern: input.parse()?,
46            fat_arrow: input.parse()?,
47            result: input.parse()?,
48        })
49    }
50}
51
52#[derive(Default, Clone)]
53struct Wildcards {
54    wildcards: HashMap<String, syn::Type>,
55    lifetimes: HashMap<String, syn::Lifetime>,
56}
57
58impl Wildcards {
59    fn track_wildcard(&mut self, arg: &str, input: &impl ToTokens) {
60        self.wildcards.insert(
61            arg.to_string(),
62            syn::parse2(input.to_token_stream()).expect("Failed to parse a type"),
63        );
64    }
65
66    fn track_lifetime(&mut self, arg: &str, input: &impl ToTokens) {
67        self.wildcards.insert(
68            arg.to_string(),
69            syn::parse2(input.to_token_stream()).expect("Failed to parse a lifetime"),
70        );
71    }
72}
73/// Attempts to match the input type with the pattern type. If there's a match, returns the templated generics:
74///
75///  - `_`, `_X` are wildcard types (optionally named)
76///  - `__`, `__X` are multi-generic wildcards (optionally named)
77///  - `'_`, `'_X` are lifetime wildcards (optionally named)
78fn match_type(input: &syn::Type, pattern: &syn::Type) -> Result<Wildcards, &'static str> {
79    match_type_recursive(input, pattern, &mut Wildcards::default())
80}
81
82/// Attempts to match the input type with the pattern type. If there's a match, returns the templated generics:
83///
84///  - `_`, `_X` are wildcard types (optionally named)
85///  - `__`, `__X` are multi-generic wildcards (optionally named)
86///  - `'_`, `'_X` are lifetime wildcards (optionally named)
87fn match_type_recursive(
88    mut input: &syn::Type,
89    mut pattern: &syn::Type,
90    wildcards: &mut Wildcards,
91) -> Result<Wildcards, &'static str> {
92    #![allow(unknown_lints)]
93    // Needs #![feature(non_exhaustive_omitted_patterns_lint)]
94    #![cfg_attr(test, deny(non_exhaustive_omitted_patterns))]
95
96    while let Group(grouped_input) = input {
97        input = &grouped_input.elem;
98    }
99
100    while let Group(grouped_pattern) = pattern {
101        pattern = &grouped_pattern.elem;
102    }
103
104    use syn::Type::*;
105    match (input, pattern) {
106        (input, Infer(_)) => {
107            wildcards.track_wildcard("_", input);
108            Ok(wildcards.clone())
109        }
110
111        (Path(input_path), Path(pattern_path)) => {
112            match_type_path(&input_path.path, &pattern_path.path, wildcards)
113        }
114        (input, Path(pattern_path)) => {
115            if pattern_path.path.segments.len() == 1 {
116                if let Some(first) = pattern_path.path.segments.first() {
117                    if first.ident.to_string().starts_with('_') {
118                        let last_args = pattern_path.path.segments.last().unwrap();
119                        if matches!(&last_args.arguments, PathArguments::None)
120                            || matches!(&last_args.arguments, PathArguments::AngleBracketed(args) if args.args.len() == 0)
121                        {
122                            wildcards.track_wildcard(&first.ident.to_string(), &input);
123                            return Ok(wildcards.clone());
124                        }
125                    }
126                }
127            }
128            Err("Type shapes are not the same")
129        }
130
131        (Array(input), Array(pattern)) => {
132            if input.len.to_token_stream().to_string() != pattern.len.to_token_stream().to_string()
133            {
134                Err("Array length mismatch")
135            } else {
136                match_type_recursive(&input.elem, &pattern.elem, wildcards)
137            }
138        }
139        (BareFn(input), BareFn(pattern)) => {
140            if input.inputs.len() != pattern.inputs.len() {
141                Err("Function argument length mismatch")
142            } else {
143                for (input_arg, pattern_arg) in input.inputs.iter().zip(pattern.inputs.iter()) {
144                    match_type_recursive(&input_arg.ty, &pattern_arg.ty, wildcards)?;
145                }
146                Ok(wildcards.clone())
147            }
148        }
149        (Group(_), Group(_)) => {
150            panic!("Groups should not exist at this point");
151        }
152        (ImplTrait(input), ImplTrait(pattern)) => {
153            panic!(
154                "ImplTrait types not supported: {:?} {:?}",
155                input.to_token_stream().to_string(),
156                pattern.to_token_stream().to_string()
157            );
158        }
159        (Macro(input), Macro(pattern)) => {
160            panic!(
161                "Macro types not supported: {:?} {:?}",
162                input.to_token_stream().to_string(),
163                pattern.to_token_stream().to_string()
164            );
165        }
166        (Never(_), Never(_)) => Ok(wildcards.clone()),
167        (Paren(input), Paren(pattern)) => {
168            match_type_recursive(&input.elem, &pattern.elem, wildcards)
169        }
170        (Ptr(input), Ptr(pattern)) => match_type_recursive(&input.elem, &pattern.elem, wildcards),
171        (Reference(input), Reference(pattern)) => {
172            match_type_recursive(&input.elem, &pattern.elem, wildcards)
173        }
174        (Slice(input), Slice(pattern)) => {
175            match_type_recursive(&input.elem, &pattern.elem, wildcards)
176        }
177        (TraitObject(input), TraitObject(pattern)) => {
178            panic!(
179                "TraitObject types not supported: {:?} {:?}",
180                input.to_token_stream().to_string(),
181                pattern.to_token_stream().to_string()
182            );
183        }
184        (Tuple(input), Tuple(pattern)) => {
185            if input.elems.len() != pattern.elems.len() {
186                Err("Tuple length mismatch")
187            } else {
188                for (input_arg, pattern_arg) in input.elems.iter().zip(pattern.elems.iter()) {
189                    match_type_recursive(&input_arg, &pattern_arg, wildcards)?;
190                }
191                Ok(wildcards.clone())
192            }
193        }
194        (Verbatim(input), Verbatim(pattern)) => {
195            panic!(
196                "Verbatim types not supported: {:?} {:?}",
197                input.to_token_stream().to_string(),
198                pattern.to_token_stream().to_string()
199            );
200        }
201        _ => Err("Type shapes are not the same"),
202    }
203}
204
205/// The core of the type matching logic. Most of the interesting matches happen here.
206///
207/// Some examples:
208///
209///  - _T matches every type, with or without generics
210///  - _T<> matches every type, without generics
211///  - _T<_> matches every type, with a single generic
212///  - _T<__> matches every type with any number of generics (0..infinity)
213///  - _T<'_, _> matches any type with one lifetime parameter and one generic
214///  - _T<___> matches any number of lifetime parameters and generics (0..infinity)
215fn match_type_path(
216    input: &syn::Path,
217    pattern: &syn::Path,
218    wildcards: &mut Wildcards,
219) -> Result<Wildcards, &'static str> {
220    let mut is_wildcard = false;
221    if pattern.segments.len() == 1 {
222        if let Some(first) = pattern.segments.first() {
223            if first.ident.to_string().starts_with('_') {
224                is_wildcard = true;
225            }
226        }
227    }
228
229    if is_wildcard {
230        // In a wildcard match, we only care about the final path segment args.
231        let input_args = &input.segments.last().as_ref().unwrap().arguments;
232        let pattern_args = &pattern.segments.last().as_ref().unwrap().arguments;
233
234        let mut input = input.clone();
235
236        if !matches!(pattern_args, PathArguments::None) {
237            input.segments.last_mut().unwrap().arguments = PathArguments::None;
238        }
239
240        wildcards.track_wildcard(&pattern.segments.first().unwrap().ident.to_string(), &input);
241
242        match_type_path_args(input_args, pattern_args, wildcards)
243    } else {
244        if input.segments.len() != pattern.segments.len() {
245            Err("Path segment lengths are not the same")
246        } else {
247            for (input_segment, pattern_segment) in
248                input.segments.iter().zip(pattern.segments.iter())
249            {
250                if input_segment.ident.to_string() != pattern_segment.ident.to_string() {
251                    return Err("Path segment identifiers are not the same");
252                }
253                match_type_path_args(
254                    &input_segment.arguments,
255                    &pattern_segment.arguments,
256                    wildcards,
257                )?;
258            }
259            Ok(wildcards.clone())
260        }
261    }
262}
263
264/// Matches the arguments of a path ie: `<...>`.
265fn match_type_path_args(
266    input: &PathArguments,
267    pattern: &PathArguments,
268    wildcards: &mut Wildcards,
269) -> Result<Wildcards, &'static str> {
270    match (&input, &pattern) {
271        // Always match if the pattern is empty, but still capture wildcards.
272        (_, PathArguments::None) => {}
273        // If the pattern is empty <>, we match if the input is empty as well.
274        (PathArguments::None, PathArguments::AngleBracketed(args)) if args.args.len() == 0 => {}
275
276        (
277            PathArguments::AngleBracketed(input_args),
278            PathArguments::AngleBracketed(pattern_args),
279        ) => {
280            if input_args.args.len() != pattern_args.args.len() {
281                return Err("Path argument lengths are not the same");
282            }
283            for (input_arg, pattern_arg) in input_args.args.iter().zip(pattern_args.args.iter()) {
284                match (input_arg, pattern_arg) {
285                    (GenericArgument::Type(input_arg), GenericArgument::Type(pattern_arg)) => {
286                        match_type_recursive(&input_arg, &pattern_arg, wildcards)?;
287                    }
288                    (
289                        GenericArgument::Lifetime(input_arg),
290                        GenericArgument::Lifetime(pattern_arg),
291                    ) => {
292                        if pattern_arg.ident.to_string() != "_" {
293                            if input_arg.ident.to_string() != pattern_arg.ident.to_string() {
294                                return Err("Lifetime mismatch");
295                            }
296                        } else {
297                            wildcards
298                                .track_lifetime(&pattern_arg.ident.to_string(), &pattern_arg.ident);
299                        }
300                    }
301                    _ => {
302                        if input_arg.to_token_stream().to_string()
303                            != pattern_arg.to_token_stream().to_string()
304                        {
305                            return Err("Path argument types are not the same");
306                        }
307                    }
308                }
309            }
310        }
311        (_, PathArguments::Parenthesized(..)) => panic!(
312            "Unsupported parenthesized arguments: {:?}",
313            input.to_token_stream().to_string()
314        ),
315        _ => {
316            return Err("Path arguments are not the same");
317        }
318    }
319    Ok(wildcards.clone())
320}
321
322/// Renders a type based on the matched wildcards. We render by updating the type in-place.
323fn render(mut result: syn::Type, matched: &Wildcards, input: &TypeMatch) -> syn::Type {
324    // TODO: This clones a lot, but does it really matter?
325    use syn::Type::*;
326    match &mut result {
327        // _T => _T: copy generics from _T verbatim
328        // _T => _T<>: copy path to _T but remove generics (_T must be a Path type)
329        // _T => _T<X, Y>: copy path to _T but set generics (_T must be a Path type)
330        // _T<X> => _T: copy path to _T only (_T must be a Path type)
331        Path(path) => {
332            // If the path is a wildcard, copy the wildcard type
333            if path.path.segments.len() == 1 {
334                if let Some(first) = path.path.segments.first() {
335                    if first.ident.to_string().starts_with('_') {
336                        let Some(wildcard) = matched.wildcards.get(&first.ident.to_string()) else {
337                            panic!("Unknown wildcard: {:?}", first.ident.to_string());
338                        };
339
340                        // If the pattern contains a final segment with arguments, the wildcard must be a Path type
341                        if let Some(args) = path.path.segments.last_mut() {
342                            if !matches!(args.arguments, PathArguments::None) {
343                                match wildcard {
344                                    Path(wildcard_path) => {
345                                        let last_segment =
346                                            path.path.segments.last_mut().unwrap().clone();
347
348                                        *path = wildcard_path.clone();
349
350                                        path.path.segments.last_mut().unwrap().arguments =
351                                            last_segment.arguments;
352
353                                        for segment in &mut path.path.segments {
354                                            segment.arguments = render_path_args(
355                                                segment.arguments.clone(),
356                                                matched,
357                                                input,
358                                            );
359                                        }
360                                        return result;
361                                    }
362                                    _ => {
363                                        panic!(
364                                            "Wildcard is not a Path type,: {:?}",
365                                            wildcard.to_token_stream().to_string()
366                                        );
367                                    }
368                                }
369                            }
370                        }
371
372                        return wildcard.clone();
373                    }
374                }
375            }
376
377            for segment in &mut path.path.segments {
378                segment.arguments = render_path_args(segment.arguments.clone(), matched, input);
379            }
380
381            return result;
382        }
383        Reference(reference) => {
384            if let Some(lifetime) = &mut reference.lifetime {
385                if lifetime.ident.to_string().starts_with("_") && lifetime.ident != "_" {
386                    *lifetime = matched
387                        .lifetimes
388                        .get(&lifetime.ident.to_string())
389                        .expect("Unknown lifetime")
390                        .clone();
391                }
392            }
393            reference.elem = Box::new(render(*reference.elem.clone(), matched, input));
394            return result;
395        }
396        Slice(slice) => {
397            slice.elem = Box::new(render(*slice.elem.clone(), matched, input));
398            return result;
399        }
400        Macro(macro_type) => {
401            if macro_type.mac.path.segments.len() == 1 {
402                if let Some(first) = macro_type.mac.path.segments.first() {
403                    if first.ident == "recurse" {
404                        let recurse_input_type =
405                            syn::parse2::<syn::Type>(macro_type.mac.tokens.clone())
406                                .expect("Recursive call failed");
407                        let recurse_type = render(recurse_input_type, &matched, input);
408
409                        for arm in &input.arms {
410                            if let Ok(matched) = match_type(&recurse_type, &arm.pattern) {
411                                return render(arm.result.clone(), &matched, input);
412                            }
413                        }
414                        panic!(
415                            "No recursive match found for {:?}",
416                            recurse_type.to_token_stream().to_string()
417                        );
418                    }
419                }
420            }
421            panic!(
422                "Unhandled macro: {:?}",
423                macro_type.mac.path.to_token_stream().to_string()
424            );
425        }
426        _ => {
427            panic!("Unhandled type: {:?}", result.to_token_stream().to_string());
428        }
429    }
430}
431
432fn render_path_args(
433    mut args: PathArguments,
434    matched: &Wildcards,
435    input: &TypeMatch,
436) -> PathArguments {
437    match &mut args {
438        PathArguments::None => {}
439        PathArguments::AngleBracketed(args) => {
440            for arg in &mut args.args {
441                match arg {
442                    GenericArgument::Type(arg) => {
443                        *arg = render(arg.clone(), matched, input);
444                    }
445                    GenericArgument::Lifetime(arg) => {
446                        if arg.ident.to_string().starts_with("_") && arg.ident != "_" {
447                            *arg = matched
448                                .lifetimes
449                                .get(&arg.ident.to_string())
450                                .expect("Unknown lifetime")
451                                .clone();
452                        }
453                    }
454                    _ => {}
455                }
456            }
457        }
458        _ => {
459            panic!(
460                "Unhandled path arguments: {:?}",
461                args.to_token_stream().to_string()
462            );
463        }
464    }
465    args
466}
467
468/// Matches something like this:
469///
470/// ```
471/// use type_mapper::map_types;
472///
473/// let x: map_types!(
474///     match Vec<T> {
475///         Vec<_> => u8,
476///         _ => u16,
477///     }
478/// ) = 1_u8;
479/// ```
480#[proc_macro]
481pub fn map_types(input: TokenStream) -> TokenStream {
482    let input = parse_macro_input!(input as TypeMatch);
483
484    let mut out = String::new();
485    for arm in &input.arms {
486        out.push_str(&arm.pattern.to_token_stream().to_string());
487
488        match match_type(&input.match_type, &arm.pattern) {
489            Ok(matched) => {
490                let result: proc_macro2::TokenStream =
491                    render(arm.result.clone(), &matched, &input).into_token_stream();
492                return TokenStream::from(quote! { #result });
493            }
494            Err(e) => {
495                out.push_str(&format!(": No match: {e}\n"));
496            }
497        }
498    }
499
500    panic!(
501        "No match found for {:?}\n{}",
502        input.match_type.to_token_stream().to_string(),
503        out
504    );
505}
506
507struct AssertTypeMatches {
508    input_type: syn::Type,
509    #[allow(unused)]
510    comma: Token![,],
511    expected_type: syn::Type,
512    message: Option<syn::LitStr>,
513}
514
515impl Parse for AssertTypeMatches {
516    fn parse(input: ParseStream) -> syn::Result<Self> {
517        Ok(AssertTypeMatches {
518            input_type: input.parse()?,
519            comma: input.parse()?,
520            expected_type: input.parse()?,
521            message: input.parse()?,
522        })
523    }
524}
525
526#[proc_macro]
527pub fn assert_type_matches(input: TokenStream) -> TokenStream {
528    let input = parse_macro_input!(input as AssertTypeMatches);
529
530    match match_type(&input.input_type, &input.expected_type) {
531        Err(e) => {
532            if let Some(message) = input.message {
533                panic!("{}", message.value());
534            } else {
535                panic!(
536                    "Type mismatch: {:?} !~ {:?}: {e}",
537                    input.input_type.to_token_stream().to_string(),
538                    input.expected_type.to_token_stream().to_string()
539                );
540            }
541        }
542        Ok(_) => TokenStream::new(),
543    }
544}
545
546#[proc_macro]
547pub fn assert_type_not_matches(input: TokenStream) -> TokenStream {
548    let input = parse_macro_input!(input as AssertTypeMatches);
549
550    match match_type(&input.input_type, &input.expected_type) {
551        Err(_) => TokenStream::new(),
552        Ok(_) => {
553            panic!(
554                "Type matches when it should not: {:?} ~ {:?}",
555                input.input_type.to_token_stream().to_string(),
556                input.expected_type.to_token_stream().to_string()
557            );
558        }
559    }
560}
561
562#[proc_macro]
563pub fn recurse(_: TokenStream) -> TokenStream {
564    panic!("Don't use this macro directly, use `map_types!` instead");
565}