1use proc_macro::TokenStream;
2use quote::quote;
3use std::collections::{BTreeMap, BTreeSet};
4use std::fs;
5use std::path::{Path, PathBuf};
6use syn::{
7 parse_macro_input, parse_quote, Data, DataEnum, DeriveInput, Fields, GenericArgument, Ident,
8 Item, ItemEnum, PathArguments, Type, TypePath,
9};
10
11#[proc_macro_derive(ErrorUnion)]
12pub fn derive_union_error(input: TokenStream) -> TokenStream {
13 let input = parse_macro_input!(input as DeriveInput);
14 expand_error_union_enum(input).into()
15}
16
17#[proc_macro_attribute]
18pub fn located_error(_attr: TokenStream, item: TokenStream) -> TokenStream {
19 let mut item_enum = parse_macro_input!(item as ItemEnum);
31
32 let mut seen_leaf_types = BTreeSet::new();
33 let mut leaves = Vec::new();
34
35 for variant in &mut item_enum.variants {
36 let original_ty = match &variant.fields {
38 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => fields.unnamed[0].ty.clone(),
39 Fields::Named(_) => {
40 return syn::Error::new_spanned(
41 &variant.ident,
42 "located_error variants must use a single unnamed field",
43 )
44 .to_compile_error()
45 .into();
46 }
47 Fields::Unit => {
48 return syn::Error::new_spanned(
49 &variant.ident,
50 "located_error variants must use a single unnamed field",
51 )
52 .to_compile_error()
53 .into();
54 }
55 Fields::Unnamed(fields) => {
56 return syn::Error::new_spanned(
57 fields,
58 "located_error variants must use exactly one field",
59 )
60 .to_compile_error()
61 .into();
62 }
63 };
64
65 let leaf_ty = normalized_leaf_type(&original_ty);
67 let leaf_key = type_key(&leaf_ty);
68 if !seen_leaf_types.insert(leaf_key.clone()) {
69 return syn::Error::new_spanned(
70 &variant.ident,
71 "duplicate leaf error type in this located_error enum",
72 )
73 .to_compile_error()
74 .into();
75 }
76
77 if let Fields::Unnamed(fields) = &mut variant.fields {
79 fields.unnamed[0].ty = parse_quote!(::union_error::Located<#leaf_ty>);
80 }
81
82 leaves.push((variant.ident.clone(), leaf_ty));
83 }
84
85 let enum_ident = &item_enum.ident;
86 let from_impls = leaves.iter().map(|(variant, ty)| {
88 quote! {
89 impl ::core::convert::From<#ty> for #enum_ident {
90 #[track_caller]
91 fn from(source: #ty) -> Self {
92 Self::#variant(::union_error::Located::new(source))
93 }
94 }
95 }
96 });
97
98 let display_arms = leaves.iter().map(|(variant, _)| display_arm(variant));
100
101 let source_arms = leaves.iter().map(|(variant, _)| source_arm(variant));
103
104 let metadata_entries = leaves.iter().map(|(variant, ty)| {
106 let variant_name = variant.to_string();
107 let leaf_name = quote!(#ty).to_string();
108 quote! {
109 ::union_error::__private::LeafSpec {
110 variant_name: #variant_name,
111 leaf_type_name: #leaf_name,
112 }
113 }
114 });
115
116 let expanded = quote! {
117 #item_enum
118
119 impl ::core::fmt::Display for #enum_ident {
120 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
121 match self {
122 #(#display_arms)*
123 }
124 }
125 }
126
127 impl ::std::error::Error for #enum_ident {
128 fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
129 match self {
130 #(#source_arms)*
131 }
132 }
133 }
134
135 #(#from_impls)*
136
137 impl ::union_error::__private::LocatedErrorMetadata for #enum_ident {
138 const LEAVES: &'static [::union_error::__private::LeafSpec] = &[
139 #(#metadata_entries),*
140 ];
141 }
142 };
143
144 expanded.into()
145}
146
147#[proc_macro_attribute]
148pub fn error_union(_attr: TokenStream, item: TokenStream) -> TokenStream {
149 let input = parse_macro_input!(item as DeriveInput);
153 expand_error_union_enum(input).into()
154}
155
156fn expand_error_union_enum(input: DeriveInput) -> proc_macro2::TokenStream {
157 let enum_name = input.ident;
159 let data = match input.data {
160 Data::Enum(e) => e,
161 _ => {
162 return syn::Error::new_spanned(enum_name, "error_union only supports enums")
163 .to_compile_error();
164 }
165 };
166
167 if let Some(where_clause) = input.generics.where_clause {
168 return syn::Error::new_spanned(where_clause, "error_union does not support generics")
169 .to_compile_error();
170 }
171
172 let attrs = input.attrs;
173 let vis = input.vis;
174
175 match resolve_union_leaves(&data) {
176 Ok(leaves) => build_union_tokens(attrs, vis, enum_name, leaves),
177 Err(err) => err.to_compile_error(),
178 }
179}
180
181#[derive(Clone)]
182struct Leaf {
183 variant_ident: Ident,
184 leaf_ty: Type,
185 local_enum_ty: Type,
186 local_variant_ident: Ident,
187}
188
189fn build_union_tokens(
190 attrs: Vec<syn::Attribute>,
191 vis: syn::Visibility,
192 enum_name: Ident,
193 leaves: Vec<Leaf>,
194) -> proc_macro2::TokenStream {
195 let union_variants = leaves.iter().map(|leaf| {
200 let v = &leaf.variant_ident;
201 let ty = &leaf.leaf_ty;
202 quote! { #v(::union_error::Located<#ty>) }
203 });
204
205 let display_arms = leaves.iter().map(|leaf| display_arm(&leaf.variant_ident));
206
207 let source_arms = leaves.iter().map(|leaf| source_arm(&leaf.variant_ident));
208
209 let from_leaf_impls = leaves.iter().map(|leaf| {
211 let v = &leaf.variant_ident;
212 let ty = &leaf.leaf_ty;
213 quote! {
214 impl ::core::convert::From<#ty> for #enum_name {
215 #[track_caller]
216 fn from(source: #ty) -> Self {
217 Self::#v(::union_error::Located::new(source))
218 }
219 }
220 }
221 });
222
223 let mut by_local_enum = BTreeMap::<String, (Type, Vec<(Ident, Ident)>)>::new();
226 for leaf in &leaves {
227 let (_, variants) = by_local_enum
228 .entry(type_key(&leaf.local_enum_ty))
229 .or_insert_with(|| (leaf.local_enum_ty.clone(), Vec::new()));
230 variants.push((leaf.local_variant_ident.clone(), leaf.variant_ident.clone()));
231 }
232
233 let from_local_impls = by_local_enum
234 .into_values()
235 .map(|(local_enum_ty, variants)| {
236 let arms = variants.iter().map(|(local_variant, union_variant)| {
237 quote! { #local_enum_ty::#local_variant(inner) => Self::#union_variant(inner), }
238 });
239 quote! {
240 impl ::core::convert::From<#local_enum_ty> for #enum_name {
241 fn from(source: #local_enum_ty) -> Self {
242 match source {
243 #(#arms)*
244 }
245 }
246 }
247 }
248 });
249
250 quote! {
251 #(#attrs)*
252 #vis enum #enum_name {
253 #(#union_variants),*
254 }
255
256 impl ::core::fmt::Display for #enum_name {
257 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
258 match self {
259 #(#display_arms)*
260 }
261 }
262 }
263
264 impl ::std::error::Error for #enum_name {
265 fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
266 match self {
267 #(#source_arms)*
268 }
269 }
270 }
271
272 #(#from_leaf_impls)*
273 #(#from_local_impls)*
274 }
275}
276
277fn resolve_union_leaves(data: &DataEnum) -> syn::Result<Vec<Leaf>> {
278 let mut leaves = Vec::new();
283 let mut by_leaf_type = BTreeMap::<String, proc_macro2::Span>::new();
284 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
285
286 for variant in &data.variants {
287 let local_enum_ty = single_field_type(variant)?;
288 let local_enum_path = extract_type_path(local_enum_ty)?;
289 let local_enum_name = local_enum_path
290 .path
291 .segments
292 .last()
293 .map(|s| s.ident.to_string())
294 .ok_or_else(|| {
295 syn::Error::new_spanned(local_enum_ty, "invalid local enum type path")
296 })?;
297
298 let module_path = module_path_for_local_enum(&local_enum_path.path, &local_enum_name)?;
299 let module_file =
300 find_module_file(Path::new(&manifest_dir), &module_path).ok_or_else(|| {
301 syn::Error::new_spanned(
302 local_enum_ty,
303 format!(
304 "could not find module source for `{}` at `src/{}`",
305 local_enum_name,
306 module_path.join("/")
307 ),
308 )
309 })?;
310
311 let local_enum = parse_local_enum(&module_file, &local_enum_name)?;
313
314 for local_variant in local_enum.variants {
315 let leaf_ty = normalized_leaf_type(single_field_type(&local_variant)?);
316 let key = type_key(&leaf_ty);
317 if let Some(_prev) = by_leaf_type.insert(key.clone(), local_variant.ident.span()) {
319 return Err(syn::Error::new_spanned(
320 &local_variant.ident,
321 format!(
322 "duplicate leaf error type across unioned local enums: `{}`",
323 key
324 ),
325 ));
326 }
327
328 leaves.push(Leaf {
329 variant_ident: local_variant.ident.clone(),
330 leaf_ty,
331 local_enum_ty: local_enum_ty.clone(),
332 local_variant_ident: local_variant.ident,
333 });
334 }
335 }
336
337 Ok(leaves)
338}
339
340fn parse_local_enum(path: &Path, enum_name: &str) -> syn::Result<ItemEnum> {
341 let content = fs::read_to_string(path).map_err(|e| {
343 syn::Error::new(
344 proc_macro2::Span::call_site(),
345 format!("failed reading {}: {}", path.display(), e),
346 )
347 })?;
348 let file = syn::parse_file(&content)?;
349 for item in file.items {
350 if let Item::Enum(item_enum) = item {
351 if item_enum.ident == enum_name {
352 return Ok(item_enum);
353 }
354 }
355 }
356
357 Err(syn::Error::new(
358 proc_macro2::Span::call_site(),
359 format!("could not find enum `{}` in {}", enum_name, path.display()),
360 ))
361}
362
363fn module_path_for_local_enum(path: &syn::Path, enum_name: &str) -> syn::Result<Vec<String>> {
364 let mut segments: Vec<String> = path.segments.iter().map(|s| s.ident.to_string()).collect();
366 if segments.last().map(|s| s.as_str()) != Some(enum_name) {
367 return Err(syn::Error::new_spanned(
368 path,
369 "union variant must reference a local enum type",
370 ));
371 }
372 segments.pop();
373 if segments.first().map(|s| s.as_str()) == Some("crate") {
374 segments.remove(0);
375 }
376 if segments.is_empty() {
377 return Err(syn::Error::new_spanned(
378 path,
379 "could not resolve module path",
380 ));
381 }
382 Ok(segments)
383}
384
385fn find_module_file(manifest_dir: &Path, module_path: &[String]) -> Option<PathBuf> {
386 for base in [manifest_dir.join("src"), manifest_dir.to_path_buf()] {
388 let mut p = base.clone();
389 for seg in module_path {
390 p.push(seg);
391 }
392
393 let flat = p.with_extension("rs");
394 if flat.exists() {
395 return Some(flat);
396 }
397
398 let nested = p.join("mod.rs");
399 if nested.exists() {
400 return Some(nested);
401 }
402 }
403
404 None
405}
406
407fn single_field_type(variant: &syn::Variant) -> syn::Result<&Type> {
408 match &variant.fields {
410 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => Ok(&fields.unnamed[0].ty),
411 _ => Err(syn::Error::new_spanned(
412 &variant.ident,
413 "each variant must have exactly one unnamed field",
414 )),
415 }
416}
417
418fn extract_type_path(ty: &Type) -> syn::Result<&TypePath> {
419 match ty {
421 Type::Path(path) => Ok(path),
422 _ => Err(syn::Error::new_spanned(
423 ty,
424 "variant field must be a path type",
425 )),
426 }
427}
428
429fn display_arm(variant: &Ident) -> proc_macro2::TokenStream {
430 quote! { Self::#variant(inner) => ::core::fmt::Display::fmt(inner, f), }
431}
432
433fn source_arm(variant: &Ident) -> proc_macro2::TokenStream {
434 quote! { Self::#variant(inner) => Some(inner as &(dyn ::std::error::Error + 'static)), }
435}
436
437fn normalized_leaf_type(ty: &Type) -> Type {
438 unwrap_located(ty).unwrap_or_else(|| ty.clone())
439}
440
441fn unwrap_located(ty: &Type) -> Option<Type> {
442 let Type::Path(path) = ty else {
444 return None;
445 };
446 let segment = path.path.segments.last()?;
447 if segment.ident != "Located" {
448 return None;
449 }
450 let PathArguments::AngleBracketed(args) = &segment.arguments else {
451 return None;
452 };
453 let GenericArgument::Type(inner) = args.args.first()? else {
454 return None;
455 };
456 Some(inner.clone())
457}
458
459fn type_key(ty: &Type) -> String {
460 quote!(#ty).to_string().replace(' ', "")
462}