union_error_derive/lib.rs
1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields};
4
5#[proc_macro_derive(ErrorUnion)]
6pub fn derive_union_error(input: TokenStream) -> TokenStream {
7 // Parse the incoming Rust item as a `syn::DeriveInput`.
8 //
9 // This is the standard entry type for derive macros and contains:
10 // - the item name
11 // - visibility
12 // - attributes
13 // - generics
14 // - and the actual item data (enum / struct / union)
15 let input = parse_macro_input!(input as DeriveInput);
16
17 // Save the enum name, e.g. `AppError`.
18 let enum_name = input.ident;
19
20 // We only support enums.
21 //
22 // If the user writes `#[derive(ErrorUnion)]` on a struct or union,
23 // emit a compile error.
24 let data = match input.data {
25 Data::Enum(e) => e,
26 _ => {
27 return syn::Error::new_spanned(enum_name, "ErrorUnion only supports enums")
28 .to_compile_error()
29 .into();
30 }
31 };
32
33 // These hold the generated pieces that will later be assembled into:
34 //
35 // - `impl From<T> for Enum`
36 // - `impl Display for Enum`
37 // - `impl Error for Enum`
38 let mut from_impls = Vec::new();
39 let mut display_arms = Vec::new();
40 let mut source_arms = Vec::new();
41
42 // Process each enum variant independently.
43 for variant in data.variants {
44 let variant_name = variant.ident;
45
46 // Require tuple variants with exactly one field:
47 //
48 // Good:
49 // Parse(Located<ParseIntError>)
50 //
51 // Rejected:
52 // Parse
53 // Parse { source: ... }
54 // Parse(A, B)
55 let field_ty = match variant.fields {
56 Fields::Unnamed(f) if f.unnamed.len() == 1 => f.unnamed.first().unwrap().ty.clone(),
57 _ => {
58 return syn::Error::new_spanned(
59 variant_name,
60 "Each variant must have exactly one unnamed field",
61 )
62 .to_compile_error()
63 .into();
64 }
65 };
66
67 // Try to extract the inner type from `Located<T>`.
68 //
69 // Example:
70 // field_ty = Located<std::io::Error>
71 // inner_ty = std::io::Error
72 //
73 // If the field is not `Located<T>`, we fall back to using the field type
74 // directly. That keeps the macro behavior somewhat flexible, though the
75 // intended design is to use `Located<T>` for all variants.
76 let inner_ty = extract_inner_type(&field_ty).unwrap_or(field_ty.clone());
77
78 // Generate:
79 //
80 // impl From<T> for Enum {
81 // #[track_caller]
82 // fn from(source: T) -> Self {
83 // Self::Variant(union_error::Located::new(source))
84 // }
85 // }
86 //
87 // This is the critical conversion used by `?`.
88 from_impls.push(quote! {
89 impl From<#inner_ty> for #enum_name {
90 #[track_caller]
91 fn from(source: #inner_ty) -> Self {
92 Self::#variant_name(union_error::Located::new(source))
93 }
94 }
95 });
96
97 // Generate a `Display` match arm that delegates to the stored inner value.
98 //
99 // Example:
100 // Self::Parse(inner) => Display::fmt(inner, f)
101 //
102 // Since `Located<T>` implements `Display`, this prints both:
103 // - the inner source error message
104 // - the stored source location
105 display_arms.push(quote! {
106 Self::#variant_name(inner) => std::fmt::Display::fmt(inner, f),
107 });
108
109 // Generate an `Error::source()` match arm.
110 //
111 // Example:
112 // Self::Parse(inner) => Some(inner as &(dyn Error + 'static))
113 //
114 // Since `Located<T>` also implements `Error`, the chain becomes:
115 // AppError -> Located<T> -> T
116 source_arms.push(quote! {
117 Self::#variant_name(inner) => Some(inner as &(dyn std::error::Error + 'static)),
118 });
119 }
120
121 // Assemble the final generated impls.
122 let expanded = quote! {
123 impl std::fmt::Display for #enum_name {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 match self {
126 #(#display_arms)*
127 }
128 }
129 }
130
131 impl std::error::Error for #enum_name {
132 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
133 match self {
134 #(#source_arms)*
135 }
136 }
137 }
138
139 #(#from_impls)*
140 };
141
142 expanded.into()
143}
144
145/// Extract the inner type `T` from `Located<T>`.
146///
147/// # Example
148///
149/// Input:
150///
151/// ```ignore
152/// Located<std::io::Error>
153/// ```
154///
155/// Output:
156///
157/// ```ignore
158/// Some(std::io::Error)
159/// ```
160///
161/// If the input is not a path type ending in `Located<...>`, this returns `None`.
162///
163/// # Why this helper exists
164///
165/// The derive macro wants to generate:
166///
167/// ```ignore
168/// impl From<T> for AppError
169/// ```
170///
171/// not:
172///
173/// ```ignore
174/// impl From<Located<T>> for AppError
175/// ```
176///
177/// because the `?` operator naturally converts from the original leaf error type.
178fn extract_inner_type(ty: &syn::Type) -> Option<syn::Type> {
179 // We only care about path types like:
180 // - Located<T>
181 // - union_error::Located<T>
182 if let syn::Type::Path(type_path) = ty {
183 // Look at the last path segment.
184 //
185 // Examples:
186 // - Located<T> -> last segment = Located
187 // - union_error::Located<T> -> last segment = Located
188 let seg = type_path.path.segments.last()?;
189
190 // Only match types whose last path segment is literally `Located`.
191 if seg.ident == "Located" {
192 // Require angle-bracket generic arguments: `Located<T>`
193 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
194 // Take the first generic argument if it is a type.
195 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
196 return Some(inner.clone());
197 }
198 }
199 }
200 }
201
202 None
203}