sylvia_derive/contract/communication/
struct_msg.rs1use crate::parser::attributes::MsgAttrForwarding;
2use crate::parser::variant_descs::AsVariantDescs;
3use crate::parser::{ContractErrorAttr, Custom, MsgType, ParsedSylviaAttributes};
4use crate::types::msg_field::MsgField;
5use crate::types::msg_variant::MsgVariants;
6use crate::utils::{as_where_clause, emit_bracketed_generics, filter_wheres};
7use proc_macro2::TokenStream;
8use proc_macro_error::emit_error;
9use quote::quote;
10use syn::spanned::Spanned;
11use syn::{GenericParam, ItemImpl, Type};
12
13pub struct StructMessage<'a> {
15 source: &'a ItemImpl,
16 contract_type: &'a Type,
17 variants: MsgVariants<'a, GenericParam>,
18 generics: &'a [&'a GenericParam],
19 error: &'a ContractErrorAttr,
20 custom: &'a Custom,
21 msg_attrs_to_forward: Vec<MsgAttrForwarding>,
22}
23
24impl<'a> StructMessage<'a> {
25 pub fn new(
26 source: &'a ItemImpl,
27 msg_ty: MsgType,
28 generics: &'a [&'a GenericParam],
29 error: &'a ContractErrorAttr,
30 custom: &'a Custom,
31 ) -> Option<StructMessage<'a>> {
32 let contract_type = &source.self_ty;
33
34 let variants = MsgVariants::new(
35 source.as_variants(),
36 msg_ty,
37 generics,
38 &source.generics.where_clause,
39 );
40
41 if variants.variants().count() == 0 && variants.msg_ty() == MsgType::Instantiate {
42 emit_error!(
43 source.span(), "Missing instantiation message.";
44 note = source.span() => "`sylvia::contract` requires exactly one method marked with `#[sv::msg(instantiation)]` attribute."
45 );
46 return None;
47 } else if variants.variants().count() > 1 {
48 let mut variants = variants.variants();
49 let first_method = variants.next().map(|v| v.function_name());
50 let obsolete = variants.next().map(|v| v.function_name());
51 emit_error!(
52 first_method.span(), "More than one instantiation or migration message";
53 note = obsolete.span() => "Instantiation/Migration message previously defined here"
54 );
55 return None;
56 }
57
58 let msg_attrs_to_forward = ParsedSylviaAttributes::new(source.attrs.iter())
59 .msg_attrs_forward
60 .into_iter()
61 .filter(|attr| attr.msg_type == msg_ty)
62 .collect();
63
64 Some(Self {
65 source,
66 contract_type,
67 variants,
68 generics,
69 error,
70 custom,
71 msg_attrs_to_forward,
72 })
73 }
74
75 pub fn emit(&self) -> TokenStream {
76 let Self {
77 source,
78 contract_type,
79 variants,
80 generics,
81 error,
82 custom,
83 msg_attrs_to_forward,
84 } = self;
85
86 let Some(variant) = variants.get_only_variant() else {
87 return quote! {};
88 };
89
90 let used_generics = variants.used_generics();
91 let unused_generics = variants.unused_generics();
92 let full_where = &source.generics.where_clause;
93 let wheres = filter_wheres(full_where, generics, used_generics);
94 let where_clause = as_where_clause(&wheres);
95 let bracketed_used_generics = emit_bracketed_generics(used_generics);
96 let bracketed_unused_generics = emit_bracketed_generics(unused_generics);
97
98 let ret_type = variant
99 .msg_attr()
100 .msg_type()
101 .emit_result_type(&custom.msg_or_default(), &error.error);
102 let name = variant.msg_attr().msg_type().emit_msg_name();
103 let function_name = variant.function_name();
104 let mut msg_name = variant.msg_attr().msg_type().emit_msg_name();
105 msg_name.set_span(function_name.span());
106
107 let ctx_type = variant
108 .msg_attr()
109 .msg_type()
110 .emit_ctx_type(&custom.query_or_default());
111 let fields_names: Vec<_> = variant.fields().iter().map(MsgField::name).collect();
112 let parameters = variant.fields().iter().map(MsgField::emit_method_field);
113 let fields = variant.fields().iter().map(MsgField::emit_pub);
114
115 let msg_attrs_to_forward = msg_attrs_to_forward.iter().map(|attr| &attr.attrs);
116 let derive_call = variant.msg_attr().msg_type().emit_derive_call();
117
118 quote! {
119 #[allow(clippy::derive_partial_eq_without_eq)]
120 #derive_call
121 #( #[ #msg_attrs_to_forward ] )*
122 #[serde(rename_all="snake_case")]
123 pub struct #name #bracketed_used_generics {
124 #(#fields,)*
125 }
126
127 impl #bracketed_used_generics #name #bracketed_used_generics #where_clause {
128 pub fn new(#(#parameters,)*) -> Self {
129 Self { #(#fields_names,)* }
130 }
131
132 pub fn dispatch #bracketed_unused_generics (self, contract: &#contract_type, ctx: #ctx_type) -> #ret_type #full_where
133 {
134 let Self { #(#fields_names,)* } = self;
135 contract.#function_name(Into::into(ctx), #(#fields_names,)*).map_err(Into::into)
136 }
137 }
138 }
139 }
140}