Skip to main content

tauri_plugin_auditaur_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::{format_ident, quote};
5use syn::{
6    parse::Parse, parse::ParseStream, punctuated::Punctuated, FnArg, ItemFn, Meta, Pat, Token,
7};
8
9struct InstrumentArgs {
10    metas: Punctuated<Meta, Token![,]>,
11}
12
13impl Parse for InstrumentArgs {
14    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
15        Ok(Self {
16            metas: Punctuated::parse_terminated(input)?,
17        })
18    }
19}
20
21/// Instruments a Tauri command span with Auditaur's IPC trace context.
22///
23/// The command remains an explicit opt-in Tauri command: keep `#[tauri::command]`
24/// and keep an `auditaur_trace_context: Option<IpcTraceContext>` parameter.
25#[proc_macro_attribute]
26pub fn instrument_ipc(attr: TokenStream, item: TokenStream) -> TokenStream {
27    expand_instrument_ipc(attr.into(), item.into()).into()
28}
29
30/// Defines and instruments a Tauri command span with Auditaur's IPC trace context.
31///
32/// This wraps `#[tauri::command]`, injects Auditaur IPC carrier arguments,
33/// and reads the frontend `traceparent` sent by `@auditaur/api`.
34#[proc_macro_attribute]
35pub fn auditaur_command(attr: TokenStream, item: TokenStream) -> TokenStream {
36    expand_auditaur_command(attr.into(), item.into()).into()
37}
38
39fn expand_instrument_ipc(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
40    let args = match syn::parse2::<InstrumentArgs>(attr) {
41        Ok(args) => args,
42        Err(error) => return error.to_compile_error(),
43    };
44
45    let auditaur_crate = dependency_crate_path("tauri-plugin-auditaur", "tauri_plugin_auditaur");
46    let tracing_crate = dependency_crate_path("tracing", "tracing");
47    let traceparent_field = quote! {
48        traceparent = #auditaur_crate::ipc_traceparent(auditaur_trace_context.as_ref())
49    };
50    let instrument_args = match merge_instrument_args(
51        args,
52        quote!(auditaur_trace_context),
53        quote!(#traceparent_field),
54    ) {
55        Ok(args) => args,
56        Err(error) => return error.to_compile_error(),
57    };
58
59    quote! {
60        #[#tracing_crate::instrument(#(#instrument_args),*)]
61        #item
62    }
63}
64
65fn expand_auditaur_command(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
66    let args = match syn::parse2::<InstrumentArgs>(attr) {
67        Ok(args) => args,
68        Err(error) => return error.to_compile_error(),
69    };
70    let mut function = match syn::parse2::<ItemFn>(item) {
71        Ok(function) => function,
72        Err(error) => return error.to_compile_error(),
73    };
74
75    let auditaur_crate = dependency_crate_path("tauri-plugin-auditaur", "tauri_plugin_auditaur");
76    let tauri_crate = dependency_crate_path("tauri", "tauri");
77    let tracing_crate = dependency_crate_path("tracing", "tracing");
78    if function_has_argument(&function, "auditaur_trace_context") {
79        return syn::Error::new_spanned(
80            &function.sig.ident,
81            "`#[tauri_plugin_auditaur::auditaur_command]` reserves the `auditaur_trace_context` argument; remove it or use `#[tauri::command]` with `#[tauri_plugin_auditaur::instrument_ipc]` instead",
82        )
83        .to_compile_error();
84    }
85    let request_ident = unique_argument_ident(&function, "auditaur_request");
86    let request_arg: FnArg = syn::parse_quote! {
87        #request_ident: #tauri_crate::ipc::Request<'_>
88    };
89    function.sig.inputs.push(request_arg);
90    let context_arg: FnArg = syn::parse_quote! {
91        auditaur_trace_context: Option<#auditaur_crate::IpcTraceContext>
92    };
93    function.sig.inputs.push(context_arg);
94
95    let traceparent_field = quote! {
96        traceparent = #auditaur_crate::ipc_traceparent_from_request_or_context(
97            &#request_ident,
98            auditaur_trace_context.as_ref()
99        )
100    };
101    let injected_skip_args = quote!(#request_ident, auditaur_trace_context);
102    let instrument_args = match merge_instrument_args(
103        args,
104        quote!(#injected_skip_args),
105        quote!(#traceparent_field),
106    ) {
107        Ok(args) => args,
108        Err(error) => return error.to_compile_error(),
109    };
110
111    quote! {
112        #[#tauri_crate::command]
113        #[#tracing_crate::instrument(#(#instrument_args),*)]
114        #function
115    }
116}
117
118fn merge_instrument_args(
119    args: InstrumentArgs,
120    injected_skip_arg: TokenStream2,
121    traceparent_field: TokenStream2,
122) -> syn::Result<Vec<TokenStream2>> {
123    let mut instrument_args = Vec::new();
124    let mut skip_args = None;
125    let mut fields_args = None;
126    let mut skip_all = false;
127
128    for meta in args.metas {
129        if meta.path().is_ident("skip") {
130            let Meta::List(list) = meta else {
131                return Err(syn::Error::new_spanned(meta, "expected skip(...)"));
132            };
133            let existing = list.tokens;
134            skip_args = Some(if existing.is_empty() {
135                quote!(#injected_skip_arg)
136            } else if token_stream_mentions(&existing, &injected_skip_arg.to_string()) {
137                quote!(#existing)
138            } else {
139                quote!(#existing, #injected_skip_arg)
140            });
141        } else if meta.path().is_ident("skip_all") {
142            skip_all = true;
143            instrument_args.push(quote!(#meta));
144        } else if meta.path().is_ident("fields") {
145            let Meta::List(list) = meta else {
146                return Err(syn::Error::new_spanned(meta, "expected fields(...)"));
147            };
148            let existing = list.tokens;
149            fields_args = Some(if existing.is_empty() {
150                traceparent_field.clone()
151            } else if token_stream_mentions(&existing, "traceparent") {
152                quote!(#existing)
153            } else {
154                quote!(#existing, #traceparent_field)
155            });
156        } else {
157            instrument_args.push(quote!(#meta));
158        }
159    }
160
161    let skip_args = skip_args.unwrap_or_else(|| quote!(#injected_skip_arg));
162    let fields_args = fields_args.unwrap_or_else(|| quote!(#traceparent_field));
163    if !skip_all {
164        instrument_args.push(quote!(skip(#skip_args)));
165    }
166    instrument_args.push(quote!(fields(#fields_args)));
167
168    Ok(instrument_args)
169}
170
171fn unique_argument_ident(function: &ItemFn, base: &str) -> syn::Ident {
172    let mut candidate = format_ident!("{base}");
173    let mut suffix = 2;
174    while function_has_argument(function, &candidate.to_string()) {
175        candidate = format_ident!("{base}_{suffix}");
176        suffix += 1;
177    }
178    candidate
179}
180
181fn function_has_argument(function: &ItemFn, name: &str) -> bool {
182    function.sig.inputs.iter().any(|arg| {
183        matches!(
184            arg,
185            FnArg::Typed(pat_type)
186                if matches!(pat_type.pat.as_ref(), Pat::Ident(ident) if ident.ident == name)
187        )
188    })
189}
190
191fn token_stream_mentions(tokens: &TokenStream2, needle: &str) -> bool {
192    tokens
193        .to_string()
194        .split_whitespace()
195        .any(|part| part == needle)
196}
197
198fn dependency_crate_path(package_name: &str, fallback_name: &str) -> TokenStream2 {
199    match crate_name(package_name) {
200        Ok(FoundCrate::Itself) => quote!(crate),
201        Ok(FoundCrate::Name(name)) => {
202            let ident = format_ident!("{}", name);
203            quote!(::#ident)
204        }
205        Err(_) => {
206            let ident = format_ident!("{}", fallback_name);
207            quote!(::#ident)
208        }
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::{expand_auditaur_command, expand_instrument_ipc};
215    use quote::quote;
216
217    #[test]
218    fn adds_traceparent_field_and_context_skip() {
219        let expanded = expand_instrument_ipc(
220            quote!(),
221            quote! {
222                fn load_user(id: String, auditaur_trace_context: Option<IpcTraceContext>) {}
223            },
224        )
225        .to_string();
226
227        assert!(expanded.contains("skip (auditaur_trace_context)"));
228        assert!(expanded.contains("fields (traceparent ="));
229        assert!(expanded.contains("ipc_traceparent"));
230    }
231
232    #[test]
233    fn preserves_tracing_options_and_merges_skip_fields() {
234        let expanded = expand_instrument_ipc(
235            quote!(err, skip(app), fields(command = "emit_backend_event")),
236            quote! {
237                fn emit_backend_event(
238                    app: tauri::AppHandle,
239                    auditaur_trace_context: Option<IpcTraceContext>,
240                ) {}
241            },
242        )
243        .to_string();
244
245        assert!(expanded.contains("err"));
246        assert!(expanded.contains("skip (app , auditaur_trace_context)"));
247        assert!(expanded.contains("command = \"emit_backend_event\""));
248        assert!(expanded.contains("traceparent ="));
249    }
250
251    #[test]
252    fn respects_skip_all_without_adding_skip() {
253        let expanded = expand_instrument_ipc(
254            quote!(skip_all),
255            quote! {
256                fn load_user(id: String, auditaur_trace_context: Option<IpcTraceContext>) {}
257            },
258        )
259        .to_string();
260
261        assert!(expanded.contains("skip_all"));
262        assert!(!expanded.contains("skip (auditaur_trace_context)"));
263        assert!(expanded.contains("fields (traceparent ="));
264    }
265
266    #[test]
267    fn avoids_duplicate_injected_arguments() {
268        let expanded = expand_instrument_ipc(
269            quote!(skip(auditaur_trace_context), fields(traceparent = "custom")),
270            quote! {
271                fn load_user(id: String, auditaur_trace_context: Option<IpcTraceContext>) {}
272            },
273        )
274        .to_string();
275
276        assert!(expanded.contains("skip (auditaur_trace_context)"));
277        assert!(!expanded.contains("auditaur_trace_context , auditaur_trace_context"));
278        assert!(expanded.contains("traceparent"));
279        assert!(expanded.contains("custom"));
280        assert!(!expanded.contains("ipc_traceparent"));
281    }
282
283    #[test]
284    fn auditaur_command_wraps_tauri_command_and_injects_request_traceparent() {
285        let expanded = expand_auditaur_command(
286            quote!(err, skip(app), fields(command = "load_user")),
287            quote! {
288                async fn load_user(app: tauri::AppHandle, id: String) -> Result<String, String> {
289                    Ok(id)
290                }
291            },
292        )
293        .to_string();
294
295        assert!(expanded.contains("tauri :: command"));
296        assert!(expanded.contains("tracing :: instrument"));
297        assert!(expanded.contains("auditaur_request :"));
298        assert!(expanded.contains("tauri :: ipc :: Request < '_ >"));
299        assert!(expanded.contains("auditaur_trace_context : Option"));
300        assert!(expanded.contains("skip (app , auditaur_request , auditaur_trace_context)"));
301        assert!(expanded.contains("ipc_traceparent_from_request_or_context"));
302        assert!(expanded.contains("& auditaur_request"));
303        assert!(expanded.contains("auditaur_trace_context . as_ref"));
304        assert!(expanded.contains("command = \"load_user\""));
305        assert!(expanded.contains("err"));
306    }
307
308    #[test]
309    fn auditaur_command_avoids_request_argument_name_collision() {
310        let expanded = expand_auditaur_command(
311            quote!(),
312            quote! {
313                fn load_user(auditaur_request: String) -> String {
314                    auditaur_request
315                }
316            },
317        )
318        .to_string();
319
320        assert!(expanded.contains("auditaur_request_2 :"));
321        assert!(expanded.contains("skip (auditaur_request_2 , auditaur_trace_context)"));
322        assert!(expanded.contains("ipc_traceparent_from_request_or_context"));
323        assert!(expanded.contains("& auditaur_request_2"));
324    }
325
326    #[test]
327    fn auditaur_command_rejects_reserved_trace_context_argument() {
328        let expanded = expand_auditaur_command(
329            quote!(),
330            quote! {
331                fn load_user(auditaur_trace_context: Option<IpcTraceContext>) {}
332            },
333        )
334        .to_string();
335
336        assert!(expanded.contains("reserves the `auditaur_trace_context` argument"));
337    }
338}