sylvia_derive/interface/
mt.rs

1use convert_case::Case;
2use proc_macro2::{Ident, TokenStream};
3use quote::quote;
4use syn::{parse_quote, GenericParam, ItemTrait, TraitItem, Type};
5
6use crate::crate_module;
7use crate::parser::attributes::msg::MsgType;
8use crate::parser::variant_descs::AsVariantDescs;
9use crate::types::associated_types::AssociatedTypes;
10use crate::types::msg_variant::{MsgVariant, MsgVariants};
11use crate::utils::SvCasing;
12
13/// Emits helpers for testing interface messages using MultiTest.
14pub struct MtHelpers<'a> {
15    source: &'a ItemTrait,
16    error_type: Type,
17    associated_types: &'a AssociatedTypes<'a>,
18    exec_variants: MsgVariants<'a, GenericParam>,
19    query_variants: MsgVariants<'a, GenericParam>,
20    sudo_variants: MsgVariants<'a, GenericParam>,
21    where_clause: &'a Option<syn::WhereClause>,
22}
23
24impl<'a> MtHelpers<'a> {
25    pub fn new(source: &'a ItemTrait, associated_types: &'a AssociatedTypes) -> Self {
26        let where_clause = &source.generics.where_clause;
27        let exec_variants =
28            MsgVariants::new(source.as_variants(), MsgType::Exec, &[], where_clause);
29        let query_variants =
30            MsgVariants::new(source.as_variants(), MsgType::Query, &[], where_clause);
31        let sudo_variants =
32            MsgVariants::new(source.as_variants(), MsgType::Sudo, &[], where_clause);
33        let associated_error = source.items.iter().find_map(|item| match item {
34            TraitItem::Type(ty) if ty.ident == "Error" => Some(&ty.ident),
35            _ => None,
36        });
37        let error_type: Type = match associated_error {
38            Some(error) => parse_quote!(#error),
39            // This should never happen as the `interface` macro requires the trait to have an associated `Error` type
40            None => unreachable!(),
41        };
42
43        Self {
44            error_type,
45            source,
46            associated_types,
47            where_clause,
48            exec_variants,
49            query_variants,
50            sudo_variants,
51        }
52    }
53
54    pub fn emit(&self) -> TokenStream {
55        let Self {
56            error_type,
57            source,
58            associated_types,
59            where_clause,
60            exec_variants,
61            query_variants,
62            sudo_variants,
63        } = self;
64
65        let sylvia = crate_module();
66
67        let interface_name = &source.ident;
68        let trait_name = Ident::new(&format!("{}Proxy", interface_name), interface_name.span());
69
70        let custom_msg: Type = parse_quote! { CustomMsgT };
71        let prefixed_error_type: Type = parse_quote! { Self:: #error_type };
72
73        let mt_app = parse_quote! {
74            #sylvia ::cw_multi_test::App<
75                BankT,
76                ApiT,
77                StorageT,
78                CustomT,
79                WasmT,
80                StakingT,
81                DistrT,
82                IbcT,
83                GovT,
84            >
85        };
86
87        let associated_args: Vec<_> = associated_types
88            .without_error()
89            .map(|associated| &associated.ident)
90            .collect();
91
92        let api = quote! {
93            < dyn #interface_name < Error = (), #(#associated_args = Self:: #associated_args,)* > as InterfaceMessagesApi >
94        };
95
96        let associated_types_declaration = associated_types.without_error();
97
98        let exec_methods = exec_variants.variants().map(|variant| {
99            variant.emit_mt_method_definition(&custom_msg, &mt_app, &prefixed_error_type, &api)
100        });
101        let query_methods = query_variants.variants().map(|variant| {
102            variant.emit_mt_method_definition(&custom_msg, &mt_app, &prefixed_error_type, &api)
103        });
104        let sudo_methods = sudo_variants.variants().map(|variant| {
105            variant.emit_mt_method_definition(&custom_msg, &mt_app, &prefixed_error_type, &api)
106        });
107
108        let exec_methods_declarations = exec_variants.variants().map(|variant| {
109            variant.emit_mt_method_declaration(&custom_msg, &prefixed_error_type, &api)
110        });
111        let query_methods_declarations = query_variants.variants().map(|variant| {
112            variant.emit_mt_method_declaration(&custom_msg, &prefixed_error_type, &api)
113        });
114        let sudo_methods_declarations = sudo_variants.variants().map(|variant| {
115            variant.emit_mt_method_declaration(&custom_msg, &prefixed_error_type, &api)
116        });
117
118        let where_predicates = where_clause
119            .as_ref()
120            .map(|where_clause| &where_clause.predicates);
121
122        quote! {
123            pub mod mt {
124                use super::*;
125
126                pub trait #trait_name <MtApp, #custom_msg > #where_clause {
127                    type #error_type: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static;
128                    #(#associated_types_declaration)*
129
130                    #(#query_methods_declarations)*
131                    #(#exec_methods_declarations)*
132                    #(#sudo_methods_declarations)*
133                }
134
135                impl<BankT, ApiT, StorageT, CustomT, WasmT, StakingT, DistrT, IbcT, GovT, #custom_msg, ContractT: super:: #interface_name > #trait_name < #mt_app, #custom_msg > for #sylvia ::multitest::Proxy<'_, #mt_app, ContractT >
136                where
137                    ContractT:: #error_type : std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
138                    #custom_msg: #sylvia ::types::CustomMsg + 'static,
139                    CustomT: #sylvia ::cw_multi_test::Module,
140                    WasmT: #sylvia ::cw_multi_test::Wasm<CustomT::ExecT, CustomT::QueryT>,
141                    BankT: #sylvia ::cw_multi_test::Bank,
142                    ApiT: #sylvia ::cw_std::Api,
143                    StorageT: #sylvia ::cw_std::Storage,
144                    CustomT: #sylvia ::cw_multi_test::Module,
145                    StakingT: #sylvia ::cw_multi_test::Staking,
146                    DistrT: #sylvia ::cw_multi_test::Distribution,
147                    IbcT: #sylvia ::cw_multi_test::Ibc,
148                    GovT: #sylvia ::cw_multi_test::Gov,
149                    CustomT::ExecT: #sylvia ::types::CustomMsg + 'static,
150                    CustomT::QueryT: #sylvia:: types::CustomQuery + 'static,
151                    #mt_app : #sylvia ::cw_multi_test::Executor< #custom_msg >,
152                    #where_predicates
153                {
154                    type #error_type = <ContractT as super:: #interface_name>:: #error_type ;
155                    #(type #associated_args = <ContractT as super:: #interface_name>:: #associated_args ;)*
156
157                    #(#query_methods)*
158                    #(#exec_methods)*
159                    #(#sudo_methods)*
160                }
161            }
162        }
163    }
164}
165
166trait EmitMethods {
167    fn emit_mt_method_definition(
168        &self,
169        custom_msg: &Type,
170        mt_app: &Type,
171        error_type: &Type,
172        api: &TokenStream,
173    ) -> TokenStream;
174
175    fn emit_mt_method_declaration(
176        &self,
177        custom_msg: &Type,
178        error_type: &Type,
179        api: &TokenStream,
180    ) -> TokenStream;
181}
182
183impl EmitMethods for MsgVariant<'_> {
184    fn emit_mt_method_definition(
185        &self,
186        custom_msg: &Type,
187        mt_app: &Type,
188        error_type: &Type,
189        api: &TokenStream,
190    ) -> TokenStream {
191        let sylvia = crate_module();
192
193        let name = self.name();
194        let return_type = self.return_type();
195
196        let params: Vec<_> = self
197            .fields()
198            .iter()
199            .map(|field| field.emit_method_field_folded())
200            .collect();
201        let arguments = self.as_fields_names();
202        let type_name = self.msg_attr().msg_type().as_accessor_name();
203        let name = name.to_case(Case::Snake);
204
205        match self.msg_attr().msg_type() {
206            MsgType::Exec => quote! {
207                #[track_caller]
208                fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api :: #type_name, #mt_app, #custom_msg> {
209                    let msg = #api :: #type_name :: #name ( #(#arguments),* );
210
211                    #sylvia ::multitest::ExecProxy::new(&self.contract_addr, msg, &self.app)
212                }
213            },
214            MsgType::Query => {
215                quote! {
216                    fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> {
217                        let msg = #api :: #type_name :: #name ( #(#arguments),* );
218
219                        (*self.app)
220                            .querier()
221                            .query_wasm_smart(self.contract_addr.clone(), &msg)
222                            .map_err(Into::into)
223                    }
224                }
225            }
226            MsgType::Sudo => quote! {
227                fn #name (&self, #(#params,)* ) -> Result< #sylvia ::cw_multi_test::AppResponse, #error_type> {
228                    let msg = #api :: #type_name :: #name ( #(#arguments),* );
229
230                    (*self.app)
231                        .app_mut()
232                        .wasm_sudo(self.contract_addr.clone(), &msg)
233                        .map_err(|err| err.downcast().unwrap())
234                }
235            },
236            MsgType::Migrate => quote! {
237                #[track_caller]
238                fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::< #error_type, #api :: #type_name , #mt_app, #custom_msg> {
239                    let msg = #api :: #type_name ::new( #(#arguments),* );
240
241                    #sylvia ::multitest::MigrateProxy::new(&self.contract_addr, msg, &self.app)
242                }
243            },
244            _ => quote! {},
245        }
246    }
247
248    fn emit_mt_method_declaration(
249        &self,
250        custom_msg: &Type,
251        error_type: &Type,
252        api: &TokenStream,
253    ) -> TokenStream {
254        let sylvia = crate_module();
255
256        let name = self.name();
257        let return_type = self.return_type();
258
259        let params: Vec<_> = self
260            .fields()
261            .iter()
262            .map(|field| field.emit_method_field_folded())
263            .collect();
264        let type_name = self.msg_attr().msg_type().as_accessor_name();
265        let name = name.to_case(Case::Snake);
266
267        match self.msg_attr().msg_type() {
268            MsgType::Exec => quote! {
269                fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api:: #type_name, MtApp, #custom_msg>;
270            },
271            MsgType::Query => quote! {
272                fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type>;
273            },
274            MsgType::Sudo => quote! {
275                fn #name (&self, #(#params,)* ) -> Result< #sylvia ::cw_multi_test::AppResponse, #error_type>;
276            },
277            MsgType::Migrate => quote! {
278                #[track_caller]
279                fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::< #error_type, #api :: #type_name, MtApp, #custom_msg>;
280            },
281            _ => quote! {},
282        }
283    }
284}