telety_impl/visitor/
apply_generic_arguments.rs

1use std::collections::HashMap;
2
3use syn::{
4    Expr, GenericArgument, GenericParam, Generics, Ident, Lifetime, Type, parse_quote_spanned,
5    spanned::Spanned as _,
6};
7
8pub struct ApplyGenericArguments<'p> {
9    lifetimes: HashMap<&'p Lifetime, Option<Lifetime>>,
10    types: HashMap<&'p Ident, Type>,
11    consts: HashMap<&'p Ident, Expr>,
12}
13
14impl<'p> ApplyGenericArguments<'p> {
15    pub fn new<'a>(
16        params: &'p Generics,
17        args: impl IntoIterator<Item = &'a GenericArgument>,
18    ) -> syn::Result<Self> {
19        // Default values can refer to preceding arguments (e.g. `<T, U = T>`),
20        // so we need to run replacement on those defaults when
21        // we encounter them, using the mapping we have built so far.
22        let mut v = Self {
23            lifetimes: HashMap::new(),
24            types: HashMap::new(),
25            consts: HashMap::new(),
26        };
27
28        let mut args_iter = args.into_iter().peekable();
29        for param in &params.params {
30            match param {
31                GenericParam::Lifetime(param_lifetime) => {
32                    if let Some(GenericArgument::Lifetime(arg_lifetime)) = args_iter.peek() {
33                        v.lifetimes
34                            .insert(&param_lifetime.lifetime, Some(arg_lifetime.clone()));
35                    } else {
36                        // TODO This isn't correct in locations where the parameter is used multiple times
37                        v.lifetimes.insert(&param_lifetime.lifetime, None);
38                    }
39                }
40                GenericParam::Type(param_type) => {
41                    if let Some(arg) = args_iter.next() {
42                        if let GenericArgument::Type(arg_type) = arg {
43                            v.types.insert(&param_type.ident, arg_type.clone());
44                        } else {
45                            return Err(syn::Error::new(arg.span(), "Expected a type argument"));
46                        }
47                    } else if let Some(param_default) = &param_type.default {
48                        let mut param_default = param_default.clone();
49                        directed_visit::visit_mut(
50                            &mut directed_visit::syn::direct::FullDefault,
51                            &mut v,
52                            &mut param_default,
53                        );
54                        v.types.insert(&param_type.ident, param_default);
55                    } else {
56                        return Err(syn::Error::new(
57                            param_type.span(),
58                            "Expected an argument for parameter",
59                        ));
60                    }
61                }
62                GenericParam::Const(param_const) => {
63                    if let Some(arg) = args_iter.next() {
64                        if let GenericArgument::Const(arg_const) = arg {
65                            v.consts.insert(&param_const.ident, arg_const.clone());
66                        } else {
67                            return Err(syn::Error::new(arg.span(), "Expected a const argument"));
68                        }
69                    } else if let Some(param_default) = &param_const.default {
70                        let mut param_default = param_default.clone();
71                        directed_visit::visit_mut(
72                            &mut directed_visit::syn::direct::FullDefault,
73                            &mut v,
74                            &mut param_default,
75                        );
76                        v.consts.insert(&param_const.ident, param_default);
77                    } else {
78                        return Err(syn::Error::new(
79                            param_const.span(),
80                            "Expected an argument for parameter",
81                        ));
82                    }
83                }
84            }
85        }
86
87        Ok(v)
88    }
89}
90
91impl<'p> directed_visit::syn::visit::FullMut for ApplyGenericArguments<'p> {
92    fn visit_lifetime_mut<D>(
93        visitor: directed_visit::Visitor<'_, D, Self>,
94        node: &mut syn::Lifetime,
95    ) where
96        D: directed_visit::DirectMut<Self, syn::Lifetime> + ?Sized,
97    {
98        if let Some(lifetime_arg) = visitor.lifetimes.get(node) {
99            if let Some(lifetime_arg) = lifetime_arg {
100                *node = lifetime_arg.clone();
101            } else {
102                let span = node.span();
103                *node = parse_quote_spanned! { span => '_ };
104            }
105            return;
106        }
107
108        directed_visit::Visitor::visit_mut(visitor, node);
109    }
110
111    fn visit_type_mut<D>(visitor: directed_visit::Visitor<'_, D, Self>, node: &mut syn::Type)
112    where
113        D: directed_visit::DirectMut<Self, syn::Type> + ?Sized,
114    {
115        if let Type::Path(path) = node {
116            // TODO should check first segment to support some associated types
117            if let Some(ident) = path.path.get_ident()
118                && let Some(value) = visitor.types.get(ident)
119            {
120                *node = value.clone();
121                return;
122            }
123        }
124
125        directed_visit::Visitor::visit_mut(visitor, node);
126    }
127
128    fn visit_expr_mut<D>(visitor: directed_visit::Visitor<'_, D, Self>, node: &mut syn::Expr)
129    where
130        D: directed_visit::DirectMut<Self, syn::Expr> + ?Sized,
131    {
132        if let Expr::Path(path) = node
133            && let Some(ident) = path.path.get_ident()
134            && let Some(value) = visitor.consts.get(ident)
135        {
136            *node = value.clone();
137            return;
138        }
139
140        directed_visit::Visitor::visit_mut(visitor, node);
141    }
142}