Skip to main content

union_error_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use std::collections::{BTreeMap, BTreeSet};
4use std::fs;
5use std::path::{Path, PathBuf};
6use syn::{
7    parse_macro_input, parse_quote, Data, DataEnum, DeriveInput, Fields, GenericArgument, Ident,
8    Item, ItemEnum, PathArguments, Type, TypePath,
9};
10
11#[proc_macro_derive(ErrorUnion)]
12pub fn derive_union_error(input: TokenStream) -> TokenStream {
13    let input = parse_macro_input!(input as DeriveInput);
14    expand_error_union_enum(input).into()
15}
16
17#[proc_macro_attribute]
18pub fn located_error(_attr: TokenStream, item: TokenStream) -> TokenStream {
19    // `#[located_error]` is applied to module-local enums like:
20    //
21    //   #[located_error]
22    //   enum LocalErrors { Parse(ParseIntError), ... }
23    //
24    // It rewrites each single-field tuple variant from `T` to
25    // `::union_error::Located<T>`, then generates:
26    // - `From<T> for LocalErrors` with `#[track_caller]`
27    // - `Display` and `Error`
28    //
29    // The local enum remains local; this macro does NOT create the app-wide union.
30    let mut item_enum = parse_macro_input!(item as ItemEnum);
31
32    let mut seen_leaf_types = BTreeSet::new();
33    let mut leaves = Vec::new();
34
35    for variant in &mut item_enum.variants {
36        // Validate variant shape and extract the source leaf error type.
37        let original_ty = match &variant.fields {
38            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => fields.unnamed[0].ty.clone(),
39            Fields::Named(_) => {
40                return syn::Error::new_spanned(
41                    &variant.ident,
42                    "located_error variants must use a single unnamed field",
43                )
44                .to_compile_error()
45                .into();
46            }
47            Fields::Unit => {
48                return syn::Error::new_spanned(
49                    &variant.ident,
50                    "located_error variants must use a single unnamed field",
51                )
52                .to_compile_error()
53                .into();
54            }
55            Fields::Unnamed(fields) => {
56                return syn::Error::new_spanned(
57                    fields,
58                    "located_error variants must use exactly one field",
59                )
60                .to_compile_error()
61                .into();
62            }
63        };
64
65        // Support both `T` and `Located<T>` inputs, but normalize to leaf `T`.
66        let leaf_ty = normalized_leaf_type(&original_ty);
67        let leaf_key = type_key(&leaf_ty);
68        if !seen_leaf_types.insert(leaf_key.clone()) {
69            return syn::Error::new_spanned(
70                &variant.ident,
71                "duplicate leaf error type in this located_error enum",
72            )
73            .to_compile_error()
74            .into();
75        }
76
77        // Rewrite field type in-place so users do not have to write `Located<T>`.
78        if let Fields::Unnamed(fields) = &mut variant.fields {
79            fields.unnamed[0].ty = parse_quote!(::union_error::Located<#leaf_ty>);
80        }
81
82        leaves.push((variant.ident.clone(), leaf_ty));
83    }
84
85    let enum_ident = &item_enum.ident;
86    // Leaf conversion used by `?` inside module functions.
87    let from_impls = leaves.iter().map(|(variant, ty)| {
88        quote! {
89            impl ::core::convert::From<#ty> for #enum_ident {
90                #[track_caller]
91                fn from(source: #ty) -> Self {
92                    Self::#variant(::union_error::Located::new(source))
93                }
94            }
95        }
96    });
97
98    // `Display` delegates to `Located<T>` formatting.
99    let display_arms = leaves.iter().map(|(variant, _)| display_arm(variant));
100
101    // `Error::source` exposes the wrapped `Located<T>` and then `T`.
102    let source_arms = leaves.iter().map(|(variant, _)| source_arm(variant));
103
104    // Internal metadata scaffold (hidden API) produced for each local enum.
105    let metadata_entries = leaves.iter().map(|(variant, ty)| {
106        let variant_name = variant.to_string();
107        let leaf_name = quote!(#ty).to_string();
108        quote! {
109            ::union_error::__private::LeafSpec {
110                variant_name: #variant_name,
111                leaf_type_name: #leaf_name,
112            }
113        }
114    });
115
116    let expanded = quote! {
117        #item_enum
118
119        impl ::core::fmt::Display for #enum_ident {
120            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
121                match self {
122                    #(#display_arms)*
123                }
124            }
125        }
126
127        impl ::std::error::Error for #enum_ident {
128            fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
129                match self {
130                    #(#source_arms)*
131                }
132            }
133        }
134
135        #(#from_impls)*
136
137        impl ::union_error::__private::LocatedErrorMetadata for #enum_ident {
138            const LEAVES: &'static [::union_error::__private::LeafSpec] = &[
139                #(#metadata_entries),*
140            ];
141        }
142    };
143
144    expanded.into()
145}
146
147#[proc_macro_attribute]
148pub fn error_union(_attr: TokenStream, item: TokenStream) -> TokenStream {
149    // `#[error_union]` is applied only to the app/root enum in `error.rs`.
150    // It auto-flattens listed local enums into one top-level enum with
151    // direct leaf variants (`Parse`, `Io`, ...).
152    let input = parse_macro_input!(item as DeriveInput);
153    expand_error_union_enum(input).into()
154}
155
156fn expand_error_union_enum(input: DeriveInput) -> proc_macro2::TokenStream {
157    // Parse/validate outer enum and then resolve flattened leaves.
158    let enum_name = input.ident;
159    let data = match input.data {
160        Data::Enum(e) => e,
161        _ => {
162            return syn::Error::new_spanned(enum_name, "error_union only supports enums")
163                .to_compile_error();
164        }
165    };
166
167    if let Some(where_clause) = input.generics.where_clause {
168        return syn::Error::new_spanned(where_clause, "error_union does not support generics")
169            .to_compile_error();
170    }
171
172    let attrs = input.attrs;
173    let vis = input.vis;
174
175    match resolve_union_leaves(&data) {
176        Ok(leaves) => build_union_tokens(attrs, vis, enum_name, leaves),
177        Err(err) => err.to_compile_error(),
178    }
179}
180
181#[derive(Clone)]
182struct Leaf {
183    variant_ident: Ident,
184    leaf_ty: Type,
185    local_enum_ty: Type,
186    local_variant_ident: Ident,
187}
188
189fn build_union_tokens(
190    attrs: Vec<syn::Attribute>,
191    vis: syn::Visibility,
192    enum_name: Ident,
193    leaves: Vec<Leaf>,
194) -> proc_macro2::TokenStream {
195    // The generated union is flat:
196    // AppError::{LeafVariant}(Located<LeafType>)
197    //
198    // No intermediate module wrappers at runtime.
199    let union_variants = leaves.iter().map(|leaf| {
200        let v = &leaf.variant_ident;
201        let ty = &leaf.leaf_ty;
202        quote! { #v(::union_error::Located<#ty>) }
203    });
204
205    let display_arms = leaves.iter().map(|leaf| display_arm(&leaf.variant_ident));
206
207    let source_arms = leaves.iter().map(|leaf| source_arm(&leaf.variant_ident));
208
209    // Primary conversion path for `?` from leaf errors into AppError.
210    let from_leaf_impls = leaves.iter().map(|leaf| {
211        let v = &leaf.variant_ident;
212        let ty = &leaf.leaf_ty;
213        quote! {
214            impl ::core::convert::From<#ty> for #enum_name {
215                #[track_caller]
216                fn from(source: #ty) -> Self {
217                    Self::#v(::union_error::Located::new(source))
218                }
219            }
220        }
221    });
222
223    // Compatibility conversion from each local module enum into AppError.
224    // Group leaves by local enum so we emit one impl per enum.
225    let mut by_local_enum = BTreeMap::<String, (Type, Vec<(Ident, Ident)>)>::new();
226    for leaf in &leaves {
227        let (_, variants) = by_local_enum
228            .entry(type_key(&leaf.local_enum_ty))
229            .or_insert_with(|| (leaf.local_enum_ty.clone(), Vec::new()));
230        variants.push((leaf.local_variant_ident.clone(), leaf.variant_ident.clone()));
231    }
232
233    let from_local_impls = by_local_enum
234        .into_values()
235        .map(|(local_enum_ty, variants)| {
236            let arms = variants.iter().map(|(local_variant, union_variant)| {
237                quote! { #local_enum_ty::#local_variant(inner) => Self::#union_variant(inner), }
238            });
239            quote! {
240                impl ::core::convert::From<#local_enum_ty> for #enum_name {
241                    fn from(source: #local_enum_ty) -> Self {
242                        match source {
243                            #(#arms)*
244                        }
245                    }
246                }
247            }
248        });
249
250    quote! {
251        #(#attrs)*
252        #vis enum #enum_name {
253            #(#union_variants),*
254        }
255
256        impl ::core::fmt::Display for #enum_name {
257            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
258                match self {
259                    #(#display_arms)*
260                }
261            }
262        }
263
264        impl ::std::error::Error for #enum_name {
265            fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
266                match self {
267                    #(#source_arms)*
268                }
269            }
270        }
271
272        #(#from_leaf_impls)*
273        #(#from_local_impls)*
274    }
275}
276
277fn resolve_union_leaves(data: &DataEnum) -> syn::Result<Vec<Leaf>> {
278    // Resolve each listed `crate::module::LocalErrors` and extract leaf variants.
279    //
280    // NOTE: current implementation reads/parses module source files to discover
281    // enum variants. This is why we resolve the module path to a file below.
282    let mut leaves = Vec::new();
283    let mut by_leaf_type = BTreeMap::<String, proc_macro2::Span>::new();
284    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
285
286    for variant in &data.variants {
287        let local_enum_ty = single_field_type(variant)?;
288        let local_enum_path = extract_type_path(local_enum_ty)?;
289        let local_enum_name = local_enum_path
290            .path
291            .segments
292            .last()
293            .map(|s| s.ident.to_string())
294            .ok_or_else(|| {
295                syn::Error::new_spanned(local_enum_ty, "invalid local enum type path")
296            })?;
297
298        let module_path = module_path_for_local_enum(&local_enum_path.path, &local_enum_name)?;
299        let module_file =
300            find_module_file(Path::new(&manifest_dir), &module_path).ok_or_else(|| {
301                syn::Error::new_spanned(
302                    local_enum_ty,
303                    format!(
304                        "could not find module source for `{}` at `src/{}`",
305                        local_enum_name,
306                        module_path.join("/")
307                    ),
308                )
309            })?;
310
311        // Parse local enum source and flatten each local leaf into the union.
312        let local_enum = parse_local_enum(&module_file, &local_enum_name)?;
313
314        for local_variant in local_enum.variants {
315            let leaf_ty = normalized_leaf_type(single_field_type(&local_variant)?);
316            let key = type_key(&leaf_ty);
317            // Enforce one unique leaf type across the entire union.
318            if let Some(_prev) = by_leaf_type.insert(key.clone(), local_variant.ident.span()) {
319                return Err(syn::Error::new_spanned(
320                    &local_variant.ident,
321                    format!(
322                        "duplicate leaf error type across unioned local enums: `{}`",
323                        key
324                    ),
325                ));
326            }
327
328            leaves.push(Leaf {
329                variant_ident: local_variant.ident.clone(),
330                leaf_ty,
331                local_enum_ty: local_enum_ty.clone(),
332                local_variant_ident: local_variant.ident,
333            });
334        }
335    }
336
337    Ok(leaves)
338}
339
340fn parse_local_enum(path: &Path, enum_name: &str) -> syn::Result<ItemEnum> {
341    // Lightweight source loader used by `#[error_union]` flattening.
342    let content = fs::read_to_string(path).map_err(|e| {
343        syn::Error::new(
344            proc_macro2::Span::call_site(),
345            format!("failed reading {}: {}", path.display(), e),
346        )
347    })?;
348    let file = syn::parse_file(&content)?;
349    for item in file.items {
350        if let Item::Enum(item_enum) = item {
351            if item_enum.ident == enum_name {
352                return Ok(item_enum);
353            }
354        }
355    }
356
357    Err(syn::Error::new(
358        proc_macro2::Span::call_site(),
359        format!("could not find enum `{}` in {}", enum_name, path.display()),
360    ))
361}
362
363fn module_path_for_local_enum(path: &syn::Path, enum_name: &str) -> syn::Result<Vec<String>> {
364    // Convert type path like `crate::file1::LocalErrors` -> ["file1"].
365    let mut segments: Vec<String> = path.segments.iter().map(|s| s.ident.to_string()).collect();
366    if segments.last().map(|s| s.as_str()) != Some(enum_name) {
367        return Err(syn::Error::new_spanned(
368            path,
369            "union variant must reference a local enum type",
370        ));
371    }
372    segments.pop();
373    if segments.first().map(|s| s.as_str()) == Some("crate") {
374        segments.remove(0);
375    }
376    if segments.is_empty() {
377        return Err(syn::Error::new_spanned(
378            path,
379            "could not resolve module path",
380        ));
381    }
382    Ok(segments)
383}
384
385fn find_module_file(manifest_dir: &Path, module_path: &[String]) -> Option<PathBuf> {
386    // Search both conventional crate src and ad-hoc root locations.
387    for base in [manifest_dir.join("src"), manifest_dir.to_path_buf()] {
388        let mut p = base.clone();
389        for seg in module_path {
390            p.push(seg);
391        }
392
393        let flat = p.with_extension("rs");
394        if flat.exists() {
395            return Some(flat);
396        }
397
398        let nested = p.join("mod.rs");
399        if nested.exists() {
400            return Some(nested);
401        }
402    }
403
404    None
405}
406
407fn single_field_type(variant: &syn::Variant) -> syn::Result<&Type> {
408    // Shared validator: all supported enum variants are exactly one tuple field.
409    match &variant.fields {
410        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => Ok(&fields.unnamed[0].ty),
411        _ => Err(syn::Error::new_spanned(
412            &variant.ident,
413            "each variant must have exactly one unnamed field",
414        )),
415    }
416}
417
418fn extract_type_path(ty: &Type) -> syn::Result<&TypePath> {
419    // Shared validator: union variant field must be a path type (`crate::x::Y`).
420    match ty {
421        Type::Path(path) => Ok(path),
422        _ => Err(syn::Error::new_spanned(
423            ty,
424            "variant field must be a path type",
425        )),
426    }
427}
428
429fn display_arm(variant: &Ident) -> proc_macro2::TokenStream {
430    quote! { Self::#variant(inner) => ::core::fmt::Display::fmt(inner, f), }
431}
432
433fn source_arm(variant: &Ident) -> proc_macro2::TokenStream {
434    quote! { Self::#variant(inner) => Some(inner as &(dyn ::std::error::Error + 'static)), }
435}
436
437fn normalized_leaf_type(ty: &Type) -> Type {
438    unwrap_located(ty).unwrap_or_else(|| ty.clone())
439}
440
441fn unwrap_located(ty: &Type) -> Option<Type> {
442    // Helper to turn `Located<T>` into `T` when needed.
443    let Type::Path(path) = ty else {
444        return None;
445    };
446    let segment = path.path.segments.last()?;
447    if segment.ident != "Located" {
448        return None;
449    }
450    let PathArguments::AngleBracketed(args) = &segment.arguments else {
451        return None;
452    };
453    let GenericArgument::Type(inner) = args.args.first()? else {
454        return None;
455    };
456    Some(inner.clone())
457}
458
459fn type_key(ty: &Type) -> String {
460    // Canonicalized type string used for duplicate detection.
461    quote!(#ty).to_string().replace(' ', "")
462}