1use proc_macro2::{Span, TokenStream};
2use quote::{quote, ToTokens};
3use syn::{
4 AngleBracketedGenericArguments, GenericArgument, Ident, ImplItemFn, PathArguments, PathSegment,
5 TypePath,
6};
7
8#[derive(Debug)]
9struct Args {
10 trait_name: Ident,
11 dyn_generics: Vec<String>,
12}
13
14impl syn::parse::Parse for Args {
15 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
16 let trait_name = input.parse()?;
17 input.parse::<syn::Token!(,)>()?;
18 input.parse::<syn::Token!(dyn)>()?;
19 input.parse::<syn::Token!(=)>()?;
20
21 let generics_group;
22 syn::bracketed!(generics_group in input);
23 let dyn_generics = generics_group
24 .parse_terminated(Ident::parse, syn::Token!(,))?
25 .into_iter()
26 .map(|ident| ident.to_string())
27 .collect();
28
29 Ok(Args {
30 trait_name,
31 dyn_generics,
32 })
33 }
34}
35
36pub fn traitify(args: TokenStream, input: TokenStream) -> TokenStream {
37 let args = syn::parse2::<Args>(args).expect("Parsing args");
38 let input = syn::parse2::<syn::ItemImpl>(input).unwrap();
39
40 let functions = input
41 .items
42 .iter()
43 .filter_map(|item| match item {
44 syn::ImplItem::Fn(function) => {
45 if matches!(function.vis, syn::Visibility::Public(_))
46 && function.sig.constness.is_none()
47 && function.sig.abi.is_none()
48 && function.sig.inputs.iter().all(|arg| match arg {
50 syn::FnArg::Receiver(_) => true,
51 syn::FnArg::Typed(t) => !args.dyn_generics.contains(&t.ty.to_token_stream().to_string()),
52 })
53 {
54 let mut function_signature = function.sig.clone();
55
56 if function.sig.receiver().is_none() {
58 let function_where_clause = function_signature.generics.make_where_clause();
59 function_where_clause
60 .predicates
61 .push(syn::parse_quote!(Self: Sized));
62 }
63 Some(function_signature)
64 } else {
65 None
66 }
67 }
68 _ => None,
69 })
70 .collect::<Vec<_>>();
71
72 let trait_generics = input
73 .generics
74 .params
75 .iter()
76 .filter(|param| match param {
77 syn::GenericParam::Type(t) => !args.dyn_generics.contains(&t.ident.to_string()),
78 syn::GenericParam::Const(t) => !args.dyn_generics.contains(&t.ident.to_string()),
79 _ => true,
80 })
81 .collect::<Vec<_>>();
82
83 let trait_definition = {
84 let trait_name = args.trait_name.clone();
85
86 let mut trait_where = input.generics.clone();
87 let trait_where =
88 trait_where
89 .make_where_clause()
90 .predicates
91 .iter()
92 .filter(|pred| match pred {
93 syn::WherePredicate::Lifetime(_) => true,
94 syn::WherePredicate::Type(t) => match &t.bounded_ty {
95 syn::Type::Path(p) => !args
96 .dyn_generics
97 .contains(&p.path.to_token_stream().to_string()),
98 _ => todo!(),
99 },
100 _ => true,
101 });
102
103 quote!(
104 pub trait #trait_name<#(#trait_generics,)*> where #(#trait_where,)* {
105 #(#functions;)*
106 }
107 )
108 };
109
110 let trait_impl = {
111 let mut trait_impl = input.clone();
113 trait_impl.items.clear();
114
115 trait_impl.attrs.clear();
116
117 let trait_generic_arguments = if trait_generics.is_empty() {
118 PathArguments::None
119 } else {
120 PathArguments::AngleBracketed(AngleBracketedGenericArguments {
121 colon2_token: None,
122 lt_token: syn::Token),
123 args: trait_generics
124 .iter()
125 .map(|param| match param {
126 syn::GenericParam::Lifetime(lt) => {
127 GenericArgument::Lifetime(lt.lifetime.clone())
128 }
129 syn::GenericParam::Type(t) => {
130 GenericArgument::Type(syn::Type::Path(TypePath {
131 qself: None,
132 path: PathSegment {
133 ident: t.ident.clone(),
134 arguments: Default::default(),
135 }
136 .into(),
137 }))
138 }
139 syn::GenericParam::Const(c) => {
140 GenericArgument::Type(syn::Type::Path(TypePath {
141 qself: None,
142 path: PathSegment {
143 ident: c.ident.clone(),
144 arguments: Default::default(),
145 }
146 .into(),
147 }))
148 }
149 })
150 .collect(),
151 gt_token: syn::Token),
152 })
153 };
154 trait_impl.trait_ = Some((
155 None,
156 PathSegment {
157 ident: args.trait_name,
158 arguments: trait_generic_arguments,
159 }
160 .into(),
161 syn::token::For(Span::call_site()),
162 ));
163
164 trait_impl.items = functions
165 .iter()
166 .map(|signature| {
167 let function_name = signature.ident.clone();
168 let function_params = signature.inputs.iter().map(|arg| match arg {
169 syn::FnArg::Receiver(_) => quote!(self),
170 syn::FnArg::Typed(t) => t.pat.to_token_stream(),
171 });
172
173 syn::ImplItem::Fn(ImplItemFn {
174 attrs: Vec::new(),
175 vis: syn::Visibility::Inherited,
176 defaultness: None,
177 sig: signature.clone(),
178 block: syn::parse_quote!({
179 Self::#function_name(#(#function_params,)*)
180 }),
181 })
182 })
183 .collect();
184
185 trait_impl
186 };
187
188 quote!(
189 #input
190
191 #trait_definition
192
193 #trait_impl
194 )
195}