1#![doc(html_root_url = "https://docs.rs/prost-derive/0.10.2")]
2#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::{
14 punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
15 FieldsUnnamed, Ident, Index, Variant,
16};
17
18mod field;
19use crate::field::Field;
20
21fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
22 let input: DeriveInput = syn::parse(input)?;
23
24 let ident = input.ident;
25
26 let variant_data = match input.data {
27 Data::Struct(variant_data) => variant_data,
28 Data::Enum(..) => bail!("Message can not be derived for an enum"),
29 Data::Union(..) => bail!("Message can not be derived for a union"),
30 };
31
32 let generics = &input.generics;
33 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
34
35 let (is_struct, fields) = match variant_data {
36 DataStruct {
37 fields: Fields::Named(FieldsNamed { named: fields, .. }),
38 ..
39 } => (true, fields.into_iter().collect()),
40 DataStruct {
41 fields:
42 Fields::Unnamed(FieldsUnnamed {
43 unnamed: fields, ..
44 }),
45 ..
46 } => (false, fields.into_iter().collect()),
47 DataStruct {
48 fields: Fields::Unit,
49 ..
50 } => (false, Vec::new()),
51 };
52
53 let mut next_tag: u32 = 1;
54 let mut fields = fields
55 .into_iter()
56 .enumerate()
57 .flat_map(|(i, field)| {
58 let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
59 let index = Index {
60 index: i as u32,
61 span: Span::call_site(),
62 };
63 quote!(#index)
64 });
65 match Field::new(field.attrs, Some(next_tag)) {
66 Ok(Some(field)) => {
67 next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
68 Some(Ok((field_ident, field)))
69 }
70 Ok(None) => None,
71 Err(err) => Some(Err(
72 err.context(format!("invalid message field {}.{}", ident, field_ident))
73 )),
74 }
75 })
76 .collect::<Result<Vec<_>, _>>()?;
77
78 let unsorted_fields = fields.clone();
80
81 fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
86 let fields = fields;
87
88 let mut tags = fields
89 .iter()
90 .flat_map(|&(_, ref field)| field.tags())
91 .collect::<Vec<_>>();
92 let num_tags = tags.len();
93 tags.sort_unstable();
94 tags.dedup();
95 if tags.len() != num_tags {
96 bail!("message {} has fields with duplicate tags", ident);
97 }
98
99 let encoded_len = fields
100 .iter()
101 .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident)));
102
103 let encode = fields
104 .iter()
105 .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));
106
107 let merge = fields.iter().map(|&(ref field_ident, ref field)| {
108 let merge = field.merge(quote!(value));
109 let tags = field.tags().into_iter().map(|tag| quote!(#tag));
110 let tags = Itertools::intersperse(tags, quote!(|));
111
112 quote! {
113 #(#tags)* => {
114 let mut value = &mut self.#field_ident;
115 #merge.map_err(|mut error| {
116 error.push(STRUCT_NAME, stringify!(#field_ident));
117 error
118 })
119 },
120 }
121 });
122
123 let struct_name = if fields.is_empty() {
124 quote!()
125 } else {
126 quote!(
127 const STRUCT_NAME: &'static str = stringify!(#ident);
128 )
129 };
130
131 let clear = fields
132 .iter()
133 .map(|&(ref field_ident, ref field)| field.clear(quote!(self.#field_ident)));
134
135 let default = if is_struct {
136 let default = fields.iter().map(|(field_ident, field)| {
137 let value = field.default();
138 quote!(#field_ident: #value,)
139 });
140 quote! {#ident {
141 #(#default)*
142 }}
143 } else {
144 let default = fields.iter().map(|(_, field)| {
145 let value = field.default();
146 quote!(#value,)
147 });
148 quote! {#ident (
149 #(#default)*
150 )}
151 };
152
153 let methods = fields
154 .iter()
155 .flat_map(|&(ref field_ident, ref field)| field.methods(field_ident))
156 .collect::<Vec<_>>();
157 let methods = if methods.is_empty() {
158 quote!()
159 } else {
160 quote! {
161 #[allow(dead_code)]
162 impl #impl_generics #ident #ty_generics #where_clause {
163 #(#methods)*
164 }
165 }
166 };
167
168 let debugs = unsorted_fields.iter().map(|&(ref field_ident, ref field)| {
169 let wrapper = field.debug(quote!(self.#field_ident));
170 let call = if is_struct {
171 quote!(builder.field(stringify!(#field_ident), &wrapper))
172 } else {
173 quote!(builder.field(&wrapper))
174 };
175 quote! {
176 let builder = {
177 let wrapper = #wrapper;
178 #call
179 };
180 }
181 });
182 let debug_builder = if is_struct {
183 quote!(f.debug_struct(stringify!(#ident)))
184 } else {
185 quote!(f.debug_tuple(stringify!(#ident)))
186 };
187
188 let expanded = quote! {
189 impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
190 #[allow(unused_variables)]
191 fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
192 #(#encode)*
193 }
194
195 #[allow(unused_variables)]
196 fn merge_field<B>(
197 &mut self,
198 tag: u32,
199 wire_type: ::prost::encoding::WireType,
200 buf: &mut B,
201 ctx: ::prost::encoding::DecodeContext,
202 ) -> ::core::result::Result<(), ::prost::DecodeError>
203 where B: ::prost::bytes::Buf {
204 #struct_name
205 match tag {
206 #(#merge)*
207 _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
208 }
209 }
210
211 #[inline]
212 fn encoded_len(&self) -> usize {
213 0 #(+ #encoded_len)*
214 }
215
216 fn clear(&mut self) {
217 #(#clear;)*
218 }
219 }
220
221 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
222 fn default() -> Self {
223 #default
224 }
225 }
226
227 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
228 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
229 let mut builder = #debug_builder;
230 #(#debugs;)*
231 builder.finish()
232 }
233 }
234
235 #methods
236 };
237
238 Ok(expanded.into())
239}
240
241#[proc_macro_derive(Message, attributes(prost))]
242pub fn message(input: TokenStream) -> TokenStream {
243 try_message(input).unwrap()
244}
245
246fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
247 let input: DeriveInput = syn::parse(input)?;
248 let ident = input.ident;
249
250 let generics = &input.generics;
251 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
252
253 let punctuated_variants = match input.data {
254 Data::Enum(DataEnum { variants, .. }) => variants,
255 Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
256 Data::Union(..) => bail!("Enumeration can not be derived for a union"),
257 };
258
259 let mut variants: Vec<(Ident, Expr)> = Vec::new();
261 for Variant {
262 ident,
263 fields,
264 discriminant,
265 ..
266 } in punctuated_variants
267 {
268 match fields {
269 Fields::Unit => (),
270 Fields::Named(_) | Fields::Unnamed(_) => {
271 bail!("Enumeration variants may not have fields")
272 }
273 }
274
275 match discriminant {
276 Some((_, expr)) => variants.push((ident, expr)),
277 None => bail!("Enumeration variants must have a disriminant"),
278 }
279 }
280
281 if variants.is_empty() {
282 panic!("Enumeration must have at least one variant");
283 }
284
285 let default = variants[0].0.clone();
286
287 let is_valid = variants
288 .iter()
289 .map(|&(_, ref value)| quote!(#value => true));
290 let from = variants.iter().map(
291 |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)),
292 );
293
294 let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
295 let from_i32_doc = format!(
296 "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
297 ident
298 );
299
300 let expanded = quote! {
301 impl #impl_generics #ident #ty_generics #where_clause {
302 #[doc=#is_valid_doc]
303 pub fn is_valid(value: i32) -> bool {
304 match value {
305 #(#is_valid,)*
306 _ => false,
307 }
308 }
309
310 #[doc=#from_i32_doc]
311 pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
312 match value {
313 #(#from,)*
314 _ => ::core::option::Option::None,
315 }
316 }
317 }
318
319 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
320 fn default() -> #ident {
321 #ident::#default
322 }
323 }
324
325 impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
326 fn from(value: #ident) -> i32 {
327 value as i32
328 }
329 }
330 };
331
332 Ok(expanded.into())
333}
334
335#[proc_macro_derive(Enumeration, attributes(prost))]
336pub fn enumeration(input: TokenStream) -> TokenStream {
337 try_enumeration(input).unwrap()
338}
339
340fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
341 let input: DeriveInput = syn::parse(input)?;
342
343 let ident = input.ident;
344
345 let variants = match input.data {
346 Data::Enum(DataEnum { variants, .. }) => variants,
347 Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
348 Data::Union(..) => bail!("Oneof can not be derived for a union"),
349 };
350
351 let generics = &input.generics;
352 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
353
354 let mut fields: Vec<(Ident, Field)> = Vec::new();
356 for Variant {
357 attrs,
358 ident: variant_ident,
359 fields: variant_fields,
360 ..
361 } in variants
362 {
363 let variant_fields = match variant_fields {
364 Fields::Unit => Punctuated::new(),
365 Fields::Named(FieldsNamed { named: fields, .. })
366 | Fields::Unnamed(FieldsUnnamed {
367 unnamed: fields, ..
368 }) => fields,
369 };
370 if variant_fields.len() != 1 {
371 bail!("Oneof enum variants must have a single field");
372 }
373 match Field::new_oneof(attrs)? {
374 Some(field) => fields.push((variant_ident, field)),
375 None => bail!("invalid oneof variant: oneof variants may not be ignored"),
376 }
377 }
378
379 let mut tags = fields
380 .iter()
381 .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> {
382 if field.tags().len() > 1 {
383 bail!(
384 "invalid oneof variant {}::{}: oneof variants may only have a single tag",
385 ident,
386 variant_ident
387 );
388 }
389 Ok(field.tags()[0])
390 })
391 .collect::<Vec<_>>();
392 tags.sort_unstable();
393 tags.dedup();
394 if tags.len() != fields.len() {
395 panic!("invalid oneof {}: variants have duplicate tags", ident);
396 }
397
398 let encode = fields.iter().map(|&(ref variant_ident, ref field)| {
399 let encode = field.encode(quote!(*value));
400 quote!(#ident::#variant_ident(ref value) => { #encode })
401 });
402
403 let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
404 let tag = field.tags()[0];
405 let merge = field.merge(quote!(value));
406 quote! {
407 #tag => {
408 match field {
409 ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
410 #merge
411 },
412 _ => {
413 let mut owned_value = ::core::default::Default::default();
414 let value = &mut owned_value;
415 #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
416 },
417 }
418 }
419 }
420 });
421
422 let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| {
423 let encoded_len = field.encoded_len(quote!(*value));
424 quote!(#ident::#variant_ident(ref value) => #encoded_len)
425 });
426
427 let debug = fields.iter().map(|&(ref variant_ident, ref field)| {
428 let wrapper = field.debug(quote!(*value));
429 quote!(#ident::#variant_ident(ref value) => {
430 let wrapper = #wrapper;
431 f.debug_tuple(stringify!(#variant_ident))
432 .field(&wrapper)
433 .finish()
434 })
435 });
436
437 let expanded = quote! {
438 impl #impl_generics #ident #ty_generics #where_clause {
439 pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
441 match *self {
442 #(#encode,)*
443 }
444 }
445
446 pub fn merge<B>(
448 field: &mut ::core::option::Option<#ident #ty_generics>,
449 tag: u32,
450 wire_type: ::prost::encoding::WireType,
451 buf: &mut B,
452 ctx: ::prost::encoding::DecodeContext,
453 ) -> ::core::result::Result<(), ::prost::DecodeError>
454 where B: ::prost::bytes::Buf {
455 match tag {
456 #(#merge,)*
457 _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
458 }
459 }
460
461 #[inline]
463 pub fn encoded_len(&self) -> usize {
464 match *self {
465 #(#encoded_len,)*
466 }
467 }
468 }
469
470 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
471 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
472 match *self {
473 #(#debug,)*
474 }
475 }
476 }
477 };
478
479 Ok(expanded.into())
480}
481
482#[proc_macro_derive(Oneof, attributes(prost))]
483pub fn oneof(input: TokenStream) -> TokenStream {
484 try_oneof(input).unwrap()
485}