sg_prost_derive/
lib.rs

1#![doc(html_root_url = "https://docs.rs/prost-derive/0.10.2")]
2// The `quote!` macro requires deep recursion.
3#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::{
14    punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
15    FieldsUnnamed, Ident, Index, Variant,
16};
17
18mod field;
19use crate::field::Field;
20
21fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
22    let input: DeriveInput = syn::parse(input)?;
23
24    let ident = input.ident;
25
26    let variant_data = match input.data {
27        Data::Struct(variant_data) => variant_data,
28        Data::Enum(..) => bail!("Message can not be derived for an enum"),
29        Data::Union(..) => bail!("Message can not be derived for a union"),
30    };
31
32    let generics = &input.generics;
33    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
34
35    let (is_struct, fields) = match variant_data {
36        DataStruct {
37            fields: Fields::Named(FieldsNamed { named: fields, .. }),
38            ..
39        } => (true, fields.into_iter().collect()),
40        DataStruct {
41            fields:
42                Fields::Unnamed(FieldsUnnamed {
43                    unnamed: fields, ..
44                }),
45            ..
46        } => (false, fields.into_iter().collect()),
47        DataStruct {
48            fields: Fields::Unit,
49            ..
50        } => (false, Vec::new()),
51    };
52
53    let mut next_tag: u32 = 1;
54    let mut fields = fields
55        .into_iter()
56        .enumerate()
57        .flat_map(|(i, field)| {
58            let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
59                let index = Index {
60                    index: i as u32,
61                    span: Span::call_site(),
62                };
63                quote!(#index)
64            });
65            match Field::new(field.attrs, Some(next_tag)) {
66                Ok(Some(field)) => {
67                    next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
68                    Some(Ok((field_ident, field)))
69                }
70                Ok(None) => None,
71                Err(err) => Some(Err(
72                    err.context(format!("invalid message field {}.{}", ident, field_ident))
73                )),
74            }
75        })
76        .collect::<Result<Vec<_>, _>>()?;
77
78    // We want Debug to be in declaration order
79    let unsorted_fields = fields.clone();
80
81    // Sort the fields by tag number so that fields will be encoded in tag order.
82    // TODO: This encodes oneof fields in the position of their lowest tag,
83    // regardless of the currently occupied variant, is that consequential?
84    // See: https://developers.google.com/protocol-buffers/docs/encoding#order
85    fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
86    let fields = fields;
87
88    let mut tags = fields
89        .iter()
90        .flat_map(|&(_, ref field)| field.tags())
91        .collect::<Vec<_>>();
92    let num_tags = tags.len();
93    tags.sort_unstable();
94    tags.dedup();
95    if tags.len() != num_tags {
96        bail!("message {} has fields with duplicate tags", ident);
97    }
98
99    let encoded_len = fields
100        .iter()
101        .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident)));
102
103    let encode = fields
104        .iter()
105        .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));
106
107    let merge = fields.iter().map(|&(ref field_ident, ref field)| {
108        let merge = field.merge(quote!(value));
109        let tags = field.tags().into_iter().map(|tag| quote!(#tag));
110        let tags = Itertools::intersperse(tags, quote!(|));
111
112        quote! {
113            #(#tags)* => {
114                let mut value = &mut self.#field_ident;
115                #merge.map_err(|mut error| {
116                    error.push(STRUCT_NAME, stringify!(#field_ident));
117                    error
118                })
119            },
120        }
121    });
122
123    let struct_name = if fields.is_empty() {
124        quote!()
125    } else {
126        quote!(
127            const STRUCT_NAME: &'static str = stringify!(#ident);
128        )
129    };
130
131    let clear = fields
132        .iter()
133        .map(|&(ref field_ident, ref field)| field.clear(quote!(self.#field_ident)));
134
135    let default = if is_struct {
136        let default = fields.iter().map(|(field_ident, field)| {
137            let value = field.default();
138            quote!(#field_ident: #value,)
139        });
140        quote! {#ident {
141            #(#default)*
142        }}
143    } else {
144        let default = fields.iter().map(|(_, field)| {
145            let value = field.default();
146            quote!(#value,)
147        });
148        quote! {#ident (
149            #(#default)*
150        )}
151    };
152
153    let methods = fields
154        .iter()
155        .flat_map(|&(ref field_ident, ref field)| field.methods(field_ident))
156        .collect::<Vec<_>>();
157    let methods = if methods.is_empty() {
158        quote!()
159    } else {
160        quote! {
161            #[allow(dead_code)]
162            impl #impl_generics #ident #ty_generics #where_clause {
163                #(#methods)*
164            }
165        }
166    };
167
168    let debugs = unsorted_fields.iter().map(|&(ref field_ident, ref field)| {
169        let wrapper = field.debug(quote!(self.#field_ident));
170        let call = if is_struct {
171            quote!(builder.field(stringify!(#field_ident), &wrapper))
172        } else {
173            quote!(builder.field(&wrapper))
174        };
175        quote! {
176             let builder = {
177                 let wrapper = #wrapper;
178                 #call
179             };
180        }
181    });
182    let debug_builder = if is_struct {
183        quote!(f.debug_struct(stringify!(#ident)))
184    } else {
185        quote!(f.debug_tuple(stringify!(#ident)))
186    };
187
188    let expanded = quote! {
189        impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
190            #[allow(unused_variables)]
191            fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
192                #(#encode)*
193            }
194
195            #[allow(unused_variables)]
196            fn merge_field<B>(
197                &mut self,
198                tag: u32,
199                wire_type: ::prost::encoding::WireType,
200                buf: &mut B,
201                ctx: ::prost::encoding::DecodeContext,
202            ) -> ::core::result::Result<(), ::prost::DecodeError>
203            where B: ::prost::bytes::Buf {
204                #struct_name
205                match tag {
206                    #(#merge)*
207                    _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
208                }
209            }
210
211            #[inline]
212            fn encoded_len(&self) -> usize {
213                0 #(+ #encoded_len)*
214            }
215
216            fn clear(&mut self) {
217                #(#clear;)*
218            }
219        }
220
221        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
222            fn default() -> Self {
223                #default
224            }
225        }
226
227        impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
228            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
229                let mut builder = #debug_builder;
230                #(#debugs;)*
231                builder.finish()
232            }
233        }
234
235        #methods
236    };
237
238    Ok(expanded.into())
239}
240
241#[proc_macro_derive(Message, attributes(prost))]
242pub fn message(input: TokenStream) -> TokenStream {
243    try_message(input).unwrap()
244}
245
246fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
247    let input: DeriveInput = syn::parse(input)?;
248    let ident = input.ident;
249
250    let generics = &input.generics;
251    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
252
253    let punctuated_variants = match input.data {
254        Data::Enum(DataEnum { variants, .. }) => variants,
255        Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
256        Data::Union(..) => bail!("Enumeration can not be derived for a union"),
257    };
258
259    // Map the variants into 'fields'.
260    let mut variants: Vec<(Ident, Expr)> = Vec::new();
261    for Variant {
262        ident,
263        fields,
264        discriminant,
265        ..
266    } in punctuated_variants
267    {
268        match fields {
269            Fields::Unit => (),
270            Fields::Named(_) | Fields::Unnamed(_) => {
271                bail!("Enumeration variants may not have fields")
272            }
273        }
274
275        match discriminant {
276            Some((_, expr)) => variants.push((ident, expr)),
277            None => bail!("Enumeration variants must have a disriminant"),
278        }
279    }
280
281    if variants.is_empty() {
282        panic!("Enumeration must have at least one variant");
283    }
284
285    let default = variants[0].0.clone();
286
287    let is_valid = variants
288        .iter()
289        .map(|&(_, ref value)| quote!(#value => true));
290    let from = variants.iter().map(
291        |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)),
292    );
293
294    let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
295    let from_i32_doc = format!(
296        "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
297        ident
298    );
299
300    let expanded = quote! {
301        impl #impl_generics #ident #ty_generics #where_clause {
302            #[doc=#is_valid_doc]
303            pub fn is_valid(value: i32) -> bool {
304                match value {
305                    #(#is_valid,)*
306                    _ => false,
307                }
308            }
309
310            #[doc=#from_i32_doc]
311            pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
312                match value {
313                    #(#from,)*
314                    _ => ::core::option::Option::None,
315                }
316            }
317        }
318
319        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
320            fn default() -> #ident {
321                #ident::#default
322            }
323        }
324
325        impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
326            fn from(value: #ident) -> i32 {
327                value as i32
328            }
329        }
330    };
331
332    Ok(expanded.into())
333}
334
335#[proc_macro_derive(Enumeration, attributes(prost))]
336pub fn enumeration(input: TokenStream) -> TokenStream {
337    try_enumeration(input).unwrap()
338}
339
340fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
341    let input: DeriveInput = syn::parse(input)?;
342
343    let ident = input.ident;
344
345    let variants = match input.data {
346        Data::Enum(DataEnum { variants, .. }) => variants,
347        Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
348        Data::Union(..) => bail!("Oneof can not be derived for a union"),
349    };
350
351    let generics = &input.generics;
352    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
353
354    // Map the variants into 'fields'.
355    let mut fields: Vec<(Ident, Field)> = Vec::new();
356    for Variant {
357        attrs,
358        ident: variant_ident,
359        fields: variant_fields,
360        ..
361    } in variants
362    {
363        let variant_fields = match variant_fields {
364            Fields::Unit => Punctuated::new(),
365            Fields::Named(FieldsNamed { named: fields, .. })
366            | Fields::Unnamed(FieldsUnnamed {
367                unnamed: fields, ..
368            }) => fields,
369        };
370        if variant_fields.len() != 1 {
371            bail!("Oneof enum variants must have a single field");
372        }
373        match Field::new_oneof(attrs)? {
374            Some(field) => fields.push((variant_ident, field)),
375            None => bail!("invalid oneof variant: oneof variants may not be ignored"),
376        }
377    }
378
379    let mut tags = fields
380        .iter()
381        .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> {
382            if field.tags().len() > 1 {
383                bail!(
384                    "invalid oneof variant {}::{}: oneof variants may only have a single tag",
385                    ident,
386                    variant_ident
387                );
388            }
389            Ok(field.tags()[0])
390        })
391        .collect::<Vec<_>>();
392    tags.sort_unstable();
393    tags.dedup();
394    if tags.len() != fields.len() {
395        panic!("invalid oneof {}: variants have duplicate tags", ident);
396    }
397
398    let encode = fields.iter().map(|&(ref variant_ident, ref field)| {
399        let encode = field.encode(quote!(*value));
400        quote!(#ident::#variant_ident(ref value) => { #encode })
401    });
402
403    let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
404        let tag = field.tags()[0];
405        let merge = field.merge(quote!(value));
406        quote! {
407            #tag => {
408                match field {
409                    ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
410                        #merge
411                    },
412                    _ => {
413                        let mut owned_value = ::core::default::Default::default();
414                        let value = &mut owned_value;
415                        #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
416                    },
417                }
418            }
419        }
420    });
421
422    let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| {
423        let encoded_len = field.encoded_len(quote!(*value));
424        quote!(#ident::#variant_ident(ref value) => #encoded_len)
425    });
426
427    let debug = fields.iter().map(|&(ref variant_ident, ref field)| {
428        let wrapper = field.debug(quote!(*value));
429        quote!(#ident::#variant_ident(ref value) => {
430            let wrapper = #wrapper;
431            f.debug_tuple(stringify!(#variant_ident))
432                .field(&wrapper)
433                .finish()
434        })
435    });
436
437    let expanded = quote! {
438        impl #impl_generics #ident #ty_generics #where_clause {
439            /// Encodes the message to a buffer.
440            pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
441                match *self {
442                    #(#encode,)*
443                }
444            }
445
446            /// Decodes an instance of the message from a buffer, and merges it into self.
447            pub fn merge<B>(
448                field: &mut ::core::option::Option<#ident #ty_generics>,
449                tag: u32,
450                wire_type: ::prost::encoding::WireType,
451                buf: &mut B,
452                ctx: ::prost::encoding::DecodeContext,
453            ) -> ::core::result::Result<(), ::prost::DecodeError>
454            where B: ::prost::bytes::Buf {
455                match tag {
456                    #(#merge,)*
457                    _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
458                }
459            }
460
461            /// Returns the encoded length of the message without a length delimiter.
462            #[inline]
463            pub fn encoded_len(&self) -> usize {
464                match *self {
465                    #(#encoded_len,)*
466                }
467            }
468        }
469
470        impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
471            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
472                match *self {
473                    #(#debug,)*
474                }
475            }
476        }
477    };
478
479    Ok(expanded.into())
480}
481
482#[proc_macro_derive(Oneof, attributes(prost))]
483pub fn oneof(input: TokenStream) -> TokenStream {
484    try_oneof(input).unwrap()
485}