1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::fallback;
4use crate::generics::InferredBounds;
5use crate::unraw::MemberUnraw;
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::collections::BTreeSet as Set;
9use syn::{DeriveInput, GenericArgument, PathArguments, Result, Token, Type};
10
11pub fn derive(input: &DeriveInput) -> TokenStream {
12 match try_expand(input) {
13 Ok(expanded) => expanded,
14 Err(error) => fallback::expand(input, error),
18 }
19}
20
21fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
22 let input = Input::from_syn(input)?;
23 input.validate()?;
24 Ok(match input {
25 Input::Struct(input) => impl_struct(input),
26 Input::Enum(input) => impl_enum(input),
27 })
28}
29
30fn impl_struct(input: Struct) -> TokenStream {
31 let ty = call_site_ident(&input.ident);
32 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
33 let mut error_inferred_bounds = InferredBounds::new();
34
35 let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
36 let only_field = &input.fields[0];
37 if only_field.contains_generic {
38 error_inferred_bounds.insert(only_field.ty, quote!(::wherror::__private::Error));
39 }
40 let member = &only_field.member;
41 Some(quote_spanned! {transparent_attr.span=>
42 ::wherror::__private::Error::source(self.#member.as_dyn_error())
43 })
44 } else if let Some(source_field) = input.source_field() {
45 let source = &source_field.member;
46 if source_field.contains_generic {
47 let ty = unoptional_type(source_field.ty);
48 error_inferred_bounds.insert(ty, quote!(::wherror::__private::Error + 'static));
49 }
50 let asref = if type_is_option(source_field.ty) {
51 Some(quote_spanned!(source.span()=> .as_ref()?))
52 } else {
53 None
54 };
55 let dyn_error = quote_spanned! {source_field.source_span()=>
56 self.#source #asref.as_dyn_error()
57 };
58 Some(quote! {
59 ::core::option::Option::Some(#dyn_error)
60 })
61 } else {
62 None
63 };
64 let source_method = source_body.map(|body| {
65 quote! {
66 fn source(&self) -> ::core::option::Option<&(dyn ::wherror::__private::Error + 'static)> {
67 use ::wherror::__private::AsDynError as _;
68 #body
69 }
70 }
71 });
72
73 let provide_method = input.backtrace_field().map(|backtrace_field| {
74 let request = quote!(request);
75 let backtrace = &backtrace_field.member;
76 let body = if let Some(source_field) = input.source_field() {
77 let source = &source_field.member;
78 let source_provide = if type_is_option(source_field.ty) {
79 quote_spanned! {source.span()=>
80 if let ::core::option::Option::Some(source) = &self.#source {
81 source.thiserror_provide(#request);
82 }
83 }
84 } else {
85 quote_spanned! {source.span()=>
86 self.#source.thiserror_provide(#request);
87 }
88 };
89 let self_provide = if source == backtrace {
90 None
91 } else if type_is_option(backtrace_field.ty) {
92 Some(quote! {
93 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
94 #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
95 }
96 })
97 } else {
98 Some(quote! {
99 #request.provide_ref::<::wherror::__private::Backtrace>(&self.#backtrace);
100 })
101 };
102 quote! {
103 use ::wherror::__private::ThiserrorProvide as _;
104 #source_provide
105 #self_provide
106 }
107 } else if type_is_option(backtrace_field.ty) {
108 quote! {
109 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
110 #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
111 }
112 }
113 } else {
114 quote! {
115 #request.provide_ref::<::wherror::__private::Backtrace>(&self.#backtrace);
116 }
117 };
118 quote! {
119 fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
120 #body
121 }
122 }
123 });
124
125 let mut display_implied_bounds = Set::new();
126 let display_body = if input.attrs.transparent.is_some() {
127 let only_field = &input.fields[0].member;
128 display_implied_bounds.insert((0, Trait::Display));
129 Some(quote! {
130 ::core::fmt::Display::fmt(&self.#only_field, __formatter)
131 })
132 } else if let Some(display) = &input.attrs.display {
133 display_implied_bounds.clone_from(&display.implied_bounds);
134 let use_as_display = use_as_display(display.has_bonus_display);
135 let pat = fields_pat(&input.fields);
136 Some(quote! {
137 #use_as_display
138 #[allow(unused_variables, deprecated)]
139 let Self #pat = self;
140 #display
141 })
142 } else {
143 None
144 };
145 let display_impl = display_body.map(|body| {
146 let mut display_inferred_bounds = InferredBounds::new();
147 for (field, bound) in display_implied_bounds {
148 let field = &input.fields[field];
149 if field.contains_generic {
150 display_inferred_bounds.insert(field.ty, bound);
151 }
152 }
153 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
154 quote! {
155 #[allow(unused_qualifications)]
156 #[automatically_derived]
157 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
158 #[allow(clippy::used_underscore_binding)]
159 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
160 #body
161 }
162 }
163 }
164 });
165
166 let from_impl = input.from_field().map(|from_field| {
167 let span = from_field.attrs.from.unwrap().span;
168 let backtrace_field = input.distinct_backtrace_field();
169 let from = unoptional_type(from_field.ty);
170 let track_caller = input.location_field().map(|_| quote!(#[track_caller]));
171 let source_var = Ident::new("source", span);
172 let body = from_initializer(
173 from_field,
174 backtrace_field,
175 &source_var,
176 input.location_field(),
177 );
178 let from_function = quote! {
179 #track_caller
180 fn from(#source_var: #from) -> Self {
181 #ty #body
182 }
183 };
184 let from_impl = quote_spanned! {span=>
185 #[automatically_derived]
186 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
187 #from_function
188 }
189 };
190 Some(quote! {
191 #[allow(
192 deprecated,
193 unused_qualifications,
194 clippy::elidable_lifetime_names,
195 clippy::needless_lifetimes,
196 )]
197 #from_impl
198 })
199 });
200
201 let location_impl = input.location_field().map(|location_field| {
202 let location = &location_field.member;
203 let body = if type_is_option(location_field.ty) {
204 quote! {
205 self.#location
206 }
207 } else {
208 quote! {
209 Some(self.#location)
210 }
211 };
212 quote! {
213 #[allow(unused_qualifications)]
214 #[automatically_derived]
215 impl #impl_generics #ty #ty_generics #where_clause {
216 pub fn location(&self) -> Option<&'static ::core::panic::Location<'static>> {
217 #body
218 }
219 }
220 }
221 });
222
223 if input.generics.type_params().next().is_some() {
224 let self_token = <Token![Self]>::default();
225 error_inferred_bounds.insert(self_token, Trait::Debug);
226 error_inferred_bounds.insert(self_token, Trait::Display);
227 }
228 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
229
230 quote! {
231 #[allow(unused_qualifications)]
232 #[automatically_derived]
233 impl #impl_generics ::wherror::__private::Error for #ty #ty_generics #error_where_clause {
234 #source_method
235 #provide_method
236 }
237 #display_impl
238 #from_impl
239 #location_impl
240 }
241}
242
243fn impl_enum(input: Enum) -> TokenStream {
244 let ty = call_site_ident(&input.ident);
245 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
246 let mut error_inferred_bounds = InferredBounds::new();
247
248 let source_method = if input.has_source() {
249 let arms = input.variants.iter().map(|variant| {
250 let ident = &variant.ident;
251 if let Some(transparent_attr) = &variant.attrs.transparent {
252 let only_field = &variant.fields[0];
253 if only_field.contains_generic {
254 error_inferred_bounds.insert(only_field.ty, quote!(::wherror::__private::Error));
255 }
256 let member = &only_field.member;
257 let source = quote_spanned! {transparent_attr.span=>
258 ::wherror::__private::Error::source(transparent.as_dyn_error())
259 };
260 quote! {
261 #ty::#ident {#member: transparent} => #source,
262 }
263 } else if let Some(source_field) = variant.source_field() {
264 let source = &source_field.member;
265 if source_field.contains_generic {
266 let ty = unoptional_type(source_field.ty);
267 error_inferred_bounds.insert(ty, quote!(::wherror::__private::Error + 'static));
268 }
269 let asref = if type_is_option(source_field.ty) {
270 Some(quote_spanned!(source.span()=> .as_ref()?))
271 } else {
272 None
273 };
274 let varsource = quote!(source);
275 let dyn_error = quote_spanned! {source_field.source_span()=>
276 #varsource #asref.as_dyn_error()
277 };
278 quote! {
279 #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
280 }
281 } else {
282 quote! {
283 #ty::#ident {..} => ::core::option::Option::None,
284 }
285 }
286 });
287 Some(quote! {
288 fn source(&self) -> ::core::option::Option<&(dyn ::wherror::__private::Error + 'static)> {
289 use ::wherror::__private::AsDynError as _;
290 #[allow(deprecated)]
291 match self {
292 #(#arms)*
293 }
294 }
295 })
296 } else {
297 None
298 };
299
300 let provide_method = if input.has_backtrace() {
301 let request = quote!(request);
302 let arms = input.variants.iter().map(|variant| {
303 let ident = &variant.ident;
304 match (variant.backtrace_field(), variant.source_field()) {
305 (Some(backtrace_field), Some(source_field))
306 if backtrace_field.attrs.backtrace.is_none() =>
307 {
308 let backtrace = &backtrace_field.member;
309 let source = &source_field.member;
310 let varsource = quote!(source);
311 let source_provide = if type_is_option(source_field.ty) {
312 quote_spanned! {source.span()=>
313 if let ::core::option::Option::Some(source) = #varsource {
314 source.thiserror_provide(#request);
315 }
316 }
317 } else {
318 quote_spanned! {source.span()=>
319 #varsource.thiserror_provide(#request);
320 }
321 };
322 let self_provide = if type_is_option(backtrace_field.ty) {
323 quote! {
324 if let ::core::option::Option::Some(backtrace) = backtrace {
325 #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
326 }
327 }
328 } else {
329 quote! {
330 #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
331 }
332 };
333 quote! {
334 #ty::#ident {
335 #backtrace: backtrace,
336 #source: #varsource,
337 ..
338 } => {
339 use ::wherror::__private::ThiserrorProvide as _;
340 #source_provide
341 #self_provide
342 }
343 }
344 }
345 (Some(backtrace_field), Some(source_field))
346 if backtrace_field.member == source_field.member =>
347 {
348 let backtrace = &backtrace_field.member;
349 let varsource = quote!(source);
350 let source_provide = if type_is_option(source_field.ty) {
351 quote_spanned! {backtrace.span()=>
352 if let ::core::option::Option::Some(source) = #varsource {
353 source.thiserror_provide(#request);
354 }
355 }
356 } else {
357 quote_spanned! {backtrace.span()=>
358 #varsource.thiserror_provide(#request);
359 }
360 };
361 quote! {
362 #ty::#ident {#backtrace: #varsource, ..} => {
363 use ::wherror::__private::ThiserrorProvide as _;
364 #source_provide
365 }
366 }
367 }
368 (Some(backtrace_field), _) => {
369 let backtrace = &backtrace_field.member;
370 let body = if type_is_option(backtrace_field.ty) {
371 quote! {
372 if let ::core::option::Option::Some(backtrace) = backtrace {
373 #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
374 }
375 }
376 } else {
377 quote! {
378 #request.provide_ref::<::wherror::__private::Backtrace>(backtrace);
379 }
380 };
381 quote! {
382 #ty::#ident {#backtrace: backtrace, ..} => {
383 #body
384 }
385 }
386 }
387 (None, _) => quote! {
388 #ty::#ident {..} => {}
389 },
390 }
391 });
392 Some(quote! {
393 fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
394 #[allow(deprecated)]
395 match self {
396 #(#arms)*
397 }
398 }
399 })
400 } else {
401 None
402 };
403
404 let display_impl = if input.has_display() {
405 let mut display_inferred_bounds = InferredBounds::new();
406 let has_bonus_display = input.variants.iter().any(|v| {
407 v.attrs
408 .display
409 .as_ref()
410 .map_or(false, |display| display.has_bonus_display)
411 });
412 let use_as_display = use_as_display(has_bonus_display);
413 let void_deref = if input.variants.is_empty() {
414 Some(quote!(*))
415 } else {
416 None
417 };
418 let arms = input.variants.iter().map(|variant| {
419 let mut display_implied_bounds = Set::new();
420 let display = if let Some(display) = &variant.attrs.display {
421 display_implied_bounds.clone_from(&display.implied_bounds);
422 display.to_token_stream()
423 } else if let Some(fmt) = &variant.attrs.fmt {
424 let fmt_path = &fmt.path;
425 let vars = variant.fields.iter().map(|field| match &field.member {
426 MemberUnraw::Named(ident) => ident.to_local(),
427 MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
428 });
429 quote!(#fmt_path(#(#vars,)* __formatter))
430 } else {
431 let only_field = match &variant.fields[0].member {
432 MemberUnraw::Named(ident) => ident.to_local(),
433 MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
434 };
435 display_implied_bounds.insert((0, Trait::Display));
436 quote!(::core::fmt::Display::fmt(#only_field, __formatter))
437 };
438 for (field, bound) in display_implied_bounds {
439 let field = &variant.fields[field];
440 if field.contains_generic {
441 display_inferred_bounds.insert(field.ty, bound);
442 }
443 }
444 let ident = &variant.ident;
445 let pat = fields_pat(&variant.fields);
446 quote! {
447 #ty::#ident #pat => #display
448 }
449 });
450 let arms = arms.collect::<Vec<_>>();
451 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
452 Some(quote! {
453 #[allow(unused_qualifications)]
454 #[automatically_derived]
455 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
456 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
457 #use_as_display
458 #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
459 match #void_deref self {
460 #(#arms,)*
461 }
462 }
463 }
464 })
465 } else {
466 None
467 };
468
469 let from_impls = input.variants.iter().filter_map(|variant| {
470 let from_field = variant.from_field()?;
471 let span = from_field.attrs.from.unwrap().span;
472 let backtrace_field = variant.distinct_backtrace_field();
473 let location_field = variant.location_field();
474 let variant = &variant.ident;
475 let from = unoptional_type(from_field.ty);
476 let source_var = Ident::new("source", span);
477 let body = from_initializer(from_field, backtrace_field, &source_var, location_field);
478 let track_caller = location_field.map(|_| quote!(#[track_caller]));
479 let from_function = quote! {
480 #track_caller
481 fn from(#source_var: #from) -> Self {
482 #ty::#variant #body
483 }
484 };
485 let from_impl = quote_spanned! {span=>
486 #[automatically_derived]
487 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
488 #from_function
489 }
490 };
491 Some(quote! {
492 #[allow(
493 deprecated,
494 unused_qualifications,
495 clippy::elidable_lifetime_names,
496 clippy::needless_lifetimes,
497 )]
498 #from_impl
499 })
500 });
501
502 let location_impl = if input.has_location() {
503 let arms = input.variants.iter().map(|variant| {
504 let ident = &variant.ident;
505 if let Some(location_field) = variant.location_field() {
506 let location = &location_field.member;
507 let var_location = quote!(location);
508 let body = if type_is_option(location_field.ty) {
509 quote! {
510 #var_location
511 }
512 } else {
513 quote! {
514 Some(#var_location)
515 }
516 };
517 quote! {
518 #ty::#ident {#location: #var_location, ..} => #body,
519 }
520 } else {
521 quote! {
522 #ty::#ident {..} => None,
523 }
524 }
525 });
526 Some(quote! {
527 #[allow(unused_qualifications)]
528 #[automatically_derived]
529 impl #impl_generics #ty #ty_generics #where_clause {
530 pub fn location(&self) -> Option<&'static ::core::panic::Location<'static>> {
531 #[allow(deprecated)]
532 match self {
533 #(#arms)*
534 }
535 }
536 }
537 })
538 } else {
539 None
540 };
541
542 if input.generics.type_params().next().is_some() {
543 let self_token = <Token![Self]>::default();
544 error_inferred_bounds.insert(self_token, Trait::Debug);
545 error_inferred_bounds.insert(self_token, Trait::Display);
546 }
547 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
548
549 quote! {
550 #[allow(unused_qualifications)]
551 #[automatically_derived]
552 impl #impl_generics ::wherror::__private::Error for #ty #ty_generics #error_where_clause {
553 #source_method
554 #provide_method
555 }
556 #display_impl
557 #(#from_impls)*
558 #location_impl
559 }
560}
561
562pub(crate) fn call_site_ident(ident: &Ident) -> Ident {
565 let mut ident = ident.clone();
566 ident.set_span(ident.span().resolved_at(Span::call_site()));
567 ident
568}
569
570fn fields_pat(fields: &[Field]) -> TokenStream {
571 let mut members = fields.iter().map(|field| &field.member).peekable();
572 match members.peek() {
573 Some(MemberUnraw::Named(_)) => quote!({ #(#members),* }),
574 Some(MemberUnraw::Unnamed(_)) => {
575 let vars = members.map(|member| match member {
576 MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
577 MemberUnraw::Named(_) => unreachable!(),
578 });
579 quote!((#(#vars),*))
580 }
581 None => quote!({}),
582 }
583}
584
585fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
586 if needs_as_display {
587 Some(quote! {
588 use ::wherror::__private::AsDisplay as _;
589 })
590 } else {
591 None
592 }
593}
594
595fn from_initializer(
596 from_field: &Field,
597 backtrace_field: Option<&Field>,
598 source_var: &Ident,
599 location_field: Option<&Field>,
600) -> TokenStream {
601 let from_member = &from_field.member;
602 let some_source = if type_is_option(from_field.ty) {
603 quote!(::core::option::Option::Some(#source_var))
604 } else {
605 quote!(#source_var)
606 };
607 let backtrace = backtrace_field.map(|backtrace_field| {
608 let backtrace_member = &backtrace_field.member;
609 if type_is_option(backtrace_field.ty) {
610 quote! {
611 #backtrace_member: ::core::option::Option::Some(::wherror::__private::Backtrace::capture()),
612 }
613 } else {
614 quote! {
615 #backtrace_member: ::core::convert::From::from(::wherror::__private::Backtrace::capture()),
616 }
617 }
618 });
619 let location = location_field.map(|location_field| {
620 let location_member = &location_field.member;
621
622 if type_is_option(location_field.ty) {
623 quote! {
624 #location_member: ::core::option::Option::Some(::core::panic::Location::caller()),
625 }
626 } else {
627 quote! {
628 #location_member: ::core::convert::From::from(::core::panic::Location::caller()),
629 }
630 }
631 });
632 quote!({
633 #from_member: #some_source,
634 #backtrace
635 #location
636 })
637}
638
639fn type_is_option(ty: &Type) -> bool {
640 type_parameter_of_option(ty).is_some()
641}
642
643fn unoptional_type(ty: &Type) -> TokenStream {
644 let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
645 quote!(#unoptional)
646}
647
648fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
649 let path = match ty {
650 Type::Path(ty) => &ty.path,
651 _ => return None,
652 };
653
654 let last = path.segments.last().unwrap();
655 if last.ident != "Option" {
656 return None;
657 }
658
659 let bracketed = match &last.arguments {
660 PathArguments::AngleBracketed(bracketed) => bracketed,
661 _ => return None,
662 };
663
664 if bracketed.args.len() != 1 {
665 return None;
666 }
667
668 match &bracketed.args[0] {
669 GenericArgument::Type(arg) => Some(arg),
670 _ => None,
671 }
672}