1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use proc_macro_error::{proc_macro_error, SpanRange};
4use quote::{quote, ToTokens};
5use syn::{
6 parse_macro_input, parse_quote, punctuated::Punctuated, Attribute, AttributeArgs, DeriveInput,
7 FnArg, Ident, Item, ItemFn, Signature, Token,
8};
9
10macro_rules! unwrap_or_compile_error {
12 ($expr:expr) => {
13 match $expr {
14 Ok(v) => v,
15 Err(e) => return e.to_compile_error().into(),
16 }
17 };
18}
19
20fn default_tarantool_crate_path() -> syn::Path {
21 parse_quote! { tarantool }
22}
23
24mod test;
25
26#[proc_macro_attribute]
30pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
31 test::impl_macro_attribute(attr, item)
32}
33
34mod msgpack {
35 use darling::FromDeriveInput;
36 use proc_macro2::TokenStream;
37 use proc_macro_error::{abort, SpanRange};
38 use quote::{format_ident, quote, quote_spanned, ToTokens};
39 use syn::{
40 parse_quote, spanned::Spanned, Data, Field, Fields, FieldsNamed, FieldsUnnamed,
41 GenericParam, Generics, Ident, Index, Path, Type,
42 };
43
44 #[derive(Default, FromDeriveInput)]
45 #[darling(attributes(encode), default)]
46 pub struct Args {
47 pub as_map: bool,
49 pub tarantool: Option<String>,
51 pub allow_array_optionals: bool,
53 pub untagged: bool,
55 }
56
57 pub fn add_trait_bounds(mut generics: Generics, tarantool_crate: &Path) -> Generics {
58 for param in &mut generics.params {
59 if let GenericParam::Type(ref mut type_param) = *param {
60 type_param
61 .bounds
62 .push(parse_quote!(#tarantool_crate::msgpack::Encode));
63 }
64 }
65 generics
66 }
67
68 trait TypeExt {
69 fn is_option(&self) -> bool;
70 }
71
72 impl TypeExt for Type {
73 fn is_option(&self) -> bool {
74 if let Type::Path(ref typepath) = self {
75 typepath
76 .path
77 .segments
78 .last()
79 .map(|segment| segment.ident == "Option")
80 .unwrap_or(false)
81 } else {
82 false
83 }
84 }
85 }
86
87 enum FieldAttr {
89 Raw,
91 Map,
93 Vec,
95 }
96
97 impl FieldAttr {
98 #[inline]
101 fn from_field(field: &Field) -> Result<Option<Self>, syn::Error> {
102 let attrs = &field.attrs;
103
104 let mut encode_attr = None;
105
106 for attr in attrs.iter().filter(|attr| attr.path.is_ident("encode")) {
107 if encode_attr.is_some() {
108 return Err(syn::Error::new(
109 attr.span(),
110 "multiple encoding types are not allowed",
111 ));
112 }
113
114 encode_attr = Some(attr);
115 }
116
117 match encode_attr {
118 Some(attr) => attr.parse_args_with(|input: syn::parse::ParseStream| {
119 if input.is_empty() {
120 return Err(syn::Error::new(
121 input.span(),
122 "empty encoding type is not allowed",
123 ));
124 }
125
126 let ident: Ident = input.parse()?;
127
128 if !input.is_empty() {
129 return Err(syn::Error::new(
130 ident.span(),
131 "multiple encoding types are not allowed",
132 ));
133 }
134
135 if ident == "as_raw" {
136 let mut field_type_name = proc_macro2::TokenStream::new();
137 field.ty.to_tokens(&mut field_type_name);
138 if field_type_name.to_string() != "Vec < u8 >" {
139 Err(syn::Error::new(
140 ident.span(),
141 "only `Vec<u8>` is supported for `as_raw`",
142 ))
143 } else {
144 Ok(Some(Self::Raw))
145 }
146 } else if ident == "as_map" {
147 Ok(Some(Self::Map))
148 } else if ident == "as_vec" {
149 Ok(Some(Self::Vec))
150 } else {
151 Err(syn::Error::new(ident.span(), "unknown encoding type"))
152 }
153 }),
154 None => Ok(None),
155 }
156 }
157 }
158
159 fn encode_named_fields(
160 fields: &FieldsNamed,
161 tarantool_crate: &Path,
162 add_self: bool,
163 ) -> proc_macro2::TokenStream {
164 fields
165 .named
166 .iter()
167 .flat_map(|f| {
168 let field_name = f.ident.as_ref().expect("only named fields here");
169 let field_repr = format_ident!("{}", field_name).to_string();
170 let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(f));
171
172 let s = if add_self {
173 quote! {&self.}
174 } else {
175 quote! {}
176 };
177
178 let write_key = quote_spanned! {f.span()=>
179 if as_map {
180 #tarantool_crate::msgpack::rmp::encode::write_str(w, #field_repr)?;
181 }
182 };
183 if let Some(attr) = field_attr {
184 match attr {
185 FieldAttr::Raw => quote_spanned! {f.span()=>
186 #write_key
187 w.write_all(#s #field_name)?;
188 },
189 FieldAttr::Map => {
191 syn::Error::new(f.span(), "`as_map` is not currently supported")
192 .to_compile_error()
193 }
194 FieldAttr::Vec => {
195 syn::Error::new(f.span(), "`as_vec` is not currently supported")
196 .to_compile_error()
197 }
198 }
199 } else {
200 quote_spanned! {f.span()=>
201 #write_key
202 #tarantool_crate::msgpack::Encode::encode(#s #field_name, w, context)?;
203 }
204 }
205 })
206 .collect()
207 }
208
209 fn encode_unnamed_fields(
210 fields: &FieldsUnnamed,
211 tarantool_crate: &Path,
212 ) -> proc_macro2::TokenStream {
213 fields
214 .unnamed
215 .iter()
216 .enumerate()
217 .flat_map(|(i, f)| {
218 let index = Index::from(i);
219 let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(f));
220
221 if let Some(field) = field_attr {
222 match field {
223 FieldAttr::Raw => quote_spanned! {f.span()=>
224 w.write_all(&self.#index)?;
225 },
226 FieldAttr::Map => {
228 syn::Error::new(f.span(), "`as_map` is not currently supported")
229 .to_compile_error()
230 }
231 FieldAttr::Vec => {
232 syn::Error::new(f.span(), "`as_vec` is not currently supported")
233 .to_compile_error()
234 }
235 }
236 } else {
237 quote_spanned! {f.span()=>
238 #tarantool_crate::msgpack::Encode::encode(&self.#index, w, context)?;
239 }
240 }
241 })
242 .collect()
243 }
244
245 pub fn encode_fields(
246 data: &Data,
247 tarantool_crate: &Path,
248 attrs_span: impl Fn() -> SpanRange,
249 args: &Args,
250 ) -> proc_macro2::TokenStream {
251 let as_map = args.as_map;
252 let is_untagged = args.untagged;
253 match *data {
254 Data::Struct(ref data) => {
255 if is_untagged {
256 abort!(
257 attrs_span(),
258 "untagged encode representation is allowed only for enums"
259 );
260 }
261 match data.fields {
262 Fields::Named(ref fields) => {
263 let field_count = fields.named.len() as u32;
264 let fields = encode_named_fields(fields, tarantool_crate, true);
265 quote! {
266 let as_map = match context.struct_style() {
267 StructStyle::Default => #as_map,
268 StructStyle::ForceAsMap => true,
269 StructStyle::ForceAsArray => false,
270 };
271 if as_map {
272 #tarantool_crate::msgpack::rmp::encode::write_map_len(w, #field_count)?;
273 } else {
274 #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
275 }
276 #fields
277 }
278 }
279 Fields::Unnamed(ref fields) => {
280 if as_map {
281 abort!(
282 attrs_span(),
283 "`as_map` attribute can be specified only for structs with named fields"
284 );
285 }
286 let field_count = fields.unnamed.len() as u32;
287 let fields = encode_unnamed_fields(fields, tarantool_crate);
288 quote! {
289 #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
290 #fields
291 }
292 }
293 Fields::Unit => {
294 quote!(#tarantool_crate::msgpack::Encode::encode(&(), w, context)?;)
295 }
296 }
297 }
298 Data::Enum(ref variants) => {
299 if as_map {
300 abort!(
301 attrs_span(),
302 "`as_map` attribute can be specified only for structs"
303 );
304 }
305 let variants: proc_macro2::TokenStream = variants
306 .variants
307 .iter()
308 .flat_map(|variant| {
309 let variant_name = &variant.ident;
310 let variant_repr = format_ident!("{}", variant_name).to_string();
311 match variant.fields {
312 Fields::Named(ref fields) => {
313 let field_count = fields.named.len() as u32;
314 let field_names = fields.named.iter().map(|field| field.ident.clone());
315 let fields = encode_named_fields(fields, tarantool_crate, false);
316 if is_untagged {
318 quote! {
319 Self::#variant_name { #(#field_names),*} => {
320 #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
321 let as_map = false;
322 #fields
323 }
324 }
325 } else {
326 quote! {
327 Self::#variant_name { #(#field_names),*} => {
328 #tarantool_crate::msgpack::rmp::encode::write_str(w, #variant_repr)?;
329 #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
330 let as_map = false;
331 #fields
332 }
333 }
334 }
335 },
336 Fields::Unnamed(ref fields) => {
337 let field_count = fields.unnamed.len() as u32;
338 let field_names = fields.unnamed.iter().enumerate().map(|(i, _)| format_ident!("_field_{}", i));
339 let fields: proc_macro2::TokenStream = field_names.clone()
340 .flat_map(|field_name| quote! {
341 #tarantool_crate::msgpack::Encode::encode(#field_name, w, context)?;
342 })
343 .collect();
344 if is_untagged {
345 quote! {
346 Self::#variant_name ( #(#field_names),*) => {
347 #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
348 #fields
349 }
350 }
351 } else {
352 quote! {
353 Self::#variant_name ( #(#field_names),*) => {
354 #tarantool_crate::msgpack::rmp::encode::write_str(w, #variant_repr)?;
355 #tarantool_crate::msgpack::rmp::encode::write_array_len(w, #field_count)?;
356 #fields
357 }
358 }
359 }
360 }
361 Fields::Unit => {
362 if is_untagged {
363 quote! {
364 Self::#variant_name => #tarantool_crate::msgpack::Encode::encode(&(), w, context)?,
365 }
366 } else {
367 quote! {
368 Self::#variant_name => {
369 #tarantool_crate::msgpack::rmp::encode::write_str(w, #variant_repr)?;
370 #tarantool_crate::msgpack::Encode::encode(&(), w, context)?;
371 }
372 }
373 }
374 },
375 }
376 })
377 .collect();
378 if is_untagged {
379 quote! {
380 match self {
381 #variants
382 }
383 }
384 } else {
385 quote! {
386 #tarantool_crate::msgpack::rmp::encode::write_map_len(w, 1)?;
387 match self {
388 #variants
389 }
390 }
391 }
392 }
393 Data::Union(_) => unimplemented!(),
394 }
395 }
396
397 fn decode_named_fields(
398 fields: &FieldsNamed,
399 tarantool_crate: &Path,
400 enum_variant: Option<&syn::Ident>,
401 args: &Args,
402 ) -> TokenStream {
403 let allow_array_optionals = args.allow_array_optionals;
404
405 let mut var_names = Vec::with_capacity(fields.named.len());
406 let mut met_option = false;
407 let fields_amount = fields.named.len();
408 let mut fields_passed = fields_amount;
409 let code: TokenStream = fields
410 .named
411 .iter()
412 .map(|f| {
413 if f.ty.is_option() {
414 met_option = true;
415 fields_passed -= 1;
416 decode_named_optional_field(f, tarantool_crate, &mut var_names, allow_array_optionals, fields_amount, fields_passed)
417 } else {
418 if met_option && allow_array_optionals {
419 return syn::Error::new(
420 f.span(),
421 "optional fields must be the last in the parameter list if allow_array_optionals is enabled",
422 )
423 .to_compile_error();
424 }
425 fields_passed -= 1;
426 decode_named_required_field(f, tarantool_crate, &mut var_names)
427 }
428 })
429 .collect();
430 let field_names = fields.named.iter().map(|f| &f.ident);
431 let enum_variant = if let Some(variant) = enum_variant {
432 quote! { ::#variant }
433 } else {
434 quote! {}
435 };
436 quote! {
437 #code
438 Ok(Self #enum_variant {
439 #(#field_names: #var_names),*
440 })
441 }
442 }
443
444 #[inline]
445 fn decode_named_optional_field(
446 field: &Field,
447 tarantool_crate: &Path,
448 names: &mut Vec<Ident>,
449 allow_array_optionals: bool,
450 fields_amount: usize,
451 fields_passed: usize,
452 ) -> TokenStream {
453 let field_type = &field.ty;
454 let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(field));
455
456 let field_ident = field.ident.as_ref().expect("only named fields here");
457 let field_repr = format_ident!("{}", field_ident).to_string();
458 let field_name = proc_macro2::Literal::byte_string(field_repr.as_bytes());
459 let var_name = format_ident!("_field_{}", field_ident);
460
461 let read_key = quote_spanned! {field.span()=>
462 if as_map {
463 use #tarantool_crate::msgpack::str_bounds;
464
465 let (byte_len, field_name_len_spaced) = str_bounds(r)
466 .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part("field name"))?;
467 let decoded_field_name = r.get(byte_len..field_name_len_spaced).unwrap();
468 if decoded_field_name != #field_name {
469 is_none = true;
470 } else {
471 let len = rmp::decode::read_str_len(r).unwrap();
472 *r = &r[(len as usize)..]; }
474 }
475 };
476
477 let out = match field_attr {
479 Some(FieldAttr::Map) => unimplemented!("`as_map` is not currently supported"),
480 Some(FieldAttr::Vec) => unimplemented!("`as_vec` is not currently supported"),
481 Some(FieldAttr::Raw) => quote_spanned! {field.span()=>
482 let mut #var_name: #field_type = None;
483 let mut is_none = false;
484
485 #read_key
486 if !is_none {
487 #var_name = Some(#tarantool_crate::msgpack::preserve_read(r).expect("only valid msgpack here"));
488 }
489 },
490 None => quote_spanned! {field.span()=>
491 let mut #var_name: #field_type = None;
492 let mut is_none = false;
493
494 #read_key
495 if !is_none {
496 match #tarantool_crate::msgpack::Decode::decode(r, context) {
497 Ok(val) => #var_name = Some(val),
498 Err(err) => {
499 let markered = err.source.get(err.source.len() - 33..).unwrap_or("") == "failed to read MessagePack marker";
500 let nulled = if err.part.is_some() {
501 err.part.as_ref().expect("Can't fail after a conditional check") == "got Null"
502 } else {
503 false
504 };
505
506 if !nulled && !#allow_array_optionals && !as_map {
507 let message = format!("not enough fields, expected {}, got {} (note: optional fields must be explicitly null unless `allow_array_optionals` attribute is passed)", #fields_amount, #fields_passed);
508 Err(#tarantool_crate::msgpack::DecodeError::new::<Self>(message))?;
509 } else if !nulled && !markered && #allow_array_optionals {
510 Err(err)?;
511 }
512 },
513 }
514 }
515 },
516 };
517
518 names.push(var_name);
519 out
520 }
521
522 #[inline]
523 fn decode_named_required_field(
524 field: &Field,
525 tarantool_crate: &Path,
526 names: &mut Vec<Ident>,
527 ) -> TokenStream {
528 let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(field));
529
530 let field_ident = field.ident.as_ref().expect("only named fields here");
531 let field_repr = format_ident!("{}", field_ident).to_string();
532 let field_name = proc_macro2::Literal::byte_string(field_repr.as_bytes());
533 let var_name = format_ident!("_field_{}", field_ident);
534
535 let read_key = quote_spanned! {field.span()=>
536 if as_map {
537 let len = rmp::decode::read_str_len(r)
538 .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err).with_part("field name"))?;
539 let decoded_field_name = r.get(0..(len as usize))
540 .ok_or_else(|| #tarantool_crate::msgpack::DecodeError::new::<Self>("not enough data").with_part("field name"))?;
541 *r = &r[(len as usize)..]; if decoded_field_name != #field_name {
543 let field_name = String::from_utf8(#field_name.to_vec()).expect("is valid utf8");
544 let err = if let Ok(decoded_field_name) = String::from_utf8(decoded_field_name.to_vec()) {
545 format!("expected field {}, got {}", field_name, decoded_field_name)
546 } else {
547 format!("expected field {}, got invalid utf8 {:?}", field_name, decoded_field_name)
548 };
549 return Err(#tarantool_crate::msgpack::DecodeError::new::<Self>(err));
550 }
551 }
552 };
553
554 let out = if let Some(FieldAttr::Raw) = field_attr {
556 quote_spanned! {field.span()=>
557 #read_key
558 let #var_name = #tarantool_crate::msgpack::preserve_read(r).expect("only valid msgpack here");
559 }
560 } else if let Some(FieldAttr::Map) = field_attr {
561 unimplemented!("`as_map` is not currently supported");
562 } else if let Some(FieldAttr::Vec) = field_attr {
563 unimplemented!("`as_vec` is not currently supported");
564 } else {
565 quote_spanned! {field.span()=>
566 #read_key
567 let #var_name = #tarantool_crate::msgpack::Decode::decode(r, context)
568 .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part(format!("field {}", stringify!(#field_ident))))?;
569 }
570 };
571
572 names.push(var_name);
573 out
574 }
575
576 fn decode_unnamed_fields(
577 fields: &FieldsUnnamed,
578 tarantool_crate: &Path,
579 enum_variant: Option<&syn::Ident>,
580 args: &Args,
581 ) -> proc_macro2::TokenStream {
582 let allow_array_optionals = args.allow_array_optionals;
583
584 let mut var_names = Vec::with_capacity(fields.unnamed.len());
585 let mut met_option = false;
586 let code: proc_macro2::TokenStream = fields
587 .unnamed
588 .iter()
589 .enumerate()
590 .map(|(i, f)| {
591 let is_option = f.ty.is_option();
592 if is_option {
593 met_option = true;
594 decode_unnamed_optional_field(f, i, tarantool_crate, &mut var_names)
595 } else if met_option && allow_array_optionals {
596 return syn::Error::new(
597 f.span(),
598 "optional fields must be the last in the parameter list with `allow_array_optionals` attribute",
599 )
600 .to_compile_error();
601 } else {
602 decode_unnamed_required_field(f, i, tarantool_crate, &mut var_names)
603 }
604 })
605 .collect();
606 let enum_variant = if let Some(variant) = enum_variant {
607 quote! { ::#variant }
608 } else {
609 quote! {}
610 };
611 quote! {
612 #code
613 Ok(Self #enum_variant (
614 #(#var_names),*
615 ))
616 }
617 }
618
619 fn decode_unnamed_optional_field(
620 field: &Field,
621 index: usize,
622 tarantool_crate: &Path,
623 names: &mut Vec<Ident>,
624 ) -> TokenStream {
625 let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(field));
626 let field_type = &field.ty;
627
628 let field_index = Index::from(index);
629 let var_name = quote::format_ident!("_field_{}", field_index);
630
631 let out = match field_attr {
632 Some(FieldAttr::Map) => unimplemented!("`as_map` is not currently supported"),
633 Some(FieldAttr::Vec) => unimplemented!("`as_vec` is not currently supported"),
634 Some(FieldAttr::Raw) => quote_spanned! {field.span()=>
635 let #var_name = #tarantool_crate::msgpack::preserve_read(r).expect("only valid msgpack here");
636 },
637 None => quote_spanned! {field.span()=>
638 let mut #var_name: #field_type = None;
639 match #tarantool_crate::msgpack::Decode::decode(r, context) {
640 Ok(val) => #var_name = Some(val),
641 Err(err) => {
642 let markered = err.source.get(err.source.len() - 33..).unwrap_or("")== "failed to read MessagePack marker";
643 let nulled = if err.part.is_some() {
644 err.part.as_ref().expect("Can't fail after a conditional check") == "got Null"
645 } else {
646 false
647 };
648
649 if !nulled && !markered {
650 Err(#tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part(format!("{}", stringify!(#field_index))))?;
651 }
652 },
653 }
654 },
655 };
656
657 names.push(var_name);
658 out
659 }
660
661 fn decode_unnamed_required_field(
662 field: &Field,
663 index: usize,
664 tarantool_crate: &Path,
665 names: &mut Vec<Ident>,
666 ) -> TokenStream {
667 let field_attr = unwrap_or_compile_error!(FieldAttr::from_field(field));
668
669 let field_index = Index::from(index);
670 let var_name = quote::format_ident!("_field_{}", field_index);
671
672 let out = if let Some(FieldAttr::Raw) = field_attr {
673 quote_spanned! {field.span()=>
674 let #var_name = #tarantool_crate::msgpack::preserve_read(r).expect("only valid msgpack here");
675 }
676 } else if let Some(FieldAttr::Map) = field_attr {
677 unimplemented!("`as_map` is not currently supported");
678 } else if let Some(FieldAttr::Vec) = field_attr {
679 unimplemented!("`as_vec` is not currently supported");
680 } else {
681 quote_spanned! {field.span()=>
682 let #var_name = #tarantool_crate::msgpack::Decode::decode(r, context)
683 .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part(format!("field {}", #index)))?;
684 }
685 };
686
687 names.push(var_name);
688 out
689 }
690
691 pub fn decode_fields(
692 data: &Data,
693 tarantool_crate: &Path,
694 attrs_span: impl Fn() -> SpanRange,
695 args: &Args,
696 ) -> TokenStream {
697 let as_map = args.as_map;
698 let is_untagged = args.untagged;
699
700 if is_untagged {
701 return decode_untagged(data, tarantool_crate, attrs_span);
702 }
703
704 match *data {
705 Data::Struct(ref data) => {
706 match data.fields {
707 Fields::Named(ref fields) => {
708 let first_field_name = fields
709 .named
710 .first()
711 .expect("not a unit struct")
712 .ident
713 .as_ref()
714 .expect("not an unnamed struct")
715 .to_string();
716 let fields = decode_named_fields(fields, tarantool_crate, None, args);
717 quote! {
718 let as_map = match context.struct_style() {
719 StructStyle::Default => #as_map,
720 StructStyle::ForceAsMap => true,
721 StructStyle::ForceAsArray => false,
722 };
723 if as_map {
725 #tarantool_crate::msgpack::rmp::decode::read_map_len(r)
726 .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err))?;
727 } else {
728 #tarantool_crate::msgpack::rmp::decode::read_array_len(r)
729 .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre_with_field::<Self>(err, #first_field_name))?;
730 }
731 #fields
732 }
733 }
734 Fields::Unnamed(ref fields) => {
735 if as_map {
736 abort!(
737 attrs_span(),
738 "`as_map` attribute can be specified only for structs with named fields"
739 );
740 }
741
742 let mut option_key = TokenStream::new();
743 if fields.unnamed.len() == 1 {
744 let first_field = fields.unnamed.first().expect("len is sufficient");
745 let is_option = first_field.ty.is_option();
746 if is_option {
747 option_key = quote! {
748 if r.is_empty() {
749 return Ok(Self(None));
750 }
751 };
752 }
753 }
754
755 let fields = decode_unnamed_fields(fields, tarantool_crate, None, args);
756 quote! {
757 #option_key
758 #tarantool_crate::msgpack::rmp::decode::read_array_len(r)
759 .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err))?;
760 #fields
761 }
762 }
763 Fields::Unit => {
764 quote! {
765 let () = #tarantool_crate::msgpack::Decode::decode(r, context)?;
766 Ok(Self)
767 }
768 }
769 }
770 }
771 Data::Enum(ref variants) => {
772 if as_map {
773 abort!(
774 attrs_span(),
775 "`as_map` attribute can be specified only for structs"
776 );
777 }
778 let mut variant_reprs = Vec::new();
779 let variants: proc_macro2::TokenStream = variants
780 .variants
781 .iter()
782 .flat_map(|variant| {
783 let variant_ident = &variant.ident;
784 let variant_repr = format_ident!("{}", variant_ident).to_string();
785 variant_reprs.push(variant_repr.clone());
786 let variant_repr = proc_macro2::Literal::byte_string(variant_repr.as_bytes());
787
788 match variant.fields {
789 Fields::Named(ref fields) => {
790 let fields = decode_named_fields(fields, tarantool_crate, Some(&variant.ident), args);
791 quote! {
793 #variant_repr => {
794 #tarantool_crate::msgpack::rmp::decode::read_array_len(r)
795 .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err))?;
796 let as_map = false;
797 #fields
798 }
799 }
800 },
801 Fields::Unnamed(ref fields) => {
802 let fields = decode_unnamed_fields(fields, tarantool_crate, Some(&variant.ident), args);
803 quote! {
804 #variant_repr => {
805 #tarantool_crate::msgpack::rmp::decode::read_array_len(r)
806 .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err))?;
807 let as_map = false;
808 #fields
809 }
810 }
811 }
812 Fields::Unit => {
813 quote! {
814 #variant_repr => {
815 let () = #tarantool_crate::msgpack::Decode::decode(r, context)
816 .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err))?;
817 Ok(Self::#variant_ident)
818 }
819 }
820 },
821 }
822 })
823 .collect();
824 quote! {
825 #tarantool_crate::msgpack::rmp::decode::read_map_len(r)
827 .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err))?;
828 let len = rmp::decode::read_str_len(r)
829 .map_err(|err| #tarantool_crate::msgpack::DecodeError::from_vre::<Self>(err).with_part("variant name"))?;
830 let variant_name = r.get(0..(len as usize))
831 .ok_or_else(|| #tarantool_crate::msgpack::DecodeError::new::<Self>("not enough data").with_part("variant name"))?;
832 *r = &r[(len as usize)..]; match variant_name {
834 #variants
835 other => {
836 let err = if let Ok(other) = String::from_utf8(other.to_vec()) {
837 format!("enum variant {} does not exist", other)
838 } else {
839 format!("enum variant {:?} is invalid utf8", other)
840 };
841 return Err(#tarantool_crate::msgpack::DecodeError::new::<Self>(err));
842 }
843 }
844 }
845 }
846 Data::Union(_) => unimplemented!(),
847 }
848 }
849
850 pub fn decode_untagged(
851 data: &Data,
852 tarantool_crate: &Path,
853 attrs_span: impl Fn() -> SpanRange,
854 ) -> TokenStream {
855 let out = match *data {
856 Data::Struct(_) => abort!(
857 attrs_span(),
858 "untagged decode representation is allowed only for enums"
859 ),
860 Data::Union(_) => unimplemented!(),
861 Data::Enum(ref variants) => {
862 let variants = variants.variants.iter();
863 let variants_amount = variants.len();
864 if variants_amount == 0 {
865 abort!(
866 attrs_span(),
867 "deserialization of enum with no variants is not possible"
868 );
869 }
870
871 variants
872 .flat_map(|variant| {
873 let variant_ident = &variant.ident;
874
875 match variant.fields {
876 Fields::Unit => {
877 quote! {
878 let mut r_try = *r;
880 let mut try_unit = || -> Result<(), #tarantool_crate::msgpack::DecodeError> {
881 let () = #tarantool_crate::msgpack::Decode::decode(&mut r_try, context)
882 .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err))?;
883 checker = Some(Self::#variant_ident);
884 Ok(())
885 };
886
887 if try_unit().is_ok() {
888 *r = r_try;
889 return Result::<Self, #tarantool_crate::msgpack::DecodeError>::Ok(checker.unwrap());
890 }
891 }
892 },
893 Fields::Unnamed(ref fields) => {
894 let fields = &fields.unnamed;
895 let fields_amount = fields.len();
896 let mut var_names = Vec::with_capacity(fields.len());
897 let code: TokenStream = fields
898 .iter()
899 .enumerate()
900 .map(|(index, field)| {
901 let field_index = Index::from(index);
902 let var_name = quote::format_ident!("_field_{}", field_index);
903 let var_type = &field.ty;
904
905 let out = quote_spanned! {field.span()=>
906 let #var_name: #var_type = #tarantool_crate::msgpack::Decode::decode(&mut r_try, context)
907 .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part(format!("field {}", #index)))?;
908 };
909
910 var_names.push(var_name);
911 out
912 })
913 .collect();
914 quote! {
915 let mut r_try = *r;
917 let mut try_unnamed = || -> Result<(), #tarantool_crate::msgpack::DecodeError> {
918 let amount = #tarantool_crate::msgpack::rmp::decode::read_array_len(&mut r_try)
919 .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err))?;
920 if amount as usize != #fields_amount {
921 Err(#tarantool_crate::msgpack::DecodeError::new::<Self>("non-equal amount of type fields"))?;
922 }
923 #code
924 checker = Some(Self::#variant_ident(
925 #(#var_names),*
926 ));
927 Ok(())
928 };
929
930 if try_unnamed().is_ok() {
931 *r = r_try;
932 return Result::<Self, #tarantool_crate::msgpack::DecodeError>::Ok(checker.unwrap());
933 }
934 }
935 },
936 Fields::Named(ref fields) => {
937 let fields = &fields.named;
938 let fields_amount = fields.len();
939 let field_names = fields.iter().map(|field| &field.ident);
940 let mut var_names = Vec::with_capacity(fields.len());
941 let code: TokenStream = fields
942 .iter()
943 .map(|field| {
944 let field_ident = field.ident.as_ref().expect("only named fields here");
945 let var_name = format_ident!("_field_{}", field_ident);
946 let var_type = &field.ty;
947
948 let out = quote_spanned! {field.span()=>
949 let #var_name: #var_type = #tarantool_crate::msgpack::Decode::decode(&mut r_try, context)
950 .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err).with_part(format!("field {}", stringify!(#field_ident))))?;
951 };
952
953 var_names.push(var_name);
954 out
955 })
956 .collect();
957 quote! {
958 let mut r_try = *r;
960 let mut try_named = || -> Result<(), #tarantool_crate::msgpack::DecodeError> {
961 let amount = #tarantool_crate::msgpack::rmp::decode::read_array_len(&mut r_try)
962 .map_err(|err| #tarantool_crate::msgpack::DecodeError::new::<Self>(err))?;
963 if amount as usize != #fields_amount {
964 Err(#tarantool_crate::msgpack::DecodeError::new::<Self>("non-equal amount of type fields"))?;
965 }
966 #code
967 checker = Some(Self::#variant_ident {
968 #(#field_names: #var_names),*
969 });
970 Ok(())
971 };
972
973 if try_named().is_ok() {
974 *r = r_try;
975 return Result::<Self, #tarantool_crate::msgpack::DecodeError>::Ok(checker.unwrap());
976 }
977 }
978 },
979 }
980 })
981 .collect::<TokenStream>()
982 }
983 };
984 quote! {
985 let mut checker: Option<Self> = None;
986 #out
987 Result::<Self, #tarantool_crate::msgpack::DecodeError>::Err(#tarantool_crate::msgpack::DecodeError::new::<Self>("received stream didn't match any enum variant"))
988 }
989 }
990}
991
992fn attrs_span<'a>(attrs: impl IntoIterator<Item = &'a Attribute>) -> SpanRange {
994 SpanRange::from_tokens(
995 &attrs
996 .into_iter()
997 .flat_map(ToTokens::into_token_stream)
998 .collect::<TokenStream2>(),
999 )
1000}
1001
1002#[inline]
1005fn collect_lifetimes(generics: &syn::Generics) -> Punctuated<syn::Lifetime, Token![+]> {
1006 let mut lifetimes = Punctuated::new();
1007 let mut unique_lifetimes = std::collections::HashSet::new();
1008
1009 for param in &generics.params {
1010 if let syn::GenericParam::Lifetime(lifetime_def) = param {
1011 if unique_lifetimes.insert(lifetime_def.lifetime.clone()) {
1012 lifetimes.push(lifetime_def.lifetime.clone());
1013 }
1014 }
1015 }
1016
1017 lifetimes
1018}
1019
1020#[proc_macro_error]
1027#[proc_macro_derive(Encode, attributes(encode))]
1028pub fn derive_encode(input: TokenStream) -> TokenStream {
1029 let input = parse_macro_input!(input as DeriveInput);
1030 let name = &input.ident;
1031
1032 let args: msgpack::Args = darling::FromDeriveInput::from_derive_input(&input).unwrap();
1034 let tarantool_crate = args
1035 .tarantool
1036 .as_deref()
1037 .map(syn::parse_str)
1038 .transpose()
1039 .unwrap()
1040 .unwrap_or_else(default_tarantool_crate_path);
1041
1042 let generics = msgpack::add_trait_bounds(input.generics, &tarantool_crate);
1044 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
1045 let encode_fields = msgpack::encode_fields(
1046 &input.data,
1047 &tarantool_crate,
1048 || attrs_span(&input.attrs),
1051 &args,
1052 );
1053 let expanded = quote! {
1054 impl #impl_generics #tarantool_crate::msgpack::Encode for #name #ty_generics #where_clause {
1056 fn encode(&self, w: &mut impl ::std::io::Write, context: &#tarantool_crate::msgpack::Context)
1057 -> Result<(), #tarantool_crate::msgpack::EncodeError>
1058 {
1059 use #tarantool_crate::msgpack::StructStyle;
1060 #encode_fields
1061 Ok(())
1062 }
1063 }
1064 };
1065
1066 expanded.into()
1067}
1068
1069#[proc_macro_error]
1076#[proc_macro_derive(Decode, attributes(encode))]
1077pub fn derive_decode(input: TokenStream) -> TokenStream {
1078 let input = parse_macro_input!(input as DeriveInput);
1079 let name = &input.ident;
1080
1081 let args: msgpack::Args = darling::FromDeriveInput::from_derive_input(&input).unwrap();
1083 let tarantool_crate = args.tarantool.as_deref().unwrap_or("tarantool");
1084 let tarantool_crate = Ident::new(tarantool_crate, Span::call_site()).into();
1085
1086 let generics = msgpack::add_trait_bounds(input.generics.clone(), &tarantool_crate);
1088 let mut impl_generics = input.generics;
1089 impl_generics.params.insert(
1090 0,
1091 syn::GenericParam::Lifetime(syn::LifetimeDef {
1092 attrs: vec![],
1093 lifetime: syn::Lifetime::new("'de", Span::call_site()),
1094 colon_token: Some(syn::token::Colon::default()),
1095 bounds: collect_lifetimes(&generics),
1096 }),
1097 );
1098 let (impl_generics, _, where_clause) = impl_generics.split_for_impl();
1100 let (_, ty_generics, _) = generics.split_for_impl();
1101 let decode_fields = msgpack::decode_fields(
1102 &input.data,
1103 &tarantool_crate,
1104 || attrs_span(&input.attrs),
1107 &args,
1108 );
1109 let expanded = quote! {
1110 impl #impl_generics #tarantool_crate::msgpack::Decode<'de> for #name #ty_generics #where_clause {
1112 fn decode(r: &mut &'de [u8], context: &#tarantool_crate::msgpack::Context)
1113 -> std::result::Result<Self, #tarantool_crate::msgpack::DecodeError>
1114 {
1115 use #tarantool_crate::msgpack::StructStyle;
1116 #decode_fields
1117 }
1118 }
1119 };
1120
1121 expanded.into()
1122}
1123
1124#[proc_macro_attribute]
1128pub fn stored_proc(attr: TokenStream, item: TokenStream) -> TokenStream {
1129 let args = parse_macro_input!(attr as AttributeArgs);
1130 let ctx = Context::from_args(args);
1131
1132 let input = parse_macro_input!(item as Item);
1133
1134 #[rustfmt::skip]
1135 let ItemFn { vis, sig, block, attrs, .. } = match input {
1136 Item::Fn(f) => f,
1137 _ => panic!("only `fn` items can be stored procedures"),
1138 };
1139
1140 let (ident, inputs, output, generics) = match sig {
1141 Signature {
1142 asyncness: Some(_), ..
1143 } => {
1144 panic!("async stored procedures are not supported yet")
1145 }
1146 Signature {
1147 variadic: Some(_), ..
1148 } => {
1149 panic!("variadic stored procedures are not supported yet")
1150 }
1151 Signature {
1152 ident,
1153 inputs,
1154 output,
1155 generics,
1156 ..
1157 } => (ident, inputs, output, generics),
1158 };
1159
1160 let Inputs {
1161 inputs,
1162 input_pattern,
1163 input_idents,
1164 inject_inputs,
1165 n_actual_arguments,
1166 } = Inputs::parse(&ctx, inputs);
1167
1168 if ctx.is_packed && n_actual_arguments > 1 {
1169 panic!("proc with 'packed_args' can only have a single parameter")
1170 }
1171
1172 let Context {
1173 tarantool,
1174 linkme,
1175 section,
1176 debug_tuple,
1177 wrap_ret,
1178 ..
1179 } = ctx;
1180
1181 let inner_fn_name = syn::Ident::new("__tp_inner", ident.span());
1182 let desc_name = ident.to_string();
1183 let desc_ident = syn::Ident::new(&desc_name.to_uppercase(), ident.span());
1184 let mut public = matches!(vis, syn::Visibility::Public(_));
1185 if let Some(override_public) = ctx.public {
1186 public = override_public;
1187 }
1188
1189 let attrs_distributed_slice = if cfg!(feature = "stored_procs_slice") {
1193 quote! {
1194 #[#linkme::distributed_slice(#section)]
1195 #[linkme(crate = #linkme)]
1196 }
1197 } else {
1198 quote! {}
1199 };
1200
1201 quote! {
1202 #attrs_distributed_slice
1203 #[cfg(not(test))]
1204 static #desc_ident: #tarantool::proc::Proc = #tarantool::proc::Proc::new(
1205 #desc_name,
1206 #ident,
1207 ).with_public(#public);
1208
1209 #(#attrs)*
1210 #[no_mangle]
1211 pub unsafe extern "C" fn #ident (
1212 __tp_ctx: #tarantool::tuple::FunctionCtx,
1213 __tp_args: #tarantool::tuple::FunctionArgs,
1214 ) -> ::std::os::raw::c_int {
1215 #debug_tuple
1216 let #input_pattern =
1217 match __tp_args.decode() {
1218 ::std::result::Result::Ok(__tp_args) => __tp_args,
1219 ::std::result::Result::Err(__tp_err) => {
1220 #tarantool::set_error!(
1221 #tarantool::error::TarantoolErrorCode::ProcC,
1222 "{}",
1223 __tp_err
1224 );
1225 return -1;
1226 }
1227 };
1228
1229 #inject_inputs
1230
1231 fn #inner_fn_name #generics (#inputs) #output {
1232 #block
1233 }
1234
1235 let __tp_res = __tp_inner(#(#input_idents),*);
1236
1237 #wrap_ret
1238
1239 #tarantool::proc::Return::ret(__tp_res, __tp_ctx)
1240 }
1241 }
1242 .into()
1243}
1244
1245struct Context {
1246 tarantool: syn::Path,
1247 section: syn::Path,
1248 linkme: syn::Path,
1249 debug_tuple: TokenStream2,
1250 is_packed: bool,
1251 public: Option<bool>,
1252 wrap_ret: TokenStream2,
1253}
1254
1255impl Context {
1256 fn from_args(args: AttributeArgs) -> Self {
1257 let mut tarantool: syn::Path = default_tarantool_crate_path();
1258 let mut linkme = None;
1259 let mut section = None;
1260 let mut debug_tuple_needed = false;
1261 let mut is_packed = false;
1262 let mut public = None;
1263 let mut wrap_ret = quote! {};
1264
1265 for arg in args {
1266 if let Some(path) = imp::parse_lit_str_with_key(&arg, "tarantool") {
1267 tarantool = path;
1268 continue;
1269 }
1270 if let Some(path) = imp::parse_lit_str_with_key(&arg, "linkme") {
1271 linkme = Some(path);
1272 continue;
1273 }
1274 if let Some(path) = imp::parse_lit_str_with_key(&arg, "section") {
1275 section = Some(path);
1276 continue;
1277 }
1278 if imp::is_path_eq_to(&arg, "custom_ret") {
1279 wrap_ret = quote! {
1280 let __tp_res = #tarantool::proc::ReturnMsgpack(__tp_res);
1281 };
1282 continue;
1283 }
1284 if imp::is_path_eq_to(&arg, "packed_args") {
1285 is_packed = true;
1286 continue;
1287 }
1288 if imp::is_path_eq_to(&arg, "debug") {
1289 debug_tuple_needed = true;
1290 continue;
1291 }
1292 if let Some(v) = imp::parse_bool_with_key(&arg, "public") {
1293 public = Some(v);
1294 continue;
1295 }
1296 panic!("unsuported attribute argument `{}`", quote!(#arg))
1297 }
1298
1299 let section = section.unwrap_or_else(|| {
1300 imp::path_from_ts2(quote! { #tarantool::proc::TARANTOOL_MODULE_STORED_PROCS })
1301 });
1302 let linkme = linkme.unwrap_or_else(|| imp::path_from_ts2(quote! { #tarantool::linkme }));
1303
1304 let debug_tuple = if debug_tuple_needed {
1305 quote! {
1306 ::std::dbg!(#tarantool::tuple::Tuple::from(&__tp_args));
1307 }
1308 } else {
1309 quote! {}
1310 };
1311 Self {
1312 tarantool,
1313 linkme,
1314 section,
1315 debug_tuple,
1316 is_packed,
1317 wrap_ret,
1318 public,
1319 }
1320 }
1321}
1322
1323struct Inputs {
1324 inputs: Punctuated<FnArg, Token![,]>,
1325 input_pattern: TokenStream2,
1326 input_idents: Vec<syn::Pat>,
1327 inject_inputs: TokenStream2,
1328 n_actual_arguments: usize,
1329}
1330
1331impl Inputs {
1332 fn parse(ctx: &Context, mut inputs: Punctuated<FnArg, Token![,]>) -> Self {
1333 let mut input_idents = vec![];
1334 let mut actual_inputs = vec![];
1335 let mut injected_inputs = vec![];
1336 let mut injected_exprs = vec![];
1337 for i in &mut inputs {
1338 let syn::PatType {
1339 ref pat,
1340 ref mut attrs,
1341 ..
1342 } = match i {
1343 FnArg::Receiver(_) => {
1344 panic!("`self` receivers aren't supported in stored procedures")
1345 }
1346 FnArg::Typed(pat_ty) => pat_ty,
1347 };
1348 let mut inject_expr = None;
1349 attrs.retain(|attr| {
1350 let path = &attr.path;
1351 if path.is_ident("inject") {
1352 match attr.parse_args() {
1353 Ok(AttrInject { expr, .. }) => {
1354 inject_expr = Some(expr);
1355 false
1356 }
1357 Err(e) => panic!("attribute argument error: {}", e),
1358 }
1359 } else {
1360 !path.is_ident("doc")
1362 }
1363 });
1364 if let Some(expr) = inject_expr {
1365 injected_inputs.push(pat.clone());
1366 injected_exprs.push(expr);
1367 } else {
1368 actual_inputs.push(pat.clone());
1369 }
1370 input_idents.push((**pat).clone());
1371 }
1372
1373 let input_pattern = if inputs.is_empty() {
1374 quote! { []: [(); 0] }
1375 } else if ctx.is_packed {
1376 quote! { #(#actual_inputs)* }
1377 } else {
1378 quote! { ( #(#actual_inputs,)* ) }
1379 };
1380
1381 let inject_inputs = quote! {
1382 #( let #injected_inputs = #injected_exprs; )*
1383 };
1384
1385 Self {
1386 inputs,
1387 input_pattern,
1388 input_idents,
1389 inject_inputs,
1390 n_actual_arguments: actual_inputs.len(),
1391 }
1392 }
1393}
1394
1395#[derive(Debug)]
1396struct AttrInject {
1397 expr: syn::Expr,
1398}
1399
1400impl syn::parse::Parse for AttrInject {
1401 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1402 Ok(AttrInject {
1403 expr: input.parse()?,
1404 })
1405 }
1406}
1407
1408mod kw {
1409 syn::custom_keyword! {inject}
1410}
1411
1412mod imp {
1413 use proc_macro2::{Group, Span, TokenStream, TokenTree};
1414 use syn::parse::{self, Parse};
1415
1416 #[track_caller]
1417 pub(crate) fn parse_lit_str_with_key<T>(nm: &syn::NestedMeta, key: &str) -> Option<T>
1418 where
1419 T: Parse,
1420 {
1421 match nm {
1422 syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue {
1423 path, lit, ..
1424 })) if path.is_ident(key) => match &lit {
1425 syn::Lit::Str(s) => Some(crate::imp::parse_lit_str(s).unwrap()),
1426 _ => panic!("{key} value must be a string literal"),
1427 },
1428 _ => None,
1429 }
1430 }
1431
1432 #[track_caller]
1433 pub(crate) fn parse_bool_with_key(nm: &syn::NestedMeta, key: &str) -> Option<bool> {
1434 match nm {
1435 syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue {
1436 path, lit, ..
1437 })) if path.is_ident(key) => match &lit {
1438 syn::Lit::Bool(b) => Some(b.value),
1439 _ => panic!("value for attribute '{key}' must be a bool literal (true | false)"),
1440 },
1441 syn::NestedMeta::Meta(syn::Meta::Path(path)) if path.is_ident(key) => {
1442 panic!("expected ({key} = true|false), got just {key}");
1443 }
1444 _ => None,
1445 }
1446 }
1447
1448 #[track_caller]
1449 pub(crate) fn is_path_eq_to(nm: &syn::NestedMeta, expected: &str) -> bool {
1450 matches!(
1451 nm,
1452 syn::NestedMeta::Meta(syn::Meta::Path(path)) if path.is_ident(expected)
1453 )
1454 }
1455
1456 pub(crate) fn path_from_ts2(ts: TokenStream) -> syn::Path {
1457 syn::parse2(ts).unwrap()
1458 }
1459
1460 pub(crate) fn parse_lit_str<T>(s: &syn::LitStr) -> parse::Result<T>
1463 where
1464 T: Parse,
1465 {
1466 let tokens = spanned_tokens(s)?;
1467 syn::parse2(tokens)
1468 }
1469
1470 fn spanned_tokens(s: &syn::LitStr) -> parse::Result<TokenStream> {
1471 let stream = syn::parse_str(&s.value())?;
1472 Ok(respan(stream, s.span()))
1473 }
1474
1475 fn respan(stream: TokenStream, span: Span) -> TokenStream {
1476 stream
1477 .into_iter()
1478 .map(|token| respan_token(token, span))
1479 .collect()
1480 }
1481
1482 fn respan_token(mut token: TokenTree, span: Span) -> TokenTree {
1483 if let TokenTree::Group(g) = &mut token {
1484 *g = Group::new(g.delimiter(), respan(g.stream(), span));
1485 }
1486 token.set_span(span);
1487 token
1488 }
1489}