1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::generics::InferredBounds;
4use crate::span::MemberSpan;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote, quote_spanned, ToTokens};
7use std::collections::BTreeSet as Set;
8use syn::{
9 Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility,
10};
11
12pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
13 let input = Input::from_syn(node)?;
14 input.validate()?;
15 Ok(match input {
16 Input::Struct(input) => impl_struct(input),
17 Input::Enum(input) => impl_enum(input),
18 })
19}
20
21fn impl_struct(input: Struct) -> TokenStream {
22 let ty = &input.ident;
23 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
24 let mut error_inferred_bounds = InferredBounds::new();
25
26 let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
27 let only_field = &input.fields[0];
28 if only_field.contains_generic {
29 error_inferred_bounds.insert(only_field.ty, quote!(thiserror::__private::Error));
30 }
31 let member = &only_field.member;
32 Some(quote_spanned! {transparent_attr.span=>
33 thiserror::__private::Error::source(self.#member.as_dyn_error())
34 })
35 } else if let Some(source_field) = input.source_field() {
36 let source = &source_field.member;
37 if source_field.contains_generic {
38 let ty = unoptional_type(source_field.ty);
39 error_inferred_bounds.insert(ty, quote!(thiserror::__private::Error + 'static));
40 }
41 let asref = if type_is_option(source_field.ty) {
42 Some(quote_spanned!(source.member_span()=> .as_ref()?))
43 } else {
44 None
45 };
46 let dyn_error = quote_spanned! {source_field.source_span()=>
47 self.#source #asref.as_dyn_error()
48 };
49 Some(quote! {
50 ::core::option::Option::Some(#dyn_error)
51 })
52 } else {
53 None
54 };
55 let source_method = source_body.map(|body| {
56 quote! {
57 fn source(&self) -> ::core::option::Option<&(dyn thiserror::__private::Error + 'static)> {
58 use thiserror::__private::AsDynError;
59 #body
60 }
61 }
62 });
63
64 let provide_method = input.backtrace_field().map(|backtrace_field| {
65 let request = quote!(request);
66 let backtrace = &backtrace_field.member;
67 let body = if let Some(source_field) = input.source_field() {
68 let source = &source_field.member;
69 let source_provide = if type_is_option(source_field.ty) {
70 quote_spanned! {source.member_span()=>
71 if let ::core::option::Option::Some(source) = &self.#source {
72 source.thiserror_provide(#request);
73 }
74 }
75 } else {
76 quote_spanned! {source.member_span()=>
77 self.#source.thiserror_provide(#request);
78 }
79 };
80 let self_provide = if source == backtrace {
81 None
82 } else if type_is_option(backtrace_field.ty) {
83 Some(quote! {
84 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
85 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
86 }
87 })
88 } else {
89 Some(quote! {
90 #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
91 })
92 };
93 quote! {
94 use thiserror::__private::ThiserrorProvide;
95 #source_provide
96 #self_provide
97 }
98 } else if type_is_option(backtrace_field.ty) {
99 quote! {
100 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
101 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
102 }
103 }
104 } else {
105 quote! {
106 #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
107 }
108 };
109 quote! {
110 fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
111 #body
112 }
113 }
114 });
115
116 let mut display_implied_bounds = Set::new();
117 let display_body = if input.attrs.transparent.is_some() {
118 let only_field = &input.fields[0].member;
119 display_implied_bounds.insert((0, Trait::Display));
120 Some(quote! {
121 ::core::fmt::Display::fmt(&self.#only_field, __formatter)
122 })
123 } else if let Some(display) = &input.attrs.display {
124 display_implied_bounds = display.implied_bounds.clone();
125 let use_as_display = use_as_display(display.has_bonus_display);
126 let pat = fields_pat(&input.fields);
127 Some(quote! {
128 #use_as_display
129 #[allow(unused_variables, deprecated)]
130 let Self #pat = self;
131 #display
132 })
133 } else {
134 None
135 };
136 let display_impl = display_body.map(|body| {
137 let mut display_inferred_bounds = InferredBounds::new();
138 for (field, bound) in display_implied_bounds {
139 let field = &input.fields[field];
140 if field.contains_generic {
141 display_inferred_bounds.insert(field.ty, bound);
142 }
143 }
144 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
145 quote! {
146 #[allow(unused_qualifications)]
147 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
148 #[allow(clippy::used_underscore_binding)]
149 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
150 #body
151 }
152 }
153 }
154 });
155
156 let from_impl = input.from_field().map(|from_field| {
157 let backtrace_field = input.distinct_backtrace_field();
158 let from = unoptional_type(from_field.ty);
159 let body = from_initializer(from_field, backtrace_field);
160 quote! {
161 #[allow(unused_qualifications)]
162 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
163 #[allow(deprecated)]
164 fn from(source: #from) -> Self {
165 #ty #body
166 }
167 }
168 }
169 });
170
171 let error_trait = spanned_error_trait(input.original);
172 if input.generics.type_params().next().is_some() {
173 let self_token = <Token![Self]>::default();
174 error_inferred_bounds.insert(self_token, Trait::Debug);
175 error_inferred_bounds.insert(self_token, Trait::Display);
176 }
177 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
178
179 quote! {
180 #[allow(unused_qualifications)]
181 impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
182 #source_method
183 #provide_method
184 }
185 #display_impl
186 #from_impl
187 }
188}
189
190fn impl_enum(input: Enum) -> TokenStream {
191 let ty = &input.ident;
192 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
193 let mut error_inferred_bounds = InferredBounds::new();
194
195 let source_method = if input.has_source() {
196 let arms = input.variants.iter().map(|variant| {
197 let ident = &variant.ident;
198 if let Some(transparent_attr) = &variant.attrs.transparent {
199 let only_field = &variant.fields[0];
200 if only_field.contains_generic {
201 error_inferred_bounds.insert(only_field.ty, quote!(thiserror::__private::Error));
202 }
203 let member = &only_field.member;
204 let source = quote_spanned! {transparent_attr.span=>
205 thiserror::__private::Error::source(transparent.as_dyn_error())
206 };
207 quote! {
208 #ty::#ident {#member: transparent} => #source,
209 }
210 } else if let Some(source_field) = variant.source_field() {
211 let source = &source_field.member;
212 if source_field.contains_generic {
213 let ty = unoptional_type(source_field.ty);
214 error_inferred_bounds.insert(ty, quote!(thiserror::__private::Error + 'static));
215 }
216 let asref = if type_is_option(source_field.ty) {
217 Some(quote_spanned!(source.member_span()=> .as_ref()?))
218 } else {
219 None
220 };
221 let varsource = quote!(source);
222 let dyn_error = quote_spanned! {source_field.source_span()=>
223 #varsource #asref.as_dyn_error()
224 };
225 quote! {
226 #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
227 }
228 } else {
229 quote! {
230 #ty::#ident {..} => ::core::option::Option::None,
231 }
232 }
233 });
234 Some(quote! {
235 fn source(&self) -> ::core::option::Option<&(dyn thiserror::__private::Error + 'static)> {
236 use thiserror::__private::AsDynError;
237 #[allow(deprecated)]
238 match self {
239 #(#arms)*
240 }
241 }
242 })
243 } else {
244 None
245 };
246
247 let provide_method = if input.has_backtrace() {
248 let request = quote!(request);
249 let arms = input.variants.iter().map(|variant| {
250 let ident = &variant.ident;
251 match (variant.backtrace_field(), variant.source_field()) {
252 (Some(backtrace_field), Some(source_field))
253 if backtrace_field.attrs.backtrace.is_none() =>
254 {
255 let backtrace = &backtrace_field.member;
256 let source = &source_field.member;
257 let varsource = quote!(source);
258 let source_provide = if type_is_option(source_field.ty) {
259 quote_spanned! {source.member_span()=>
260 if let ::core::option::Option::Some(source) = #varsource {
261 source.thiserror_provide(#request);
262 }
263 }
264 } else {
265 quote_spanned! {source.member_span()=>
266 #varsource.thiserror_provide(#request);
267 }
268 };
269 let self_provide = if type_is_option(backtrace_field.ty) {
270 quote! {
271 if let ::core::option::Option::Some(backtrace) = backtrace {
272 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
273 }
274 }
275 } else {
276 quote! {
277 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
278 }
279 };
280 quote! {
281 #ty::#ident {
282 #backtrace: backtrace,
283 #source: #varsource,
284 ..
285 } => {
286 use thiserror::__private::ThiserrorProvide;
287 #source_provide
288 #self_provide
289 }
290 }
291 }
292 (Some(backtrace_field), Some(source_field))
293 if backtrace_field.member == source_field.member =>
294 {
295 let backtrace = &backtrace_field.member;
296 let varsource = quote!(source);
297 let source_provide = if type_is_option(source_field.ty) {
298 quote_spanned! {backtrace.member_span()=>
299 if let ::core::option::Option::Some(source) = #varsource {
300 source.thiserror_provide(#request);
301 }
302 }
303 } else {
304 quote_spanned! {backtrace.member_span()=>
305 #varsource.thiserror_provide(#request);
306 }
307 };
308 quote! {
309 #ty::#ident {#backtrace: #varsource, ..} => {
310 use thiserror::__private::ThiserrorProvide;
311 #source_provide
312 }
313 }
314 }
315 (Some(backtrace_field), _) => {
316 let backtrace = &backtrace_field.member;
317 let body = if type_is_option(backtrace_field.ty) {
318 quote! {
319 if let ::core::option::Option::Some(backtrace) = backtrace {
320 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
321 }
322 }
323 } else {
324 quote! {
325 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
326 }
327 };
328 quote! {
329 #ty::#ident {#backtrace: backtrace, ..} => {
330 #body
331 }
332 }
333 }
334 (None, _) => quote! {
335 #ty::#ident {..} => {}
336 },
337 }
338 });
339 Some(quote! {
340 fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
341 #[allow(deprecated)]
342 match self {
343 #(#arms)*
344 }
345 }
346 })
347 } else {
348 None
349 };
350
351 let display_impl = if input.has_display() {
352 let mut display_inferred_bounds = InferredBounds::new();
353 let has_bonus_display = input.variants.iter().any(|v| {
354 v.attrs
355 .display
356 .as_ref()
357 .map_or(false, |display| display.has_bonus_display)
358 });
359 let use_as_display = use_as_display(has_bonus_display);
360 let void_deref = if input.variants.is_empty() {
361 Some(quote!(*))
362 } else {
363 None
364 };
365 let arms = input.variants.iter().map(|variant| {
366 let mut display_implied_bounds = Set::new();
367 let display = match &variant.attrs.display {
368 Some(display) => {
369 display_implied_bounds = display.implied_bounds.clone();
370 display.to_token_stream()
371 }
372 None => {
373 let only_field = match &variant.fields[0].member {
374 Member::Named(ident) => ident.clone(),
375 Member::Unnamed(index) => format_ident!("_{}", index),
376 };
377 display_implied_bounds.insert((0, Trait::Display));
378 quote!(::core::fmt::Display::fmt(#only_field, __formatter))
379 }
380 };
381 for (field, bound) in display_implied_bounds {
382 let field = &variant.fields[field];
383 if field.contains_generic {
384 display_inferred_bounds.insert(field.ty, bound);
385 }
386 }
387 let ident = &variant.ident;
388 let pat = fields_pat(&variant.fields);
389 quote! {
390 #ty::#ident #pat => #display
391 }
392 });
393 let arms = arms.collect::<Vec<_>>();
394 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
395 Some(quote! {
396 #[allow(unused_qualifications)]
397 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
398 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
399 #use_as_display
400 #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
401 match #void_deref self {
402 #(#arms,)*
403 }
404 }
405 }
406 })
407 } else {
408 None
409 };
410
411 let from_impls = input.variants.iter().filter_map(|variant| {
412 let from_field = variant.from_field()?;
413 let backtrace_field = variant.distinct_backtrace_field();
414 let variant = &variant.ident;
415 let from = unoptional_type(from_field.ty);
416 let body = from_initializer(from_field, backtrace_field);
417 Some(quote! {
418 #[allow(unused_qualifications)]
419 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
420 #[allow(deprecated)]
421 fn from(source: #from) -> Self {
422 #ty::#variant #body
423 }
424 }
425 })
426 });
427
428 let error_trait = spanned_error_trait(input.original);
429 if input.generics.type_params().next().is_some() {
430 let self_token = <Token![Self]>::default();
431 error_inferred_bounds.insert(self_token, Trait::Debug);
432 error_inferred_bounds.insert(self_token, Trait::Display);
433 }
434 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
435
436 quote! {
437 #[allow(unused_qualifications)]
438 impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
439 #source_method
440 #provide_method
441 }
442 #display_impl
443 #(#from_impls)*
444 }
445}
446
447fn fields_pat(fields: &[Field]) -> TokenStream {
448 let mut members = fields.iter().map(|field| &field.member).peekable();
449 match members.peek() {
450 Some(Member::Named(_)) => quote!({ #(#members),* }),
451 Some(Member::Unnamed(_)) => {
452 let vars = members.map(|member| match member {
453 Member::Unnamed(member) => format_ident!("_{}", member),
454 Member::Named(_) => unreachable!(),
455 });
456 quote!((#(#vars),*))
457 }
458 None => quote!({}),
459 }
460}
461
462fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
463 if needs_as_display {
464 Some(quote! {
465 use thiserror::__private::AsDisplay as _;
466 })
467 } else {
468 None
469 }
470}
471
472fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
473 let from_member = &from_field.member;
474 let some_source = if type_is_option(from_field.ty) {
475 quote!(::core::option::Option::Some(source))
476 } else {
477 quote!(source)
478 };
479 let backtrace = backtrace_field.map(|backtrace_field| {
480 let backtrace_member = &backtrace_field.member;
481 if type_is_option(backtrace_field.ty) {
482 quote! {
483 #backtrace_member: ::core::option::Option::Some(std::backtrace::Backtrace::capture()),
484 }
485 } else {
486 quote! {
487 #backtrace_member: ::core::convert::From::from(std::backtrace::Backtrace::capture()),
488 }
489 }
490 });
491 quote!({
492 #from_member: #some_source,
493 #backtrace
494 })
495}
496
497fn type_is_option(ty: &Type) -> bool {
498 type_parameter_of_option(ty).is_some()
499}
500
501fn unoptional_type(ty: &Type) -> TokenStream {
502 let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
503 quote!(#unoptional)
504}
505
506fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
507 let path = match ty {
508 Type::Path(ty) => &ty.path,
509 _ => return None,
510 };
511
512 let last = path.segments.last().unwrap();
513 if last.ident != "Option" {
514 return None;
515 }
516
517 let bracketed = match &last.arguments {
518 PathArguments::AngleBracketed(bracketed) => bracketed,
519 _ => return None,
520 };
521
522 if bracketed.args.len() != 1 {
523 return None;
524 }
525
526 match &bracketed.args[0] {
527 GenericArgument::Type(arg) => Some(arg),
528 _ => None,
529 }
530}
531
532fn spanned_error_trait(input: &DeriveInput) -> TokenStream {
533 let vis_span = match &input.vis {
534 Visibility::Public(vis) => Some(vis.span),
535 Visibility::Restricted(vis) => Some(vis.pub_token.span),
536 Visibility::Inherited => None,
537 };
538 let data_span = match &input.data {
539 Data::Struct(data) => data.struct_token.span,
540 Data::Enum(data) => data.enum_token.span,
541 Data::Union(data) => data.union_token.span,
542 };
543 let first_span = vis_span.unwrap_or(data_span);
544 let last_span = input.ident.span();
545 let path = quote_spanned!(first_span=> thiserror::__private::);
546 let error = quote_spanned!(last_span=> Error);
547 quote!(#path #error)
548}