sylvia_derive/types/
msg_variant.rs1use crate::crate_module;
2use crate::fold::StripSelfPath;
3use crate::parser::attributes::VariantAttrForwarding;
4use crate::parser::check_generics::{CheckGenerics, GetPath};
5use crate::parser::variant_descs::VariantDescs;
6use crate::parser::{process_fields, MsgAttr, MsgType};
7use crate::utils::{extract_return_type, filter_wheres, SvCasing};
8use convert_case::{Case, Casing};
9use proc_macro2::TokenStream;
10use quote::{quote, ToTokens};
11use syn::fold::Fold;
12use syn::visit::Visit;
13use syn::{parse_quote, Ident, Signature, Type, WhereClause, WherePredicate};
14
15use super::msg_field::MsgField;
16
17#[derive(Debug)]
19pub struct MsgVariant<'a> {
20 name: Ident,
21 function_name: &'a Ident,
22 fields: Vec<MsgField<'a>>,
23 return_type: Option<Type>,
26 msg_attr: MsgAttr,
27 attrs_to_forward: Vec<VariantAttrForwarding>,
28}
29
30impl<'a> MsgVariant<'a> {
31 pub fn new<Generic>(
33 sig: &'a Signature,
34 generics_checker: &mut CheckGenerics<Generic>,
35 msg_attr: MsgAttr,
36 attrs_to_forward: Vec<VariantAttrForwarding>,
37 ) -> MsgVariant<'a>
38 where
39 Generic: GetPath + PartialEq,
40 {
41 let function_name = &sig.ident;
42
43 let name = function_name.to_case(Case::UpperCamel);
44 let fields = process_fields(sig, generics_checker);
45
46 let return_type = if msg_attr.msg_type() == MsgType::Query {
47 let resp_type = &msg_attr.resp_type();
48 match resp_type {
49 Some(resp_type) => {
50 let resp_type = parse_quote! { #resp_type };
51 generics_checker.visit_type(&resp_type);
52 Some(resp_type)
53 }
54 None => {
55 let return_type = extract_return_type(&sig.output);
56 let stripped_return_type = StripSelfPath.fold_path(return_type.clone());
57 generics_checker.visit_path(&stripped_return_type);
58 Some(parse_quote! { #return_type })
59 }
60 }
61 } else {
62 None
63 };
64
65 Self {
66 name,
67 function_name,
68 fields,
69 return_type,
70 msg_attr,
71 attrs_to_forward,
72 }
73 }
74
75 pub fn emit(&self) -> TokenStream {
77 let Self {
78 name,
79 fields,
80 msg_attr,
81 return_type,
82 attrs_to_forward,
83 ..
84 } = self;
85 let fields = fields.iter().map(MsgField::emit);
86 let returns_attribute = msg_attr.msg_type().emit_returns_attribute(return_type);
87 let attrs_to_forward = attrs_to_forward.iter().map(|attr| &attr.attrs);
88
89 quote! {
90 #returns_attribute
91 #( #[ #attrs_to_forward ] )*
92 #name {
93 #(#fields,)*
94 }
95 }
96 }
97
98 pub fn emit_dispatch_leg(&self) -> TokenStream {
102 let Self {
103 name,
104 fields,
105 function_name,
106 msg_attr,
107 ..
108 } = self;
109
110 let args: Vec<_> = fields
111 .iter()
112 .zip(1..)
113 .map(|(field, num)| Ident::new(&format!("field{}", num), field.name().span()))
114 .collect();
115
116 let fields = fields
117 .iter()
118 .map(MsgField::name)
119 .zip(args.clone())
120 .map(|(field, num_field)| quote!(#field : #num_field));
121
122 let method_call = msg_attr.msg_type().emit_dispatch_leg(function_name, &args);
123
124 quote! {
125 #name {
126 #(#fields,)*
127 } => #method_call
128 }
129 }
130
131 pub fn emit_variants_constructors(&self) -> TokenStream {
133 let Self { name, fields, .. } = self;
134
135 let method_name = name.to_case(Case::Snake);
136 let parameters = fields.iter().map(MsgField::emit_method_field);
137 let arguments = fields.iter().map(MsgField::name);
138
139 quote! {
140 pub fn #method_name( #(#parameters),*) -> Self {
141 Self :: #name { #(#arguments),* }
142 }
143 }
144 }
145
146 pub fn as_fields_names(&self) -> Vec<&Ident> {
147 self.fields.iter().map(MsgField::name).collect()
148 }
149
150 pub fn emit_method_field(&self) -> Vec<TokenStream> {
151 self.fields
152 .iter()
153 .map(MsgField::emit_method_field)
154 .collect()
155 }
156
157 pub fn name(&self) -> &Ident {
158 &self.name
159 }
160
161 pub fn function_name(&self) -> &Ident {
162 self.function_name
163 }
164
165 pub fn fields(&self) -> &Vec<MsgField> {
166 &self.fields
167 }
168
169 pub fn msg_attr(&self) -> &MsgAttr {
170 &self.msg_attr
171 }
172
173 pub fn return_type(&self) -> &Option<Type> {
174 &self.return_type
175 }
176}
177
178#[derive(Debug)]
179pub struct MsgVariants<'a, Generic> {
180 variants: Vec<MsgVariant<'a>>,
181 used_generics: Vec<&'a Generic>,
182 unused_generics: Vec<&'a Generic>,
183 where_predicates: Vec<&'a WherePredicate>,
184 msg_ty: MsgType,
185}
186
187impl<'a, Generic> MsgVariants<'a, Generic>
188where
189 Generic: GetPath + PartialEq + ToTokens,
190{
191 pub fn new(
192 source: VariantDescs<'a>,
193 msg_ty: MsgType,
194 all_generics: &'a [&'a Generic],
195 unfiltered_where_clause: &'a Option<WhereClause>,
196 ) -> Self {
197 let mut generics_checker = CheckGenerics::new(all_generics);
198 let variants: Vec<_> = source
199 .filter_map(|variant_desc| {
200 let msg_attr: MsgAttr = variant_desc.attr_msg()?;
201 let attrs_to_forward = variant_desc.attrs_to_forward();
202
203 if msg_attr.msg_type() != msg_ty {
204 return None;
205 }
206
207 Some(MsgVariant::new(
208 variant_desc.into_sig(),
209 &mut generics_checker,
210 msg_attr,
211 attrs_to_forward,
212 ))
213 })
214 .collect();
215
216 let (used_generics, unused_generics) = generics_checker.used_unused();
217 let where_predicates = filter_wheres(unfiltered_where_clause, all_generics, &used_generics);
218
219 Self {
220 variants,
221 used_generics,
222 unused_generics,
223 where_predicates,
224 msg_ty,
225 }
226 }
227
228 pub fn where_clause(&self) -> Option<WhereClause> {
229 let where_predicates = &self.where_predicates;
230 if !where_predicates.is_empty() {
231 Some(parse_quote! { where #(#where_predicates),* })
232 } else {
233 None
234 }
235 }
236
237 pub fn variants(&self) -> impl Iterator<Item = &MsgVariant> {
238 self.variants.iter()
239 }
240
241 pub fn used_generics(&self) -> &Vec<&'a Generic> {
242 &self.used_generics
243 }
244
245 pub fn unused_generics(&self) -> &Vec<&'a Generic> {
246 &self.unused_generics
247 }
248
249 pub fn msg_ty(&self) -> MsgType {
250 self.msg_ty
251 }
252
253 pub fn emit_phantom_match_arm(&self) -> TokenStream {
254 let sylvia = crate_module();
255 let Self { used_generics, .. } = self;
256 if used_generics.is_empty() {
257 return quote! {};
258 }
259 quote! {
260 _Phantom(_) => Err(#sylvia ::cw_std::StdError::generic_err("Phantom message should not be constructed.")).map_err(Into::into),
261 }
262 }
263
264 pub fn emit_dispatch_legs(&self) -> impl Iterator<Item = TokenStream> + '_ {
265 self.variants
266 .iter()
267 .map(|variant| variant.emit_dispatch_leg())
268 }
269
270 pub fn as_names_snake_cased(&self) -> Vec<String> {
271 self.variants
272 .iter()
273 .map(|variant| variant.name.to_string().to_case(Case::Snake))
274 .collect()
275 }
276
277 pub fn emit_constructors(&self) -> impl Iterator<Item = TokenStream> + '_ {
278 self.variants
279 .iter()
280 .map(MsgVariant::emit_variants_constructors)
281 }
282
283 pub fn emit(&self) -> impl Iterator<Item = TokenStream> + '_ {
284 self.variants.iter().map(MsgVariant::emit)
285 }
286
287 pub fn get_only_variant(&self) -> Option<&MsgVariant> {
288 self.variants.first()
289 }
290
291 pub fn emit_phantom_variant(&self) -> TokenStream {
292 let Self {
293 msg_ty,
294 used_generics,
295 ..
296 } = self;
297
298 if used_generics.is_empty() {
299 return quote! {};
300 }
301
302 let return_attr = match msg_ty {
303 MsgType::Query => quote! { #[returns((#(#used_generics,)*))] },
304 _ => quote! {},
305 };
306
307 quote! {
308 #[serde(skip)]
309 #return_attr
310 _Phantom(std::marker::PhantomData<( #(#used_generics,)* )>),
311 }
312 }
313}