unimock_macros/unimock/
mod.rs

1use quote::{quote, quote_spanned, ToTokens};
2use syn::parse_quote;
3
4mod answer_fn;
5mod associated_future;
6mod attr;
7mod method;
8mod output;
9mod trait_info;
10mod util;
11
12use crate::doc::SynDoc;
13use crate::unimock::method::{InputsSyntax, Receiver, SelfReference, SelfToDelegator, Tupled};
14use crate::unimock::util::replace_self_ty_with_path;
15pub use attr::{Attr, MockApi};
16use trait_info::TraitInfo;
17
18use attr::{UnmockFn, UnmockFnParams};
19
20use self::answer_fn::make_answer_fn;
21use self::method::{ArgClass, MockMethod};
22use self::util::{iter_generic_type_params, InferImplTrait};
23
24pub fn generate(attr: Attr, item_trait: syn::ItemTrait) -> syn::Result<proc_macro2::TokenStream> {
25    let trait_info = trait_info::TraitInfo::analyze(&item_trait, &attr)?;
26    attr.validate(&trait_info)?;
27
28    let prefix = &attr.prefix;
29    let trait_path = &trait_info.trait_path;
30    let mirrored_impl_attributes = trait_info
31        .input_trait
32        .attrs
33        .iter()
34        .filter(|attribute| match attribute.style {
35            syn::AttrStyle::Outer => {
36                if let Some(last_segment) = attribute.path().segments.last() {
37                    last_segment.ident == "async_trait"
38                } else {
39                    false
40                }
41            }
42            syn::AttrStyle::Inner(_) => false,
43        })
44        .collect::<Vec<_>>();
45    let impl_allow_lints = impl_allow_lints();
46
47    let mock_fn_defs: Vec<Option<MockFnDef>> = trait_info
48        .methods
49        .iter()
50        .map(|method| def_mock_fn(method.as_ref(), &trait_info, &attr))
51        .collect();
52    let associated_futures = trait_info
53        .methods
54        .iter()
55        .filter_map(|method| associated_future::def_associated_future(method.as_ref()));
56    let method_impls = trait_info
57        .methods
58        .iter()
59        .enumerate()
60        .map(|(index, method)| {
61            def_method_impl(
62                index,
63                method.as_ref(),
64                &trait_info,
65                &attr,
66                MethodImplKind::Mock,
67            )
68        });
69
70    let where_clause = &trait_info.input_trait.generics.where_clause;
71    let mock_fn_struct_items = mock_fn_defs
72        .iter()
73        .filter_map(Option::as_ref)
74        .map(|def| &def.mock_fn_struct_item);
75    let mock_fn_impl_details = mock_fn_defs
76        .iter()
77        .filter_map(Option::as_ref)
78        .map(|def| &def.impl_details);
79    let generic_params = util::Generics::trait_params(&trait_info, None);
80    let generic_args = util::Generics::trait_args(
81        &trait_info.input_trait.generics,
82        None,
83        InferImplTrait(false),
84    );
85
86    let attr_associated_types = trait_info
87        .input_trait
88        .items
89        .iter()
90        .filter_map(|item| match item {
91            syn::TraitItem::Type(trait_item_type) => {
92                let ident = &trait_item_type.ident;
93                let ident_string = ident.to_string();
94                attr.associated_types
95                    .get(&ident_string)
96                    .map(|trait_item_type| {
97                        quote! {
98                            #trait_item_type
99                        }
100                    })
101            }
102            _ => None,
103        })
104        .collect::<Vec<_>>();
105
106    let attr_associated_consts = trait_info
107        .input_trait
108        .items
109        .iter()
110        .filter_map(|item| match item {
111            syn::TraitItem::Const(trait_item_const) => {
112                let ident = &trait_item_const.ident;
113                let ident_string = ident.to_string();
114                attr.associated_consts
115                    .get(&ident_string)
116                    .map(|trait_item_const| {
117                        quote! { #trait_item_const }
118                    })
119            }
120            _ => None,
121        })
122        .collect::<Vec<_>>();
123
124    let (opt_mock_interface_public, opt_mock_interface_private, impl_doc) = match &attr.mock_api {
125        MockApi::Hidden => (
126            None,
127            Some(quote! {
128                #(#mock_fn_struct_items)*
129            }),
130            None,
131        ),
132        MockApi::MockMod(module_ident) => {
133            let path_string = path_to_string(trait_path);
134            let mod_doc_string = format!("Unimock mock API for [{path_string}].");
135            let mod_doc_lit_str = syn::LitStr::new(&mod_doc_string, proc_macro2::Span::call_site());
136
137            let impl_doc_string =
138                format!("Mocked implementation. Mock API is available at [{module_ident}].");
139            let impl_doc_lit_str =
140                syn::LitStr::new(&impl_doc_string, proc_macro2::Span::call_site());
141
142            let vis = &trait_info.input_trait.vis;
143            (
144                Some(quote! {
145                    #[doc = #mod_doc_lit_str]
146                    #[allow(non_snake_case)]
147                    #vis mod #module_ident {
148                        #(#mock_fn_struct_items)*
149                    }
150                }),
151                None,
152                Some(quote! {
153                    #[doc = #impl_doc_lit_str]
154                }),
155            )
156        }
157        MockApi::Flattened(_) => (
158            Some(quote! {
159                #(#mock_fn_struct_items)*
160            }),
161            None,
162            None,
163        ),
164    };
165
166    let default_impl_delegator = if trait_info.has_default_impls {
167        let non_default_methods = trait_info
168            .methods
169            .iter()
170            .enumerate()
171            .filter_map(|(index, opt)| opt.as_ref().map(|method| (index, method)))
172            .filter(|(_, method)| method.method.default.is_none())
173            .map(|(index, method)| {
174                def_method_impl(
175                    index,
176                    Some(method),
177                    &trait_info,
178                    &attr,
179                    MethodImplKind::Delegate0,
180                )
181            });
182
183        Some(quote! {
184            #(#mirrored_impl_attributes)*
185            #impl_allow_lints
186            impl #generic_params #trait_path #generic_args for #prefix::private::DefaultImplDelegator #where_clause {
187                #(#attr_associated_types)*
188                #(#attr_associated_consts)*
189                #(#non_default_methods)*
190            }
191        })
192    } else {
193        None
194    };
195
196    let output_trait = trait_info.output_trait;
197
198    Ok(quote! {
199        #output_trait
200        #opt_mock_interface_public
201
202        // private part:
203        const _: () = {
204            #opt_mock_interface_private
205            #(#mock_fn_impl_details)*
206
207            #impl_doc
208            #(#mirrored_impl_attributes)*
209            #impl_allow_lints
210            impl #generic_params #trait_path #generic_args for #prefix::Unimock #where_clause {
211                #(#attr_associated_types)*
212                #(#attr_associated_consts)*
213                #(#associated_futures)*
214                #(#method_impls)*
215            }
216
217            #default_impl_delegator
218        };
219    })
220}
221
222struct MockFnDef {
223    mock_fn_struct_item: proc_macro2::TokenStream,
224    impl_details: proc_macro2::TokenStream,
225}
226
227fn def_mock_fn(
228    method: Option<&method::MockMethod>,
229    trait_info: &TraitInfo,
230    attr: &Attr,
231) -> Option<MockFnDef> {
232    let method = method?;
233    let prefix = &attr.prefix;
234    let span = method.span();
235    let mirrored_attrs = method.mirrored_attrs();
236    let impl_allow_lints = impl_allow_lints();
237    let mock_fn_ident = &method.mock_fn_ident;
238    let mock_fn_path = method.mock_fn_path(attr);
239    let trait_ident_lit = &trait_info.ident_lit;
240    let method_ident_lit = &method.ident_lit;
241
242    let mock_visibility = match &attr.mock_api {
243        MockApi::MockMod(_) => {
244            syn::Visibility::Public(syn::token::Pub(proc_macro2::Span::call_site()))
245        }
246        _ => trait_info.input_trait.vis.clone(),
247    };
248
249    let input_lifetime = &attr.input_lifetime;
250    let input_types_tuple = InputTypesTuple::new(method, trait_info, attr);
251
252    let generic_params = util::Generics::fn_params(trait_info, Some(method));
253    let generic_args = util::Generics::fn_args(
254        &trait_info.input_trait.generics,
255        Some(method),
256        InferImplTrait(false),
257    );
258    let where_clause = &trait_info.input_trait.generics.where_clause;
259
260    let doc_attrs = if matches!(attr.mock_api, attr::MockApi::Hidden) {
261        vec![]
262    } else {
263        method.mockfn_doc_attrs(&trait_info.trait_path)
264    };
265
266    let output_kind_assoc_type = method
267        .output_structure
268        .output_kind_assoc_type(prefix, trait_info, attr);
269
270    let answer_fn_assoc_type = make_answer_fn(method, trait_info, attr);
271
272    let debug_inputs_fn = method.generate_debug_inputs_fn(attr);
273
274    let gen_mock_fn_struct_item = |non_generic_ident: &syn::Ident| {
275        quote! {
276            #[allow(non_camel_case_types)]
277            #(#doc_attrs)*
278            #mock_visibility struct #non_generic_ident;
279        }
280    };
281
282    let info_set_default_impl = if method.has_default_impl {
283        Some(quote! { .default_impl() })
284    } else {
285        None
286    };
287
288    let impl_block = quote_spanned! { span=>
289        #(#mirrored_attrs)*
290        #impl_allow_lints
291        impl #generic_params #prefix::MockFn for #mock_fn_path #generic_args #where_clause {
292            type Inputs<#input_lifetime> = #input_types_tuple;
293            type OutputKind = #output_kind_assoc_type;
294            type AnswerFn = #answer_fn_assoc_type;
295
296            fn info() -> #prefix::MockFnInfo {
297                #prefix::MockFnInfo::new::<Self>()
298                    .path(&[#trait_ident_lit, #method_ident_lit])
299                    #info_set_default_impl
300            }
301
302            #debug_inputs_fn
303        }
304    };
305
306    let mock_fn_def = if let Some(non_generic_ident) = &method.non_generic_mock_entry_ident {
307        // the trait is generic
308        let phantoms_tuple = util::MockFnPhantomsTuple { trait_info, method };
309        let untyped_phantoms =
310            iter_generic_type_params(trait_info, method).map(util::PhantomDataConstructor);
311        let module_scope = match &attr.mock_api {
312            MockApi::MockMod(ident) => Some(quote_spanned! { span=> #ident:: }),
313            _ => None,
314        };
315        let answer_fn_assoc_type = make_answer_fn(method, trait_info, attr);
316
317        MockFnDef {
318            mock_fn_struct_item: gen_mock_fn_struct_item(non_generic_ident),
319            impl_details: quote! {
320                #impl_allow_lints
321                impl #module_scope #non_generic_ident {
322                    #[doc = "Provide the generic parameters to the mocked method"]
323                    pub fn with_types #generic_params(
324                        self
325                    ) -> impl for<#input_lifetime> #prefix::MockFn<
326                        Inputs<#input_lifetime> = #input_types_tuple,
327                        OutputKind = #output_kind_assoc_type,
328                        AnswerFn = #answer_fn_assoc_type,
329                    >
330                        #where_clause
331                    {
332                        #mock_fn_ident(#(#untyped_phantoms),*)
333                    }
334                }
335
336                #[allow(non_camel_case_types)]
337                struct #mock_fn_ident #generic_args #phantoms_tuple;
338
339                #impl_block
340            },
341        }
342    } else {
343        MockFnDef {
344            mock_fn_struct_item: gen_mock_fn_struct_item(mock_fn_ident),
345            impl_details: impl_block,
346        }
347    };
348
349    Some(mock_fn_def)
350}
351
352enum MethodImplKind {
353    Mock,
354    Delegate0,
355}
356
357fn def_method_impl(
358    index: usize,
359    method: Option<&method::MockMethod>,
360    trait_info: &TraitInfo,
361    attr: &Attr,
362    kind: MethodImplKind,
363) -> proc_macro2::TokenStream {
364    let method = match method {
365        Some(method) => method,
366        None => return quote! {},
367    };
368
369    let span = method.span();
370    let prefix = prefix_with_span(&attr.prefix, span);
371    let method_sig = &method.method.sig;
372    let mirrored_attrs = method.mirrored_attrs();
373    let mock_fn_path = method.mock_fn_path(attr);
374
375    let receiver = method.receiver();
376    let self_ref = SelfReference(&receiver);
377    let self_to_delegator = SelfToDelegator(&receiver);
378    let eval_generic_args = util::Generics::fn_args(
379        &trait_info.input_trait.generics,
380        Some(method),
381        InferImplTrait(true),
382    );
383
384    let must_async_wrap = matches!(
385        method.output_structure.wrapping,
386        output::OutputWrapping::RpitFuture | output::OutputWrapping::AssociatedFuture(_)
387    );
388
389    let trait_path = &trait_info.trait_path;
390    let method_ident = &method_sig.ident;
391    let opt_dot_await = method.opt_dot_await();
392    let track_caller = if method.method.sig.asyncness.is_none() {
393        Some(quote! {
394            #[track_caller]
395        })
396    } else {
397        None
398    };
399
400    let allow_lints: proc_macro2::TokenStream = {
401        let mut lints: Vec<proc_macro2::TokenStream> = vec![quote! { unused }];
402
403        if matches!(
404            method.output_structure.wrapping,
405            output::OutputWrapping::RpitFuture
406        ) {
407            lints.push(quote! { manual_async_fn });
408        }
409
410        quote! { #[allow(#(#lints),*)] }
411    };
412
413    let body = match kind {
414        MethodImplKind::Mock => {
415            let unmock_arm = attr.get_unmock_fn(index).map(
416                |UnmockFn {
417                     path: unmock_path,
418                     params: unmock_params,
419                 }| {
420                    let fn_params =
421                        method.inputs_destructuring(InputsSyntax::FnParams, Tupled(false), attr);
422
423                    let unmock_expr = match unmock_params {
424                        None => quote! {
425                            #unmock_path(self, #fn_params) #opt_dot_await
426                        },
427                        Some(UnmockFnParams { params }) => quote! {
428                            #unmock_path(#params) #opt_dot_await
429                        },
430                    };
431
432                    let eval_pattern = method.inputs_destructuring(
433                        InputsSyntax::EvalPatternMutAsWildcard,
434                        Tupled(true),
435                        attr,
436                    );
437
438                    quote! {
439                        #prefix::private::Eval::Continue(#prefix::private::Continuation::Unmock, #eval_pattern) => #unmock_expr,
440                    }
441                },
442            );
443
444            let inputs_eval_params =
445                method.inputs_destructuring(InputsSyntax::EvalParams, Tupled(true), attr);
446            let fn_params =
447                method.inputs_destructuring(InputsSyntax::FnParams, Tupled(false), attr);
448
449            let default_delegator_call = if method.method.default.is_some() {
450                let delegator_path = quote! {
451                    #prefix::private::DefaultImplDelegator
452                };
453                let delegator_constructor = match method_sig.receiver() {
454                    Some(syn::Receiver {
455                        reference: None,
456                        ty,
457                        ..
458                    }) => {
459                        quote! {
460                            <#ty as #prefix::private::DelegateToDefaultImpl>::to_delegator(#self_to_delegator)
461                        }
462                    }
463                    Some(syn::Receiver {
464                        reference: Some(_),
465                        mutability: None,
466                        ..
467                    }) => quote! {
468                        #prefix::private::as_ref(self)
469                    },
470                    Some(syn::Receiver {
471                        reference: Some(_),
472                        mutability: Some(_),
473                        ..
474                    }) => quote! {
475                        #prefix::private::as_mut(__self)
476                    },
477                    _ => todo!("unhandled DefaultImplDelegator constructor"),
478                };
479
480                let generic_args = util::Generics::trait_args(
481                    &trait_info.input_trait.generics,
482                    None,
483                    InferImplTrait(false),
484                );
485
486                Some(quote! {
487                    <#delegator_path as #trait_path #generic_args>::#method_ident(
488                        #delegator_constructor,
489                        #fn_params
490                    )
491                        #opt_dot_await
492                })
493            } else {
494                None
495            };
496
497            match &receiver {
498                Receiver::MutRef { .. } | Receiver::Pin { .. } => {
499                    let eval_pattern_no_mut = method.inputs_destructuring(
500                        InputsSyntax::EvalPatternMutAsWildcard,
501                        Tupled(true),
502                        attr,
503                    );
504                    let eval_pattern_all = method.inputs_destructuring(
505                        InputsSyntax::EvalPatternAll,
506                        Tupled(true),
507                        attr,
508                    );
509                    let fn_params_tupled =
510                        method.inputs_destructuring(InputsSyntax::FnParams, Tupled(true), attr);
511
512                    let polonius_return_type: syn::Type = match method.method.sig.output.clone() {
513                        syn::ReturnType::Default => syn::parse_quote!(()),
514                        syn::ReturnType::Type(_arrow, ty) => {
515                            util::substitute_lifetimes(*ty, Some(&syn::parse_quote!('polonius)))
516                        }
517                    };
518
519                    let default_impl_input_eval_arm = if default_delegator_call.is_some() {
520                        quote! {
521                            #prefix::private::Continuation::CallDefaultImpl => {
522                                #default_delegator_call
523                            }
524                        }
525                    } else {
526                        quote!()
527                    };
528
529                    quote! {
530                        let (__cont, #eval_pattern_all) = #prefix::polonius::_polonius!(|#self_ref| -> #polonius_return_type {
531                            match #prefix::private::eval::<#mock_fn_path #eval_generic_args>(#self_ref, #inputs_eval_params) {
532                                #prefix::private::Eval::Return(output) => #prefix::polonius::_return!(output),
533                                #prefix::private::Eval::Continue(__cont, #eval_pattern_no_mut) => #prefix::polonius::_exit!((__cont, #fn_params_tupled)),
534                            }
535                        });
536                        match __cont {
537                            #prefix::private::Continuation::Answer(__answer_fn) => {
538                                __answer_fn(__self, #fn_params)
539                            }
540                            #default_impl_input_eval_arm
541                            cont => cont.report(__self)
542                        }
543                    }
544                }
545                _ => {
546                    let eval_pattern_no_mut = method.inputs_destructuring(
547                        InputsSyntax::EvalPatternMutAsWildcard,
548                        Tupled(true),
549                        attr,
550                    );
551
552                    let default_impl_delegate_arm = if method.method.default.is_some() {
553                        Some(quote! {
554                            #prefix::private::Eval::Continue(#prefix::private::Continuation::CallDefaultImpl, #eval_pattern_no_mut) => {
555                                #default_delegator_call
556                            },
557                        })
558                    } else {
559                        None
560                    };
561
562                    quote_spanned! { span=>
563                        match #prefix::private::eval::<#mock_fn_path #eval_generic_args>(#self_ref, #inputs_eval_params) {
564                            #prefix::private::Eval::Return(output) => output,
565                            #prefix::private::Eval::Continue(#prefix::private::Continuation::Answer(__answer_fn), #eval_pattern_no_mut) => {
566                                __answer_fn(self, #fn_params)
567                            }
568                            #unmock_arm
569                            #default_impl_delegate_arm
570                            #prefix::private::Eval::Continue(cont, _) => cont.report(#self_ref),
571                        }
572                    }
573                }
574            }
575        }
576        MethodImplKind::Delegate0 => {
577            let inputs_destructuring =
578                method.inputs_destructuring(InputsSyntax::FnParams, Tupled(false), attr);
579            let unimock_accessor = match method_sig.receiver() {
580                Some(syn::Receiver {
581                    reference: None,
582                    ty,
583                    ..
584                }) => {
585                    let unimock_type = replace_self_ty_with_path(
586                        *ty.clone(),
587                        &parse_quote! {
588                            #prefix::Unimock
589                        },
590                    );
591
592                    quote! {
593                        {
594                            <#unimock_type as #prefix::private::DelegateToDefaultImpl>::from_delegator(self)
595                        }
596                    }
597                }
598                Some(syn::Receiver {
599                    reference: Some(_),
600                    mutability: None,
601                    ..
602                }) => {
603                    quote! { #prefix::private::as_ref(self) }
604                }
605                Some(syn::Receiver {
606                    reference: Some(_),
607                    mutability: Some(_),
608                    ..
609                }) => {
610                    quote! { #prefix::private::as_mut(self) }
611                }
612                _ => panic!("BUG: Incompatible receiver for default delegator"),
613            };
614            let generic_args = util::Generics::trait_args(
615                &trait_info.input_trait.generics,
616                None,
617                InferImplTrait(false),
618            );
619            quote! {
620                <#prefix::Unimock as #trait_path #generic_args>::#method_ident(
621                    #unimock_accessor,
622                    #inputs_destructuring
623                )
624                    #opt_dot_await
625            }
626        }
627    };
628
629    let body = match (kind, &receiver) {
630        (MethodImplKind::Mock, Receiver::MutRef { surrogate_self }) => {
631            quote! {
632                let mut #surrogate_self = self;
633                #body
634            }
635        }
636        (MethodImplKind::Mock, Receiver::Pin { surrogate_self }) => {
637            quote! {
638                let mut #surrogate_self = ::core::pin::Pin::into_inner(self);
639                #body
640            }
641        }
642        _ => body,
643    };
644
645    let body = if must_async_wrap {
646        quote_spanned! { span=>
647            async move { #body }
648        }
649    } else {
650        body
651    };
652
653    quote_spanned! { span=>
654        #(#mirrored_attrs)*
655        #track_caller
656        #allow_lints
657        #method_sig {
658            #body
659        }
660    }
661}
662
663fn prefix_with_span(prefix: &syn::Path, span: proc_macro2::Span) -> syn::Path {
664    let mut prefix = prefix.clone();
665    for segment in &mut prefix.segments {
666        segment.ident.set_span(span);
667    }
668
669    prefix
670}
671
672struct InputTypesTuple(Vec<syn::Type>);
673
674impl InputTypesTuple {
675    fn new(mock_method: &MockMethod, trait_info: &TraitInfo, attr: &Attr) -> Self {
676        let prefix = &attr.prefix;
677        let input_lifetime = &attr.input_lifetime;
678        Self(
679            mock_method
680                .adapted_sig
681                .inputs
682                .iter()
683                .enumerate()
684                .filter_map(
685                    |(index, input)| match mock_method.classify_arg(input, index) {
686                        ArgClass::Receiver => None,
687                        ArgClass::MutImpossible(..) => Some(syn::parse_quote!(
688                            #prefix::Impossible
689                        )),
690                        ArgClass::Other(_, ty) => Some(ty.clone()),
691                        ArgClass::Unprocessable(_) => None,
692                    },
693                )
694                .map(|mut ty| {
695                    ty = util::substitute_lifetimes(ty, Some(input_lifetime));
696                    ty = util::self_type_to_unimock(ty, trait_info, attr);
697                    ty
698                })
699                .collect::<Vec<_>>(),
700        )
701    }
702}
703
704impl ToTokens for InputTypesTuple {
705    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
706        if self.0.len() == 1 {
707            tokens.extend(self.0.first().to_token_stream());
708        } else {
709            let types = &self.0;
710            tokens.extend(quote! {
711                (#(#types),*)
712            });
713        }
714    }
715}
716
717fn path_to_string(path: &syn::Path) -> String {
718    let mut out = String::new();
719    for pair in path.segments.pairs() {
720        out.push_str(&pair.value().ident.to_string());
721        if let Some(sep) = pair.punct() {
722            out.push_str(&sep.doc_string());
723        }
724    }
725    out
726}
727
728fn impl_allow_lints() -> proc_macro2::TokenStream {
729    quote! {
730        #[allow(clippy::multiple_bound_locations)]
731    }
732}