sylvia_derive/contract/communication/
struct_msg.rs

1use crate::parser::attributes::MsgAttrForwarding;
2use crate::parser::variant_descs::AsVariantDescs;
3use crate::parser::{ContractErrorAttr, Custom, MsgType, ParsedSylviaAttributes};
4use crate::types::msg_field::MsgField;
5use crate::types::msg_variant::MsgVariants;
6use crate::utils::{as_where_clause, emit_bracketed_generics, filter_wheres};
7use proc_macro2::TokenStream;
8use proc_macro_error::emit_error;
9use quote::quote;
10use syn::spanned::Spanned;
11use syn::{GenericParam, ItemImpl, Type};
12
13/// Representation of single struct message
14pub struct StructMessage<'a> {
15    source: &'a ItemImpl,
16    contract_type: &'a Type,
17    variants: MsgVariants<'a, GenericParam>,
18    generics: &'a [&'a GenericParam],
19    error: &'a ContractErrorAttr,
20    custom: &'a Custom,
21    msg_attrs_to_forward: Vec<MsgAttrForwarding>,
22}
23
24impl<'a> StructMessage<'a> {
25    pub fn new(
26        source: &'a ItemImpl,
27        msg_ty: MsgType,
28        generics: &'a [&'a GenericParam],
29        error: &'a ContractErrorAttr,
30        custom: &'a Custom,
31    ) -> Option<StructMessage<'a>> {
32        let contract_type = &source.self_ty;
33
34        let variants = MsgVariants::new(
35            source.as_variants(),
36            msg_ty,
37            generics,
38            &source.generics.where_clause,
39        );
40
41        if variants.variants().count() == 0 && variants.msg_ty() == MsgType::Instantiate {
42            emit_error!(
43                source.span(), "Missing instantiation message.";
44                note = source.span() => "`sylvia::contract` requires exactly one method marked with `#[sv::msg(instantiation)]` attribute."
45            );
46            return None;
47        } else if variants.variants().count() > 1 {
48            let mut variants = variants.variants();
49            let first_method = variants.next().map(|v| v.function_name());
50            let obsolete = variants.next().map(|v| v.function_name());
51            emit_error!(
52                first_method.span(), "More than one instantiation or migration message";
53                note = obsolete.span() => "Instantiation/Migration message previously defined here"
54            );
55            return None;
56        }
57
58        let msg_attrs_to_forward = ParsedSylviaAttributes::new(source.attrs.iter())
59            .msg_attrs_forward
60            .into_iter()
61            .filter(|attr| attr.msg_type == msg_ty)
62            .collect();
63
64        Some(Self {
65            source,
66            contract_type,
67            variants,
68            generics,
69            error,
70            custom,
71            msg_attrs_to_forward,
72        })
73    }
74
75    pub fn emit(&self) -> TokenStream {
76        let Self {
77            source,
78            contract_type,
79            variants,
80            generics,
81            error,
82            custom,
83            msg_attrs_to_forward,
84        } = self;
85
86        let Some(variant) = variants.get_only_variant() else {
87            return quote! {};
88        };
89
90        let used_generics = variants.used_generics();
91        let unused_generics = variants.unused_generics();
92        let full_where = &source.generics.where_clause;
93        let wheres = filter_wheres(full_where, generics, used_generics);
94        let where_clause = as_where_clause(&wheres);
95        let bracketed_used_generics = emit_bracketed_generics(used_generics);
96        let bracketed_unused_generics = emit_bracketed_generics(unused_generics);
97
98        let ret_type = variant
99            .msg_attr()
100            .msg_type()
101            .emit_result_type(&custom.msg_or_default(), &error.error);
102        let name = variant.msg_attr().msg_type().emit_msg_name();
103        let function_name = variant.function_name();
104        let mut msg_name = variant.msg_attr().msg_type().emit_msg_name();
105        msg_name.set_span(function_name.span());
106
107        let ctx_type = variant
108            .msg_attr()
109            .msg_type()
110            .emit_ctx_type(&custom.query_or_default());
111        let fields_names: Vec<_> = variant.fields().iter().map(MsgField::name).collect();
112        let parameters = variant.fields().iter().map(MsgField::emit_method_field);
113        let fields = variant.fields().iter().map(MsgField::emit_pub);
114
115        let msg_attrs_to_forward = msg_attrs_to_forward.iter().map(|attr| &attr.attrs);
116        let derive_call = variant.msg_attr().msg_type().emit_derive_call();
117
118        quote! {
119            #[allow(clippy::derive_partial_eq_without_eq)]
120            #derive_call
121            #( #[ #msg_attrs_to_forward ] )*
122            #[serde(rename_all="snake_case")]
123            pub struct #name #bracketed_used_generics {
124                #(#fields,)*
125            }
126
127            impl #bracketed_used_generics #name #bracketed_used_generics #where_clause {
128                pub fn new(#(#parameters,)*) -> Self {
129                    Self { #(#fields_names,)* }
130                }
131
132                pub fn dispatch #bracketed_unused_generics (self, contract: &#contract_type, ctx: #ctx_type) -> #ret_type #full_where
133                {
134                    let Self { #(#fields_names,)* } = self;
135                    contract.#function_name(Into::into(ctx), #(#fields_names,)*).map_err(Into::into)
136                }
137            }
138        }
139    }
140}