tauri_plugin_auditaur_macros/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::{format_ident, quote};
5use syn::{
6 parse::Parse, parse::ParseStream, punctuated::Punctuated, FnArg, ItemFn, Meta, Pat, Token,
7};
8
9struct InstrumentArgs {
10 metas: Punctuated<Meta, Token![,]>,
11}
12
13impl Parse for InstrumentArgs {
14 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
15 Ok(Self {
16 metas: Punctuated::parse_terminated(input)?,
17 })
18 }
19}
20
21#[proc_macro_attribute]
26pub fn instrument_ipc(attr: TokenStream, item: TokenStream) -> TokenStream {
27 expand_instrument_ipc(attr.into(), item.into()).into()
28}
29
30#[proc_macro_attribute]
35pub fn auditaur_command(attr: TokenStream, item: TokenStream) -> TokenStream {
36 expand_auditaur_command(attr.into(), item.into()).into()
37}
38
39fn expand_instrument_ipc(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
40 let args = match syn::parse2::<InstrumentArgs>(attr) {
41 Ok(args) => args,
42 Err(error) => return error.to_compile_error(),
43 };
44
45 let auditaur_crate = dependency_crate_path("tauri-plugin-auditaur", "tauri_plugin_auditaur");
46 let tracing_crate = dependency_crate_path("tracing", "tracing");
47 let traceparent_field = quote! {
48 traceparent = #auditaur_crate::ipc_traceparent(auditaur_trace_context.as_ref())
49 };
50 let instrument_args = match merge_instrument_args(
51 args,
52 quote!(auditaur_trace_context),
53 quote!(#traceparent_field),
54 ) {
55 Ok(args) => args,
56 Err(error) => return error.to_compile_error(),
57 };
58
59 quote! {
60 #[#tracing_crate::instrument(#(#instrument_args),*)]
61 #item
62 }
63}
64
65fn expand_auditaur_command(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
66 let args = match syn::parse2::<InstrumentArgs>(attr) {
67 Ok(args) => args,
68 Err(error) => return error.to_compile_error(),
69 };
70 let mut function = match syn::parse2::<ItemFn>(item) {
71 Ok(function) => function,
72 Err(error) => return error.to_compile_error(),
73 };
74
75 let auditaur_crate = dependency_crate_path("tauri-plugin-auditaur", "tauri_plugin_auditaur");
76 let tauri_crate = dependency_crate_path("tauri", "tauri");
77 let tracing_crate = dependency_crate_path("tracing", "tracing");
78 if function_has_argument(&function, "auditaur_trace_context") {
79 return syn::Error::new_spanned(
80 &function.sig.ident,
81 "`#[tauri_plugin_auditaur::auditaur_command]` reserves the `auditaur_trace_context` argument; remove it or use `#[tauri::command]` with `#[tauri_plugin_auditaur::instrument_ipc]` instead",
82 )
83 .to_compile_error();
84 }
85 let request_ident = unique_argument_ident(&function, "auditaur_request");
86 let request_arg: FnArg = syn::parse_quote! {
87 #request_ident: #tauri_crate::ipc::Request<'_>
88 };
89 function.sig.inputs.push(request_arg);
90 let context_arg: FnArg = syn::parse_quote! {
91 auditaur_trace_context: Option<#auditaur_crate::IpcTraceContext>
92 };
93 function.sig.inputs.push(context_arg);
94
95 let traceparent_field = quote! {
96 traceparent = #auditaur_crate::ipc_traceparent_from_request_or_context(
97 &#request_ident,
98 auditaur_trace_context.as_ref()
99 )
100 };
101 let injected_skip_args = quote!(#request_ident, auditaur_trace_context);
102 let instrument_args = match merge_instrument_args(
103 args,
104 quote!(#injected_skip_args),
105 quote!(#traceparent_field),
106 ) {
107 Ok(args) => args,
108 Err(error) => return error.to_compile_error(),
109 };
110
111 quote! {
112 #[#tauri_crate::command]
113 #[#tracing_crate::instrument(#(#instrument_args),*)]
114 #function
115 }
116}
117
118fn merge_instrument_args(
119 args: InstrumentArgs,
120 injected_skip_arg: TokenStream2,
121 traceparent_field: TokenStream2,
122) -> syn::Result<Vec<TokenStream2>> {
123 let mut instrument_args = Vec::new();
124 let mut skip_args = None;
125 let mut fields_args = None;
126 let mut skip_all = false;
127
128 for meta in args.metas {
129 if meta.path().is_ident("skip") {
130 let Meta::List(list) = meta else {
131 return Err(syn::Error::new_spanned(meta, "expected skip(...)"));
132 };
133 let existing = list.tokens;
134 skip_args = Some(if existing.is_empty() {
135 quote!(#injected_skip_arg)
136 } else if token_stream_mentions(&existing, &injected_skip_arg.to_string()) {
137 quote!(#existing)
138 } else {
139 quote!(#existing, #injected_skip_arg)
140 });
141 } else if meta.path().is_ident("skip_all") {
142 skip_all = true;
143 instrument_args.push(quote!(#meta));
144 } else if meta.path().is_ident("fields") {
145 let Meta::List(list) = meta else {
146 return Err(syn::Error::new_spanned(meta, "expected fields(...)"));
147 };
148 let existing = list.tokens;
149 fields_args = Some(if existing.is_empty() {
150 traceparent_field.clone()
151 } else if token_stream_mentions(&existing, "traceparent") {
152 quote!(#existing)
153 } else {
154 quote!(#existing, #traceparent_field)
155 });
156 } else {
157 instrument_args.push(quote!(#meta));
158 }
159 }
160
161 let skip_args = skip_args.unwrap_or_else(|| quote!(#injected_skip_arg));
162 let fields_args = fields_args.unwrap_or_else(|| quote!(#traceparent_field));
163 if !skip_all {
164 instrument_args.push(quote!(skip(#skip_args)));
165 }
166 instrument_args.push(quote!(fields(#fields_args)));
167
168 Ok(instrument_args)
169}
170
171fn unique_argument_ident(function: &ItemFn, base: &str) -> syn::Ident {
172 let mut candidate = format_ident!("{base}");
173 let mut suffix = 2;
174 while function_has_argument(function, &candidate.to_string()) {
175 candidate = format_ident!("{base}_{suffix}");
176 suffix += 1;
177 }
178 candidate
179}
180
181fn function_has_argument(function: &ItemFn, name: &str) -> bool {
182 function.sig.inputs.iter().any(|arg| {
183 matches!(
184 arg,
185 FnArg::Typed(pat_type)
186 if matches!(pat_type.pat.as_ref(), Pat::Ident(ident) if ident.ident == name)
187 )
188 })
189}
190
191fn token_stream_mentions(tokens: &TokenStream2, needle: &str) -> bool {
192 tokens
193 .to_string()
194 .split_whitespace()
195 .any(|part| part == needle)
196}
197
198fn dependency_crate_path(package_name: &str, fallback_name: &str) -> TokenStream2 {
199 match crate_name(package_name) {
200 Ok(FoundCrate::Itself) => quote!(crate),
201 Ok(FoundCrate::Name(name)) => {
202 let ident = format_ident!("{}", name);
203 quote!(::#ident)
204 }
205 Err(_) => {
206 let ident = format_ident!("{}", fallback_name);
207 quote!(::#ident)
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::{expand_auditaur_command, expand_instrument_ipc};
215 use quote::quote;
216
217 #[test]
218 fn adds_traceparent_field_and_context_skip() {
219 let expanded = expand_instrument_ipc(
220 quote!(),
221 quote! {
222 fn load_user(id: String, auditaur_trace_context: Option<IpcTraceContext>) {}
223 },
224 )
225 .to_string();
226
227 assert!(expanded.contains("skip (auditaur_trace_context)"));
228 assert!(expanded.contains("fields (traceparent ="));
229 assert!(expanded.contains("ipc_traceparent"));
230 }
231
232 #[test]
233 fn preserves_tracing_options_and_merges_skip_fields() {
234 let expanded = expand_instrument_ipc(
235 quote!(err, skip(app), fields(command = "emit_backend_event")),
236 quote! {
237 fn emit_backend_event(
238 app: tauri::AppHandle,
239 auditaur_trace_context: Option<IpcTraceContext>,
240 ) {}
241 },
242 )
243 .to_string();
244
245 assert!(expanded.contains("err"));
246 assert!(expanded.contains("skip (app , auditaur_trace_context)"));
247 assert!(expanded.contains("command = \"emit_backend_event\""));
248 assert!(expanded.contains("traceparent ="));
249 }
250
251 #[test]
252 fn respects_skip_all_without_adding_skip() {
253 let expanded = expand_instrument_ipc(
254 quote!(skip_all),
255 quote! {
256 fn load_user(id: String, auditaur_trace_context: Option<IpcTraceContext>) {}
257 },
258 )
259 .to_string();
260
261 assert!(expanded.contains("skip_all"));
262 assert!(!expanded.contains("skip (auditaur_trace_context)"));
263 assert!(expanded.contains("fields (traceparent ="));
264 }
265
266 #[test]
267 fn avoids_duplicate_injected_arguments() {
268 let expanded = expand_instrument_ipc(
269 quote!(skip(auditaur_trace_context), fields(traceparent = "custom")),
270 quote! {
271 fn load_user(id: String, auditaur_trace_context: Option<IpcTraceContext>) {}
272 },
273 )
274 .to_string();
275
276 assert!(expanded.contains("skip (auditaur_trace_context)"));
277 assert!(!expanded.contains("auditaur_trace_context , auditaur_trace_context"));
278 assert!(expanded.contains("traceparent"));
279 assert!(expanded.contains("custom"));
280 assert!(!expanded.contains("ipc_traceparent"));
281 }
282
283 #[test]
284 fn auditaur_command_wraps_tauri_command_and_injects_request_traceparent() {
285 let expanded = expand_auditaur_command(
286 quote!(err, skip(app), fields(command = "load_user")),
287 quote! {
288 async fn load_user(app: tauri::AppHandle, id: String) -> Result<String, String> {
289 Ok(id)
290 }
291 },
292 )
293 .to_string();
294
295 assert!(expanded.contains("tauri :: command"));
296 assert!(expanded.contains("tracing :: instrument"));
297 assert!(expanded.contains("auditaur_request :"));
298 assert!(expanded.contains("tauri :: ipc :: Request < '_ >"));
299 assert!(expanded.contains("auditaur_trace_context : Option"));
300 assert!(expanded.contains("skip (app , auditaur_request , auditaur_trace_context)"));
301 assert!(expanded.contains("ipc_traceparent_from_request_or_context"));
302 assert!(expanded.contains("& auditaur_request"));
303 assert!(expanded.contains("auditaur_trace_context . as_ref"));
304 assert!(expanded.contains("command = \"load_user\""));
305 assert!(expanded.contains("err"));
306 }
307
308 #[test]
309 fn auditaur_command_avoids_request_argument_name_collision() {
310 let expanded = expand_auditaur_command(
311 quote!(),
312 quote! {
313 fn load_user(auditaur_request: String) -> String {
314 auditaur_request
315 }
316 },
317 )
318 .to_string();
319
320 assert!(expanded.contains("auditaur_request_2 :"));
321 assert!(expanded.contains("skip (auditaur_request_2 , auditaur_trace_context)"));
322 assert!(expanded.contains("ipc_traceparent_from_request_or_context"));
323 assert!(expanded.contains("& auditaur_request_2"));
324 }
325
326 #[test]
327 fn auditaur_command_rejects_reserved_trace_context_argument() {
328 let expanded = expand_auditaur_command(
329 quote!(),
330 quote! {
331 fn load_user(auditaur_trace_context: Option<IpcTraceContext>) {}
332 },
333 )
334 .to_string();
335
336 assert!(expanded.contains("reserves the `auditaur_trace_context` argument"));
337 }
338}