thiserror_nostd_notrait_impl/
expand.rs

1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::generics::InferredBounds;
4#[cfg(feature = "std")]
5use crate::span::MemberSpan;
6extern crate alloc;
7use alloc::collections::BTreeSet as Set;
8use proc_macro2::TokenStream;
9#[cfg(feature = "std")]
10use quote::quote_spanned;
11use quote::{format_ident, quote, ToTokens};
12#[cfg(feature = "std")]
13use syn::Token;
14use syn::{DeriveInput, GenericArgument, Member, PathArguments, Result, Type};
15
16pub fn derive(input: &DeriveInput) -> TokenStream {
17    match try_expand(input) {
18        Ok(expanded) => expanded,
19        // If there are invalid attributes in the input, expand to an Error impl
20        // anyway to minimize spurious knock-on errors in other code that uses
21        // this type as an Error.
22        Err(error) => fallback(input, error),
23    }
24}
25
26fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
27    let input = Input::from_syn(input)?;
28    input.validate()?;
29    Ok(match input {
30        Input::Struct(input) => impl_struct(input),
31        Input::Enum(input) => impl_enum(input),
32    })
33}
34
35fn fallback(input: &DeriveInput, error: syn::Error) -> TokenStream {
36    let ty = &input.ident;
37    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
38
39    let error = error.to_compile_error();
40
41    quote! {
42        #error
43
44        #[allow(unused_qualifications)]
45        impl #impl_generics std::error::Error for #ty #ty_generics #where_clause
46        where
47            // Work around trivial bounds being unstable.
48            // https://github.com/rust-lang/rust/issues/48214
49            for<'workaround> #ty #ty_generics: ::core::fmt::Debug,
50        {}
51
52        #[allow(unused_qualifications)]
53        impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
54            fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
55                ::core::unreachable!()
56            }
57        }
58    }
59}
60
61fn impl_struct(input: Struct) -> TokenStream {
62    let ty = &input.ident;
63    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
64    #[cfg(feature = "std")]
65    let mut error_inferred_bounds = InferredBounds::new();
66
67    #[cfg(feature = "std")]
68    let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
69        let only_field = &input.fields[0];
70        if only_field.contains_generic {
71            error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
72        }
73        let member = &only_field.member;
74        Some(quote_spanned! {transparent_attr.span=>
75            std::error::Error::source(self.#member.as_dyn_error())
76        })
77    } else if let Some(source_field) = input.source_field() {
78        let source = &source_field.member;
79        if source_field.contains_generic {
80            let ty = unoptional_type(source_field.ty);
81            error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
82        }
83        let asref = if type_is_option(source_field.ty) {
84            Some(quote_spanned!(source.member_span()=> .as_ref()?))
85        } else {
86            None
87        };
88        let dyn_error = quote_spanned! {source_field.source_span()=>
89            self.#source #asref.as_dyn_error()
90        };
91        Some(quote! {
92            ::core::option::Option::Some(#dyn_error)
93        })
94    } else {
95        None
96    };
97    #[cfg(feature = "std")]
98    let source_method = source_body.map(|body| {
99        quote! {
100            fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
101                use thiserror_nostd_notrait::__private::AsDynError as _;
102                #body
103            }
104        }
105    });
106
107    #[cfg(feature = "std")]
108    let provide_method = input.backtrace_field().map(|backtrace_field| {
109        let request = quote!(request);
110        let backtrace = &backtrace_field.member;
111        let body = if let Some(source_field) = input.source_field() {
112            let source = &source_field.member;
113            let source_provide = if type_is_option(source_field.ty) {
114                quote_spanned! {source.member_span()=>
115                    if let ::core::option::Option::Some(source) = &self.#source {
116                        source.thiserror_provide(#request);
117                    }
118                }
119            } else {
120                quote_spanned! {source.member_span()=>
121                    self.#source.thiserror_provide(#request);
122                }
123            };
124            let self_provide = if source == backtrace {
125                None
126            } else if type_is_option(backtrace_field.ty) {
127                Some(quote! {
128                    if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
129                        #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
130                    }
131                })
132            } else {
133                Some(quote! {
134                    #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
135                })
136            };
137            quote! {
138                use thiserror_nostd_notrait::__private::ThiserrorProvide as _;
139                #source_provide
140                #self_provide
141            }
142        } else if type_is_option(backtrace_field.ty) {
143            quote! {
144                if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
145                    #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
146                }
147            }
148        } else {
149            quote! {
150                #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
151            }
152        };
153        quote! {
154            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
155                #body
156            }
157        }
158    });
159
160    let mut display_implied_bounds = Set::new();
161    let display_body = if input.attrs.transparent.is_some() {
162        let only_field = &input.fields[0].member;
163        display_implied_bounds.insert((0, Trait::Display));
164        Some(quote! {
165            ::core::fmt::Display::fmt(&self.#only_field, __formatter)
166        })
167    } else if let Some(display) = &input.attrs.display {
168        display_implied_bounds = display.implied_bounds.clone();
169        let use_as_display = use_as_display(display.has_bonus_display);
170        let pat = fields_pat(&input.fields);
171        Some(quote! {
172            #use_as_display
173            #[allow(unused_variables, deprecated)]
174            let Self #pat = self;
175            #display
176        })
177    } else {
178        None
179    };
180    let display_impl = display_body.map(|body| {
181        let mut display_inferred_bounds = InferredBounds::new();
182        for (field, bound) in display_implied_bounds {
183            let field = &input.fields[field];
184            if field.contains_generic {
185                display_inferred_bounds.insert(field.ty, bound);
186            }
187        }
188        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
189        quote! {
190            #[allow(unused_qualifications)]
191            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
192                #[allow(clippy::used_underscore_binding)]
193                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
194                    #body
195                }
196            }
197        }
198    });
199
200    let from_impl = input.from_field().map(|from_field| {
201        let backtrace_field = input.distinct_backtrace_field();
202        let from = unoptional_type(from_field.ty);
203        let body = from_initializer(from_field, backtrace_field);
204        quote! {
205            #[allow(unused_qualifications)]
206            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
207                #[allow(deprecated)]
208                fn from(source: #from) -> Self {
209                    #ty #body
210                }
211            }
212        }
213    });
214
215    #[cfg(feature = "std")]
216    if input.generics.type_params().next().is_some() {
217        let self_token = <Token![Self]>::default();
218        error_inferred_bounds.insert(self_token, Trait::Debug);
219        error_inferred_bounds.insert(self_token, Trait::Display);
220    }
221    #[cfg(feature = "std")]
222    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
223
224    #[cfg(feature = "std")]
225    quote! {
226        #[allow(unused_qualifications)]
227        impl #impl_generics std::error::Error for #ty #ty_generics #error_where_clause {
228            #source_method
229            #provide_method
230        }
231        #display_impl
232        #from_impl
233    }
234    #[cfg(not(feature = "std"))]
235    quote! {
236        #display_impl
237        #from_impl
238    }
239}
240
241fn impl_enum(input: Enum) -> TokenStream {
242    let ty = &input.ident;
243    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
244    #[cfg(feature = "std")]
245    let mut error_inferred_bounds = InferredBounds::new();
246
247    #[cfg(feature = "std")]
248    let source_method = if input.has_source() {
249        let arms = input.variants.iter().map(|variant| {
250            let ident = &variant.ident;
251            if let Some(transparent_attr) = &variant.attrs.transparent {
252                let only_field = &variant.fields[0];
253                if only_field.contains_generic {
254                    error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
255                }
256                let member = &only_field.member;
257                let source = quote_spanned! {transparent_attr.span=>
258                    std::error::Error::source(transparent.as_dyn_error())
259                };
260                quote! {
261                    #ty::#ident {#member: transparent} => #source,
262                }
263            } else if let Some(source_field) = variant.source_field() {
264                let source = &source_field.member;
265                if source_field.contains_generic {
266                    let ty = unoptional_type(source_field.ty);
267                    error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
268                }
269                let asref = if type_is_option(source_field.ty) {
270                    Some(quote_spanned!(source.member_span()=> .as_ref()?))
271                } else {
272                    None
273                };
274                let varsource = quote!(source);
275                let dyn_error = quote_spanned! {source_field.source_span()=>
276                    #varsource #asref.as_dyn_error()
277                };
278                quote! {
279                    #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
280                }
281            } else {
282                quote! {
283                    #ty::#ident {..} => ::core::option::Option::None,
284                }
285            }
286        });
287        Some(quote! {
288            fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
289                use thiserror_nostd_notrait::__private::AsDynError as _;
290                #[allow(deprecated)]
291                match self {
292                    #(#arms)*
293                }
294            }
295        })
296    } else {
297        None
298    };
299
300    #[cfg(feature = "std")]
301    let provide_method = if input.has_backtrace() {
302        let request = quote!(request);
303        let arms = input.variants.iter().map(|variant| {
304            let ident = &variant.ident;
305            match (variant.backtrace_field(), variant.source_field()) {
306                (Some(backtrace_field), Some(source_field))
307                    if backtrace_field.attrs.backtrace.is_none() =>
308                {
309                    let backtrace = &backtrace_field.member;
310                    let source = &source_field.member;
311                    let varsource = quote!(source);
312                    let source_provide = if type_is_option(source_field.ty) {
313                        quote_spanned! {source.member_span()=>
314                            if let ::core::option::Option::Some(source) = #varsource {
315                                source.thiserror_provide(#request);
316                            }
317                        }
318                    } else {
319                        quote_spanned! {source.member_span()=>
320                            #varsource.thiserror_provide(#request);
321                        }
322                    };
323                    let self_provide = if type_is_option(backtrace_field.ty) {
324                        quote! {
325                            if let ::core::option::Option::Some(backtrace) = backtrace {
326                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
327                            }
328                        }
329                    } else {
330                        quote! {
331                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
332                        }
333                    };
334                    quote! {
335                        #ty::#ident {
336                            #backtrace: backtrace,
337                            #source: #varsource,
338                            ..
339                        } => {
340                            use thiserror_nostd_notrait::__private::ThiserrorProvide as _;
341                            #source_provide
342                            #self_provide
343                        }
344                    }
345                }
346                (Some(backtrace_field), Some(source_field))
347                    if backtrace_field.member == source_field.member =>
348                {
349                    let backtrace = &backtrace_field.member;
350                    let varsource = quote!(source);
351                    let source_provide = if type_is_option(source_field.ty) {
352                        quote_spanned! {backtrace.member_span()=>
353                            if let ::core::option::Option::Some(source) = #varsource {
354                                source.thiserror_provide(#request);
355                            }
356                        }
357                    } else {
358                        quote_spanned! {backtrace.member_span()=>
359                            #varsource.thiserror_provide(#request);
360                        }
361                    };
362                    quote! {
363                        #ty::#ident {#backtrace: #varsource, ..} => {
364                            use thiserror_nostd_notrait::__private::ThiserrorProvide as _;
365                            #source_provide
366                        }
367                    }
368                }
369                (Some(backtrace_field), _) => {
370                    let backtrace = &backtrace_field.member;
371                    let body = if type_is_option(backtrace_field.ty) {
372                        quote! {
373                            if let ::core::option::Option::Some(backtrace) = backtrace {
374                                #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
375                            }
376                        }
377                    } else {
378                        quote! {
379                            #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
380                        }
381                    };
382                    quote! {
383                        #ty::#ident {#backtrace: backtrace, ..} => {
384                            #body
385                        }
386                    }
387                }
388                (None, _) => quote! {
389                    #ty::#ident {..} => {}
390                },
391            }
392        });
393        Some(quote! {
394            fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
395                #[allow(deprecated)]
396                match self {
397                    #(#arms)*
398                }
399            }
400        })
401    } else {
402        None
403    };
404
405    let display_impl = if input.has_display() {
406        let mut display_inferred_bounds = InferredBounds::new();
407        let has_bonus_display = input.variants.iter().any(|v| {
408            v.attrs
409                .display
410                .as_ref()
411                .map_or(false, |display| display.has_bonus_display)
412        });
413        let use_as_display = use_as_display(has_bonus_display);
414        let void_deref = if input.variants.is_empty() {
415            Some(quote!(*))
416        } else {
417            None
418        };
419        let arms = input.variants.iter().map(|variant| {
420            let mut display_implied_bounds = Set::new();
421            let display = match &variant.attrs.display {
422                Some(display) => {
423                    display_implied_bounds = display.implied_bounds.clone();
424                    display.to_token_stream()
425                }
426                None => {
427                    let only_field = match &variant.fields[0].member {
428                        Member::Named(ident) => ident.clone(),
429                        Member::Unnamed(index) => format_ident!("_{}", index),
430                    };
431                    display_implied_bounds.insert((0, Trait::Display));
432                    quote!(::core::fmt::Display::fmt(#only_field, __formatter))
433                }
434            };
435            for (field, bound) in display_implied_bounds {
436                let field = &variant.fields[field];
437                if field.contains_generic {
438                    display_inferred_bounds.insert(field.ty, bound);
439                }
440            }
441            let ident = &variant.ident;
442            let pat = fields_pat(&variant.fields);
443            quote! {
444                #ty::#ident #pat => #display
445            }
446        });
447        let arms = arms.collect::<Vec<_>>();
448        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
449        Some(quote! {
450            #[allow(unused_qualifications)]
451            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
452                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
453                    #use_as_display
454                    #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
455                    match #void_deref self {
456                        #(#arms,)*
457                    }
458                }
459            }
460        })
461    } else {
462        None
463    };
464
465    let from_impls = input.variants.iter().filter_map(|variant| {
466        let from_field = variant.from_field()?;
467        let backtrace_field = variant.distinct_backtrace_field();
468        let variant = &variant.ident;
469        let from = unoptional_type(from_field.ty);
470        let body = from_initializer(from_field, backtrace_field);
471        Some(quote! {
472            #[allow(unused_qualifications)]
473            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
474                #[allow(deprecated)]
475                fn from(source: #from) -> Self {
476                    #ty::#variant #body
477                }
478            }
479        })
480    });
481
482    #[cfg(feature = "std")]
483    if input.generics.type_params().next().is_some() {
484        let self_token = <Token![Self]>::default();
485        error_inferred_bounds.insert(self_token, Trait::Debug);
486        error_inferred_bounds.insert(self_token, Trait::Display);
487    }
488    #[cfg(feature = "std")]
489    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
490
491    #[cfg(feature = "std")]
492    quote! {
493        #[allow(unused_qualifications)]
494        impl #impl_generics std::error::Error for #ty #ty_generics #error_where_clause {
495            #source_method
496            #provide_method
497        }
498        #display_impl
499        #(#from_impls)*
500    }
501    #[cfg(not(feature = "std"))]
502    quote! {
503        #display_impl
504        #(#from_impls)*
505    }
506}
507
508fn fields_pat(fields: &[Field]) -> TokenStream {
509    let mut members = fields.iter().map(|field| &field.member).peekable();
510    match members.peek() {
511        Some(Member::Named(_)) => quote!({ #(#members),* }),
512        Some(Member::Unnamed(_)) => {
513            let vars = members.map(|member| match member {
514                Member::Unnamed(member) => format_ident!("_{}", member),
515                Member::Named(_) => unreachable!(),
516            });
517            quote!((#(#vars),*))
518        }
519        None => quote!({}),
520    }
521}
522
523fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
524    if needs_as_display {
525        Some(quote! {
526            use thiserror_nostd_notrait::__private::AsDisplay as _;
527        })
528    } else {
529        None
530    }
531}
532
533fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
534    let from_member = &from_field.member;
535    let some_source = if type_is_option(from_field.ty) {
536        quote!(::core::option::Option::Some(source))
537    } else {
538        quote!(source)
539    };
540    let backtrace = backtrace_field.map(|backtrace_field| {
541        let backtrace_member = &backtrace_field.member;
542        if type_is_option(backtrace_field.ty) {
543            quote! {
544                #backtrace_member: ::core::option::Option::Some(std::backtrace::Backtrace::capture()),
545            }
546        } else {
547            quote! {
548                #backtrace_member: ::core::convert::From::from(std::backtrace::Backtrace::capture()),
549            }
550        }
551    });
552    quote!({
553        #from_member: #some_source,
554        #backtrace
555    })
556}
557
558fn type_is_option(ty: &Type) -> bool {
559    type_parameter_of_option(ty).is_some()
560}
561
562fn unoptional_type(ty: &Type) -> TokenStream {
563    let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
564    quote!(#unoptional)
565}
566
567fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
568    let path = match ty {
569        Type::Path(ty) => &ty.path,
570        _ => return None,
571    };
572
573    let last = path.segments.last().unwrap();
574    if last.ident != "Option" {
575        return None;
576    }
577
578    let bracketed = match &last.arguments {
579        PathArguments::AngleBracketed(bracketed) => bracketed,
580        _ => return None,
581    };
582
583    if bracketed.args.len() != 1 {
584        return None;
585    }
586
587    match &bracketed.args[0] {
588        GenericArgument::Type(arg) => Some(arg),
589        _ => None,
590    }
591}