Skip to main content

thistrace_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, spanned::Spanned, Fields, ItemEnum, Variant};
4
5#[proc_macro_attribute]
6pub fn traceable(_attr: TokenStream, item: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(item as ItemEnum);
8    expand_traceable(input).unwrap_or_else(|e| e.to_compile_error()).into()
9}
10
11fn expand_traceable(mut item: ItemEnum) -> syn::Result<proc_macro2::TokenStream> {
12    let enum_ident = item.ident.clone();
13    let generics = item.generics.clone();
14    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
15
16    let mut from_impls = Vec::new();
17
18    let mut seen_from_sources: std::collections::HashMap<String, proc_macro2::Span> =
19        std::collections::HashMap::new();
20
21    for variant in &mut item.variants {
22        let from_info = extract_from_source(variant)?;
23        let Some(from_info) = from_info else {
24            continue;
25        };
26        // Reserved for future: variant-level trace merging.
27
28        let source_ty = from_info.source_ty.clone();
29        let source_ty_key = quote!(#source_ty).to_string();
30        if let Some(prev_span) = seen_from_sources.get(&source_ty_key) {
31            let mut err = syn::Error::new(
32                variant.span(),
33                format!(
34                    "duplicate #[from] source type `{}`; this would create conflicting `From<{}>` impls",
35                    source_ty_key, source_ty_key
36                ),
37            );
38            err.combine(syn::Error::new(*prev_span, "previous #[from] source type seen here"));
39            return Err(err);
40        }
41        seen_from_sources.insert(source_ty_key, variant.span());
42
43        rewrite_from_variant(variant, &from_info)?;
44
45        let variant_ident = variant.ident.clone();
46        let source_field = from_info.source_field.clone();
47        let extra_fields = extra_default_inits(variant, &source_field)?;
48        let merge_origin = is_thistrace_origin(&source_ty);
49        let merge_bubbled = is_thistrace_bubbled(&source_ty);
50        let from_impl = if merge_origin {
51            quote! {
52                impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
53                    #[track_caller]
54                    fn from(source: #source_ty) -> Self {
55                        let __loc = ::core::panic::Location::caller();
56                        let __frame = ::thistrace::Frame::from_location(__loc);
57                        let mut __trace = ::thistrace::HasTrace::trace(&source)
58                            .cloned()
59                            .unwrap_or_else(::thistrace::Trace::empty);
60                        __trace.push(__frame);
61
62                        #enum_ident::#variant_ident {
63                            #source_field: source,
64                            #(#extra_fields,)*
65                            trace: __trace,
66                        }
67                    }
68                }
69            }
70        } else if merge_bubbled {
71            quote! {
72                impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
73                    #[track_caller]
74                    fn from(source: #source_ty) -> Self {
75                        let __trace = ::thistrace::HasTrace::trace(&source)
76                            .cloned()
77                            .unwrap_or_else(::thistrace::Trace::empty);
78
79                        #enum_ident::#variant_ident {
80                            #source_field: source,
81                            #(#extra_fields,)*
82                            trace: __trace,
83                        }
84                    }
85                }
86            }
87        } else {
88            quote! {
89                impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
90                    #[track_caller]
91                    fn from(source: #source_ty) -> Self {
92                        let __loc = ::core::panic::Location::caller();
93                        let __frame = ::thistrace::Frame::from_location(__loc);
94                        #enum_ident::#variant_ident {
95                            #source_field: source,
96                            #(#extra_fields,)*
97                            trace: ::thistrace::Trace::from_frame(__frame),
98                        }
99                    }
100                }
101            }
102        };
103        from_impls.push(from_impl);
104    }
105
106    // Generate HasTrace impl that returns the variant trace if present.
107    let match_arms = item.variants.iter().map(|v| {
108        let vident = &v.ident;
109        match &v.fields {
110            Fields::Named(named) => {
111                let has_trace = named.named.iter().any(|f| {
112                    f.ident
113                        .as_ref()
114                        .is_some_and(|id| id == "trace")
115                });
116                if has_trace {
117                    quote! { Self::#vident { trace, .. } => ::core::option::Option::Some(trace), }
118                } else {
119                    quote! { Self::#vident { .. } => ::core::option::Option::None, }
120                }
121            }
122            Fields::Unnamed(_) => quote! { Self::#vident ( .. ) => ::core::option::Option::None, },
123            Fields::Unit => quote! { Self::#vident => ::core::option::Option::None, },
124        }
125    });
126
127    let has_trace_impl = quote! {
128        impl #impl_generics ::thistrace::HasTrace for #enum_ident #ty_generics #where_clause {
129            fn trace(&self) -> ::core::option::Option<&::thistrace::Trace> {
130                match self {
131                    #(#match_arms)*
132                }
133            }
134        }
135    };
136
137    Ok(quote! {
138        #item
139        #(#from_impls)*
140        #has_trace_impl
141    })
142}
143
144struct FromInfo {
145    source_ty: syn::Type,
146    source_field: syn::Ident,
147    shape: FromShape,
148    tuple_ctx_tys: Vec<syn::Type>,
149}
150
151enum FromShape {
152    Tuple,
153    Struct,
154}
155
156fn extract_from_source(variant: &Variant) -> syn::Result<Option<FromInfo>> {
157    // tuple form: Foo(#[from] io::Error, Ctx0, Ctx1, ...)
158    if let Fields::Unnamed(fields) = &variant.fields {
159        let from_indices: Vec<usize> = fields
160            .unnamed
161            .iter()
162            .enumerate()
163            .filter(|(_, f)| f.attrs.iter().any(|a| a.path().is_ident("from")))
164            .map(|(i, _)| i)
165            .collect();
166        if from_indices.len() > 1 {
167            return Err(syn::Error::new(
168                variant.span(),
169                "multiple #[from] fields in a single tuple variant are not supported",
170            ));
171        }
172        if from_indices.len() == 1 {
173            let from_index = from_indices[0];
174            let from_field = &fields.unnamed[from_index];
175            let ctx_tys = fields
176                .unnamed
177                .iter()
178                .enumerate()
179                .filter(|(i, _)| *i != from_index)
180                .map(|(_, f)| f.ty.clone())
181                .collect::<Vec<_>>();
182            if !ctx_tys.is_empty() || from_field.attrs.iter().any(|a| a.path().is_ident("from")) {
183                return Ok(Some(FromInfo {
184                    source_ty: from_field.ty.clone(),
185                    source_field: format_ident!("source"),
186                    shape: FromShape::Tuple,
187                    tuple_ctx_tys: ctx_tys,
188                }));
189            }
190        }
191    }
192
193    // struct form: Foo { #[from] source: io::Error }
194    if let Fields::Named(fields) = &variant.fields {
195        let from_fields: Vec<_> = fields
196            .named
197            .iter()
198            .filter(|f| f.attrs.iter().any(|a| a.path().is_ident("from")))
199            .collect();
200        if from_fields.len() > 1 {
201            return Err(syn::Error::new(
202                variant.span(),
203                "multiple #[from] fields in a single struct variant are not supported",
204            ));
205        }
206        if from_fields.len() == 1 {
207            let field = from_fields[0];
208            let ident = field.ident.clone().ok_or_else(|| {
209                syn::Error::new(field.span(), "expected a named field for struct #[from] variant")
210            })?;
211            return Ok(Some(FromInfo {
212                source_ty: field.ty.clone(),
213                source_field: ident,
214                shape: FromShape::Struct,
215                tuple_ctx_tys: Vec::new(),
216            }));
217        }
218    }
219
220    Ok(None)
221}
222
223fn rewrite_from_variant(variant: &mut Variant, info: &FromInfo) -> syn::Result<()> {
224    match info.shape {
225        FromShape::Tuple => rewrite_tuple_from_variant(variant, &info.source_ty, &info.tuple_ctx_tys),
226        FromShape::Struct => rewrite_struct_from_variant(variant, info),
227    }
228}
229
230fn rewrite_tuple_from_variant(
231    variant: &mut Variant,
232    source_ty: &syn::Type,
233    ctx_tys: &[syn::Type],
234) -> syn::Result<()> {
235    let variant_ident = variant.ident.clone();
236    match &variant.fields {
237        Fields::Unnamed(_) => {
238            let mut named = syn::punctuated::Punctuated::new();
239            named.push(syn::Field {
240                attrs: vec![syn::parse_quote!(#[source])],
241                vis: syn::Visibility::Inherited,
242                mutability: syn::FieldMutability::None,
243                ident: Some(format_ident!("source")),
244                colon_token: Some(Default::default()),
245                ty: source_ty.clone(),
246            });
247
248            for (i, ty) in ctx_tys.iter().enumerate() {
249                named.push(syn::Field {
250                    attrs: vec![],
251                    vis: syn::Visibility::Inherited,
252                    mutability: syn::FieldMutability::None,
253                    ident: Some(format_ident!("ctx{i}")),
254                    colon_token: Some(Default::default()),
255                    ty: ty.clone(),
256                });
257            }
258
259            named.push(syn::Field {
260                attrs: vec![],
261                vis: syn::Visibility::Inherited,
262                mutability: syn::FieldMutability::None,
263                ident: Some(format_ident!("trace")),
264                colon_token: Some(Default::default()),
265                ty: syn::parse_quote!(::thistrace::Trace),
266            });
267
268            variant.fields = Fields::Named(syn::FieldsNamed {
269                brace_token: Default::default(),
270                named,
271            });
272            Ok(())
273        }
274        _ => Err(syn::Error::new(
275            variant_ident.span(),
276            "only tuple variants can be rewritten for #[from]",
277        )),
278    }
279}
280
281fn rewrite_struct_from_variant(variant: &mut Variant, info: &FromInfo) -> syn::Result<()> {
282    let Fields::Named(fields) = &mut variant.fields else {
283        return Err(syn::Error::new(variant.span(), "expected struct variant"));
284    };
285
286    // Remove #[from] from the field to avoid thiserror generating a conflicting From impl.
287    for field in fields.named.iter_mut() {
288        if field.ident.as_ref() == Some(&info.source_field) {
289            field.attrs.retain(|a| !a.path().is_ident("from"));
290            // Ensure #[source] so thiserror's source() chain works.
291            let has_source = field.attrs.iter().any(|a| a.path().is_ident("source"));
292            if !has_source {
293                field.attrs.push(syn::parse_quote!(#[source]));
294            }
295        }
296    }
297
298    let has_trace = fields
299        .named
300        .iter()
301        .any(|f| f.ident.as_ref().is_some_and(|id| id == "trace"));
302    if !has_trace {
303        fields.named.push(syn::Field {
304            attrs: vec![],
305            vis: syn::Visibility::Inherited,
306            mutability: syn::FieldMutability::None,
307            ident: Some(format_ident!("trace")),
308            colon_token: Some(Default::default()),
309            ty: syn::parse_quote!(::thistrace::Trace),
310        });
311    }
312
313    Ok(())
314}
315
316fn extra_default_inits(
317    variant: &Variant,
318    source_field: &syn::Ident,
319) -> syn::Result<Vec<proc_macro2::TokenStream>> {
320    let mut inits = Vec::new();
321    let Fields::Named(fields) = &variant.fields else {
322        return Ok(inits);
323    };
324
325    for field in fields.named.iter() {
326        let Some(ident) = field.ident.as_ref() else {
327            continue;
328        };
329        if ident == source_field {
330            continue;
331        }
332        if ident == "trace" {
333            continue;
334        }
335        inits.push(quote! { #ident: ::core::default::Default::default() });
336    }
337
338    Ok(inits)
339}
340
341fn is_thistrace_origin(ty: &syn::Type) -> bool {
342    let syn::Type::Path(p) = ty else {
343        return false;
344    };
345    let Some(seg) = p.path.segments.last() else {
346        return false;
347    };
348    if seg.ident != "Origin" {
349        return false;
350    }
351    // If it is `Origin<T>` we treat it as our wrapper.
352    matches!(seg.arguments, syn::PathArguments::AngleBracketed(_))
353}
354
355fn is_thistrace_bubbled(ty: &syn::Type) -> bool {
356    let syn::Type::Path(p) = ty else {
357        return false;
358    };
359    let Some(seg) = p.path.segments.last() else {
360        return false;
361    };
362    if seg.ident != "Bubbled" {
363        return false;
364    }
365    matches!(seg.arguments, syn::PathArguments::AngleBracketed(_))
366}
367