Skip to main content

telemetry_safe_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse_macro_input;
4use syn::spanned::Spanned;
5use syn::{
6    Attribute, Data, DataEnum, DataStruct, DeriveInput, Error, Expr, Fields, LitStr, Result, Token,
7};
8
9#[proc_macro_derive(ToTelemetry, attributes(telemetry))]
10pub fn derive_to_telemetry(input: TokenStream) -> TokenStream {
11    let input = parse_macro_input!(input as DeriveInput);
12    match expand_derive(&input) {
13        Ok(tokens) => tokens.into(),
14        Err(err) => err.to_compile_error().into(),
15    }
16}
17
18fn expand_derive(input: &DeriveInput) -> Result<proc_macro2::TokenStream> {
19    let ident = &input.ident;
20    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
21
22    let body = match &input.data {
23        Data::Struct(data) => expand_struct(ident, data)?,
24        Data::Enum(data) => expand_enum(data)?,
25        Data::Union(data) => {
26            return Err(Error::new(
27                data.union_token.span(),
28                "ToTelemetry cannot be derived for unions",
29            ));
30        }
31    };
32
33    Ok(quote! {
34        impl #impl_generics ::telemetry_safe::ToTelemetry for #ident #ty_generics #where_clause {
35            fn fmt_telemetry(
36                &self,
37                f: &mut ::std::fmt::Formatter<'_>,
38            ) -> ::std::fmt::Result {
39                #body
40            }
41        }
42    })
43}
44
45fn expand_struct(ident: &syn::Ident, data: &DataStruct) -> Result<proc_macro2::TokenStream> {
46    match &data.fields {
47        Fields::Named(fields) => {
48            let mut field_exprs = Vec::new();
49            for field in &fields.named {
50                let attr = parse_field_attr(&field.attrs).transpose()?;
51                if matches!(attr, Some(FieldAttr::Skip)) {
52                    continue;
53                }
54
55                let name = field.ident.as_ref().expect("named field");
56                let key = LitStr::new(&name.to_string(), name.span());
57                let value = field_expr(field, quote! { self.#name }, attr)?;
58                field_exprs.push(quote! {
59                    ds.field(#key, &#value);
60                });
61            }
62
63            Ok(quote! {
64                let mut ds = f.debug_struct(stringify!(#ident));
65                #(#field_exprs)*
66                ds.finish()
67            })
68        }
69        Fields::Unnamed(fields) => {
70            let mut field_exprs = Vec::new();
71            for (index, field) in fields.unnamed.iter().enumerate() {
72                let attr = parse_field_attr(&field.attrs).transpose()?;
73                if matches!(attr, Some(FieldAttr::Skip)) {
74                    continue;
75                }
76
77                let accessor = syn::Index::from(index);
78                let value = field_expr(field, quote! { self.#accessor }, attr)?;
79                field_exprs.push(quote! {
80                    ds.field(&#value);
81                });
82            }
83
84            Ok(quote! {
85                let mut ds = f.debug_tuple(stringify!(#ident));
86                #(#field_exprs)*
87                ds.finish()
88            })
89        }
90        Fields::Unit => Ok(quote! {
91            f.write_str(stringify!(#ident))
92        }),
93    }
94}
95
96fn expand_enum(data: &DataEnum) -> Result<proc_macro2::TokenStream> {
97    let arms = data
98        .variants
99        .iter()
100        .map(|variant| {
101            let ident = &variant.ident;
102            match &variant.fields {
103                Fields::Named(fields) => {
104                    let mut bindings = Vec::new();
105                    let mut formatter = Vec::new();
106                    for field in &fields.named {
107                        let attr = parse_field_attr(&field.attrs).transpose()?;
108                        let name = field.ident.as_ref().expect("named field");
109
110                        if matches!(attr, Some(FieldAttr::Skip)) {
111                            // Skipped fields must not bind a local name, otherwise enum
112                            // patterns trigger `unused variable` warnings in downstream crates.
113                            bindings.push(quote! { #name: _ });
114                            continue;
115                        }
116
117                        bindings.push(quote! { #name });
118                        let key = LitStr::new(&name.to_string(), name.span());
119                        let value = field_expr(field, quote! { #name }, attr)?;
120                        formatter.push(quote! {
121                            ds.field(#key, &#value);
122                        });
123                    }
124
125                    Ok(quote! {
126                        Self::#ident { #(#bindings),* } => {
127                            let mut ds = f.debug_struct(stringify!(#ident));
128                            #(#formatter)*
129                            ds.finish()
130                        }
131                    })
132                }
133                Fields::Unnamed(fields) => {
134                    let mut bindings = Vec::new();
135                    let mut formatter = Vec::new();
136                    for (index, field) in fields.unnamed.iter().enumerate() {
137                        let attr = parse_field_attr(&field.attrs).transpose()?;
138                        if matches!(attr, Some(FieldAttr::Skip)) {
139                            bindings.push(quote! { _ });
140                            continue;
141                        }
142
143                        let binding = syn::Ident::new(&format!("field_{index}"), ident.span());
144                        bindings.push(quote! { #binding });
145                        let value = field_expr(field, quote! { #binding }, attr)?;
146                        formatter.push(quote! {
147                            ds.field(&#value);
148                        });
149                    }
150
151                    Ok(quote! {
152                        Self::#ident(#(#bindings),*) => {
153                            let mut ds = f.debug_tuple(stringify!(#ident));
154                            #(#formatter)*
155                            ds.finish()
156                        }
157                    })
158                }
159                Fields::Unit => Ok(quote! {
160                    Self::#ident => f.write_str(stringify!(#ident))
161                }),
162            }
163        })
164        .collect::<Result<Vec<_>>>()?;
165
166    Ok(quote! {
167        match self {
168            #(#arms),*
169        }
170    })
171}
172
173fn field_expr(
174    field: &syn::Field,
175    accessor: proc_macro2::TokenStream,
176    attr: Option<FieldAttr>,
177) -> Result<proc_macro2::TokenStream> {
178    match attr {
179        Some(FieldAttr::Display(format)) => {
180            if format.value() != "{}" {
181                return Err(Error::new(
182                    format.span(),
183                    "only #[telemetry(\"{}\")] is currently supported",
184                ));
185            }
186
187            // `format_args!` keeps the derive path allocation-free while still allowing
188            // explicit escape hatches for types whose Display output is already curated.
189            Ok(quote! {
190                ::std::format_args!("{}", #accessor)
191            })
192        }
193        Some(FieldAttr::Skip) | None => {
194            let ty = &field.ty;
195            Ok(quote! {{
196                let value: &#ty = &#accessor;
197                ::telemetry_safe::telemetry_debug(value)
198            }})
199        }
200    }
201}
202
203enum FieldAttr {
204    Display(LitStr),
205    Skip,
206}
207
208fn parse_field_attr(attrs: &[Attribute]) -> Option<Result<FieldAttr>> {
209    attrs
210        .iter()
211        .find(|attr| attr.path().is_ident("telemetry"))
212        .map(parse_single_field_attr)
213}
214
215fn parse_single_field_attr(attr: &Attribute) -> Result<FieldAttr> {
216    attr.parse_args_with(|input: syn::parse::ParseStream<'_>| {
217        if input.peek(syn::Ident) {
218            let ident: syn::Ident = input.parse()?;
219            if ident == "skip" {
220                if !input.is_empty() {
221                    return Err(input.error("unexpected tokens after skip"));
222                }
223                return Ok(FieldAttr::Skip);
224            }
225
226            return Err(Error::new(ident.span(), "unsupported telemetry attribute"));
227        }
228
229        let format: Expr = input.parse()?;
230        if !input.is_empty() {
231            let _comma: Token![,] = input.parse()?;
232            if !input.is_empty() {
233                return Err(input.error("expected a single format string or `skip`"));
234            }
235        }
236
237        match format {
238            Expr::Lit(expr_lit) => match expr_lit.lit {
239                syn::Lit::Str(lit) => Ok(FieldAttr::Display(lit)),
240                other => Err(Error::new(other.span(), "expected string literal")),
241            },
242            other => Err(Error::new(other.span(), "expected string literal")),
243        }
244    })
245}