thiserror_core_impl/
expand.rs

1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::generics::InferredBounds;
4use crate::span::MemberSpan;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote, quote_spanned, ToTokens};
7use std::collections::BTreeSet as Set;
8use syn::{
9    Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility,
10};
11
12pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
13    let input = Input::from_syn(node)?;
14    input.validate()?;
15    Ok(match input {
16        Input::Struct(input) => impl_struct(input),
17        Input::Enum(input) => impl_enum(input),
18    })
19}
20
21fn impl_struct(input: Struct) -> TokenStream {
22    let ty = &input.ident;
23    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
24    let mut error_inferred_bounds = InferredBounds::new();
25
26    let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
27        let only_field = &input.fields[0];
28        if only_field.contains_generic {
29            error_inferred_bounds.insert(only_field.ty, quote!(thiserror::__private::Error));
30        }
31        let member = &only_field.member;
32        Some(quote_spanned! {transparent_attr.span=>
33            thiserror::__private::Error::source(self.#member.as_dyn_error())
34        })
35    } else if let Some(source_field) = input.source_field() {
36        let source = &source_field.member;
37        if source_field.contains_generic {
38            let ty = unoptional_type(source_field.ty);
39            error_inferred_bounds.insert(ty, quote!(thiserror::__private::Error + 'static));
40        }
41        let asref = if type_is_option(source_field.ty) {
42            Some(quote_spanned!(source.member_span()=> .as_ref()?))
43        } else {
44            None
45        };
46        let dyn_error = quote_spanned! {source_field.source_span()=>
47            self.#source #asref.as_dyn_error()
48        };
49        Some(quote! {
50            ::core::option::Option::Some(#dyn_error)
51        })
52    } else {
53        None
54    };
55    let source_method = source_body.map(|body| {
56        quote! {
57            fn source(&self) -> ::core::option::Option<&(dyn thiserror::__private::Error + 'static)> {
58                use thiserror::__private::AsDynError;
59                #body
60            }
61        }
62    });
63
64    let provide_method = input.backtrace_field().map(|backtrace_field| {
65        let request = quote!(request);
66        let backtrace = &backtrace_field.member;
67        let body = if let Some(source_field) = input.source_field() {
68            let source = &source_field.member;
69            let source_provide = if type_is_option(source_field.ty) {
70                quote_spanned! {source.member_span()=>
71                    if let ::core::option::Option::Some(source) = &self.#source {
72                        source.thiserror_provide(#request);
73                    }
74                }
75            } else {
76                quote_spanned! {source.member_span()=>
77                    self.#source.thiserror_provide(#request);
78                }
79            };
80            let self_provide = if source == backtrace {
81                None
82            } else if type_is_option(backtrace_field.ty) {
83                Some(quote! {
84                    if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
85                        #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
86                    }
87                })
88            } else {
89                Some(quote! {
90                    #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
91                })
92            };
93            quote! {
94                use thiserror::__private::ThiserrorProvide;
95                #source_provide
96                #self_provide
97            }
98        } else if type_is_option(backtrace_field.ty) {
99            quote! {
100                if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
101                    #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
102                }
103            }
104        } else {
105            quote! {
106                #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
107            }
108        };
109        quote! {
110            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
111                #body
112            }
113        }
114    });
115
116    let mut display_implied_bounds = Set::new();
117    let display_body = if input.attrs.transparent.is_some() {
118        let only_field = &input.fields[0].member;
119        display_implied_bounds.insert((0, Trait::Display));
120        Some(quote! {
121            ::core::fmt::Display::fmt(&self.#only_field, __formatter)
122        })
123    } else if let Some(display) = &input.attrs.display {
124        display_implied_bounds = display.implied_bounds.clone();
125        let use_as_display = use_as_display(display.has_bonus_display);
126        let pat = fields_pat(&input.fields);
127        Some(quote! {
128            #use_as_display
129            #[allow(unused_variables, deprecated)]
130            let Self #pat = self;
131            #display
132        })
133    } else {
134        None
135    };
136    let display_impl = display_body.map(|body| {
137        let mut display_inferred_bounds = InferredBounds::new();
138        for (field, bound) in display_implied_bounds {
139            let field = &input.fields[field];
140            if field.contains_generic {
141                display_inferred_bounds.insert(field.ty, bound);
142            }
143        }
144        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
145        quote! {
146            #[allow(unused_qualifications)]
147            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
148                #[allow(clippy::used_underscore_binding)]
149                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
150                    #body
151                }
152            }
153        }
154    });
155
156    let from_impl = input.from_field().map(|from_field| {
157        let backtrace_field = input.distinct_backtrace_field();
158        let from = unoptional_type(from_field.ty);
159        let body = from_initializer(from_field, backtrace_field);
160        quote! {
161            #[allow(unused_qualifications)]
162            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
163                #[allow(deprecated)]
164                fn from(source: #from) -> Self {
165                    #ty #body
166                }
167            }
168        }
169    });
170
171    let error_trait = spanned_error_trait(input.original);
172    if input.generics.type_params().next().is_some() {
173        let self_token = <Token![Self]>::default();
174        error_inferred_bounds.insert(self_token, Trait::Debug);
175        error_inferred_bounds.insert(self_token, Trait::Display);
176    }
177    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
178
179    quote! {
180        #[allow(unused_qualifications)]
181        impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
182            #source_method
183            #provide_method
184        }
185        #display_impl
186        #from_impl
187    }
188}
189
190fn impl_enum(input: Enum) -> TokenStream {
191    let ty = &input.ident;
192    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
193    let mut error_inferred_bounds = InferredBounds::new();
194
195    let source_method = if input.has_source() {
196        let arms = input.variants.iter().map(|variant| {
197            let ident = &variant.ident;
198            if let Some(transparent_attr) = &variant.attrs.transparent {
199                let only_field = &variant.fields[0];
200                if only_field.contains_generic {
201                    error_inferred_bounds.insert(only_field.ty, quote!(thiserror::__private::Error));
202                }
203                let member = &only_field.member;
204                let source = quote_spanned! {transparent_attr.span=>
205                    thiserror::__private::Error::source(transparent.as_dyn_error())
206                };
207                quote! {
208                    #ty::#ident {#member: transparent} => #source,
209                }
210            } else if let Some(source_field) = variant.source_field() {
211                let source = &source_field.member;
212                if source_field.contains_generic {
213                    let ty = unoptional_type(source_field.ty);
214                    error_inferred_bounds.insert(ty, quote!(thiserror::__private::Error + 'static));
215                }
216                let asref = if type_is_option(source_field.ty) {
217                    Some(quote_spanned!(source.member_span()=> .as_ref()?))
218                } else {
219                    None
220                };
221                let varsource = quote!(source);
222                let dyn_error = quote_spanned! {source_field.source_span()=>
223                    #varsource #asref.as_dyn_error()
224                };
225                quote! {
226                    #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
227                }
228            } else {
229                quote! {
230                    #ty::#ident {..} => ::core::option::Option::None,
231                }
232            }
233        });
234        Some(quote! {
235            fn source(&self) -> ::core::option::Option<&(dyn thiserror::__private::Error + 'static)> {
236                use thiserror::__private::AsDynError;
237                #[allow(deprecated)]
238                match self {
239                    #(#arms)*
240                }
241            }
242        })
243    } else {
244        None
245    };
246
247    let provide_method = if input.has_backtrace() {
248        let request = quote!(request);
249        let arms = input.variants.iter().map(|variant| {
250            let ident = &variant.ident;
251            match (variant.backtrace_field(), variant.source_field()) {
252                (Some(backtrace_field), Some(source_field))
253                    if backtrace_field.attrs.backtrace.is_none() =>
254                {
255                    let backtrace = &backtrace_field.member;
256                    let source = &source_field.member;
257                    let varsource = quote!(source);
258                    let source_provide = if type_is_option(source_field.ty) {
259                        quote_spanned! {source.member_span()=>
260                            if let ::core::option::Option::Some(source) = #varsource {
261                                source.thiserror_provide(#request);
262                            }
263                        }
264                    } else {
265                        quote_spanned! {source.member_span()=>
266                            #varsource.thiserror_provide(#request);
267                        }
268                    };
269                    let self_provide = if type_is_option(backtrace_field.ty) {
270                        quote! {
271                            if let ::core::option::Option::Some(backtrace) = backtrace {
272                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
273                            }
274                        }
275                    } else {
276                        quote! {
277                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
278                        }
279                    };
280                    quote! {
281                        #ty::#ident {
282                            #backtrace: backtrace,
283                            #source: #varsource,
284                            ..
285                        } => {
286                            use thiserror::__private::ThiserrorProvide;
287                            #source_provide
288                            #self_provide
289                        }
290                    }
291                }
292                (Some(backtrace_field), Some(source_field))
293                    if backtrace_field.member == source_field.member =>
294                {
295                    let backtrace = &backtrace_field.member;
296                    let varsource = quote!(source);
297                    let source_provide = if type_is_option(source_field.ty) {
298                        quote_spanned! {backtrace.member_span()=>
299                            if let ::core::option::Option::Some(source) = #varsource {
300                                source.thiserror_provide(#request);
301                            }
302                        }
303                    } else {
304                        quote_spanned! {backtrace.member_span()=>
305                            #varsource.thiserror_provide(#request);
306                        }
307                    };
308                    quote! {
309                        #ty::#ident {#backtrace: #varsource, ..} => {
310                            use thiserror::__private::ThiserrorProvide;
311                            #source_provide
312                        }
313                    }
314                }
315                (Some(backtrace_field), _) => {
316                    let backtrace = &backtrace_field.member;
317                    let body = if type_is_option(backtrace_field.ty) {
318                        quote! {
319                            if let ::core::option::Option::Some(backtrace) = backtrace {
320                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
321                            }
322                        }
323                    } else {
324                        quote! {
325                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
326                        }
327                    };
328                    quote! {
329                        #ty::#ident {#backtrace: backtrace, ..} => {
330                            #body
331                        }
332                    }
333                }
334                (None, _) => quote! {
335                    #ty::#ident {..} => {}
336                },
337            }
338        });
339        Some(quote! {
340            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
341                #[allow(deprecated)]
342                match self {
343                    #(#arms)*
344                }
345            }
346        })
347    } else {
348        None
349    };
350
351    let display_impl = if input.has_display() {
352        let mut display_inferred_bounds = InferredBounds::new();
353        let has_bonus_display = input.variants.iter().any(|v| {
354            v.attrs
355                .display
356                .as_ref()
357                .map_or(false, |display| display.has_bonus_display)
358        });
359        let use_as_display = use_as_display(has_bonus_display);
360        let void_deref = if input.variants.is_empty() {
361            Some(quote!(*))
362        } else {
363            None
364        };
365        let arms = input.variants.iter().map(|variant| {
366            let mut display_implied_bounds = Set::new();
367            let display = match &variant.attrs.display {
368                Some(display) => {
369                    display_implied_bounds = display.implied_bounds.clone();
370                    display.to_token_stream()
371                }
372                None => {
373                    let only_field = match &variant.fields[0].member {
374                        Member::Named(ident) => ident.clone(),
375                        Member::Unnamed(index) => format_ident!("_{}", index),
376                    };
377                    display_implied_bounds.insert((0, Trait::Display));
378                    quote!(::core::fmt::Display::fmt(#only_field, __formatter))
379                }
380            };
381            for (field, bound) in display_implied_bounds {
382                let field = &variant.fields[field];
383                if field.contains_generic {
384                    display_inferred_bounds.insert(field.ty, bound);
385                }
386            }
387            let ident = &variant.ident;
388            let pat = fields_pat(&variant.fields);
389            quote! {
390                #ty::#ident #pat => #display
391            }
392        });
393        let arms = arms.collect::<Vec<_>>();
394        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
395        Some(quote! {
396            #[allow(unused_qualifications)]
397            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
398                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
399                    #use_as_display
400                    #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
401                    match #void_deref self {
402                        #(#arms,)*
403                    }
404                }
405            }
406        })
407    } else {
408        None
409    };
410
411    let from_impls = input.variants.iter().filter_map(|variant| {
412        let from_field = variant.from_field()?;
413        let backtrace_field = variant.distinct_backtrace_field();
414        let variant = &variant.ident;
415        let from = unoptional_type(from_field.ty);
416        let body = from_initializer(from_field, backtrace_field);
417        Some(quote! {
418            #[allow(unused_qualifications)]
419            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
420                #[allow(deprecated)]
421                fn from(source: #from) -> Self {
422                    #ty::#variant #body
423                }
424            }
425        })
426    });
427
428    let error_trait = spanned_error_trait(input.original);
429    if input.generics.type_params().next().is_some() {
430        let self_token = <Token![Self]>::default();
431        error_inferred_bounds.insert(self_token, Trait::Debug);
432        error_inferred_bounds.insert(self_token, Trait::Display);
433    }
434    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
435
436    quote! {
437        #[allow(unused_qualifications)]
438        impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
439            #source_method
440            #provide_method
441        }
442        #display_impl
443        #(#from_impls)*
444    }
445}
446
447fn fields_pat(fields: &[Field]) -> TokenStream {
448    let mut members = fields.iter().map(|field| &field.member).peekable();
449    match members.peek() {
450        Some(Member::Named(_)) => quote!({ #(#members),* }),
451        Some(Member::Unnamed(_)) => {
452            let vars = members.map(|member| match member {
453                Member::Unnamed(member) => format_ident!("_{}", member),
454                Member::Named(_) => unreachable!(),
455            });
456            quote!((#(#vars),*))
457        }
458        None => quote!({}),
459    }
460}
461
462fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
463    if needs_as_display {
464        Some(quote! {
465            use thiserror::__private::AsDisplay as _;
466        })
467    } else {
468        None
469    }
470}
471
472fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
473    let from_member = &from_field.member;
474    let some_source = if type_is_option(from_field.ty) {
475        quote!(::core::option::Option::Some(source))
476    } else {
477        quote!(source)
478    };
479    let backtrace = backtrace_field.map(|backtrace_field| {
480        let backtrace_member = &backtrace_field.member;
481        if type_is_option(backtrace_field.ty) {
482            quote! {
483                #backtrace_member: ::core::option::Option::Some(std::backtrace::Backtrace::capture()),
484            }
485        } else {
486            quote! {
487                #backtrace_member: ::core::convert::From::from(std::backtrace::Backtrace::capture()),
488            }
489        }
490    });
491    quote!({
492        #from_member: #some_source,
493        #backtrace
494    })
495}
496
497fn type_is_option(ty: &Type) -> bool {
498    type_parameter_of_option(ty).is_some()
499}
500
501fn unoptional_type(ty: &Type) -> TokenStream {
502    let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
503    quote!(#unoptional)
504}
505
506fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
507    let path = match ty {
508        Type::Path(ty) => &ty.path,
509        _ => return None,
510    };
511
512    let last = path.segments.last().unwrap();
513    if last.ident != "Option" {
514        return None;
515    }
516
517    let bracketed = match &last.arguments {
518        PathArguments::AngleBracketed(bracketed) => bracketed,
519        _ => return None,
520    };
521
522    if bracketed.args.len() != 1 {
523        return None;
524    }
525
526    match &bracketed.args[0] {
527        GenericArgument::Type(arg) => Some(arg),
528        _ => None,
529    }
530}
531
532fn spanned_error_trait(input: &DeriveInput) -> TokenStream {
533    let vis_span = match &input.vis {
534        Visibility::Public(vis) => Some(vis.span),
535        Visibility::Restricted(vis) => Some(vis.pub_token.span),
536        Visibility::Inherited => None,
537    };
538    let data_span = match &input.data {
539        Data::Struct(data) => data.struct_token.span,
540        Data::Enum(data) => data.enum_token.span,
541        Data::Union(data) => data.union_token.span,
542    };
543    let first_span = vis_span.unwrap_or(data_span);
544    let last_span = input.ident.span();
545    let path = quote_spanned!(first_span=> thiserror::__private::);
546    let error = quote_spanned!(last_span=> Error);
547    quote!(#path #error)
548}