sylvia_derive/types/
msg_variant.rs

1use crate::crate_module;
2use crate::fold::StripSelfPath;
3use crate::parser::attributes::VariantAttrForwarding;
4use crate::parser::check_generics::{CheckGenerics, GetPath};
5use crate::parser::variant_descs::VariantDescs;
6use crate::parser::{process_fields, MsgAttr, MsgType};
7use crate::utils::{extract_return_type, filter_wheres, SvCasing};
8use convert_case::{Case, Casing};
9use proc_macro2::TokenStream;
10use quote::{quote, ToTokens};
11use syn::fold::Fold;
12use syn::visit::Visit;
13use syn::{parse_quote, Ident, Signature, Type, WhereClause, WherePredicate};
14
15use super::msg_field::MsgField;
16
17/// Representation of whole message variant
18#[derive(Debug)]
19pub struct MsgVariant<'a> {
20    name: Ident,
21    function_name: &'a Ident,
22    fields: Vec<MsgField<'a>>,
23    /// Type extracted only in case of `Query` and used in `cosmwasm_schema::QueryResponses`
24    /// `returns` attribute.
25    return_type: Option<Type>,
26    msg_attr: MsgAttr,
27    attrs_to_forward: Vec<VariantAttrForwarding>,
28}
29
30impl<'a> MsgVariant<'a> {
31    /// Creates new message variant from trait method
32    pub fn new<Generic>(
33        sig: &'a Signature,
34        generics_checker: &mut CheckGenerics<Generic>,
35        msg_attr: MsgAttr,
36        attrs_to_forward: Vec<VariantAttrForwarding>,
37    ) -> MsgVariant<'a>
38    where
39        Generic: GetPath + PartialEq,
40    {
41        let function_name = &sig.ident;
42
43        let name = function_name.to_case(Case::UpperCamel);
44        let fields = process_fields(sig, generics_checker);
45
46        let return_type = if msg_attr.msg_type() == MsgType::Query {
47            let resp_type = &msg_attr.resp_type();
48            match resp_type {
49                Some(resp_type) => {
50                    let resp_type = parse_quote! { #resp_type };
51                    generics_checker.visit_type(&resp_type);
52                    Some(resp_type)
53                }
54                None => {
55                    let return_type = extract_return_type(&sig.output);
56                    let stripped_return_type = StripSelfPath.fold_path(return_type.clone());
57                    generics_checker.visit_path(&stripped_return_type);
58                    Some(parse_quote! { #return_type })
59                }
60            }
61        } else {
62            None
63        };
64
65        Self {
66            name,
67            function_name,
68            fields,
69            return_type,
70            msg_attr,
71            attrs_to_forward,
72        }
73    }
74
75    /// Emits message variant
76    pub fn emit(&self) -> TokenStream {
77        let Self {
78            name,
79            fields,
80            msg_attr,
81            return_type,
82            attrs_to_forward,
83            ..
84        } = self;
85        let fields = fields.iter().map(MsgField::emit);
86        let returns_attribute = msg_attr.msg_type().emit_returns_attribute(return_type);
87        let attrs_to_forward = attrs_to_forward.iter().map(|attr| &attr.attrs);
88
89        quote! {
90            #returns_attribute
91            #( #[ #attrs_to_forward ] )*
92            #name {
93                #(#fields,)*
94            }
95        }
96    }
97
98    /// Emits match leg dispatching against this variant. Assumes enum variants are imported into the
99    /// scope. Dispatching is performed by calling the function this variant is build from on the
100    /// `contract` variable, with `ctx` as its first argument - both of them should be in scope.
101    pub fn emit_dispatch_leg(&self) -> TokenStream {
102        let Self {
103            name,
104            fields,
105            function_name,
106            msg_attr,
107            ..
108        } = self;
109
110        let args: Vec<_> = fields
111            .iter()
112            .zip(1..)
113            .map(|(field, num)| Ident::new(&format!("field{}", num), field.name().span()))
114            .collect();
115
116        let fields = fields
117            .iter()
118            .map(MsgField::name)
119            .zip(args.clone())
120            .map(|(field, num_field)| quote!(#field : #num_field));
121
122        let method_call = msg_attr.msg_type().emit_dispatch_leg(function_name, &args);
123
124        quote! {
125            #name {
126                #(#fields,)*
127            } => #method_call
128        }
129    }
130
131    /// Emits variants constructors. Constructors names are variants names in snake_case.
132    pub fn emit_variants_constructors(&self) -> TokenStream {
133        let Self { name, fields, .. } = self;
134
135        let method_name = name.to_case(Case::Snake);
136        let parameters = fields.iter().map(MsgField::emit_method_field);
137        let arguments = fields.iter().map(MsgField::name);
138
139        quote! {
140            pub fn #method_name( #(#parameters),*) -> Self {
141                Self :: #name { #(#arguments),* }
142            }
143        }
144    }
145
146    pub fn as_fields_names(&self) -> Vec<&Ident> {
147        self.fields.iter().map(MsgField::name).collect()
148    }
149
150    pub fn emit_method_field(&self) -> Vec<TokenStream> {
151        self.fields
152            .iter()
153            .map(MsgField::emit_method_field)
154            .collect()
155    }
156
157    pub fn name(&self) -> &Ident {
158        &self.name
159    }
160
161    pub fn function_name(&self) -> &Ident {
162        self.function_name
163    }
164
165    pub fn fields(&self) -> &Vec<MsgField> {
166        &self.fields
167    }
168
169    pub fn msg_attr(&self) -> &MsgAttr {
170        &self.msg_attr
171    }
172
173    pub fn return_type(&self) -> &Option<Type> {
174        &self.return_type
175    }
176}
177
178#[derive(Debug)]
179pub struct MsgVariants<'a, Generic> {
180    variants: Vec<MsgVariant<'a>>,
181    used_generics: Vec<&'a Generic>,
182    unused_generics: Vec<&'a Generic>,
183    where_predicates: Vec<&'a WherePredicate>,
184    msg_ty: MsgType,
185}
186
187impl<'a, Generic> MsgVariants<'a, Generic>
188where
189    Generic: GetPath + PartialEq + ToTokens,
190{
191    pub fn new(
192        source: VariantDescs<'a>,
193        msg_ty: MsgType,
194        all_generics: &'a [&'a Generic],
195        unfiltered_where_clause: &'a Option<WhereClause>,
196    ) -> Self {
197        let mut generics_checker = CheckGenerics::new(all_generics);
198        let variants: Vec<_> = source
199            .filter_map(|variant_desc| {
200                let msg_attr: MsgAttr = variant_desc.attr_msg()?;
201                let attrs_to_forward = variant_desc.attrs_to_forward();
202
203                if msg_attr.msg_type() != msg_ty {
204                    return None;
205                }
206
207                Some(MsgVariant::new(
208                    variant_desc.into_sig(),
209                    &mut generics_checker,
210                    msg_attr,
211                    attrs_to_forward,
212                ))
213            })
214            .collect();
215
216        let (used_generics, unused_generics) = generics_checker.used_unused();
217        let where_predicates = filter_wheres(unfiltered_where_clause, all_generics, &used_generics);
218
219        Self {
220            variants,
221            used_generics,
222            unused_generics,
223            where_predicates,
224            msg_ty,
225        }
226    }
227
228    pub fn where_clause(&self) -> Option<WhereClause> {
229        let where_predicates = &self.where_predicates;
230        if !where_predicates.is_empty() {
231            Some(parse_quote! { where #(#where_predicates),* })
232        } else {
233            None
234        }
235    }
236
237    pub fn variants(&self) -> impl Iterator<Item = &MsgVariant> {
238        self.variants.iter()
239    }
240
241    pub fn used_generics(&self) -> &Vec<&'a Generic> {
242        &self.used_generics
243    }
244
245    pub fn unused_generics(&self) -> &Vec<&'a Generic> {
246        &self.unused_generics
247    }
248
249    pub fn msg_ty(&self) -> MsgType {
250        self.msg_ty
251    }
252
253    pub fn emit_phantom_match_arm(&self) -> TokenStream {
254        let sylvia = crate_module();
255        let Self { used_generics, .. } = self;
256        if used_generics.is_empty() {
257            return quote! {};
258        }
259        quote! {
260            _Phantom(_) => Err(#sylvia ::cw_std::StdError::generic_err("Phantom message should not be constructed.")).map_err(Into::into),
261        }
262    }
263
264    pub fn emit_dispatch_legs(&self) -> impl Iterator<Item = TokenStream> + '_ {
265        self.variants
266            .iter()
267            .map(|variant| variant.emit_dispatch_leg())
268    }
269
270    pub fn as_names_snake_cased(&self) -> Vec<String> {
271        self.variants
272            .iter()
273            .map(|variant| variant.name.to_string().to_case(Case::Snake))
274            .collect()
275    }
276
277    pub fn emit_constructors(&self) -> impl Iterator<Item = TokenStream> + '_ {
278        self.variants
279            .iter()
280            .map(MsgVariant::emit_variants_constructors)
281    }
282
283    pub fn emit(&self) -> impl Iterator<Item = TokenStream> + '_ {
284        self.variants.iter().map(MsgVariant::emit)
285    }
286
287    pub fn get_only_variant(&self) -> Option<&MsgVariant> {
288        self.variants.first()
289    }
290
291    pub fn emit_phantom_variant(&self) -> TokenStream {
292        let Self {
293            msg_ty,
294            used_generics,
295            ..
296        } = self;
297
298        if used_generics.is_empty() {
299            return quote! {};
300        }
301
302        let return_attr = match msg_ty {
303            MsgType::Query => quote! { #[returns((#(#used_generics,)*))] },
304            _ => quote! {},
305        };
306
307        quote! {
308            #[serde(skip)]
309            #return_attr
310            _Phantom(std::marker::PhantomData<( #(#used_generics,)* )>),
311        }
312    }
313}