therror_impl/
expand.rs

1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::generics::InferredBounds;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote, quote_spanned, ToTokens};
6use std::collections::BTreeSet as Set;
7use syn::spanned::Spanned;
8use syn::{
9    Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility,
10};
11
12pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
13    let input = Input::from_syn(node)?;
14    input.validate()?;
15    Ok(match input {
16        Input::Struct(input) => impl_struct(input),
17        Input::Enum(input) => impl_enum(input),
18    })
19}
20
21fn impl_struct(input: Struct) -> TokenStream {
22    let ty = &input.ident;
23    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
24    let mut error_inferred_bounds = InferredBounds::new();
25
26    let source_body = if input.attrs.transparent.is_some() {
27        let only_field = &input.fields[0];
28        if only_field.contains_generic {
29            error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
30        }
31        let member = &only_field.member;
32        Some(quote! {
33            std::error::Error::source(self.#member.as_dyn_error())
34        })
35    } else if let Some(source_field) = input.source_field() {
36        let source = &source_field.member;
37        if source_field.contains_generic {
38            let ty = unoptional_type(source_field.ty);
39            error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
40        }
41        let asref = if type_is_option(source_field.ty) {
42            Some(quote_spanned!(source.span()=> .as_ref()?))
43        } else {
44            None
45        };
46        let dyn_error = quote_spanned!(source.span()=> self.#source #asref.as_dyn_error());
47        Some(quote! {
48            std::option::Option::Some(#dyn_error)
49        })
50    } else {
51        None
52    };
53    let source_method = source_body.map(|body| {
54        quote! {
55            fn source(&self) -> std::option::Option<&(dyn std::error::Error + 'static)> {
56                use therror::__private::AsDynError;
57                #body
58            }
59        }
60    });
61
62    let provide_method = input.backtrace_field().map(|backtrace_field| {
63        let request = quote!(request);
64        let backtrace = &backtrace_field.member;
65        let body = if let Some(source_field) = input.source_field() {
66            let source = &source_field.member;
67            let source_provide = if type_is_option(source_field.ty) {
68                quote_spanned! {source.span()=>
69                    if let std::option::Option::Some(source) = &self.#source {
70                        source.thiserror_provide(#request);
71                    }
72                }
73            } else {
74                quote_spanned! {source.span()=>
75                    self.#source.thiserror_provide(#request);
76                }
77            };
78            let self_provide = if source == backtrace {
79                None
80            } else if type_is_option(backtrace_field.ty) {
81                Some(quote! {
82                    if let std::option::Option::Some(backtrace) = &self.#backtrace {
83                        #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
84                    }
85                })
86            } else {
87                Some(quote! {
88                    #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
89                })
90            };
91            quote! {
92                use therror::__private::ThiserrorProvide;
93                #source_provide
94                #self_provide
95            }
96        } else if type_is_option(backtrace_field.ty) {
97            quote! {
98                if let std::option::Option::Some(backtrace) = &self.#backtrace {
99                    #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
100                }
101            }
102        } else {
103            quote! {
104                #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
105            }
106        };
107        quote! {
108            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
109                #body
110            }
111        }
112    });
113
114    let mut display_implied_bounds = Set::new();
115    let display_body = if input.attrs.transparent.is_some() {
116        let only_field = &input.fields[0].member;
117        display_implied_bounds.insert((0, Trait::Display));
118        Some(quote! {
119            std::fmt::Display::fmt(&self.#only_field, __formatter)
120        })
121    } else if let Some(display) = &input.attrs.display {
122        display_implied_bounds = display.implied_bounds.clone();
123        let use_as_display = use_as_display(display.has_bonus_display);
124        let pat = fields_pat(&input.fields);
125        Some(quote! {
126            #use_as_display
127            #[allow(unused_variables, deprecated)]
128            let Self #pat = self;
129            #display
130        })
131    } else {
132        None
133    };
134    let display_impl = display_body.map(|body| {
135        let mut display_inferred_bounds = InferredBounds::new();
136        for (field, bound) in display_implied_bounds {
137            let field = &input.fields[field];
138            if field.contains_generic {
139                display_inferred_bounds.insert(field.ty, bound);
140            }
141        }
142        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
143        quote! {
144            #[allow(unused_qualifications)]
145            impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
146                #[allow(clippy::used_underscore_binding)]
147                fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
148                    #body
149                }
150            }
151        }
152    });
153
154    let from_impl = input.from_field().map(|from_field| {
155        let backtrace_field = input.distinct_backtrace_field();
156        let from = unoptional_type(from_field.ty);
157        let body = from_initializer(from_field, backtrace_field);
158        quote! {
159            #[allow(unused_qualifications)]
160            impl #impl_generics std::convert::From<#from> for #ty #ty_generics #where_clause {
161                #[allow(deprecated)]
162                fn from(source: #from) -> Self {
163                    #ty #body
164                }
165            }
166        }
167    });
168
169    let error_trait = spanned_error_trait(input.original);
170    if input.generics.type_params().next().is_some() {
171        let self_token = <Token![Self]>::default();
172        error_inferred_bounds.insert(self_token, Trait::Debug);
173        error_inferred_bounds.insert(self_token, Trait::Display);
174    }
175    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
176
177    quote! {
178        #[allow(unused_qualifications)]
179        impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
180            #source_method
181            #provide_method
182        }
183        #display_impl
184        #from_impl
185    }
186}
187
188fn impl_enum(input: Enum) -> TokenStream {
189    let ty = &input.ident;
190    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
191    let mut error_inferred_bounds = InferredBounds::new();
192
193    let source_method = if input.has_source() {
194        let arms = input.variants.iter().map(|variant| {
195            let ident = &variant.ident;
196            if variant.attrs.transparent.is_some() {
197                let only_field = &variant.fields[0];
198                if only_field.contains_generic {
199                    error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
200                }
201                let member = &only_field.member;
202                let source = quote!(std::error::Error::source(transparent.as_dyn_error()));
203                quote! {
204                    #ty::#ident {#member: transparent} => #source,
205                }
206            } else if let Some(source_field) = variant.source_field() {
207                let source = &source_field.member;
208                if source_field.contains_generic {
209                    let ty = unoptional_type(source_field.ty);
210                    error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
211                }
212                let asref = if type_is_option(source_field.ty) {
213                    Some(quote_spanned!(source.span()=> .as_ref()?))
214                } else {
215                    None
216                };
217                let varsource = quote!(source);
218                let dyn_error = quote_spanned!(source.span()=> #varsource #asref.as_dyn_error());
219                quote! {
220                    #ty::#ident {#source: #varsource, ..} => std::option::Option::Some(#dyn_error),
221                }
222            } else {
223                quote! {
224                    #ty::#ident {..} => std::option::Option::None,
225                }
226            }
227        });
228        Some(quote! {
229            fn source(&self) -> std::option::Option<&(dyn std::error::Error + 'static)> {
230                use therror::__private::AsDynError;
231                #[allow(deprecated)]
232                match self {
233                    #(#arms)*
234                }
235            }
236        })
237    } else {
238        None
239    };
240
241    let provide_method = if input.has_backtrace() {
242        let request = quote!(request);
243        let arms = input.variants.iter().map(|variant| {
244            let ident = &variant.ident;
245            match (variant.backtrace_field(), variant.source_field()) {
246                (Some(backtrace_field), Some(source_field))
247                    if backtrace_field.attrs.backtrace.is_none() =>
248                {
249                    let backtrace = &backtrace_field.member;
250                    let source = &source_field.member;
251                    let varsource = quote!(source);
252                    let source_provide = if type_is_option(source_field.ty) {
253                        quote_spanned! {source.span()=>
254                            if let std::option::Option::Some(source) = #varsource {
255                                source.thiserror_provide(#request);
256                            }
257                        }
258                    } else {
259                        quote_spanned! {source.span()=>
260                            #varsource.thiserror_provide(#request);
261                        }
262                    };
263                    let self_provide = if type_is_option(backtrace_field.ty) {
264                        quote! {
265                            if let std::option::Option::Some(backtrace) = backtrace {
266                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
267                            }
268                        }
269                    } else {
270                        quote! {
271                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
272                        }
273                    };
274                    quote! {
275                        #ty::#ident {
276                            #backtrace: backtrace,
277                            #source: #varsource,
278                            ..
279                        } => {
280                            use therror::__private::ThiserrorProvide;
281                            #source_provide
282                            #self_provide
283                        }
284                    }
285                }
286                (Some(backtrace_field), Some(source_field))
287                    if backtrace_field.member == source_field.member =>
288                {
289                    let backtrace = &backtrace_field.member;
290                    let varsource = quote!(source);
291                    let source_provide = if type_is_option(source_field.ty) {
292                        quote_spanned! {backtrace.span()=>
293                            if let std::option::Option::Some(source) = #varsource {
294                                source.thiserror_provide(#request);
295                            }
296                        }
297                    } else {
298                        quote_spanned! {backtrace.span()=>
299                            #varsource.thiserror_provide(#request);
300                        }
301                    };
302                    quote! {
303                        #ty::#ident {#backtrace: #varsource, ..} => {
304                            use therror::__private::ThiserrorProvide;
305                            #source_provide
306                        }
307                    }
308                }
309                (Some(backtrace_field), _) => {
310                    let backtrace = &backtrace_field.member;
311                    let body = if type_is_option(backtrace_field.ty) {
312                        quote! {
313                            if let std::option::Option::Some(backtrace) = backtrace {
314                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
315                            }
316                        }
317                    } else {
318                        quote! {
319                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
320                        }
321                    };
322                    quote! {
323                        #ty::#ident {#backtrace: backtrace, ..} => {
324                            #body
325                        }
326                    }
327                }
328                (None, _) => quote! {
329                    #ty::#ident {..} => {}
330                },
331            }
332        });
333        Some(quote! {
334            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
335                #[allow(deprecated)]
336                match self {
337                    #(#arms)*
338                }
339            }
340        })
341    } else {
342        None
343    };
344
345    let display_impl = if input.has_display() {
346        let mut display_inferred_bounds = InferredBounds::new();
347        let has_bonus_display = input.variants.iter().any(|v| {
348            v.attrs
349                .display
350                .as_ref()
351                .map_or(false, |display| display.has_bonus_display)
352        });
353        let use_as_display = use_as_display(has_bonus_display);
354        let void_deref = if input.variants.is_empty() {
355            Some(quote!(*))
356        } else {
357            None
358        };
359        let arms = input.variants.iter().map(|variant| {
360            let mut display_implied_bounds = Set::new();
361            let display = match &variant.attrs.display {
362                Some(display) => {
363                    display_implied_bounds = display.implied_bounds.clone();
364                    display.to_token_stream()
365                }
366                None => {
367                    let only_field = match &variant.fields[0].member {
368                        Member::Named(ident) => ident.clone(),
369                        Member::Unnamed(index) => format_ident!("_{}", index),
370                    };
371                    display_implied_bounds.insert((0, Trait::Display));
372                    quote!(std::fmt::Display::fmt(#only_field, __formatter))
373                }
374            };
375            for (field, bound) in display_implied_bounds {
376                let field = &variant.fields[field];
377                if field.contains_generic {
378                    display_inferred_bounds.insert(field.ty, bound);
379                }
380            }
381            let ident = &variant.ident;
382            let pat = fields_pat(&variant.fields);
383            quote! {
384                #ty::#ident #pat => #display
385            }
386        });
387        let arms = arms.collect::<Vec<_>>();
388        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
389        Some(quote! {
390            #[allow(unused_qualifications)]
391            impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
392                fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
393                    #use_as_display
394                    #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
395                    match #void_deref self {
396                        #(#arms,)*
397                    }
398                }
399            }
400        })
401    } else {
402        None
403    };
404
405    let from_impls = input.variants.iter().filter_map(|variant| {
406        let from_field = variant.from_field()?;
407        let backtrace_field = variant.distinct_backtrace_field();
408        let variant = &variant.ident;
409        let from = unoptional_type(from_field.ty);
410        let body = from_initializer(from_field, backtrace_field);
411        Some(quote! {
412            #[allow(unused_qualifications)]
413            impl #impl_generics std::convert::From<#from> for #ty #ty_generics #where_clause {
414                #[allow(deprecated)]
415                fn from(source: #from) -> Self {
416                    #ty::#variant #body
417                }
418            }
419        })
420    });
421
422    let error_trait = spanned_error_trait(input.original);
423    if input.generics.type_params().next().is_some() {
424        let self_token = <Token![Self]>::default();
425        error_inferred_bounds.insert(self_token, Trait::Debug);
426        error_inferred_bounds.insert(self_token, Trait::Display);
427    }
428    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
429
430    quote! {
431        #[allow(unused_qualifications)]
432        impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
433            #source_method
434            #provide_method
435        }
436        #display_impl
437        #(#from_impls)*
438    }
439}
440
441fn fields_pat(fields: &[Field]) -> TokenStream {
442    let mut members = fields.iter().map(|field| &field.member).peekable();
443    match members.peek() {
444        Some(Member::Named(_)) => quote!({ #(#members),* }),
445        Some(Member::Unnamed(_)) => {
446            let vars = members.map(|member| match member {
447                Member::Unnamed(member) => format_ident!("_{}", member),
448                Member::Named(_) => unreachable!(),
449            });
450            quote!((#(#vars),*))
451        }
452        None => quote!({}),
453    }
454}
455
456fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
457    if needs_as_display {
458        Some(quote! {
459            use therror::__private::AsDisplay as _;
460        })
461    } else {
462        None
463    }
464}
465
466fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
467    let from_member = &from_field.member;
468    let some_source = if type_is_option(from_field.ty) {
469        quote!(std::option::Option::Some(source))
470    } else {
471        quote!(source)
472    };
473    let backtrace = backtrace_field.map(|backtrace_field| {
474        let backtrace_member = &backtrace_field.member;
475        if type_is_option(backtrace_field.ty) {
476            quote! {
477                #backtrace_member: std::option::Option::Some(std::backtrace::Backtrace::capture()),
478            }
479        } else {
480            quote! {
481                #backtrace_member: std::convert::From::from(std::backtrace::Backtrace::capture()),
482            }
483        }
484    });
485    quote!({
486        #from_member: #some_source,
487        #backtrace
488    })
489}
490
491fn type_is_option(ty: &Type) -> bool {
492    type_parameter_of_option(ty).is_some()
493}
494
495fn unoptional_type(ty: &Type) -> TokenStream {
496    let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
497    quote!(#unoptional)
498}
499
500fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
501    let path = match ty {
502        Type::Path(ty) => &ty.path,
503        _ => return None,
504    };
505
506    let last = path.segments.last().unwrap();
507    if last.ident != "Option" {
508        return None;
509    }
510
511    let bracketed = match &last.arguments {
512        PathArguments::AngleBracketed(bracketed) => bracketed,
513        _ => return None,
514    };
515
516    if bracketed.args.len() != 1 {
517        return None;
518    }
519
520    match &bracketed.args[0] {
521        GenericArgument::Type(arg) => Some(arg),
522        _ => None,
523    }
524}
525
526fn spanned_error_trait(input: &DeriveInput) -> TokenStream {
527    let vis_span = match &input.vis {
528        Visibility::Public(vis) => Some(vis.span),
529        Visibility::Restricted(vis) => Some(vis.pub_token.span),
530        Visibility::Inherited => None,
531    };
532    let data_span = match &input.data {
533        Data::Struct(data) => data.struct_token.span,
534        Data::Enum(data) => data.enum_token.span,
535        Data::Union(data) => data.union_token.span,
536    };
537    let first_span = vis_span.unwrap_or(data_span);
538    let last_span = input.ident.span();
539    let path = quote_spanned!(first_span=> std::error::);
540    let error = quote_spanned!(last_span=> Error);
541    quote!(#path #error)
542}