wherror_impl/
expand.rs

1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::fallback;
4use crate::generics::InferredBounds;
5use crate::unraw::MemberUnraw;
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::collections::BTreeSet as Set;
9use syn::{DeriveInput, GenericArgument, PathArguments, Result, Token, Type};
10
11pub fn derive(input: &DeriveInput) -> TokenStream {
12    match try_expand(input) {
13        Ok(expanded) => expanded,
14        // If there are invalid attributes in the input, expand to an Error impl
15        // anyway to minimize spurious secondary errors in other code that uses
16        // this type as an Error.
17        Err(error) => fallback::expand(input, error),
18    }
19}
20
21fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
22    let input = Input::from_syn(input)?;
23    input.validate()?;
24    Ok(match input {
25        Input::Struct(input) => impl_struct(input),
26        Input::Enum(input) => impl_enum(input),
27    })
28}
29
30fn impl_struct(input: Struct) -> TokenStream {
31    let ty = call_site_ident(&input.ident);
32    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
33    let mut error_inferred_bounds = InferredBounds::new();
34
35    let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
36        let only_field = &input.fields[0];
37        if only_field.contains_generic {
38            error_inferred_bounds.insert(only_field.ty, quote!(::wherror::__private::Error));
39        }
40        let member = &only_field.member;
41        Some(quote_spanned! {transparent_attr.span=>
42            ::wherror::__private::Error::source(self.#member.as_dyn_error())
43        })
44    } else if let Some(source_field) = input.source_field() {
45        let source = &source_field.member;
46        if source_field.contains_generic {
47            let ty = unoptional_type(source_field.ty);
48            error_inferred_bounds.insert(ty, quote!(::wherror::__private::Error + 'static));
49        }
50        let asref = if type_is_option(source_field.ty) {
51            Some(quote_spanned!(source.span()=> .as_ref()?))
52        } else {
53            None
54        };
55        let dyn_error = quote_spanned! {source_field.source_span()=>
56            self.#source #asref.as_dyn_error()
57        };
58        Some(quote! {
59            ::core::option::Option::Some(#dyn_error)
60        })
61    } else {
62        None
63    };
64    let source_method = source_body.map(|body| {
65        quote! {
66            fn source(&self) -> ::core::option::Option<&(dyn ::wherror::__private::Error + 'static)> {
67                use ::wherror::__private::AsDynError as _;
68                #body
69            }
70        }
71    });
72
73    let provide_method = input.backtrace_field().map(|backtrace_field| {
74        let request = quote!(request);
75        let backtrace = &backtrace_field.member;
76        let body = if let Some(source_field) = input.source_field() {
77            let source = &source_field.member;
78            let source_provide = if type_is_option(source_field.ty) {
79                quote_spanned! {source.span()=>
80                    if let ::core::option::Option::Some(source) = &self.#source {
81                        source.thiserror_provide(#request);
82                    }
83                }
84            } else {
85                quote_spanned! {source.span()=>
86                    self.#source.thiserror_provide(#request);
87                }
88            };
89            let self_provide = if source == backtrace {
90                None
91            } else if type_is_option(backtrace_field.ty) {
92                Some(quote! {
93                    if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
94                        #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
95                    }
96                })
97            } else {
98                Some(quote! {
99                    #request.provide_ref::<::wherror::__private::Backtrace>(&self.#backtrace);
100                })
101            };
102            quote! {
103                use ::wherror::__private::ThiserrorProvide as _;
104                #source_provide
105                #self_provide
106            }
107        } else if type_is_option(backtrace_field.ty) {
108            quote! {
109                if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
110                    #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
111                }
112            }
113        } else {
114            quote! {
115                #request.provide_ref::<::wherror::__private::Backtrace>(&self.#backtrace);
116            }
117        };
118        quote! {
119            fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
120                #body
121            }
122        }
123    });
124
125    let mut display_implied_bounds = Set::new();
126    let display_body = if input.attrs.transparent.is_some() {
127        let only_field = &input.fields[0].member;
128        display_implied_bounds.insert((0, Trait::Display));
129        Some(quote! {
130            ::core::fmt::Display::fmt(&self.#only_field, __formatter)
131        })
132    } else if let Some(display) = &input.attrs.display {
133        display_implied_bounds.clone_from(&display.implied_bounds);
134        let use_as_display = use_as_display(display.has_bonus_display);
135        let pat = fields_pat(&input.fields);
136        Some(quote! {
137            #use_as_display
138            #[allow(unused_variables, deprecated)]
139            let Self #pat = self;
140            #display
141        })
142    } else {
143        None
144    };
145    let display_impl = display_body.map(|body| {
146        let mut display_inferred_bounds = InferredBounds::new();
147        for (field, bound) in display_implied_bounds {
148            let field = &input.fields[field];
149            if field.contains_generic {
150                display_inferred_bounds.insert(field.ty, bound);
151            }
152        }
153        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
154        quote! {
155            #[allow(unused_qualifications)]
156            #[automatically_derived]
157            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
158                #[allow(clippy::used_underscore_binding)]
159                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
160                    #body
161                }
162            }
163        }
164    });
165
166    let from_impl = input.from_field().map(|from_field| {
167        let span = from_field.attrs.from.unwrap().span;
168        let backtrace_field = input.distinct_backtrace_field();
169        let from = unoptional_type(from_field.ty);
170        let track_caller = input.location_field().map(|_| quote!(#[track_caller]));
171        let source_var = Ident::new("source", span);
172        let body = from_initializer(
173            from_field,
174            backtrace_field,
175            &source_var,
176            input.location_field(),
177        );
178        let from_function = quote! {
179            #track_caller
180            fn from(#source_var: #from) -> Self {
181                #ty #body
182            }
183        };
184        let from_impl = quote_spanned! {span=>
185            #[automatically_derived]
186            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
187                #from_function
188            }
189        };
190        Some(quote! {
191            #[allow(
192                deprecated,
193                unused_qualifications,
194                clippy::elidable_lifetime_names,
195                clippy::needless_lifetimes,
196            )]
197            #from_impl
198        })
199    });
200
201    let location_impl = input.location_field().map(|location_field| {
202        let location = &location_field.member;
203        let body = if type_is_option(location_field.ty) {
204            quote! {
205                self.#location
206            }
207        } else {
208            quote! {
209                Some(self.#location)
210            }
211        };
212        quote! {
213            #[allow(unused_qualifications)]
214            #[automatically_derived]
215            impl #impl_generics #ty #ty_generics #where_clause {
216                pub fn location(&self) -> Option<&'static ::core::panic::Location<'static>> {
217                    #body
218                }
219            }
220        }
221    });
222
223    if input.generics.type_params().next().is_some() {
224        let self_token = <Token![Self]>::default();
225        error_inferred_bounds.insert(self_token, Trait::Debug);
226        error_inferred_bounds.insert(self_token, Trait::Display);
227    }
228    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
229
230    quote! {
231        #[allow(unused_qualifications)]
232        #[automatically_derived]
233        impl #impl_generics ::wherror::__private::Error for #ty #ty_generics #error_where_clause {
234            #source_method
235            #provide_method
236        }
237        #display_impl
238        #from_impl
239        #location_impl
240    }
241}
242
243fn impl_enum(input: Enum) -> TokenStream {
244    let ty = call_site_ident(&input.ident);
245    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
246    let mut error_inferred_bounds = InferredBounds::new();
247
248    let source_method = if input.has_source() {
249        let arms = input.variants.iter().map(|variant| {
250            let ident = &variant.ident;
251            if let Some(transparent_attr) = &variant.attrs.transparent {
252                let only_field = &variant.fields[0];
253                if only_field.contains_generic {
254                    error_inferred_bounds.insert(only_field.ty, quote!(::wherror::__private::Error));
255                }
256                let member = &only_field.member;
257                let source = quote_spanned! {transparent_attr.span=>
258                    ::wherror::__private::Error::source(transparent.as_dyn_error())
259                };
260                quote! {
261                    #ty::#ident {#member: transparent} => #source,
262                }
263            } else if let Some(source_field) = variant.source_field() {
264                let source = &source_field.member;
265                if source_field.contains_generic {
266                    let ty = unoptional_type(source_field.ty);
267                    error_inferred_bounds.insert(ty, quote!(::wherror::__private::Error + 'static));
268                }
269                let asref = if type_is_option(source_field.ty) {
270                    Some(quote_spanned!(source.span()=> .as_ref()?))
271                } else {
272                    None
273                };
274                let varsource = quote!(source);
275                let dyn_error = quote_spanned! {source_field.source_span()=>
276                    #varsource #asref.as_dyn_error()
277                };
278                quote! {
279                    #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
280                }
281            } else {
282                quote! {
283                    #ty::#ident {..} => ::core::option::Option::None,
284                }
285            }
286        });
287        Some(quote! {
288            fn source(&self) -> ::core::option::Option<&(dyn ::wherror::__private::Error + 'static)> {
289                use ::wherror::__private::AsDynError as _;
290                #[allow(deprecated)]
291                match self {
292                    #(#arms)*
293                }
294            }
295        })
296    } else {
297        None
298    };
299
300    let provide_method = if input.has_backtrace() {
301        let request = quote!(request);
302        let arms = input.variants.iter().map(|variant| {
303            let ident = &variant.ident;
304            match (variant.backtrace_field(), variant.source_field()) {
305                (Some(backtrace_field), Some(source_field))
306                    if backtrace_field.attrs.backtrace.is_none() =>
307                {
308                    let backtrace = &backtrace_field.member;
309                    let source = &source_field.member;
310                    let varsource = quote!(source);
311                    let source_provide = if type_is_option(source_field.ty) {
312                        quote_spanned! {source.span()=>
313                            if let ::core::option::Option::Some(source) = #varsource {
314                                source.thiserror_provide(#request);
315                            }
316                        }
317                    } else {
318                        quote_spanned! {source.span()=>
319                            #varsource.thiserror_provide(#request);
320                        }
321                    };
322                    let self_provide = if type_is_option(backtrace_field.ty) {
323                        quote! {
324                            if let ::core::option::Option::Some(backtrace) = backtrace {
325                                #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
326                            }
327                        }
328                    } else {
329                        quote! {
330                            #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
331                        }
332                    };
333                    quote! {
334                        #ty::#ident {
335                            #backtrace: backtrace,
336                            #source: #varsource,
337                            ..
338                        } => {
339                            use ::wherror::__private::ThiserrorProvide as _;
340                            #source_provide
341                            #self_provide
342                        }
343                    }
344                }
345                (Some(backtrace_field), Some(source_field))
346                    if backtrace_field.member == source_field.member =>
347                {
348                    let backtrace = &backtrace_field.member;
349                    let varsource = quote!(source);
350                    let source_provide = if type_is_option(source_field.ty) {
351                        quote_spanned! {backtrace.span()=>
352                            if let ::core::option::Option::Some(source) = #varsource {
353                                source.thiserror_provide(#request);
354                            }
355                        }
356                    } else {
357                        quote_spanned! {backtrace.span()=>
358                            #varsource.thiserror_provide(#request);
359                        }
360                    };
361                    quote! {
362                        #ty::#ident {#backtrace: #varsource, ..} => {
363                            use ::wherror::__private::ThiserrorProvide as _;
364                            #source_provide
365                        }
366                    }
367                }
368                (Some(backtrace_field), _) => {
369                    let backtrace = &backtrace_field.member;
370                    let body = if type_is_option(backtrace_field.ty) {
371                        quote! {
372                            if let ::core::option::Option::Some(backtrace) = backtrace {
373                                #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
374                            }
375                        }
376                    } else {
377                        quote! {
378                            #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
379                        }
380                    };
381                    quote! {
382                        #ty::#ident {#backtrace: backtrace, ..} => {
383                            #body
384                        }
385                    }
386                }
387                (None, _) => quote! {
388                    #ty::#ident {..} => {}
389                },
390            }
391        });
392        Some(quote! {
393            fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
394                #[allow(deprecated)]
395                match self {
396                    #(#arms)*
397                }
398            }
399        })
400    } else {
401        None
402    };
403
404    let display_impl = if input.has_display() {
405        let mut display_inferred_bounds = InferredBounds::new();
406        let has_bonus_display = input.variants.iter().any(|v| {
407            v.attrs
408                .display
409                .as_ref()
410                .map_or(false, |display| display.has_bonus_display)
411        });
412        let use_as_display = use_as_display(has_bonus_display);
413        let void_deref = if input.variants.is_empty() {
414            Some(quote!(*))
415        } else {
416            None
417        };
418        let arms = input.variants.iter().map(|variant| {
419            let mut display_implied_bounds = Set::new();
420            let display = if let Some(display) = &variant.attrs.display {
421                display_implied_bounds.clone_from(&display.implied_bounds);
422                display.to_token_stream()
423            } else if let Some(fmt) = &variant.attrs.fmt {
424                let fmt_path = &fmt.path;
425                let vars = variant.fields.iter().map(|field| match &field.member {
426                    MemberUnraw::Named(ident) => ident.to_local(),
427                    MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
428                });
429                quote!(#fmt_path(#(#vars,)* __formatter))
430            } else {
431                let only_field = match &variant.fields[0].member {
432                    MemberUnraw::Named(ident) => ident.to_local(),
433                    MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
434                };
435                display_implied_bounds.insert((0, Trait::Display));
436                quote!(::core::fmt::Display::fmt(#only_field, __formatter))
437            };
438            for (field, bound) in display_implied_bounds {
439                let field = &variant.fields[field];
440                if field.contains_generic {
441                    display_inferred_bounds.insert(field.ty, bound);
442                }
443            }
444            let ident = &variant.ident;
445            let pat = fields_pat(&variant.fields);
446            quote! {
447                #ty::#ident #pat => #display
448            }
449        });
450        let arms = arms.collect::<Vec<_>>();
451        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
452        Some(quote! {
453            #[allow(unused_qualifications)]
454            #[automatically_derived]
455            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
456                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
457                    #use_as_display
458                    #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
459                    match #void_deref self {
460                        #(#arms,)*
461                    }
462                }
463            }
464        })
465    } else {
466        None
467    };
468
469    let from_impls = input.variants.iter().filter_map(|variant| {
470        let from_field = variant.from_field()?;
471        let span = from_field.attrs.from.unwrap().span;
472        let backtrace_field = variant.distinct_backtrace_field();
473        let location_field = variant.location_field();
474        let variant = &variant.ident;
475        let from = unoptional_type(from_field.ty);
476        let source_var = Ident::new("source", span);
477        let body = from_initializer(from_field, backtrace_field, &source_var, location_field);
478        let track_caller = location_field.map(|_| quote!(#[track_caller]));
479        let from_function = quote! {
480            #track_caller
481            fn from(#source_var: #from) -> Self {
482                #ty::#variant #body
483            }
484        };
485        let from_impl = quote_spanned! {span=>
486            #[automatically_derived]
487            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
488                #from_function
489            }
490        };
491        Some(quote! {
492            #[allow(
493                deprecated,
494                unused_qualifications,
495                clippy::elidable_lifetime_names,
496                clippy::needless_lifetimes,
497            )]
498            #from_impl
499        })
500    });
501
502    let location_impl = if input.has_location() {
503        let arms = input.variants.iter().map(|variant| {
504            let ident = &variant.ident;
505            if let Some(location_field) = variant.location_field() {
506                let location = &location_field.member;
507                let var_location = quote!(location);
508                let body = if type_is_option(location_field.ty) {
509                    quote! {
510                        #var_location
511                    }
512                } else {
513                    quote! {
514                        Some(#var_location)
515                    }
516                };
517                quote! {
518                    #ty::#ident {#location: #var_location, ..} => #body,
519                }
520            } else {
521                quote! {
522                    #ty::#ident {..} => None,
523                }
524            }
525        });
526        Some(quote! {
527            #[allow(unused_qualifications)]
528            #[automatically_derived]
529            impl #impl_generics #ty #ty_generics #where_clause {
530                pub fn location(&self) -> Option<&'static ::core::panic::Location<'static>> {
531                    #[allow(deprecated)]
532                    match self {
533                        #(#arms)*
534                    }
535                }
536            }
537        })
538    } else {
539        None
540    };
541
542    if input.generics.type_params().next().is_some() {
543        let self_token = <Token![Self]>::default();
544        error_inferred_bounds.insert(self_token, Trait::Debug);
545        error_inferred_bounds.insert(self_token, Trait::Display);
546    }
547    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
548
549    quote! {
550        #[allow(unused_qualifications)]
551        #[automatically_derived]
552        impl #impl_generics ::wherror::__private::Error for #ty #ty_generics #error_where_clause {
553            #source_method
554            #provide_method
555        }
556        #display_impl
557        #(#from_impls)*
558        #location_impl
559    }
560}
561
562// Create an ident with which we can expand `impl Trait for #ident {}` on a
563// deprecated type without triggering deprecation warning on the generated impl.
564pub(crate) fn call_site_ident(ident: &Ident) -> Ident {
565    let mut ident = ident.clone();
566    ident.set_span(ident.span().resolved_at(Span::call_site()));
567    ident
568}
569
570fn fields_pat(fields: &[Field]) -> TokenStream {
571    let mut members = fields.iter().map(|field| &field.member).peekable();
572    match members.peek() {
573        Some(MemberUnraw::Named(_)) => quote!({ #(#members),* }),
574        Some(MemberUnraw::Unnamed(_)) => {
575            let vars = members.map(|member| match member {
576                MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
577                MemberUnraw::Named(_) => unreachable!(),
578            });
579            quote!((#(#vars),*))
580        }
581        None => quote!({}),
582    }
583}
584
585fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
586    if needs_as_display {
587        Some(quote! {
588            use ::wherror::__private::AsDisplay as _;
589        })
590    } else {
591        None
592    }
593}
594
595fn from_initializer(
596    from_field: &Field,
597    backtrace_field: Option<&Field>,
598    source_var: &Ident,
599    location_field: Option<&Field>,
600) -> TokenStream {
601    let from_member = &from_field.member;
602    let some_source = if type_is_option(from_field.ty) {
603        quote!(::core::option::Option::Some(#source_var))
604    } else {
605        quote!(#source_var)
606    };
607    let backtrace = backtrace_field.map(|backtrace_field| {
608        let backtrace_member = &backtrace_field.member;
609        if type_is_option(backtrace_field.ty) {
610            quote! {
611                #backtrace_member: ::core::option::Option::Some(::wherror::__private::Backtrace::capture()),
612            }
613        } else {
614            quote! {
615                #backtrace_member: ::core::convert::From::from(::wherror::__private::Backtrace::capture()),
616            }
617        }
618    });
619    let location = location_field.map(|location_field| {
620        let location_member = &location_field.member;
621
622        if type_is_option(location_field.ty) {
623            quote! {
624                #location_member: ::core::option::Option::Some(::core::panic::Location::caller()),
625            }
626        } else {
627            quote! {
628                #location_member: ::core::convert::From::from(::core::panic::Location::caller()),
629            }
630        }
631    });
632    quote!({
633        #from_member: #some_source,
634        #backtrace
635        #location
636    })
637}
638
639fn type_is_option(ty: &Type) -> bool {
640    type_parameter_of_option(ty).is_some()
641}
642
643fn unoptional_type(ty: &Type) -> TokenStream {
644    let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
645    quote!(#unoptional)
646}
647
648fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
649    let path = match ty {
650        Type::Path(ty) => &ty.path,
651        _ => return None,
652    };
653
654    let last = path.segments.last().unwrap();
655    if last.ident != "Option" {
656        return None;
657    }
658
659    let bracketed = match &last.arguments {
660        PathArguments::AngleBracketed(bracketed) => bracketed,
661        _ => return None,
662    };
663
664    if bracketed.args.len() != 1 {
665        return None;
666    }
667
668    match &bracketed.args[0] {
669        GenericArgument::Type(arg) => Some(arg),
670        _ => None,
671    }
672}