zerror_derive/
lib.rs

1#![doc = include_str!("../README.md")]
2#![recursion_limit = "128"]
3
4extern crate proc_macro;
5#[macro_use]
6extern crate quote;
7extern crate syn;
8
9use proc_macro2::{Span, TokenStream};
10use quote::ToTokens;
11use syn::spanned::Spanned;
12use syn::{parse_macro_input, DeriveInput};
13
14use derive_util::EnumVisitor;
15
16/////////////////////////////////////// #[derive(ZerrorCore)] //////////////////////////////////////
17
18/// Derive ZerrorCore for an error.  This assumes a type has a core, and makes the with_* methods
19/// for it.
20#[proc_macro_derive(Z, attributes())]
21pub fn derive_command_line(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
22    let input = parse_macro_input!(input as DeriveInput);
23    // `ty_name` holds the type's identifier.
24    let ty_name = input.ident;
25
26    let data = match input.data {
27        syn::Data::Struct(_) => {
28            panic!("structs are not supported");
29        }
30        syn::Data::Enum(de) => de,
31        syn::Data::Union(_) => {
32            panic!("unions are not supported");
33        }
34    };
35
36    let mut cmv = CoreMethodsVisitor {};
37    let core_methods = cmv.visit_enum(&ty_name, &data);
38    let mut dmv = DisplayMethodVisitor {};
39    let display_method = dmv.visit_enum(&ty_name, &data);
40    let mut pemv = PartialEqMethodVisitor {};
41    let partial_eq_method = pemv.visit_enum(&ty_name, &data);
42    let gen = quote! {
43        impl ::zerror::Z for #ty_name {
44            type Error = Self;
45
46            fn long_form(&self) -> String {
47                format!("{}\n", self) + &self.core().long_form()
48            }
49
50            fn with_info<X: ::std::fmt::Debug>(mut self, name: &str, value: X) -> Self::Error {
51                self.core_mut().set_info(name, value);
52                self
53            }
54
55            fn with_lazy_info<F: FnOnce() -> String>(mut self, name: &str, value: F) -> Self::Error {
56                self.core_mut().set_lazy_info(name, value);
57                self
58            }
59        }
60
61        impl ::std::fmt::Debug for #ty_name {
62            fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> {
63                <Self as ::std::fmt::Display>::fmt(self, fmt)
64            }
65        }
66
67        #core_methods
68        #display_method
69        #partial_eq_method
70    };
71    gen.into()
72}
73
74//////////////////////////////////////// CommandLineVisitor ////////////////////////////////////////
75
76struct CoreMethodsVisitor {}
77
78impl EnumVisitor for CoreMethodsVisitor {
79    type Output = TokenStream;
80    type VariantOutput = TokenStream;
81
82    /// Combine the provided variants into one output.
83    fn combine_variants(
84        &mut self,
85        ty_name: &syn::Ident,
86        _de: &syn::DataEnum,
87        variants: &[Self::VariantOutput],
88    ) -> Self::Output {
89        let mut variant_sum = quote! {};
90        for v in variants {
91            let one = quote! {
92                #variant_sum
93                #v
94            };
95            variant_sum = one;
96        }
97        quote! {
98            impl #ty_name {
99                /// Get an immutable reference to this core.
100                pub fn core(&self) -> &::zerror_core::ErrorCore {
101                    match self {
102                        #variant_sum
103                    }
104                }
105
106                /// Get a mutable reference to this core.
107                pub fn core_mut(&mut self) -> &mut ::zerror_core::ErrorCore {
108                    match self {
109                        #variant_sum
110                    }
111                }
112            }
113        }
114    }
115
116    /// Visit an enum with [syn::FieldsNamed].
117    fn visit_enum_variant_named_field(
118        &mut self,
119        ty_name: &syn::Ident,
120        _de: &syn::DataEnum,
121        variant: &syn::Variant,
122        _fields: &syn::FieldsNamed,
123    ) -> Self::VariantOutput {
124        let variant = &variant.ident;
125        quote! {
126            #ty_name::#variant { core, .. } => core,
127        }
128    }
129}
130
131/////////////////////////////////////// DisplayMethodVisitor ///////////////////////////////////////
132
133struct DisplayMethodVisitor {}
134
135impl EnumVisitor for DisplayMethodVisitor {
136    type Output = TokenStream;
137    type VariantOutput = TokenStream;
138
139    /// Combine the provided variants into one output.
140    fn combine_variants(
141        &mut self,
142        ty_name: &syn::Ident,
143        _de: &syn::DataEnum,
144        variants: &[Self::VariantOutput],
145    ) -> Self::Output {
146        let mut variant_sum = quote! {};
147        for v in variants {
148            variant_sum = quote! {
149                #variant_sum
150                #v
151            };
152        }
153        quote! {
154            impl ::std::fmt::Display for #ty_name {
155                fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> {
156                    match self {
157                        #variant_sum
158                    }
159                }
160            }
161        }
162    }
163
164    /// Visit an enum with [syn::FieldsNamed].
165    fn visit_enum_variant_named_field(
166        &mut self,
167        ty_name: &syn::Ident,
168        _de: &syn::DataEnum,
169        variant: &syn::Variant,
170        fields: &syn::FieldsNamed,
171    ) -> Self::VariantOutput {
172        let mut fields_list_quote = quote! {};
173        let mut fields_fmt_quote = quote! {};
174        let mut first_field = true;
175        for field in fields.named.iter() {
176            if field.ident == Some(syn::Ident::new("core", field.span())) {
177                continue;
178            }
179            let field_ident = &field.ident;
180            if first_field {
181                fields_list_quote = quote! {
182                    #field_ident
183                };
184            } else {
185                fields_list_quote = quote! {
186                    #fields_list_quote, #field_ident
187                };
188            }
189            let field_str = field_ident.clone().into_token_stream().to_string();
190            fields_fmt_quote = quote! {
191                #fields_fmt_quote
192                .field(#field_str, #field_ident)
193            };
194            first_field = false;
195        }
196        let variant = &variant.ident;
197        let variant_str = variant.clone().into_token_stream().to_string();
198        quote! {
199            #ty_name::#variant { core: _, #fields_list_quote } => {
200                fmt.debug_struct(#variant_str)
201                #fields_fmt_quote
202                .finish()
203            }
204        }
205    }
206}
207
208////////////////////////////////////// PartialEqMethodVisitor //////////////////////////////////////
209
210struct PartialEqMethodVisitor {}
211
212impl EnumVisitor for PartialEqMethodVisitor {
213    type Output = TokenStream;
214    type VariantOutput = TokenStream;
215
216    /// Combine the provided variants into one output.
217    fn combine_variants(
218        &mut self,
219        ty_name: &syn::Ident,
220        _de: &syn::DataEnum,
221        variants: &[Self::VariantOutput],
222    ) -> Self::Output {
223        let mut variant_sum = quote! {};
224        for v in variants {
225            variant_sum = quote! {
226                #variant_sum
227                #v
228            };
229        }
230        quote! {
231            impl Eq for #ty_name {}
232
233            impl PartialEq for #ty_name {
234                fn eq(&self, other: &#ty_name) -> bool {
235                    match (self, other) {
236                        #variant_sum
237                        (_, _) => { false },
238                    }
239                }
240            }
241        }
242    }
243
244    /// Visit an enum with [syn::FieldsNamed].
245    fn visit_enum_variant_named_field(
246        &mut self,
247        ty_name: &syn::Ident,
248        _de: &syn::DataEnum,
249        variant: &syn::Variant,
250        fields: &syn::FieldsNamed,
251    ) -> Self::VariantOutput {
252        let mut fields_list_lhs = quote! {};
253        let mut fields_list_rhs = quote! {};
254        let mut fields_compare = quote! {};
255        let mut num_fields = 0;
256        for field in fields.named.iter() {
257            if field.ident == Some(syn::Ident::new("core", field.span())) {
258                continue;
259            }
260            let field_ident = &field.ident;
261            let field_lhs =
262                syn::Ident::new(&format!("zerror_{}_lhs", num_fields), Span::call_site());
263            let field_rhs =
264                syn::Ident::new(&format!("zerror_{}_rhs", num_fields), Span::call_site());
265            if num_fields == 0 {
266                fields_list_lhs = quote! {
267                    #field_ident: #field_lhs
268                };
269                fields_list_rhs = quote! {
270                    #field_ident: #field_rhs
271                };
272                fields_compare = quote! {
273                    #field_lhs == #field_rhs
274                }
275            } else {
276                fields_list_lhs = quote! {
277                    #fields_list_lhs, #field_ident: #field_lhs
278                };
279                fields_list_rhs = quote! {
280                    #fields_list_rhs, #field_ident: #field_rhs
281                };
282                fields_compare = quote! {
283                    #fields_compare && #field_lhs == #field_rhs
284                }
285            }
286            num_fields += 1;
287        }
288        if num_fields == 0 {
289            fields_compare = quote! { true }
290        }
291        let variant_ident = &variant.ident;
292        quote! {
293            (#ty_name::#variant_ident { core: _, #fields_list_lhs },
294             #ty_name::#variant_ident { core: _, #fields_list_rhs }) => {
295                #fields_compare
296            },
297        }
298    }
299}