Skip to main content

stalwart_lite_event_macro/
lib.rs

1/*
2 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
3 *
4 * SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
5 */
6
7use proc_macro::TokenStream;
8use quote::quote;
9use syn::{
10    Data, DeriveInput, Expr, ExprPath, Fields, Ident, Token, parse::Parse, parse_macro_input,
11};
12
13static mut GLOBAL_ID_COUNTER: usize = 0;
14
15#[proc_macro_attribute]
16pub fn event_type(_attr: TokenStream, item: TokenStream) -> TokenStream {
17    let input = parse_macro_input!(item as DeriveInput);
18    let name = &input.ident;
19    let name_str = name.to_string();
20    let prefix = to_snake_case(name_str.strip_suffix("Event").unwrap_or(&name_str));
21
22    let enum_variants = match &input.data {
23        Data::Enum(data_enum) => &data_enum.variants,
24        _ => panic!("This macro only works with enums"),
25    };
26
27    let mut variant_ids = Vec::new();
28    let mut variant_names = Vec::new();
29    let mut event_names = Vec::new();
30
31    for variant in enum_variants {
32        unsafe {
33            variant_ids.push(GLOBAL_ID_COUNTER);
34            GLOBAL_ID_COUNTER += 1;
35        }
36        let variant_name = &variant.ident;
37        event_names.push(format!(
38            "{prefix}.{}",
39            to_snake_case(&variant_name.to_string())
40        ));
41        variant_names.push(variant_name);
42    }
43
44    let id_fn = quote! {
45        pub const fn id(&self) -> usize {
46            match self {
47                #(Self::#variant_names => #variant_ids,)*
48            }
49        }
50    };
51
52    let name_fn = quote! {
53        pub fn name(&self) -> &'static str {
54            match self {
55                #(Self::#variant_names => #event_names,)*
56            }
57        }
58    };
59
60    let parse_fn = quote! {
61        pub fn try_parse(name: &str) -> Option<Self> {
62            match name {
63                #(#event_names => Some(Self::#variant_names),)*
64                _ => None,
65            }
66        }
67    };
68
69    let variants_fn = quote! {
70        pub const fn variants() -> &'static [Self] {
71            &[
72                #(#name::#variant_names,)*
73            ]
74        }
75    };
76
77    let expanded = quote! {
78        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79        pub enum #name {
80            #(#variant_names),*
81        }
82
83        impl #name {
84            #id_fn
85            #name_fn
86            #parse_fn
87            #variants_fn
88        }
89    };
90
91    TokenStream::from(expanded)
92}
93
94#[proc_macro_attribute]
95pub fn event_family(_attr: TokenStream, item: TokenStream) -> TokenStream {
96    let input = parse_macro_input!(item as DeriveInput);
97    let name = &input.ident;
98
99    let variants = match &input.data {
100        Data::Enum(data_enum) => &data_enum.variants,
101        _ => panic!("EventType must be an enum"),
102    };
103
104    let variant_idents: Vec<_> = variants.iter().map(|v| &v.ident).collect();
105
106    let event_types: Vec<_> = variants
107        .iter()
108        .map(|v| match &v.fields {
109            Fields::Unnamed(fields) => &fields.unnamed[0],
110            _ => panic!("EventType variants must be unnamed and contain a single type"),
111        })
112        .map(|f| &f.ty)
113        .collect();
114
115    let variant_names: Vec<_> = variant_idents
116        .iter()
117        .map(|ident| {
118            let name_str = ident.to_string();
119            to_snake_case(name_str.strip_suffix("Event").unwrap_or(&name_str))
120        })
121        .collect();
122
123    let expanded = quote! {
124        pub enum #name {
125            #(#variant_idents(#event_types)),*
126        }
127
128        impl #name {
129            pub const fn id(&self) -> usize {
130                match self {
131                    #(#name::#variant_idents(e) => e.id()),*
132                }
133            }
134
135            pub fn name(&self) -> &'static str {
136                match self {
137                    #(#name::#variant_idents(e) => e.name()),*
138                }
139            }
140
141            pub fn try_parse(name: &str) -> Option<Self> {
142                match name.trim().split_once('.')?.0 {
143                #(
144                    #variant_names =>  <#event_types>::try_parse(&name).map(#name::#variant_idents),
145                )*
146                    _ => None,
147                }
148            }
149
150            pub const fn variants() -> [#name; crate::TOTAL_EVENT_COUNT] {
151                let mut variants = [crate::EventType::Eval(crate::EvalEvent::Error); crate::TOTAL_EVENT_COUNT];
152                #(
153                    {
154                        let sub_variants = <#event_types>::variants();
155                        let mut i = 0;
156                        while i < sub_variants.len() {
157                            variants[sub_variants[i].id()] = #name::#variant_idents(sub_variants[i]);
158                            i += 1;
159                        }
160
161                    }
162                )*
163                variants
164            }
165        }
166    };
167
168    TokenStream::from(expanded)
169}
170
171#[proc_macro_attribute]
172pub fn key_names(_attr: TokenStream, item: TokenStream) -> TokenStream {
173    let input = parse_macro_input!(item as DeriveInput);
174    let name = &input.ident;
175
176    let enum_variants = match &input.data {
177        Data::Enum(data_enum) => &data_enum.variants,
178        _ => panic!("This macro only works with enums"),
179    };
180
181    let mut variant_names = Vec::new();
182    let mut camel_case_names = Vec::new();
183    let mut snake_case_names = Vec::new();
184
185    for variant in enum_variants.iter() {
186        let variant_name = &variant.ident;
187        variant_names.push(variant_name);
188        snake_case_names.push(to_snake_case(&variant_name.to_string()));
189        camel_case_names.push(
190            variant_name
191                .to_string()
192                .char_indices()
193                .map(|(i, c)| if i == 0 { c.to_ascii_lowercase() } else { c })
194                .collect::<String>(),
195        );
196    }
197
198    let id_fn = quote! {
199        pub fn id(&self) -> &'static str {
200            match self {
201                #(Self::#variant_names => #snake_case_names,)*
202            }
203        }
204    };
205
206    let name_fn = quote! {
207        pub fn name(&self) -> &'static str {
208            match self {
209                #(Self::#variant_names => #camel_case_names,)*
210            }
211        }
212    };
213
214    let parse_fn = quote! {
215        pub fn try_parse(name: &str) -> Option<Self> {
216            match name {
217                #(#snake_case_names => Some(Self::#variant_names),)*
218                _ => None,
219            }
220        }
221    };
222
223    let expanded = quote! {
224        #input
225
226        impl #name {
227            #name_fn
228            #id_fn
229            #parse_fn
230        }
231    };
232
233    TokenStream::from(expanded)
234}
235
236#[proc_macro]
237pub fn total_event_count(_item: TokenStream) -> TokenStream {
238    let count = unsafe { GLOBAL_ID_COUNTER };
239    let expanded = quote! {
240        #count
241    };
242    TokenStream::from(expanded)
243}
244
245fn to_snake_case(name: &str) -> String {
246    let mut out = String::with_capacity(name.len());
247    for (idx, ch) in name.char_indices() {
248        if ch.is_ascii_uppercase() {
249            if idx > 0 {
250                out.push('-');
251            }
252            out.push(ch.to_ascii_lowercase());
253        } else {
254            out.push(ch);
255        }
256    }
257    out
258}
259
260struct EventMacroInput {
261    event: Ident,
262    param: Expr,
263    key_values: Vec<(Ident, Expr)>,
264}
265
266impl Parse for EventMacroInput {
267    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
268        let event: Ident = input.parse()?;
269        let content;
270        syn::parenthesized!(content in input);
271        let param: Expr = content.parse()?;
272
273        let mut key_values = Vec::new();
274        while !input.is_empty() {
275            input.parse::<Token![,]>()?;
276            if input.is_empty() {
277                break;
278            }
279            let key: Ident = input.parse()?;
280            input.parse::<Token![=]>()?;
281            let value: Expr = input.parse()?;
282            key_values.push((key, value));
283        }
284
285        Ok(EventMacroInput {
286            event,
287            param,
288            key_values,
289        })
290    }
291}
292
293#[proc_macro]
294pub fn event(input: TokenStream) -> TokenStream {
295    let EventMacroInput {
296        event,
297        param,
298        key_values,
299    } = parse_macro_input!(input as EventMacroInput);
300
301    let key_value_tokens = key_values.iter().map(|(key, value)| {
302        quote! {
303            (trc::Key::#key, trc::Value::from(#value))
304        }
305    });
306    // This avoids having to evaluate expensive values when we know we are not interested in the event
307    let key_value_metric_tokens = key_values.iter().filter_map(|(key, value)| {
308        if key.is_metric_key() {
309            Some(quote! {
310                (trc::Key::#key, trc::Value::from(#value))
311            })
312        } else {
313            None
314        }
315    });
316
317    let expanded = if matches!(&param, Expr::Path(ExprPath { path, .. })  if path.segments.len() > 1 && path.segments.last().unwrap().arguments.is_empty() )
318    {
319        quote! {{
320            const ET: trc::EventType = trc::EventType::#event(#param);
321            const ET_ID: usize = ET.id();
322            if trc::Collector::has_interest(ET_ID) {
323                let keys = vec![#(#key_value_tokens),*];
324                if trc::Collector::is_metric(ET_ID) {
325                    trc::Collector::record_metric(ET, ET_ID, &keys);
326                }
327                trc::Event::with_keys(ET, keys).send();
328            } else if trc::Collector::is_metric(ET_ID) {
329                trc::Collector::record_metric(ET, ET_ID, &[#(#key_value_metric_tokens),*]);
330            }
331        }}
332    } else {
333        quote! {{
334            let et = trc::EventType::#event(#param);
335            let et_id = et.id();
336            if trc::Collector::has_interest(et_id) {
337                let keys = vec![#(#key_value_tokens),*];
338                if trc::Collector::is_metric(et_id) {
339                    trc::Collector::record_metric(et, et_id, &keys);
340                }
341                trc::Event::with_keys(et, keys).send();
342            } else if trc::Collector::is_metric(et_id) {
343                trc::Collector::record_metric(et, et_id, &[#(#key_value_metric_tokens),*]);
344            }
345        }}
346    };
347
348    TokenStream::from(expanded)
349}
350
351trait IsMetricKey {
352    fn is_metric_key(&self) -> bool;
353}
354
355impl IsMetricKey for Ident {
356    fn is_metric_key(&self) -> bool {
357        matches!(
358            self.to_string().as_ref(),
359            "Total"
360                | "Elapsed"
361                | "Size"
362                | "TotalSuccesses"
363                | "TotalFailures"
364                | "DmarcPass"
365                | "DmarcQuarantine"
366                | "DmarcReject"
367                | "DmarcNone"
368                | "DkimPass"
369                | "DkimFail"
370                | "DkimNone"
371                | "SpfPass"
372                | "SpfFail"
373                | "SpfNone"
374                | "Protocol"
375                | "Code"
376        )
377    }
378}