sylvia_derive/contract/communication/
reply.rs1use convert_case::{Case, Casing};
9use proc_macro2::TokenStream;
10use proc_macro_error::emit_error;
11use quote::quote;
12use syn::{parse_quote, GenericParam, Ident, ItemImpl, Type};
13
14use crate::crate_module;
15use crate::parser::attributes::msg::ReplyOn;
16use crate::parser::{MsgType, ParsedSylviaAttributes};
17use crate::types::msg_field::MsgField;
18use crate::types::msg_variant::{MsgVariant, MsgVariants};
19use crate::utils::emit_turbofish;
20
21const NUMBER_OF_ALLOWED_DATA_FIELDS: usize = 1;
22const NUMBER_OF_ALLOWED_RAW_PAYLOAD_FIELDS: usize = 1;
23
24fn assert_no_redundant_params(payload: &[&MsgField]) {
27 let payload_param = payload.iter().enumerate().find(|(_, field)| {
28 ParsedSylviaAttributes::new(field.attrs().iter())
29 .payload
30 .is_some()
31 });
32
33 if payload.len() == NUMBER_OF_ALLOWED_RAW_PAYLOAD_FIELDS {
34 return;
35 }
36
37 let Some((index, payload_param)) = payload_param else {
38 return;
39 };
40
41 if index == 0 {
42 emit_error!(payload[1].name().span(), "Redundant payload parameter.";
43 note = payload_param.name().span() => "Expected no parameters after the parameter marked with `#[sv::payload(raw)]`."
44 )
45 } else {
46 emit_error!(payload[0].name().span(), "Redundant payload parameter.";
47 note = payload_param.name().span() => "Expected no parameters between the parameter marked with `#[sv::data]` and `#[sv::payload(raw)]`."
48 )
49 }
50}
51
52pub struct Reply<'a> {
53 source: &'a ItemImpl,
54 generics: &'a [&'a GenericParam],
55 reply_data: Vec<ReplyData<'a>>,
56 error: Type,
57}
58
59impl<'a> Reply<'a> {
60 pub fn new(
61 source: &'a ItemImpl,
62 generics: &'a [&'a GenericParam],
63 variants: &'a MsgVariants<'a, GenericParam>,
64 ) -> Self {
65 let reply_data = variants.as_reply_data();
66 let parsed_attrs = ParsedSylviaAttributes::new(source.attrs.iter());
67 let error = parsed_attrs.error_attrs.unwrap_or_default().error;
68
69 Self {
70 source,
71 generics,
72 reply_data,
73 error,
74 }
75 }
76
77 pub fn emit(&self) -> TokenStream {
78 let unique_handlers: Vec<_> = self.emit_reply_ids().collect();
79 let dispatch = self.emit_dispatch();
80 let sub_msg_trait = self.emit_sub_msg_trait();
81
82 quote! {
83 #(#unique_handlers)*
84
85 #dispatch
86
87 #sub_msg_trait
88 }
89 }
90
91 pub fn emit_dispatch(&self) -> TokenStream {
94 let sylvia = crate_module();
95 let Self {
96 source,
97 generics,
98 reply_data,
99 error,
100 ..
101 } = self;
102
103 let msg_ty = MsgType::Reply;
104 let contract = &source.self_ty;
105 let where_clause = &source.generics.where_clause;
106
107 let custom_query = parse_quote!( < #contract as #sylvia ::types::ContractApi>::CustomQuery);
108 let custom_msg = parse_quote!( < #contract as #sylvia ::types::ContractApi>::CustomMsg);
109 let ctx_params = msg_ty.emit_ctx_params(&custom_query);
110 let ret_type = msg_ty.emit_result_type(&custom_msg, error);
111
112 let match_arms = reply_data
113 .iter()
114 .map(|data| data.emit_match_arms(contract, generics));
115
116 quote! {
117 pub fn dispatch_reply < #(#generics),* >( #ctx_params , msg: #sylvia ::cw_std::Reply, contract: #contract ) -> #ret_type #where_clause {
118 let #sylvia ::cw_std::Reply {
119 id,
120 payload,
121 gas_used,
122 result,
123 } = msg;
124
125 match id {
126 #(#match_arms,)*
127 _ => {
128 let err_msg = format!("Unknown reply id: {}.", id);
129 Err( #sylvia ::cw_std::StdError::generic_err(err_msg)).map_err(Into::into)
130 }
131 }
132 }
133 }
134 }
135
136 fn emit_reply_ids(&'a self) -> impl Iterator<Item = TokenStream> + 'a {
141 self.reply_data.iter().enumerate().map(|(id, data)| {
142 let id = id as u64;
143 let reply_id = &data.reply_id;
144 quote! {
145 pub const #reply_id : u64 = #id ;
146 }
147 })
148 }
149
150 fn emit_sub_msg_trait(&self) -> TokenStream {
152 let Self { reply_data, .. } = self;
153
154 let sylvia = crate_module();
155
156 let methods_declaration = reply_data.iter().map(ReplyData::emit_submsg_trait_method);
157 let submsg_reply_setters = reply_data.iter().map(ReplyData::emit_submsg_setter);
158 let submsg_converters: Vec<_> = reply_data
159 .iter()
160 .map(ReplyData::emit_submsg_converter)
161 .collect();
162
163 quote! {
164 pub trait SubMsgMethods<CustomMsgT> {
165 #(#methods_declaration)*
166 }
167
168 impl<CustomMsgT> SubMsgMethods<CustomMsgT> for #sylvia ::cw_std::SubMsg<CustomMsgT> {
169 #(#submsg_reply_setters)*
170 }
171
172 impl<CustomMsgT> SubMsgMethods<CustomMsgT> for #sylvia ::cw_std::WasmMsg {
173 #(#submsg_converters)*
174 }
175
176 impl<CustomMsgT> SubMsgMethods<CustomMsgT> for #sylvia ::cw_std::CosmosMsg<CustomMsgT> {
177 #(#submsg_converters)*
178 }
179 }
180 }
181}
182
183trait ReplyVariants<'a> {
184 fn as_reply_data(&self) -> Vec<ReplyData>;
187}
188
189impl<'a> ReplyVariants<'a> for MsgVariants<'a, GenericParam> {
190 fn as_reply_data(&self) -> Vec<ReplyData> {
191 let mut reply_data: Vec<ReplyData> = vec![];
192
193 self.variants()
194 .flat_map(ReplyVariant::as_variant_handlers_pair)
195 .for_each(|(handler, handler_id)| {
196 let reply_on = handler.msg_attr().reply_on();
197 let reply_id = handler_id.as_reply_id();
198 match reply_data
199 .iter_mut()
200 .find(|existing_data| existing_data.reply_id == reply_id)
201 {
202 Some(existing_data)
203 if existing_data
204 .handlers
205 .iter()
206 .any(|(_, existing_reply_on)| existing_reply_on.excludes(&reply_on)) =>
207 {
208 existing_data.handlers.iter().for_each(
209 |(existing_function_name, existing_reply_on)| {
210 let existing_reply_id = &existing_data.reply_id;
211
212 emit_error!(reply_id.span(), "Duplicated reply handler.";
213 note = existing_data.reply_id.span() => format!("Previous definition of handler=`{}` for reply_on=`{}` defined on `fn {}()`", existing_reply_id, existing_reply_on, existing_function_name);
214 )
215 },
216 )
217 }
218 Some(existing_data) => existing_data.merge(handler),
219 None => reply_data.push(ReplyData::new(reply_id, handler, handler_id)),
220 }
221 });
222
223 reply_data
224 }
225}
226
227struct ReplyData<'a> {
229 pub reply_id: Ident,
231 pub handler_id: &'a Ident,
233 pub handlers: Vec<(&'a Ident, ReplyOn)>,
235 pub data: Option<&'a MsgField<'a>>,
237 pub payload: Vec<&'a MsgField<'a>>,
239}
240
241impl<'a> ReplyData<'a> {
242 pub fn new(reply_id: Ident, variant: &'a MsgVariant<'a>, handler_id: &'a Ident) -> Self {
243 let data = variant.as_data_field();
244 variant.validate_fields_attributes();
245 let payload = variant.fields().iter();
246 let payload = if data.is_some() || variant.msg_attr().reply_on() != ReplyOn::Success {
247 payload
248 .skip(NUMBER_OF_ALLOWED_DATA_FIELDS)
249 .collect::<Vec<_>>()
250 } else {
251 payload.collect::<Vec<_>>()
252 };
253
254 if payload.is_empty() {
255 emit_error!(variant.name().span(), "Missing payload parameter.";
256 note = "Expected at least one payload parameter at the end of parameter list."
257 )
258 }
259
260 assert_no_redundant_params(&payload);
261 let method_name = variant.function_name();
262 let reply_on = variant.msg_attr().reply_on();
263
264 Self {
265 reply_id,
266 handler_id,
267 handlers: vec![(method_name, reply_on)],
268 data,
269 payload,
270 }
271 }
272
273 pub fn merge(&mut self, new_handler: &'a MsgVariant<'a>) {
275 let (current_method_name, _) = match self.handlers.first() {
276 Some(handler) => handler,
277 _ => return,
278 };
279
280 let new_reply_data = ReplyData::new(self.reply_id.clone(), new_handler, self.handler_id);
281
282 if self.payload.len() != new_reply_data.payload.len() {
283 emit_error!(current_method_name.span(), "Mismatched quantity of method parameters.";
284 note = self.handler_id.span() => format!("Both `{}` handlers should have the same number of parameters.", self.handler_id);
285 note = new_handler.function_name().span() => format!("Previous definition of {} handler.", self.handler_id)
286 );
287 }
288
289 self.payload
290 .iter()
291 .zip(new_reply_data.payload.iter())
292 .for_each(|(current_field, new_field)|
293 {
294 if current_field.ty() != new_field.ty() {
295 emit_error!(current_field.name().span(), "Mismatched parameter in reply handlers.";
296 note = current_field.name().span() => format!("Parameters for the `{}` handler have to be the same.", self.handler_id);
297 note = new_field.name().span() => format!("Previous parameter defined for the `{}` handler.", self.handler_id)
298 )
299 }
300 });
301
302 let new_function_name = new_handler.function_name();
303 let new_reply_on = new_handler.msg_attr().reply_on();
304 self.handlers.push((new_function_name, new_reply_on));
305 }
306
307 fn emit_match_arms(&self, contract: &Type, generics: &[&GenericParam]) -> TokenStream {
309 let reply_id = &self.reply_id;
310 let contract_turbofish = emit_turbofish(contract, generics);
311 let success_match_arm = self.emit_success_match_arm(&contract_turbofish);
312 let error_match_arm = self.emit_error_match_arm(&contract_turbofish);
313
314 quote! {
315 #reply_id => {
316 match result {
317 #success_match_arm
318 #error_match_arm
319 }
320 }
321 }
322 }
323
324 fn emit_cw_reply_on(&self) -> TokenStream {
327 let sylvia = crate_module();
328 let is_always = self
329 .handlers
330 .iter()
331 .any(|(_, reply_on)| reply_on == &ReplyOn::Always);
332 let is_success = self
333 .handlers
334 .iter()
335 .any(|(_, reply_on)| reply_on == &ReplyOn::Success);
336 let is_error = self
337 .handlers
338 .iter()
339 .any(|(_, reply_on)| reply_on == &ReplyOn::Error);
340
341 if is_always || (is_success && is_error) {
342 quote! { #sylvia ::cw_std::ReplyOn::Always }
343 } else if is_success {
344 quote! { #sylvia ::cw_std::ReplyOn::Success }
345 } else {
346 quote! { #sylvia ::cw_std::ReplyOn::Error }
347 }
348 }
349
350 fn emit_submsg_setter(&self) -> TokenStream {
352 let sylvia = crate_module();
353 let Self {
354 reply_id,
355 handler_id,
356 payload,
357 ..
358 } = self;
359
360 let method_name = handler_id;
361 let reply_on = self.emit_cw_reply_on();
362 let payload_parameters = payload.iter().map(|field| field.emit_method_field());
363 let payload_serialization = payload.emit_payload_serialization();
364
365 quote! {
366 fn #method_name (self, #(#payload_parameters),* ) -> #sylvia ::cw_std::StdResult< #sylvia ::cw_std::SubMsg<CustomMsgT>> {
367 #payload_serialization
368
369 Ok( #sylvia ::cw_std::SubMsg {
370 reply_on: #reply_on ,
371 id: #reply_id ,
372 payload,
373 ..self
374 })
375 }
376 }
377 }
378
379 fn emit_submsg_converter(&self) -> TokenStream {
381 let sylvia = crate_module();
382 let Self {
383 reply_id,
384 handler_id,
385 payload,
386 ..
387 } = self;
388
389 let method_name = handler_id;
390 let reply_on = self.emit_cw_reply_on();
391 let payload_parameters = payload.iter().map(|field| field.emit_method_field());
392 let payload_serialization = payload.emit_payload_serialization();
393
394 quote! {
395 fn #method_name (self, #(#payload_parameters),* ) -> #sylvia ::cw_std::StdResult< #sylvia ::cw_std::SubMsg<CustomMsgT>> {
396 #payload_serialization
397
398 Ok( #sylvia ::cw_std::SubMsg {
399 reply_on: #reply_on ,
400 id: #reply_id ,
401 msg: self.into(),
402 payload,
403 gas_limit: None,
404 })
405 }
406 }
407 }
408
409 fn emit_submsg_trait_method(&self) -> TokenStream {
410 let sylvia = crate_module();
411 let method_name = &self.handler_id;
412 let payload_parameters = self.payload.iter().map(|field| field.emit_method_field());
413
414 quote! {
415 fn #method_name (self, #(#payload_parameters),* ) -> #sylvia ::cw_std::StdResult< #sylvia ::cw_std::SubMsg<CustomMsgT>>;
416 }
417 }
418
419 fn emit_success_match_arm(&self, contract_turbofish: &Type) -> TokenStream {
423 let sylvia = crate_module();
424
425 match self
426 .handlers
427 .iter()
428 .find(|(_, reply_on)| reply_on == &ReplyOn::Success || reply_on == &ReplyOn::Always)
429 {
430 Some((method_name, reply_on)) if reply_on == &ReplyOn::Success => {
431 let payload_values = self.payload.iter().map(|field| field.name());
432 let payload_deserialization = self.payload.emit_payload_deserialization();
433 let data_deserialization = self.data.map(DataField::emit_data_deserialization);
434 let data = self.data.map(|_| quote! { data, });
435
436 quote! {
437 #sylvia ::cw_std::SubMsgResult::Ok(sub_msg_resp) => {
438 #[allow(deprecated)]
439 let #sylvia ::cw_std::SubMsgResponse { events, data, msg_responses} = sub_msg_resp;
440 #payload_deserialization
441 #data_deserialization
442
443 #contract_turbofish ::new(). #method_name ((deps, env, gas_used, events, msg_responses).into(), #data #(#payload_values),* )
444 }
445 }
446 }
447 Some((method_name, reply_on)) if reply_on == &ReplyOn::Always => {
448 let payload_values = self.payload.iter().map(|field| field.name());
449 let payload_deserialization = self.payload.emit_payload_deserialization();
450
451 quote! {
452 #sylvia ::cw_std::SubMsgResult::Ok(_) => {
453 #payload_deserialization
454
455 #contract_turbofish ::new(). #method_name ((deps, env, gas_used, vec![], vec![]).into(), result, #(#payload_values),* )
456 }
457 }
458 }
459 _ => quote! {
460 #sylvia ::cw_std::SubMsgResult::Ok(sub_msg_resp) => {
461 let mut resp = sylvia::cw_std::Response::new().add_events(sub_msg_resp.events);
462
463 #[allow(deprecated)]
464 if sub_msg_resp.data.is_some() {
465 resp = resp.set_data(sub_msg_resp.data.unwrap());
466 }
467
468 Ok(resp)
469 }
470 },
471 }
472 }
473
474 fn emit_error_match_arm(&self, contract_turbofish: &Type) -> TokenStream {
478 let sylvia = crate_module();
479
480 match self
481 .handlers
482 .iter()
483 .find(|(_, reply_on)| reply_on == &ReplyOn::Error || reply_on == &ReplyOn::Always)
484 {
485 Some((method_name, reply_on)) if reply_on == &ReplyOn::Error => {
486 let payload_values = self.payload.iter().map(|field| field.name());
487 let payload_deserialization = self.payload.emit_payload_deserialization();
488
489 quote! {
490 #sylvia ::cw_std::SubMsgResult::Err(error) => {
491 #payload_deserialization
492
493 #contract_turbofish ::new(). #method_name ((deps, env, gas_used, vec![], vec![]).into(), error, #(#payload_values),* )
494 }
495 }
496 }
497 Some((method_name, reply_on)) if reply_on == &ReplyOn::Always => {
498 let payload_values = self.payload.iter().map(|field| field.name());
499 let payload_deserialization = self.payload.emit_payload_deserialization();
500
501 quote! {
502 #sylvia ::cw_std::SubMsgResult::Err(_) => {
503 #payload_deserialization
504
505 #contract_turbofish ::new(). #method_name ((deps, env, gas_used, vec![], vec![]).into(), result, #(#payload_values),* )
506 }
507 }
508 }
509 _ => quote! {
510 #sylvia ::cw_std::SubMsgResult::Err(error) => {
511 Err(sylvia::cw_std::StdError::generic_err(error)).map_err(Into::into)
512 }
513 },
514 }
515 }
516}
517
518trait ReplyVariant<'a> {
519 fn as_variant_handlers_pair(&'a self) -> Vec<(&'a MsgVariant<'a>, &'a Ident)>;
520 fn as_data_field(&'a self) -> Option<&'a MsgField<'a>>;
521 fn validate_fields_attributes(&'a self);
522}
523
524impl<'a> ReplyVariant<'a> for MsgVariant<'a> {
525 fn as_variant_handlers_pair(&'a self) -> Vec<(&'a MsgVariant<'a>, &'a Ident)> {
526 let variant_handler_id_pair: Vec<_> = self
527 .msg_attr()
528 .handlers()
529 .iter()
530 .map(|handler| (self, handler))
531 .collect();
532
533 if variant_handler_id_pair.is_empty() {
534 return vec![(self, self.function_name())];
535 }
536
537 variant_handler_id_pair
538 }
539
540 fn as_data_field(&'a self) -> Option<&'a MsgField<'a>> {
543 let data_param = self.fields().iter().enumerate().find(|(_, field)| {
544 ParsedSylviaAttributes::new(field.attrs().iter())
545 .data
546 .is_some()
547 });
548 match data_param {
549 Some((index, field))
550 if self.msg_attr().reply_on() == ReplyOn::Success && index == 0 =>
551 {
552 Some(field)
553 }
554 Some((index, field))
555 if self.msg_attr().reply_on() == ReplyOn::Success && index != 0 =>
556 {
557 emit_error!(field.name().span(), "Wrong usage of `#[sv::data]` attribute.";
558 note = "The `#[sv::data]` attribute can only be used on the first parameter after the `ReplyCtx`."
559 );
560 None
561 }
562 Some((_, field)) if self.msg_attr().reply_on() != ReplyOn::Success => {
563 emit_error!(field.name().span(), "Wrong usage of `#[sv::data]` attribute.";
564 note = "The `#[sv::data]` attribute can only be used in `success` scenario.";
565 note = format!("Found usage in `{}` scenario.", self.msg_attr().reply_on())
566 );
567 None
568 }
569 _ => None,
570 }
571 }
572
573 fn validate_fields_attributes(&'a self) {
575 let field_attrs = self.fields().iter().flat_map(|field| field.attrs());
576 ParsedSylviaAttributes::new(field_attrs);
577 }
578}
579
580pub trait DataField {
581 fn emit_data_deserialization(&self) -> TokenStream;
582}
583
584impl DataField for MsgField<'_> {
585 fn emit_data_deserialization(&self) -> TokenStream {
586 let sylvia = crate_module();
587 let data = ParsedSylviaAttributes::new(self.attrs().iter()).data;
588 let missing_data_err = "Missing reply data field.";
589 let transaction_id = quote! {
590 env
591 .transaction
592 .as_ref()
593 .map(|tx| format!("{}", &tx.index))
594 .unwrap_or_else(|| "Missing".to_string())
595 };
596 let invalid_reply_data_err = quote! {
597 format! {"Invalid reply data at block height: {}, transaction id: {}.\nSerde error while deserializing {}",
598 env.block.height,
599 #transaction_id,
600 err}
601 };
602 let execute_data_deserialization = quote! {
603 let deserialized_data =
604 #sylvia ::cw_utils::parse_execute_response_data(data.as_slice())
605 .map_err(|err| #sylvia ::cw_std::StdError::generic_err(
606 format!("Failed deserializing protobuf data: {}", err)
607 ))?;
608 let deserialized_data = match deserialized_data.data {
609 Some(data) => #sylvia ::cw_std::from_json(&data).map_err(|err| #sylvia ::cw_std::StdError::generic_err( #invalid_reply_data_err ))?,
610 None => return Err(Into::into( #sylvia ::cw_std::StdError::generic_err( #missing_data_err ))),
611 };
612 };
613
614 let instantiate_data_deserialization = quote! {
615 let deserialized_data =
616 #sylvia ::cw_utils::parse_instantiate_response_data(data.as_slice())
617 .map_err(|err| #sylvia ::cw_std::StdError::generic_err(
618 format!("Failed deserializing protobuf data: {}", err)
619 ))?;
620 };
621
622 match data {
623 Some(data) if data.raw && data.opt => quote! {},
624 Some(data) if data.raw => quote! {
625 let data = match data {
626 Some(data) => data,
627 None => return Err(Into::into( #sylvia ::cw_std::StdError::generic_err( #missing_data_err ))),
628 };
629 },
630 Some(data) if data.instantiate && data.opt => quote! {
631 let data = match data {
632 Some(data) => {
633 #instantiate_data_deserialization
634
635 Some(deserialized_data)
636 },
637 None => None,
638 };
639 },
640 Some(data) if data.instantiate => quote! {
641 let data = match data {
642 Some(data) => {
643 #instantiate_data_deserialization
644
645 deserialized_data
646 },
647 None => return Err(Into::into( #sylvia ::cw_std::StdError::generic_err( #missing_data_err ))),
648 };
649 },
650 Some(data) if data.opt => quote! {
651 let data = match data {
652 Some(data) => {
653 #execute_data_deserialization
654
655 Some(deserialized_data)
656 },
657 None => None,
658 };
659 },
660 _ => quote! {
661 let data = match data {
662 Some(data) => {
663 #execute_data_deserialization
664
665 deserialized_data
666 },
667 None => return Err(Into::into( #sylvia ::cw_std::StdError::generic_err( #missing_data_err ))),
668 };
669 },
670 }
671 }
672}
673
674pub trait PayloadFields {
675 fn emit_payload_deserialization(&self) -> TokenStream;
676 fn emit_payload_serialization(&self) -> TokenStream;
677 fn is_payload_marked(&self) -> bool;
678}
679
680impl PayloadFields for Vec<&MsgField<'_>> {
681 fn emit_payload_deserialization(&self) -> TokenStream {
682 let sylvia = crate_module();
683 if self.is_payload_marked() {
684 let payload_value = self.first().unwrap().name();
686 return quote! {
687 let #payload_value = payload ;
688 };
689 }
690
691 let deserialized_payload_names = self.iter().map(|field| field.name());
692 quote! {
693 let ( #(#deserialized_payload_names),* ) = #sylvia ::cw_std::from_json(&payload)?;
694 }
695 }
696
697 fn emit_payload_serialization(&self) -> TokenStream {
698 let sylvia = crate_module();
699 if self.is_payload_marked() {
700 let payload_value = self.first().unwrap().name();
702 return quote! {
703 let payload = #payload_value ;
704 };
705 }
706
707 let payload_values = self.iter().map(|field| field.name());
708 quote! {
709 let payload = #sylvia ::cw_std::to_json_binary(&( #(#payload_values),* ))?;
710 }
711 }
712
713 fn is_payload_marked(&self) -> bool {
714 self.iter().any(|field| {
715 ParsedSylviaAttributes::new(field.attrs().iter())
716 .payload
717 .is_some()
718 })
719 }
720}
721
722trait AsReplyId {
724 fn as_reply_id(&self) -> Ident;
725}
726
727impl AsReplyId for Ident {
728 fn as_reply_id(&self) -> Ident {
729 let reply_id = format! {"{}_REPLY_ID", self.to_string().to_case(Case::UpperSnake)};
730 Ident::new(&reply_id, self.span())
731 }
732}