1use quote::{quote, quote_spanned, ToTokens};
2use syn::parse_quote;
3
4mod answer_fn;
5mod associated_future;
6mod attr;
7mod method;
8mod output;
9mod trait_info;
10mod util;
11
12use crate::doc::SynDoc;
13use crate::unimock::method::{InputsSyntax, Receiver, SelfReference, SelfToDelegator, Tupled};
14use crate::unimock::util::replace_self_ty_with_path;
15pub use attr::{Attr, MockApi};
16use trait_info::TraitInfo;
17
18use attr::{UnmockFn, UnmockFnParams};
19
20use self::answer_fn::make_answer_fn;
21use self::method::{ArgClass, MockMethod};
22use self::util::{iter_generic_type_params, InferImplTrait};
23
24pub fn generate(attr: Attr, item_trait: syn::ItemTrait) -> syn::Result<proc_macro2::TokenStream> {
25 let trait_info = trait_info::TraitInfo::analyze(&item_trait, &attr)?;
26 attr.validate(&trait_info)?;
27
28 let prefix = &attr.prefix;
29 let trait_path = &trait_info.trait_path;
30 let mirrored_impl_attributes = trait_info
31 .input_trait
32 .attrs
33 .iter()
34 .filter(|attribute| match attribute.style {
35 syn::AttrStyle::Outer => {
36 if let Some(last_segment) = attribute.path().segments.last() {
37 last_segment.ident == "async_trait"
38 } else {
39 false
40 }
41 }
42 syn::AttrStyle::Inner(_) => false,
43 })
44 .collect::<Vec<_>>();
45 let impl_allow_lints = impl_allow_lints();
46
47 let mock_fn_defs: Vec<Option<MockFnDef>> = trait_info
48 .methods
49 .iter()
50 .map(|method| def_mock_fn(method.as_ref(), &trait_info, &attr))
51 .collect();
52 let associated_futures = trait_info
53 .methods
54 .iter()
55 .filter_map(|method| associated_future::def_associated_future(method.as_ref()));
56 let method_impls = trait_info
57 .methods
58 .iter()
59 .enumerate()
60 .map(|(index, method)| {
61 def_method_impl(
62 index,
63 method.as_ref(),
64 &trait_info,
65 &attr,
66 MethodImplKind::Mock,
67 )
68 });
69
70 let where_clause = &trait_info.input_trait.generics.where_clause;
71 let mock_fn_struct_items = mock_fn_defs
72 .iter()
73 .filter_map(Option::as_ref)
74 .map(|def| &def.mock_fn_struct_item);
75 let mock_fn_impl_details = mock_fn_defs
76 .iter()
77 .filter_map(Option::as_ref)
78 .map(|def| &def.impl_details);
79 let generic_params = util::Generics::trait_params(&trait_info, None);
80 let generic_args = util::Generics::trait_args(
81 &trait_info.input_trait.generics,
82 None,
83 InferImplTrait(false),
84 );
85
86 let attr_associated_types = trait_info
87 .input_trait
88 .items
89 .iter()
90 .filter_map(|item| match item {
91 syn::TraitItem::Type(trait_item_type) => {
92 let ident = &trait_item_type.ident;
93 let ident_string = ident.to_string();
94 attr.associated_types
95 .get(&ident_string)
96 .map(|trait_item_type| {
97 quote! {
98 #trait_item_type
99 }
100 })
101 }
102 _ => None,
103 })
104 .collect::<Vec<_>>();
105
106 let attr_associated_consts = trait_info
107 .input_trait
108 .items
109 .iter()
110 .filter_map(|item| match item {
111 syn::TraitItem::Const(trait_item_const) => {
112 let ident = &trait_item_const.ident;
113 let ident_string = ident.to_string();
114 attr.associated_consts
115 .get(&ident_string)
116 .map(|trait_item_const| {
117 quote! { #trait_item_const }
118 })
119 }
120 _ => None,
121 })
122 .collect::<Vec<_>>();
123
124 let (opt_mock_interface_public, opt_mock_interface_private, impl_doc) = match &attr.mock_api {
125 MockApi::Hidden => (
126 None,
127 Some(quote! {
128 #(#mock_fn_struct_items)*
129 }),
130 None,
131 ),
132 MockApi::MockMod(module_ident) => {
133 let path_string = path_to_string(trait_path);
134 let mod_doc_string = format!("Unimock mock API for [{path_string}].");
135 let mod_doc_lit_str = syn::LitStr::new(&mod_doc_string, proc_macro2::Span::call_site());
136
137 let impl_doc_string =
138 format!("Mocked implementation. Mock API is available at [{module_ident}].");
139 let impl_doc_lit_str =
140 syn::LitStr::new(&impl_doc_string, proc_macro2::Span::call_site());
141
142 let vis = &trait_info.input_trait.vis;
143 (
144 Some(quote! {
145 #[doc = #mod_doc_lit_str]
146 #[allow(non_snake_case)]
147 #vis mod #module_ident {
148 #(#mock_fn_struct_items)*
149 }
150 }),
151 None,
152 Some(quote! {
153 #[doc = #impl_doc_lit_str]
154 }),
155 )
156 }
157 MockApi::Flattened(_) => (
158 Some(quote! {
159 #(#mock_fn_struct_items)*
160 }),
161 None,
162 None,
163 ),
164 };
165
166 let default_impl_delegator = if trait_info.has_default_impls {
167 let non_default_methods = trait_info
168 .methods
169 .iter()
170 .enumerate()
171 .filter_map(|(index, opt)| opt.as_ref().map(|method| (index, method)))
172 .filter(|(_, method)| method.method.default.is_none())
173 .map(|(index, method)| {
174 def_method_impl(
175 index,
176 Some(method),
177 &trait_info,
178 &attr,
179 MethodImplKind::Delegate0,
180 )
181 });
182
183 Some(quote! {
184 #(#mirrored_impl_attributes)*
185 #impl_allow_lints
186 impl #generic_params #trait_path #generic_args for #prefix::private::DefaultImplDelegator #where_clause {
187 #(#attr_associated_types)*
188 #(#attr_associated_consts)*
189 #(#non_default_methods)*
190 }
191 })
192 } else {
193 None
194 };
195
196 let output_trait = trait_info.output_trait;
197
198 Ok(quote! {
199 #output_trait
200 #opt_mock_interface_public
201
202 const _: () = {
204 #opt_mock_interface_private
205 #(#mock_fn_impl_details)*
206
207 #impl_doc
208 #(#mirrored_impl_attributes)*
209 #impl_allow_lints
210 impl #generic_params #trait_path #generic_args for #prefix::Unimock #where_clause {
211 #(#attr_associated_types)*
212 #(#attr_associated_consts)*
213 #(#associated_futures)*
214 #(#method_impls)*
215 }
216
217 #default_impl_delegator
218 };
219 })
220}
221
222struct MockFnDef {
223 mock_fn_struct_item: proc_macro2::TokenStream,
224 impl_details: proc_macro2::TokenStream,
225}
226
227fn def_mock_fn(
228 method: Option<&method::MockMethod>,
229 trait_info: &TraitInfo,
230 attr: &Attr,
231) -> Option<MockFnDef> {
232 let method = method?;
233 let prefix = &attr.prefix;
234 let span = method.span();
235 let mirrored_attrs = method.mirrored_attrs();
236 let impl_allow_lints = impl_allow_lints();
237 let mock_fn_ident = &method.mock_fn_ident;
238 let mock_fn_path = method.mock_fn_path(attr);
239 let trait_ident_lit = &trait_info.ident_lit;
240 let method_ident_lit = &method.ident_lit;
241
242 let mock_visibility = match &attr.mock_api {
243 MockApi::MockMod(_) => {
244 syn::Visibility::Public(syn::token::Pub(proc_macro2::Span::call_site()))
245 }
246 _ => trait_info.input_trait.vis.clone(),
247 };
248
249 let input_lifetime = &attr.input_lifetime;
250 let input_types_tuple = InputTypesTuple::new(method, trait_info, attr);
251
252 let generic_params = util::Generics::fn_params(trait_info, Some(method));
253 let generic_args = util::Generics::fn_args(
254 &trait_info.input_trait.generics,
255 Some(method),
256 InferImplTrait(false),
257 );
258 let where_clause = &trait_info.input_trait.generics.where_clause;
259
260 let doc_attrs = if matches!(attr.mock_api, attr::MockApi::Hidden) {
261 vec![]
262 } else {
263 method.mockfn_doc_attrs(&trait_info.trait_path)
264 };
265
266 let output_kind_assoc_type = method
267 .output_structure
268 .output_kind_assoc_type(prefix, trait_info, attr);
269
270 let answer_fn_assoc_type = make_answer_fn(method, trait_info, attr);
271
272 let debug_inputs_fn = method.generate_debug_inputs_fn(attr);
273
274 let gen_mock_fn_struct_item = |non_generic_ident: &syn::Ident| {
275 quote! {
276 #[allow(non_camel_case_types)]
277 #(#doc_attrs)*
278 #mock_visibility struct #non_generic_ident;
279 }
280 };
281
282 let info_set_default_impl = if method.has_default_impl {
283 Some(quote! { .default_impl() })
284 } else {
285 None
286 };
287
288 let impl_block = quote_spanned! { span=>
289 #(#mirrored_attrs)*
290 #impl_allow_lints
291 impl #generic_params #prefix::MockFn for #mock_fn_path #generic_args #where_clause {
292 type Inputs<#input_lifetime> = #input_types_tuple;
293 type OutputKind = #output_kind_assoc_type;
294 type AnswerFn = #answer_fn_assoc_type;
295
296 fn info() -> #prefix::MockFnInfo {
297 #prefix::MockFnInfo::new::<Self>()
298 .path(&[#trait_ident_lit, #method_ident_lit])
299 #info_set_default_impl
300 }
301
302 #debug_inputs_fn
303 }
304 };
305
306 let mock_fn_def = if let Some(non_generic_ident) = &method.non_generic_mock_entry_ident {
307 let phantoms_tuple = util::MockFnPhantomsTuple { trait_info, method };
309 let untyped_phantoms =
310 iter_generic_type_params(trait_info, method).map(util::PhantomDataConstructor);
311 let module_scope = match &attr.mock_api {
312 MockApi::MockMod(ident) => Some(quote_spanned! { span=> #ident:: }),
313 _ => None,
314 };
315 let answer_fn_assoc_type = make_answer_fn(method, trait_info, attr);
316
317 MockFnDef {
318 mock_fn_struct_item: gen_mock_fn_struct_item(non_generic_ident),
319 impl_details: quote! {
320 #impl_allow_lints
321 impl #module_scope #non_generic_ident {
322 #[doc = "Provide the generic parameters to the mocked method"]
323 pub fn with_types #generic_params(
324 self
325 ) -> impl for<#input_lifetime> #prefix::MockFn<
326 Inputs<#input_lifetime> = #input_types_tuple,
327 OutputKind = #output_kind_assoc_type,
328 AnswerFn = #answer_fn_assoc_type,
329 >
330 #where_clause
331 {
332 #mock_fn_ident(#(#untyped_phantoms),*)
333 }
334 }
335
336 #[allow(non_camel_case_types)]
337 struct #mock_fn_ident #generic_args #phantoms_tuple;
338
339 #impl_block
340 },
341 }
342 } else {
343 MockFnDef {
344 mock_fn_struct_item: gen_mock_fn_struct_item(mock_fn_ident),
345 impl_details: impl_block,
346 }
347 };
348
349 Some(mock_fn_def)
350}
351
352enum MethodImplKind {
353 Mock,
354 Delegate0,
355}
356
357fn def_method_impl(
358 index: usize,
359 method: Option<&method::MockMethod>,
360 trait_info: &TraitInfo,
361 attr: &Attr,
362 kind: MethodImplKind,
363) -> proc_macro2::TokenStream {
364 let method = match method {
365 Some(method) => method,
366 None => return quote! {},
367 };
368
369 let span = method.span();
370 let prefix = prefix_with_span(&attr.prefix, span);
371 let method_sig = &method.method.sig;
372 let mirrored_attrs = method.mirrored_attrs();
373 let mock_fn_path = method.mock_fn_path(attr);
374
375 let receiver = method.receiver();
376 let self_ref = SelfReference(&receiver);
377 let self_to_delegator = SelfToDelegator(&receiver);
378 let eval_generic_args = util::Generics::fn_args(
379 &trait_info.input_trait.generics,
380 Some(method),
381 InferImplTrait(true),
382 );
383
384 let must_async_wrap = matches!(
385 method.output_structure.wrapping,
386 output::OutputWrapping::RpitFuture | output::OutputWrapping::AssociatedFuture(_)
387 );
388
389 let trait_path = &trait_info.trait_path;
390 let method_ident = &method_sig.ident;
391 let opt_dot_await = method.opt_dot_await();
392 let track_caller = if method.method.sig.asyncness.is_none() {
393 Some(quote! {
394 #[track_caller]
395 })
396 } else {
397 None
398 };
399
400 let allow_lints: proc_macro2::TokenStream = {
401 let mut lints: Vec<proc_macro2::TokenStream> = vec![quote! { unused }];
402
403 if matches!(
404 method.output_structure.wrapping,
405 output::OutputWrapping::RpitFuture
406 ) {
407 lints.push(quote! { manual_async_fn });
408 }
409
410 quote! { #[allow(#(#lints),*)] }
411 };
412
413 let body = match kind {
414 MethodImplKind::Mock => {
415 let unmock_arm = attr.get_unmock_fn(index).map(
416 |UnmockFn {
417 path: unmock_path,
418 params: unmock_params,
419 }| {
420 let fn_params =
421 method.inputs_destructuring(InputsSyntax::FnParams, Tupled(false), attr);
422
423 let unmock_expr = match unmock_params {
424 None => quote! {
425 #unmock_path(self, #fn_params) #opt_dot_await
426 },
427 Some(UnmockFnParams { params }) => quote! {
428 #unmock_path(#params) #opt_dot_await
429 },
430 };
431
432 let eval_pattern = method.inputs_destructuring(
433 InputsSyntax::EvalPatternMutAsWildcard,
434 Tupled(true),
435 attr,
436 );
437
438 quote! {
439 #prefix::private::Eval::Continue(#prefix::private::Continuation::Unmock, #eval_pattern) => #unmock_expr,
440 }
441 },
442 );
443
444 let inputs_eval_params =
445 method.inputs_destructuring(InputsSyntax::EvalParams, Tupled(true), attr);
446 let fn_params =
447 method.inputs_destructuring(InputsSyntax::FnParams, Tupled(false), attr);
448
449 let default_delegator_call = if method.method.default.is_some() {
450 let delegator_path = quote! {
451 #prefix::private::DefaultImplDelegator
452 };
453 let delegator_constructor = match method_sig.receiver() {
454 Some(syn::Receiver {
455 reference: None,
456 ty,
457 ..
458 }) => {
459 quote! {
460 <#ty as #prefix::private::DelegateToDefaultImpl>::to_delegator(#self_to_delegator)
461 }
462 }
463 Some(syn::Receiver {
464 reference: Some(_),
465 mutability: None,
466 ..
467 }) => quote! {
468 #prefix::private::as_ref(self)
469 },
470 Some(syn::Receiver {
471 reference: Some(_),
472 mutability: Some(_),
473 ..
474 }) => quote! {
475 #prefix::private::as_mut(__self)
476 },
477 _ => todo!("unhandled DefaultImplDelegator constructor"),
478 };
479
480 let generic_args = util::Generics::trait_args(
481 &trait_info.input_trait.generics,
482 None,
483 InferImplTrait(false),
484 );
485
486 Some(quote! {
487 <#delegator_path as #trait_path #generic_args>::#method_ident(
488 #delegator_constructor,
489 #fn_params
490 )
491 #opt_dot_await
492 })
493 } else {
494 None
495 };
496
497 match &receiver {
498 Receiver::MutRef { .. } | Receiver::Pin { .. } => {
499 let eval_pattern_no_mut = method.inputs_destructuring(
500 InputsSyntax::EvalPatternMutAsWildcard,
501 Tupled(true),
502 attr,
503 );
504 let eval_pattern_all = method.inputs_destructuring(
505 InputsSyntax::EvalPatternAll,
506 Tupled(true),
507 attr,
508 );
509 let fn_params_tupled =
510 method.inputs_destructuring(InputsSyntax::FnParams, Tupled(true), attr);
511
512 let polonius_return_type: syn::Type = match method.method.sig.output.clone() {
513 syn::ReturnType::Default => syn::parse_quote!(()),
514 syn::ReturnType::Type(_arrow, ty) => {
515 util::substitute_lifetimes(*ty, Some(&syn::parse_quote!('polonius)))
516 }
517 };
518
519 let default_impl_input_eval_arm = if default_delegator_call.is_some() {
520 quote! {
521 #prefix::private::Continuation::CallDefaultImpl => {
522 #default_delegator_call
523 }
524 }
525 } else {
526 quote!()
527 };
528
529 quote! {
530 let (__cont, #eval_pattern_all) = #prefix::polonius::_polonius!(|#self_ref| -> #polonius_return_type {
531 match #prefix::private::eval::<#mock_fn_path #eval_generic_args>(#self_ref, #inputs_eval_params) {
532 #prefix::private::Eval::Return(output) => #prefix::polonius::_return!(output),
533 #prefix::private::Eval::Continue(__cont, #eval_pattern_no_mut) => #prefix::polonius::_exit!((__cont, #fn_params_tupled)),
534 }
535 });
536 match __cont {
537 #prefix::private::Continuation::Answer(__answer_fn) => {
538 __answer_fn(__self, #fn_params)
539 }
540 #default_impl_input_eval_arm
541 cont => cont.report(__self)
542 }
543 }
544 }
545 _ => {
546 let eval_pattern_no_mut = method.inputs_destructuring(
547 InputsSyntax::EvalPatternMutAsWildcard,
548 Tupled(true),
549 attr,
550 );
551
552 let default_impl_delegate_arm = if method.method.default.is_some() {
553 Some(quote! {
554 #prefix::private::Eval::Continue(#prefix::private::Continuation::CallDefaultImpl, #eval_pattern_no_mut) => {
555 #default_delegator_call
556 },
557 })
558 } else {
559 None
560 };
561
562 quote_spanned! { span=>
563 match #prefix::private::eval::<#mock_fn_path #eval_generic_args>(#self_ref, #inputs_eval_params) {
564 #prefix::private::Eval::Return(output) => output,
565 #prefix::private::Eval::Continue(#prefix::private::Continuation::Answer(__answer_fn), #eval_pattern_no_mut) => {
566 __answer_fn(self, #fn_params)
567 }
568 #unmock_arm
569 #default_impl_delegate_arm
570 #prefix::private::Eval::Continue(cont, _) => cont.report(#self_ref),
571 }
572 }
573 }
574 }
575 }
576 MethodImplKind::Delegate0 => {
577 let inputs_destructuring =
578 method.inputs_destructuring(InputsSyntax::FnParams, Tupled(false), attr);
579 let unimock_accessor = match method_sig.receiver() {
580 Some(syn::Receiver {
581 reference: None,
582 ty,
583 ..
584 }) => {
585 let unimock_type = replace_self_ty_with_path(
586 *ty.clone(),
587 &parse_quote! {
588 #prefix::Unimock
589 },
590 );
591
592 quote! {
593 {
594 <#unimock_type as #prefix::private::DelegateToDefaultImpl>::from_delegator(self)
595 }
596 }
597 }
598 Some(syn::Receiver {
599 reference: Some(_),
600 mutability: None,
601 ..
602 }) => {
603 quote! { #prefix::private::as_ref(self) }
604 }
605 Some(syn::Receiver {
606 reference: Some(_),
607 mutability: Some(_),
608 ..
609 }) => {
610 quote! { #prefix::private::as_mut(self) }
611 }
612 _ => panic!("BUG: Incompatible receiver for default delegator"),
613 };
614 let generic_args = util::Generics::trait_args(
615 &trait_info.input_trait.generics,
616 None,
617 InferImplTrait(false),
618 );
619 quote! {
620 <#prefix::Unimock as #trait_path #generic_args>::#method_ident(
621 #unimock_accessor,
622 #inputs_destructuring
623 )
624 #opt_dot_await
625 }
626 }
627 };
628
629 let body = match (kind, &receiver) {
630 (MethodImplKind::Mock, Receiver::MutRef { surrogate_self }) => {
631 quote! {
632 let mut #surrogate_self = self;
633 #body
634 }
635 }
636 (MethodImplKind::Mock, Receiver::Pin { surrogate_self }) => {
637 quote! {
638 let mut #surrogate_self = ::core::pin::Pin::into_inner(self);
639 #body
640 }
641 }
642 _ => body,
643 };
644
645 let body = if must_async_wrap {
646 quote_spanned! { span=>
647 async move { #body }
648 }
649 } else {
650 body
651 };
652
653 quote_spanned! { span=>
654 #(#mirrored_attrs)*
655 #track_caller
656 #allow_lints
657 #method_sig {
658 #body
659 }
660 }
661}
662
663fn prefix_with_span(prefix: &syn::Path, span: proc_macro2::Span) -> syn::Path {
664 let mut prefix = prefix.clone();
665 for segment in &mut prefix.segments {
666 segment.ident.set_span(span);
667 }
668
669 prefix
670}
671
672struct InputTypesTuple(Vec<syn::Type>);
673
674impl InputTypesTuple {
675 fn new(mock_method: &MockMethod, trait_info: &TraitInfo, attr: &Attr) -> Self {
676 let prefix = &attr.prefix;
677 let input_lifetime = &attr.input_lifetime;
678 Self(
679 mock_method
680 .adapted_sig
681 .inputs
682 .iter()
683 .enumerate()
684 .filter_map(
685 |(index, input)| match mock_method.classify_arg(input, index) {
686 ArgClass::Receiver => None,
687 ArgClass::MutImpossible(..) => Some(syn::parse_quote!(
688 #prefix::Impossible
689 )),
690 ArgClass::Other(_, ty) => Some(ty.clone()),
691 ArgClass::Unprocessable(_) => None,
692 },
693 )
694 .map(|mut ty| {
695 ty = util::substitute_lifetimes(ty, Some(input_lifetime));
696 ty = util::self_type_to_unimock(ty, trait_info, attr);
697 ty
698 })
699 .collect::<Vec<_>>(),
700 )
701 }
702}
703
704impl ToTokens for InputTypesTuple {
705 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
706 if self.0.len() == 1 {
707 tokens.extend(self.0.first().to_token_stream());
708 } else {
709 let types = &self.0;
710 tokens.extend(quote! {
711 (#(#types),*)
712 });
713 }
714 }
715}
716
717fn path_to_string(path: &syn::Path) -> String {
718 let mut out = String::new();
719 for pair in path.segments.pairs() {
720 out.push_str(&pair.value().ident.to_string());
721 if let Some(sep) = pair.punct() {
722 out.push_str(&sep.doc_string());
723 }
724 }
725 out
726}
727
728fn impl_allow_lints() -> proc_macro2::TokenStream {
729 quote! {
730 #[allow(clippy::multiple_bound_locations)]
731 }
732}