Skip to main content

union_error_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields};
4
5#[proc_macro_derive(ErrorUnion)]
6pub fn derive_union_error(input: TokenStream) -> TokenStream {
7    // Parse the incoming Rust item as a `syn::DeriveInput`.
8    //
9    // This is the standard entry type for derive macros and contains:
10    // - the item name
11    // - visibility
12    // - attributes
13    // - generics
14    // - and the actual item data (enum / struct / union)
15    let input = parse_macro_input!(input as DeriveInput);
16
17    // Save the enum name, e.g. `AppError`.
18    let enum_name = input.ident;
19
20    // We only support enums.
21    //
22    // If the user writes `#[derive(ErrorUnion)]` on a struct or union,
23    // emit a compile error.
24    let data = match input.data {
25        Data::Enum(e) => e,
26        _ => {
27            return syn::Error::new_spanned(enum_name, "ErrorUnion only supports enums")
28                .to_compile_error()
29                .into();
30        }
31    };
32
33    // These hold the generated pieces that will later be assembled into:
34    //
35    // - `impl From<T> for Enum`
36    // - `impl Display for Enum`
37    // - `impl Error for Enum`
38    let mut from_impls = Vec::new();
39    let mut display_arms = Vec::new();
40    let mut source_arms = Vec::new();
41
42    // Process each enum variant independently.
43    for variant in data.variants {
44        let variant_name = variant.ident;
45
46        // Require tuple variants with exactly one field:
47        //
48        // Good:
49        //   Parse(Located<ParseIntError>)
50        //
51        // Rejected:
52        //   Parse
53        //   Parse { source: ... }
54        //   Parse(A, B)
55        let field_ty = match variant.fields {
56            Fields::Unnamed(f) if f.unnamed.len() == 1 => f.unnamed.first().unwrap().ty.clone(),
57            _ => {
58                return syn::Error::new_spanned(
59                    variant_name,
60                    "Each variant must have exactly one unnamed field",
61                )
62                .to_compile_error()
63                .into();
64            }
65        };
66
67        // Try to extract the inner type from `Located<T>`.
68        //
69        // Example:
70        //   field_ty = Located<std::io::Error>
71        //   inner_ty = std::io::Error
72        //
73        // If the field is not `Located<T>`, we fall back to using the field type
74        // directly. That keeps the macro behavior somewhat flexible, though the
75        // intended design is to use `Located<T>` for all variants.
76        let inner_ty = extract_inner_type(&field_ty).unwrap_or(field_ty.clone());
77
78        // Generate:
79        //
80        // impl From<T> for Enum {
81        //     #[track_caller]
82        //     fn from(source: T) -> Self {
83        //         Self::Variant(union_error::Located::new(source))
84        //     }
85        // }
86        //
87        // This is the critical conversion used by `?`.
88        from_impls.push(quote! {
89            impl From<#inner_ty> for #enum_name {
90                #[track_caller]
91                fn from(source: #inner_ty) -> Self {
92                    Self::#variant_name(union_error::Located::new(source))
93                }
94            }
95        });
96
97        // Generate a `Display` match arm that delegates to the stored inner value.
98        //
99        // Example:
100        //   Self::Parse(inner) => Display::fmt(inner, f)
101        //
102        // Since `Located<T>` implements `Display`, this prints both:
103        // - the inner source error message
104        // - the stored source location
105        display_arms.push(quote! {
106            Self::#variant_name(inner) => std::fmt::Display::fmt(inner, f),
107        });
108
109        // Generate an `Error::source()` match arm.
110        //
111        // Example:
112        //   Self::Parse(inner) => Some(inner as &(dyn Error + 'static))
113        //
114        // Since `Located<T>` also implements `Error`, the chain becomes:
115        //   AppError -> Located<T> -> T
116        source_arms.push(quote! {
117            Self::#variant_name(inner) => Some(inner as &(dyn std::error::Error + 'static)),
118        });
119    }
120
121    // Assemble the final generated impls.
122    let expanded = quote! {
123        impl std::fmt::Display for #enum_name {
124            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125                match self {
126                    #(#display_arms)*
127                }
128            }
129        }
130
131        impl std::error::Error for #enum_name {
132            fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
133                match self {
134                    #(#source_arms)*
135                }
136            }
137        }
138
139        #(#from_impls)*
140    };
141
142    expanded.into()
143}
144
145/// Extract the inner type `T` from `Located<T>`.
146///
147/// # Example
148///
149/// Input:
150///
151/// ```ignore
152/// Located<std::io::Error>
153/// ```
154///
155/// Output:
156///
157/// ```ignore
158/// Some(std::io::Error)
159/// ```
160///
161/// If the input is not a path type ending in `Located<...>`, this returns `None`.
162///
163/// # Why this helper exists
164///
165/// The derive macro wants to generate:
166///
167/// ```ignore
168/// impl From<T> for AppError
169/// ```
170///
171/// not:
172///
173/// ```ignore
174/// impl From<Located<T>> for AppError
175/// ```
176///
177/// because the `?` operator naturally converts from the original leaf error type.
178fn extract_inner_type(ty: &syn::Type) -> Option<syn::Type> {
179    // We only care about path types like:
180    // - Located<T>
181    // - union_error::Located<T>
182    if let syn::Type::Path(type_path) = ty {
183        // Look at the last path segment.
184        //
185        // Examples:
186        // - Located<T>                  -> last segment = Located
187        // - union_error::Located<T>     -> last segment = Located
188        let seg = type_path.path.segments.last()?;
189
190        // Only match types whose last path segment is literally `Located`.
191        if seg.ident == "Located" {
192            // Require angle-bracket generic arguments: `Located<T>`
193            if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
194                // Take the first generic argument if it is a type.
195                if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
196                    return Some(inner.clone());
197                }
198            }
199        }
200    }
201
202    None
203}