typechain_macros/
lib.rs

1#![deny(missing_docs)]
2
3//! # `typechain-macros`
4//! 
5//! This crate contains macros for working with
6//! related type functionality. Using dynamic
7//! dispatch, it is possible to create a chain
8//! of traits that can be used to access the
9//! fields of a struct.
10//! 
11//! The macros in this crate use user-defined traits
12//! and structs to generate an easy-to-use chain. See
13//! the [`typechain`](https://crates.io/crates/typechain)
14//! crate for more information.
15
16extern crate proc_macro;
17
18use std::collections::{HashMap, hash_map::Entry};
19
20use parse::{ChainlinkField, ChainFieldData};
21use proc_macro::TokenStream;
22use proc_macro2::Span;
23use proc_macro_error::{proc_macro_error, emit_error, abort_if_dirty};
24use quote::{quote, ToTokens, quote_spanned};
25use syn::{Path, spanned::Spanned, Visibility};
26
27mod parse;
28
29
30/// Create a chainlink trait.
31/// 
32/// The trait will be renamed to `{{name}}Chainlink`,
33/// and the original name will be used for the
34/// associated type (dyn `{{name}}Chainlink`).
35#[proc_macro_error]
36#[proc_macro]
37pub fn chainlink(input: TokenStream) -> TokenStream {
38    let ast = syn::parse_macro_input!(input as parse::Chainlink);
39
40    let name = ast.name.clone();
41    let generics = ast.generics.clone();
42    let generics = quote! {
43        < #( #generics ),* >
44    };
45
46    let fields = ast.fields.iter().map(|f| {
47        match f {
48            ChainlinkField::Const(name, ty) => {
49                quote! {
50                    fn #name(&self) -> & #ty;
51                }
52            },
53            ChainlinkField::Mut(name, ty) => {
54                quote! {
55                    fn #name(&mut self) -> &mut #ty;
56                }
57            },
58            ChainlinkField::Static(name, ty) => {
59                quote! {
60                    fn #name(&self) -> #ty;
61                }
62            },
63            ChainlinkField::Fn(func) => {
64                let name = func.sig.ident.clone();
65                let generics = func.sig.generics.clone();
66                let inputs = func.sig.inputs.clone();
67                let output = func.sig.output.clone();
68                let where_clause = func.sig.generics.where_clause.clone();
69
70                quote! {
71                    #generics
72                    fn #name(#inputs) #output #where_clause;
73                }
74            }
75        }
76    });
77
78    let trait_name = syn::Ident::new(&format!("{}Chainlink", name), Span::call_site());
79
80    let expanded = quote! {
81        #[allow(missing_docs)]
82        pub trait #trait_name #generics {
83            #(#fields)*
84        }
85
86        #[allow(missing_docs)]
87        pub type #name #generics = dyn #trait_name #generics;
88    };
89
90    expanded.into()
91}
92
93/// Create a chain.
94#[proc_macro_error]
95#[proc_macro]
96pub fn chain(input: TokenStream) -> TokenStream {
97    let ast = syn::parse_macro_input!(input as parse::Chain);
98
99    let name = ast.name.clone();
100    let generics = ast.generics.clone();
101
102    let generics = quote! {
103        < #( #generics ),* >
104    };
105
106    let fields = ast.fields.iter().filter_map(|f| {
107        match f.field.clone() {
108            ChainFieldData::Const(vis, name, ty) => {
109                Some(quote! {
110                    #vis #name: #ty
111                })
112            },
113            ChainFieldData::Mut(name, ty) => {
114                Some(quote! {
115                    #name: #ty
116                })
117            },
118            ChainFieldData::Static(..) => {
119                None
120            }
121        }
122    });
123
124    let trait_funcs: HashMap<Path, Vec<proc_macro2::TokenStream>> = ast.fields.iter().fold(HashMap::new(), |mut map, f| {
125        let parents = f.parents.clone();
126
127        for parent in parents {
128            if let Entry::Vacant(_) = map.entry(parent.clone()) {
129                map.insert(parent.clone(), vec![]);
130            }
131
132            let tokens = match f.field.clone() {
133                ChainFieldData::Const(vis, name, ty) => {
134                    if !matches!(vis, Visibility::Inherited) {
135                        emit_error!(vis, "Chainlink fields must be of inherited visibility");
136                    }
137
138                    quote! {
139                        fn #name(&self) -> & #ty {
140                            &self.#name
141                        }
142                    }
143                },
144                ChainFieldData::Mut(name, ty) => {
145                    quote! {
146                        fn #name(&mut self) -> &mut #ty {
147                            &mut self.#name
148                        }
149                    }
150                },
151                ChainFieldData::Static(name, ty, expr) => {
152                    quote_spanned! { expr.span() =>
153                        #[allow(clippy::needless_borrow)]
154                        fn #name(&self) -> #ty {
155                            #expr
156                        }
157                    }
158                }
159            };
160
161            abort_if_dirty();
162
163            map.get_mut(&parent).unwrap().push(tokens);
164        }
165
166        map
167    });
168
169    let trait_impls = trait_funcs.iter().map(|(trait_, tokens)| {
170        let mut trait_ = trait_.clone();
171        trait_.segments.last_mut().unwrap().ident = syn::Ident::new(&format!("{}Chainlink", trait_.segments.last().unwrap().ident), trait_.span());
172
173        let tokens = tokens.clone();
174
175        quote! {
176            impl #generics #trait_ for #name #generics {
177                #(#tokens)*
178            }
179        }
180    });
181
182    let expanded = quote! {
183        pub struct #name #generics {
184            #(#fields),*
185        }
186
187        #(#trait_impls)*
188    };
189
190    expanded.into()
191}
192
193/// Import chainlink traits.
194/// 
195/// This is a helper macro for importing chainlink
196/// traits and their associated types.
197#[proc_macro_error]
198#[proc_macro]
199pub fn use_chains(input: TokenStream) -> TokenStream {
200    let paths = syn::parse_macro_input!(input as parse::UseChains);
201
202    let paths = paths.0.iter().map(|p| {
203        let mut path = p.clone();
204        path.segments.last_mut().unwrap().ident = syn::Ident::new(&format!("{}Chainlink", path.segments.last().unwrap().ident), p.span());
205
206        quote! {
207            #[allow(unused_imports)]
208            use #path;
209            #[allow(unused_imports)]
210            use #p;
211        }
212    }).collect::<Vec<_>>();
213
214    let expanded = quote! {
215        #(#paths)*
216    };
217
218    expanded.into()
219}
220
221/// Manually implement chains.
222/// 
223/// This macro will generate chain implementations
224/// manually. This is useful when you want to implement
225/// chains for a type that you don't own.
226#[proc_macro_error]
227#[proc_macro]
228pub fn impl_chains(input: TokenStream) -> TokenStream {
229    let ast = syn::parse_macro_input!(input as parse::ImplChains);
230
231    let ty = ast.ty.clone();
232    let where_clause = ast.where_clause.clone();
233
234    let mut impls: HashMap<Path, Vec<proc_macro2::TokenStream>> = HashMap::new();
235
236    for impl_ in ast.impls {
237        let tokens = impl_.func.to_token_stream();
238
239        if let Entry::Vacant(_) = impls.entry(impl_.chain.clone()) {
240            impls.insert(impl_.chain.clone(), vec![]);
241        }
242
243        impls.get_mut(&impl_.chain).unwrap().push(tokens);
244    }
245
246    let impls = impls.iter().map(|(trait_, tokens)| {
247        let mut trait_ = trait_.clone();
248        trait_.segments.last_mut().unwrap().ident = syn::Ident::new(&format!("{}Chainlink", trait_.segments.last().unwrap().ident), trait_.span());
249
250        let tokens = tokens.clone();
251
252        quote! {
253            impl #where_clause #trait_ for #ty {
254                #(#tokens)*
255            }
256        }
257    }).collect::<Vec<_>>();
258
259    let expanded = quote! {
260        #(#impls)*
261    };
262
263    expanded.into()
264}