1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, spanned::Spanned, Fields, ItemEnum, Variant};
4
5#[proc_macro_attribute]
6pub fn traceable(_attr: TokenStream, item: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(item as ItemEnum);
8 expand_traceable(input).unwrap_or_else(|e| e.to_compile_error()).into()
9}
10
11fn expand_traceable(mut item: ItemEnum) -> syn::Result<proc_macro2::TokenStream> {
12 let enum_ident = item.ident.clone();
13 let generics = item.generics.clone();
14 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
15
16 let mut from_impls = Vec::new();
17
18 let mut seen_from_sources: std::collections::HashMap<String, proc_macro2::Span> =
19 std::collections::HashMap::new();
20
21 for variant in &mut item.variants {
22 let from_info = extract_from_source(variant)?;
23 let Some(from_info) = from_info else {
24 continue;
25 };
26 let source_ty = from_info.source_ty.clone();
29 let source_ty_key = quote!(#source_ty).to_string();
30 if let Some(prev_span) = seen_from_sources.get(&source_ty_key) {
31 let mut err = syn::Error::new(
32 variant.span(),
33 format!(
34 "duplicate #[from] source type `{}`; this would create conflicting `From<{}>` impls",
35 source_ty_key, source_ty_key
36 ),
37 );
38 err.combine(syn::Error::new(*prev_span, "previous #[from] source type seen here"));
39 return Err(err);
40 }
41 seen_from_sources.insert(source_ty_key, variant.span());
42
43 rewrite_from_variant(variant, &from_info)?;
44
45 let variant_ident = variant.ident.clone();
46 let source_field = from_info.source_field.clone();
47 let extra_fields = extra_default_inits(variant, &source_field)?;
48 let merge_origin = is_thistrace_origin(&source_ty);
49 let merge_bubbled = is_thistrace_bubbled(&source_ty);
50 let from_impl = if merge_origin {
51 quote! {
52 impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
53 #[track_caller]
54 fn from(source: #source_ty) -> Self {
55 let __loc = ::core::panic::Location::caller();
56 let __frame = ::thistrace::Frame::from_location(__loc);
57 let mut __trace = ::thistrace::HasTrace::trace(&source)
58 .cloned()
59 .unwrap_or_else(::thistrace::Trace::empty);
60 __trace.push(__frame);
61
62 #enum_ident::#variant_ident {
63 #source_field: source,
64 #(#extra_fields,)*
65 trace: __trace,
66 }
67 }
68 }
69 }
70 } else if merge_bubbled {
71 quote! {
72 impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
73 #[track_caller]
74 fn from(source: #source_ty) -> Self {
75 let __trace = ::thistrace::HasTrace::trace(&source)
76 .cloned()
77 .unwrap_or_else(::thistrace::Trace::empty);
78
79 #enum_ident::#variant_ident {
80 #source_field: source,
81 #(#extra_fields,)*
82 trace: __trace,
83 }
84 }
85 }
86 }
87 } else {
88 quote! {
89 impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
90 #[track_caller]
91 fn from(source: #source_ty) -> Self {
92 let __loc = ::core::panic::Location::caller();
93 let __frame = ::thistrace::Frame::from_location(__loc);
94 #enum_ident::#variant_ident {
95 #source_field: source,
96 #(#extra_fields,)*
97 trace: ::thistrace::Trace::from_frame(__frame),
98 }
99 }
100 }
101 }
102 };
103 from_impls.push(from_impl);
104 }
105
106 let match_arms = item.variants.iter().map(|v| {
108 let vident = &v.ident;
109 match &v.fields {
110 Fields::Named(named) => {
111 let has_trace = named.named.iter().any(|f| {
112 f.ident
113 .as_ref()
114 .is_some_and(|id| id == "trace")
115 });
116 if has_trace {
117 quote! { Self::#vident { trace, .. } => ::core::option::Option::Some(trace), }
118 } else {
119 quote! { Self::#vident { .. } => ::core::option::Option::None, }
120 }
121 }
122 Fields::Unnamed(_) => quote! { Self::#vident ( .. ) => ::core::option::Option::None, },
123 Fields::Unit => quote! { Self::#vident => ::core::option::Option::None, },
124 }
125 });
126
127 let has_trace_impl = quote! {
128 impl #impl_generics ::thistrace::HasTrace for #enum_ident #ty_generics #where_clause {
129 fn trace(&self) -> ::core::option::Option<&::thistrace::Trace> {
130 match self {
131 #(#match_arms)*
132 }
133 }
134 }
135 };
136
137 Ok(quote! {
138 #item
139 #(#from_impls)*
140 #has_trace_impl
141 })
142}
143
144struct FromInfo {
145 source_ty: syn::Type,
146 source_field: syn::Ident,
147 shape: FromShape,
148 tuple_ctx_tys: Vec<syn::Type>,
149}
150
151enum FromShape {
152 Tuple,
153 Struct,
154}
155
156fn extract_from_source(variant: &Variant) -> syn::Result<Option<FromInfo>> {
157 if let Fields::Unnamed(fields) = &variant.fields {
159 let from_indices: Vec<usize> = fields
160 .unnamed
161 .iter()
162 .enumerate()
163 .filter(|(_, f)| f.attrs.iter().any(|a| a.path().is_ident("from")))
164 .map(|(i, _)| i)
165 .collect();
166 if from_indices.len() > 1 {
167 return Err(syn::Error::new(
168 variant.span(),
169 "multiple #[from] fields in a single tuple variant are not supported",
170 ));
171 }
172 if from_indices.len() == 1 {
173 let from_index = from_indices[0];
174 let from_field = &fields.unnamed[from_index];
175 let ctx_tys = fields
176 .unnamed
177 .iter()
178 .enumerate()
179 .filter(|(i, _)| *i != from_index)
180 .map(|(_, f)| f.ty.clone())
181 .collect::<Vec<_>>();
182 if !ctx_tys.is_empty() || from_field.attrs.iter().any(|a| a.path().is_ident("from")) {
183 return Ok(Some(FromInfo {
184 source_ty: from_field.ty.clone(),
185 source_field: format_ident!("source"),
186 shape: FromShape::Tuple,
187 tuple_ctx_tys: ctx_tys,
188 }));
189 }
190 }
191 }
192
193 if let Fields::Named(fields) = &variant.fields {
195 let from_fields: Vec<_> = fields
196 .named
197 .iter()
198 .filter(|f| f.attrs.iter().any(|a| a.path().is_ident("from")))
199 .collect();
200 if from_fields.len() > 1 {
201 return Err(syn::Error::new(
202 variant.span(),
203 "multiple #[from] fields in a single struct variant are not supported",
204 ));
205 }
206 if from_fields.len() == 1 {
207 let field = from_fields[0];
208 let ident = field.ident.clone().ok_or_else(|| {
209 syn::Error::new(field.span(), "expected a named field for struct #[from] variant")
210 })?;
211 return Ok(Some(FromInfo {
212 source_ty: field.ty.clone(),
213 source_field: ident,
214 shape: FromShape::Struct,
215 tuple_ctx_tys: Vec::new(),
216 }));
217 }
218 }
219
220 Ok(None)
221}
222
223fn rewrite_from_variant(variant: &mut Variant, info: &FromInfo) -> syn::Result<()> {
224 match info.shape {
225 FromShape::Tuple => rewrite_tuple_from_variant(variant, &info.source_ty, &info.tuple_ctx_tys),
226 FromShape::Struct => rewrite_struct_from_variant(variant, info),
227 }
228}
229
230fn rewrite_tuple_from_variant(
231 variant: &mut Variant,
232 source_ty: &syn::Type,
233 ctx_tys: &[syn::Type],
234) -> syn::Result<()> {
235 let variant_ident = variant.ident.clone();
236 match &variant.fields {
237 Fields::Unnamed(_) => {
238 let mut named = syn::punctuated::Punctuated::new();
239 named.push(syn::Field {
240 attrs: vec![syn::parse_quote!(#[source])],
241 vis: syn::Visibility::Inherited,
242 mutability: syn::FieldMutability::None,
243 ident: Some(format_ident!("source")),
244 colon_token: Some(Default::default()),
245 ty: source_ty.clone(),
246 });
247
248 for (i, ty) in ctx_tys.iter().enumerate() {
249 named.push(syn::Field {
250 attrs: vec![],
251 vis: syn::Visibility::Inherited,
252 mutability: syn::FieldMutability::None,
253 ident: Some(format_ident!("ctx{i}")),
254 colon_token: Some(Default::default()),
255 ty: ty.clone(),
256 });
257 }
258
259 named.push(syn::Field {
260 attrs: vec![],
261 vis: syn::Visibility::Inherited,
262 mutability: syn::FieldMutability::None,
263 ident: Some(format_ident!("trace")),
264 colon_token: Some(Default::default()),
265 ty: syn::parse_quote!(::thistrace::Trace),
266 });
267
268 variant.fields = Fields::Named(syn::FieldsNamed {
269 brace_token: Default::default(),
270 named,
271 });
272 Ok(())
273 }
274 _ => Err(syn::Error::new(
275 variant_ident.span(),
276 "only tuple variants can be rewritten for #[from]",
277 )),
278 }
279}
280
281fn rewrite_struct_from_variant(variant: &mut Variant, info: &FromInfo) -> syn::Result<()> {
282 let Fields::Named(fields) = &mut variant.fields else {
283 return Err(syn::Error::new(variant.span(), "expected struct variant"));
284 };
285
286 for field in fields.named.iter_mut() {
288 if field.ident.as_ref() == Some(&info.source_field) {
289 field.attrs.retain(|a| !a.path().is_ident("from"));
290 let has_source = field.attrs.iter().any(|a| a.path().is_ident("source"));
292 if !has_source {
293 field.attrs.push(syn::parse_quote!(#[source]));
294 }
295 }
296 }
297
298 let has_trace = fields
299 .named
300 .iter()
301 .any(|f| f.ident.as_ref().is_some_and(|id| id == "trace"));
302 if !has_trace {
303 fields.named.push(syn::Field {
304 attrs: vec![],
305 vis: syn::Visibility::Inherited,
306 mutability: syn::FieldMutability::None,
307 ident: Some(format_ident!("trace")),
308 colon_token: Some(Default::default()),
309 ty: syn::parse_quote!(::thistrace::Trace),
310 });
311 }
312
313 Ok(())
314}
315
316fn extra_default_inits(
317 variant: &Variant,
318 source_field: &syn::Ident,
319) -> syn::Result<Vec<proc_macro2::TokenStream>> {
320 let mut inits = Vec::new();
321 let Fields::Named(fields) = &variant.fields else {
322 return Ok(inits);
323 };
324
325 for field in fields.named.iter() {
326 let Some(ident) = field.ident.as_ref() else {
327 continue;
328 };
329 if ident == source_field {
330 continue;
331 }
332 if ident == "trace" {
333 continue;
334 }
335 inits.push(quote! { #ident: ::core::default::Default::default() });
336 }
337
338 Ok(inits)
339}
340
341fn is_thistrace_origin(ty: &syn::Type) -> bool {
342 let syn::Type::Path(p) = ty else {
343 return false;
344 };
345 let Some(seg) = p.path.segments.last() else {
346 return false;
347 };
348 if seg.ident != "Origin" {
349 return false;
350 }
351 matches!(seg.arguments, syn::PathArguments::AngleBracketed(_))
353}
354
355fn is_thistrace_bubbled(ty: &syn::Type) -> bool {
356 let syn::Type::Path(p) = ty else {
357 return false;
358 };
359 let Some(seg) = p.path.segments.last() else {
360 return false;
361 };
362 if seg.ident != "Bubbled" {
363 return false;
364 }
365 matches!(seg.arguments, syn::PathArguments::AngleBracketed(_))
366}
367