traitify_core/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{quote, ToTokens};
3use syn::{
4    AngleBracketedGenericArguments, GenericArgument, Ident, ImplItemFn, PathArguments, PathSegment,
5    TypePath,
6};
7
8#[derive(Debug)]
9struct Args {
10    trait_name: Ident,
11    dyn_generics: Vec<String>,
12}
13
14impl syn::parse::Parse for Args {
15    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
16        let trait_name = input.parse()?;
17        input.parse::<syn::Token!(,)>()?;
18        input.parse::<syn::Token!(dyn)>()?;
19        input.parse::<syn::Token!(=)>()?;
20
21        let generics_group;
22        syn::bracketed!(generics_group in input);
23        let dyn_generics = generics_group
24            .parse_terminated(Ident::parse, syn::Token!(,))?
25            .into_iter()
26            .map(|ident| ident.to_string())
27            .collect();
28
29        Ok(Args {
30            trait_name,
31            dyn_generics,
32        })
33    }
34}
35
36pub fn traitify(args: TokenStream, input: TokenStream) -> TokenStream {
37    let args = syn::parse2::<Args>(args).expect("Parsing args");
38    let input = syn::parse2::<syn::ItemImpl>(input).unwrap();
39
40    let functions = input
41        .items
42        .iter()
43        .filter_map(|item| match item {
44            syn::ImplItem::Fn(function) => {
45                if matches!(function.vis, syn::Visibility::Public(_))
46                    && function.sig.constness.is_none()
47                    && function.sig.abi.is_none()
48                    // No arguments may of a generic type we're dyn over
49                    && function.sig.inputs.iter().all(|arg| match arg {
50                        syn::FnArg::Receiver(_) => true,
51                        syn::FnArg::Typed(t) => !args.dyn_generics.contains(&t.ty.to_token_stream().to_string()),
52                    })
53                {
54                    let mut function_signature = function.sig.clone();
55
56                    // No &self makes the trait not object safe. So add a where clause making it `Self: Sized`.
57                    if function.sig.receiver().is_none() {
58                        let function_where_clause = function_signature.generics.make_where_clause();
59                        function_where_clause
60                            .predicates
61                            .push(syn::parse_quote!(Self: Sized));
62                    }
63                    Some(function_signature)
64                } else {
65                    None
66                }
67            }
68            _ => None,
69        })
70        .collect::<Vec<_>>();
71
72    let trait_generics = input
73        .generics
74        .params
75        .iter()
76        .filter(|param| match param {
77            syn::GenericParam::Type(t) => !args.dyn_generics.contains(&t.ident.to_string()),
78            syn::GenericParam::Const(t) => !args.dyn_generics.contains(&t.ident.to_string()),
79            _ => true,
80        })
81        .collect::<Vec<_>>();
82
83    let trait_definition = {
84        let trait_name = args.trait_name.clone();
85
86        let mut trait_where = input.generics.clone();
87        let trait_where =
88            trait_where
89                .make_where_clause()
90                .predicates
91                .iter()
92                .filter(|pred| match pred {
93                    syn::WherePredicate::Lifetime(_) => true,
94                    syn::WherePredicate::Type(t) => match &t.bounded_ty {
95                        syn::Type::Path(p) => !args
96                            .dyn_generics
97                            .contains(&p.path.to_token_stream().to_string()),
98                        _ => todo!(),
99                    },
100                    _ => true,
101                });
102
103        quote!(
104            pub trait #trait_name<#(#trait_generics,)*> where #(#trait_where,)* {
105                #(#functions;)*
106            }
107        )
108    };
109
110    let trait_impl = {
111        // We're gonna take the original impl, strip out all functions, make it implement the exact functions of the trait
112        let mut trait_impl = input.clone();
113        trait_impl.items.clear();
114
115        trait_impl.attrs.clear();
116
117        let trait_generic_arguments = if trait_generics.is_empty() {
118            PathArguments::None
119        } else {
120            PathArguments::AngleBracketed(AngleBracketedGenericArguments {
121                colon2_token: None,
122                lt_token: syn::Token![<](Span::call_site()),
123                args: trait_generics
124                    .iter()
125                    .map(|param| match param {
126                        syn::GenericParam::Lifetime(lt) => {
127                            GenericArgument::Lifetime(lt.lifetime.clone())
128                        }
129                        syn::GenericParam::Type(t) => {
130                            GenericArgument::Type(syn::Type::Path(TypePath {
131                                qself: None,
132                                path: PathSegment {
133                                    ident: t.ident.clone(),
134                                    arguments: Default::default(),
135                                }
136                                .into(),
137                            }))
138                        }
139                        syn::GenericParam::Const(c) => {
140                            GenericArgument::Type(syn::Type::Path(TypePath {
141                                qself: None,
142                                path: PathSegment {
143                                    ident: c.ident.clone(),
144                                    arguments: Default::default(),
145                                }
146                                .into(),
147                            }))
148                        }
149                    })
150                    .collect(),
151                gt_token: syn::Token![>](Span::call_site()),
152            })
153        };
154        trait_impl.trait_ = Some((
155            None,
156            PathSegment {
157                ident: args.trait_name,
158                arguments: trait_generic_arguments,
159            }
160            .into(),
161            syn::token::For(Span::call_site()),
162        ));
163
164        trait_impl.items = functions
165            .iter()
166            .map(|signature| {
167                let function_name = signature.ident.clone();
168                let function_params = signature.inputs.iter().map(|arg| match arg {
169                    syn::FnArg::Receiver(_) => quote!(self),
170                    syn::FnArg::Typed(t) => t.pat.to_token_stream(),
171                });
172
173                syn::ImplItem::Fn(ImplItemFn {
174                    attrs: Vec::new(),
175                    vis: syn::Visibility::Inherited,
176                    defaultness: None,
177                    sig: signature.clone(),
178                    block: syn::parse_quote!({
179                        Self::#function_name(#(#function_params,)*)
180                    }),
181                })
182            })
183            .collect();
184
185        trait_impl
186    };
187
188    quote!(
189        #input
190
191        #trait_definition
192
193        #trait_impl
194    )
195}