1use convert_case::Case;
2use proc_macro2::{Ident, TokenStream};
3use quote::quote;
4use syn::{parse_quote, GenericParam, ItemTrait, TraitItem, Type};
5
6use crate::crate_module;
7use crate::parser::attributes::msg::MsgType;
8use crate::parser::variant_descs::AsVariantDescs;
9use crate::types::associated_types::AssociatedTypes;
10use crate::types::msg_variant::{MsgVariant, MsgVariants};
11use crate::utils::SvCasing;
12
13pub struct MtHelpers<'a> {
15 source: &'a ItemTrait,
16 error_type: Type,
17 associated_types: &'a AssociatedTypes<'a>,
18 exec_variants: MsgVariants<'a, GenericParam>,
19 query_variants: MsgVariants<'a, GenericParam>,
20 sudo_variants: MsgVariants<'a, GenericParam>,
21 where_clause: &'a Option<syn::WhereClause>,
22}
23
24impl<'a> MtHelpers<'a> {
25 pub fn new(source: &'a ItemTrait, associated_types: &'a AssociatedTypes) -> Self {
26 let where_clause = &source.generics.where_clause;
27 let exec_variants =
28 MsgVariants::new(source.as_variants(), MsgType::Exec, &[], where_clause);
29 let query_variants =
30 MsgVariants::new(source.as_variants(), MsgType::Query, &[], where_clause);
31 let sudo_variants =
32 MsgVariants::new(source.as_variants(), MsgType::Sudo, &[], where_clause);
33 let associated_error = source.items.iter().find_map(|item| match item {
34 TraitItem::Type(ty) if ty.ident == "Error" => Some(&ty.ident),
35 _ => None,
36 });
37 let error_type: Type = match associated_error {
38 Some(error) => parse_quote!(#error),
39 None => unreachable!(),
41 };
42
43 Self {
44 error_type,
45 source,
46 associated_types,
47 where_clause,
48 exec_variants,
49 query_variants,
50 sudo_variants,
51 }
52 }
53
54 pub fn emit(&self) -> TokenStream {
55 let Self {
56 error_type,
57 source,
58 associated_types,
59 where_clause,
60 exec_variants,
61 query_variants,
62 sudo_variants,
63 } = self;
64
65 let sylvia = crate_module();
66
67 let interface_name = &source.ident;
68 let trait_name = Ident::new(&format!("{}Proxy", interface_name), interface_name.span());
69
70 let custom_msg: Type = parse_quote! { CustomMsgT };
71 let prefixed_error_type: Type = parse_quote! { Self:: #error_type };
72
73 let mt_app = parse_quote! {
74 #sylvia ::cw_multi_test::App<
75 BankT,
76 ApiT,
77 StorageT,
78 CustomT,
79 WasmT,
80 StakingT,
81 DistrT,
82 IbcT,
83 GovT,
84 >
85 };
86
87 let associated_args: Vec<_> = associated_types
88 .without_error()
89 .map(|associated| &associated.ident)
90 .collect();
91
92 let api = quote! {
93 < dyn #interface_name < Error = (), #(#associated_args = Self:: #associated_args,)* > as InterfaceMessagesApi >
94 };
95
96 let associated_types_declaration = associated_types.without_error();
97
98 let exec_methods = exec_variants.variants().map(|variant| {
99 variant.emit_mt_method_definition(&custom_msg, &mt_app, &prefixed_error_type, &api)
100 });
101 let query_methods = query_variants.variants().map(|variant| {
102 variant.emit_mt_method_definition(&custom_msg, &mt_app, &prefixed_error_type, &api)
103 });
104 let sudo_methods = sudo_variants.variants().map(|variant| {
105 variant.emit_mt_method_definition(&custom_msg, &mt_app, &prefixed_error_type, &api)
106 });
107
108 let exec_methods_declarations = exec_variants.variants().map(|variant| {
109 variant.emit_mt_method_declaration(&custom_msg, &prefixed_error_type, &api)
110 });
111 let query_methods_declarations = query_variants.variants().map(|variant| {
112 variant.emit_mt_method_declaration(&custom_msg, &prefixed_error_type, &api)
113 });
114 let sudo_methods_declarations = sudo_variants.variants().map(|variant| {
115 variant.emit_mt_method_declaration(&custom_msg, &prefixed_error_type, &api)
116 });
117
118 let where_predicates = where_clause
119 .as_ref()
120 .map(|where_clause| &where_clause.predicates);
121
122 quote! {
123 pub mod mt {
124 use super::*;
125
126 pub trait #trait_name <MtApp, #custom_msg > #where_clause {
127 type #error_type: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static;
128 #(#associated_types_declaration)*
129
130 #(#query_methods_declarations)*
131 #(#exec_methods_declarations)*
132 #(#sudo_methods_declarations)*
133 }
134
135 impl<BankT, ApiT, StorageT, CustomT, WasmT, StakingT, DistrT, IbcT, GovT, #custom_msg, ContractT: super:: #interface_name > #trait_name < #mt_app, #custom_msg > for #sylvia ::multitest::Proxy<'_, #mt_app, ContractT >
136 where
137 ContractT:: #error_type : std::fmt::Debug + std::fmt::Display + Send + Sync + 'static,
138 #custom_msg: #sylvia ::types::CustomMsg + 'static,
139 CustomT: #sylvia ::cw_multi_test::Module,
140 WasmT: #sylvia ::cw_multi_test::Wasm<CustomT::ExecT, CustomT::QueryT>,
141 BankT: #sylvia ::cw_multi_test::Bank,
142 ApiT: #sylvia ::cw_std::Api,
143 StorageT: #sylvia ::cw_std::Storage,
144 CustomT: #sylvia ::cw_multi_test::Module,
145 StakingT: #sylvia ::cw_multi_test::Staking,
146 DistrT: #sylvia ::cw_multi_test::Distribution,
147 IbcT: #sylvia ::cw_multi_test::Ibc,
148 GovT: #sylvia ::cw_multi_test::Gov,
149 CustomT::ExecT: #sylvia ::types::CustomMsg + 'static,
150 CustomT::QueryT: #sylvia:: types::CustomQuery + 'static,
151 #mt_app : #sylvia ::cw_multi_test::Executor< #custom_msg >,
152 #where_predicates
153 {
154 type #error_type = <ContractT as super:: #interface_name>:: #error_type ;
155 #(type #associated_args = <ContractT as super:: #interface_name>:: #associated_args ;)*
156
157 #(#query_methods)*
158 #(#exec_methods)*
159 #(#sudo_methods)*
160 }
161 }
162 }
163 }
164}
165
166trait EmitMethods {
167 fn emit_mt_method_definition(
168 &self,
169 custom_msg: &Type,
170 mt_app: &Type,
171 error_type: &Type,
172 api: &TokenStream,
173 ) -> TokenStream;
174
175 fn emit_mt_method_declaration(
176 &self,
177 custom_msg: &Type,
178 error_type: &Type,
179 api: &TokenStream,
180 ) -> TokenStream;
181}
182
183impl EmitMethods for MsgVariant<'_> {
184 fn emit_mt_method_definition(
185 &self,
186 custom_msg: &Type,
187 mt_app: &Type,
188 error_type: &Type,
189 api: &TokenStream,
190 ) -> TokenStream {
191 let sylvia = crate_module();
192
193 let name = self.name();
194 let return_type = self.return_type();
195
196 let params: Vec<_> = self
197 .fields()
198 .iter()
199 .map(|field| field.emit_method_field_folded())
200 .collect();
201 let arguments = self.as_fields_names();
202 let type_name = self.msg_attr().msg_type().as_accessor_name();
203 let name = name.to_case(Case::Snake);
204
205 match self.msg_attr().msg_type() {
206 MsgType::Exec => quote! {
207 #[track_caller]
208 fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api :: #type_name, #mt_app, #custom_msg> {
209 let msg = #api :: #type_name :: #name ( #(#arguments),* );
210
211 #sylvia ::multitest::ExecProxy::new(&self.contract_addr, msg, &self.app)
212 }
213 },
214 MsgType::Query => {
215 quote! {
216 fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type> {
217 let msg = #api :: #type_name :: #name ( #(#arguments),* );
218
219 (*self.app)
220 .querier()
221 .query_wasm_smart(self.contract_addr.clone(), &msg)
222 .map_err(Into::into)
223 }
224 }
225 }
226 MsgType::Sudo => quote! {
227 fn #name (&self, #(#params,)* ) -> Result< #sylvia ::cw_multi_test::AppResponse, #error_type> {
228 let msg = #api :: #type_name :: #name ( #(#arguments),* );
229
230 (*self.app)
231 .app_mut()
232 .wasm_sudo(self.contract_addr.clone(), &msg)
233 .map_err(|err| err.downcast().unwrap())
234 }
235 },
236 MsgType::Migrate => quote! {
237 #[track_caller]
238 fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::< #error_type, #api :: #type_name , #mt_app, #custom_msg> {
239 let msg = #api :: #type_name ::new( #(#arguments),* );
240
241 #sylvia ::multitest::MigrateProxy::new(&self.contract_addr, msg, &self.app)
242 }
243 },
244 _ => quote! {},
245 }
246 }
247
248 fn emit_mt_method_declaration(
249 &self,
250 custom_msg: &Type,
251 error_type: &Type,
252 api: &TokenStream,
253 ) -> TokenStream {
254 let sylvia = crate_module();
255
256 let name = self.name();
257 let return_type = self.return_type();
258
259 let params: Vec<_> = self
260 .fields()
261 .iter()
262 .map(|field| field.emit_method_field_folded())
263 .collect();
264 let type_name = self.msg_attr().msg_type().as_accessor_name();
265 let name = name.to_case(Case::Snake);
266
267 match self.msg_attr().msg_type() {
268 MsgType::Exec => quote! {
269 fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::ExecProxy::< #error_type, #api:: #type_name, MtApp, #custom_msg>;
270 },
271 MsgType::Query => quote! {
272 fn #name (&self, #(#params,)* ) -> Result<#return_type, #error_type>;
273 },
274 MsgType::Sudo => quote! {
275 fn #name (&self, #(#params,)* ) -> Result< #sylvia ::cw_multi_test::AppResponse, #error_type>;
276 },
277 MsgType::Migrate => quote! {
278 #[track_caller]
279 fn #name (&self, #(#params,)* ) -> #sylvia ::multitest::MigrateProxy::< #error_type, #api :: #type_name, MtApp, #custom_msg>;
280 },
281 _ => quote! {},
282 }
283 }
284}