1use std::cmp::Ordering;
2use std::collections::HashSet;
3
4use proc_macro2::{Ident, Span, TokenStream};
5use quote::ToTokens;
6use syn::parse::{Parse, ParseStream};
7use syn::{
8 parenthesized, parse_macro_input, parse_quote, spanned::Spanned, Attribute, Data, DeriveInput,
9 Error, LitInt, Path, Token, Type, TypePath,
10};
11use syn::{Field, Fields, LitStr, Variant};
12
13#[macro_use]
14extern crate quote;
15
16mod read;
17mod sized;
18mod write;
19
20#[proc_macro_derive(SerryWrite, attributes(serry))]
21pub fn derive_write(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
22 let item = parse_macro_input!(item as DeriveInput);
23 match write::derive_write_impl(item) {
24 Ok(output) => output,
25 Err(e) => e.to_compile_error(),
26 }
27 .into()
28}
29
30#[proc_macro_derive(SerryRead, attributes(serry))]
31pub fn derive_read(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
32 let item = parse_macro_input!(item as DeriveInput);
33 match read::derive_read_impl(item) {
34 Ok(output) => output,
35 Err(e) => e.to_compile_error(),
36 }
37 .into()
38}
39
40#[proc_macro_derive(SerrySized, attributes(serry))]
41pub fn derive_sized(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
42 let item = parse_macro_input!(item as DeriveInput);
43 match sized::derive_sized_impl(item) {
44 Ok(output) => output,
45 Err(e) => e.to_compile_error(),
46 }
47 .into()
48}
49
50type ProcessedFields<'a> = Vec<(FieldName, &'a Field)>;
51fn process_fields(fields: &Fields, field_order: FieldOrder) -> Option<ProcessedFields> {
52 let fields: Vec<_> = match fields {
53 Fields::Unit => return None,
54 Fields::Named(named) => named.named.iter().collect(),
55 Fields::Unnamed(unnamed) => unnamed.unnamed.iter().collect(),
56 };
57
58 let mut vec: Vec<(FieldName, &Field)> = vec![];
59 for (i, field) in fields.into_iter().enumerate() {
60 vec.push((
61 match &field.ident {
62 Some(ident) => FieldName::Ident(ident.clone()),
63 None => FieldName::Index(LitInt::new(i.to_string().as_str(), Span::call_site())),
64 },
65 field,
66 ));
67 }
68
69 if field_order.do_sort() {
70 vec.sort_by(|a, b| field_order.cmp(a.1, b.1));
71 }
72
73 Some(vec)
74}
75
76fn create_pattern_match<'a, I>(iter: I, unnamed: bool) -> TokenStream
77where
78 I: Iterator<Item = &'a FieldName>,
79{
80 if unnamed {
81 let names = iter.map(FieldName::output_ident);
82 quote!((#(#names),*))
83 } else {
84 let names = iter.map(|name| {
85 let output = name.output_ident();
86 quote!(#name: #output)
87 });
88 quote!({ #(#names),* })
89 }
90}
91
92struct RootVersionInfo {
93 minimum_supported_version: usize,
94 current_version: usize,
95 version_type: Type,
96}
97
98#[derive(Default, Copy, Clone, Debug)]
99struct VersionRange {
100 since: usize,
101 until: Option<usize>,
102}
103
104impl Parse for VersionRange {
105 fn parse(input: ParseStream) -> syn::Result<Self> {
106 let since: LitInt = input.parse()?;
107 let since = since.base10_parse()?;
108
109 let until = if input.peek(Token![..]) {
110 let _ = input.parse::<Token![..]>();
111 let until: LitInt = input.parse()?;
112 Some(until.base10_parse()?)
113 } else {
114 None
115 };
116
117 Ok(Self { since, until })
118 }
119}
120
121struct SerryAttr<'a> {
122 version_info: Option<RootVersionInfo>,
123 version_range: Option<VersionRange>,
124 extrapolate: Option<Extrapolate>,
125 discriminant_value: Option<LitInt>,
126 discriminant_type: Option<TypePath>,
127 field_order: Option<FieldOrder>,
128 attr: Option<&'a Attribute>,
129}
130
131impl<'a> SerryAttr<'a> {
132 fn version_with_range_of<'all>(
133 &'all self,
134 field_attr: &'all SerryAttr,
135 ) -> syn::Result<Option<(&'all RootVersionInfo, &'all VersionRange)>> {
136 Ok(match (&self.version_info, &field_attr.version_range) {
137 (Some(info), Some(range)) => {
138 if let Some(until) = range.until {
139 if info.current_version > until && field_attr.extrapolate.is_none() {
140 return Err(Error::new(
141 field_attr.span(),
142 "extrapolate is required if version has upper limit",
143 ));
144 }
145 }
146 Some((info, range))
147 }
148 (None, Some(_)) => {
149 return Err(Error::new(
150 field_attr.span(),
151 "field has version range, but structure does not",
152 ));
153 }
154 (Some(_), None) => {
155 return Err(Error::new(
156 field_attr.span(),
157 "structure has versioning, but field does not",
158 ));
159 }
160 (None, None) => None,
161 })
162 }
163}
164
165#[derive(Copy, Clone)]
166enum FieldOrder {
167 Alphabetical,
168 AsSpecified,
169}
170
171impl Default for FieldOrder {
172 fn default() -> Self {
173 FieldOrder::Alphabetical
174 }
175}
176
177impl FieldOrder {
178 fn do_sort(&self) -> bool {
179 match self {
180 Self::AsSpecified => false,
181 _ => true,
182 }
183 }
184
185 fn cmp(&self, a: &Field, b: &Field) -> Ordering {
186 match self {
187 Self::AsSpecified => Ordering::Equal,
188 Self::Alphabetical => a.ident.cmp(&b.ident),
189 }
190 }
191}
192
193impl Parse for FieldOrder {
194 fn parse(input: ParseStream) -> syn::Result<Self> {
195 let str: LitStr = input.parse()?;
196 Ok(match str.value().to_lowercase().as_str() {
197 "alphabetical" => Self::Alphabetical,
198 "as_specified" => Self::AsSpecified,
199 _ => {
200 return Err(Error::new_spanned(
201 str,
202 "invalid field order - must be either 'alphabetical' or 'as_specified'",
203 ))
204 }
205 })
206 }
207}
208
209enum Extrapolate {
210 Default,
211 Function(Path),
212}
213
214impl<'a> ToTokens for SerryAttr<'a> {
215 fn to_tokens(&self, tokens: &mut TokenStream) {
216 self.attr.to_tokens(tokens)
217 }
218 fn to_token_stream(&self) -> TokenStream {
219 self.attr.to_token_stream()
220 }
221 fn into_token_stream(self) -> TokenStream
222 where
223 Self: Sized,
224 {
225 self.attr.into_token_stream()
226 }
227}
228
229impl<'a> Default for SerryAttr<'a> {
230 fn default() -> Self {
231 Self {
232 version_info: None,
233 version_range: None,
234 extrapolate: None,
235 discriminant_value: None,
236 discriminant_type: None,
237 field_order: None,
238 attr: None,
239 }
240 }
241}
242
243#[derive(Copy, Clone)]
244struct SerryAttrFields {
245 version: SerryAttrVersionField,
246 extrapolate: bool,
247 discriminate_by: bool,
248 discriminator: bool,
249 field_order: bool,
250}
251
252impl Default for SerryAttrFields {
253 fn default() -> Self {
254 Self {
255 version: SerryAttrVersionField::None,
256 extrapolate: false,
257 discriminate_by: false,
258 discriminator: false,
259 field_order: false,
260 }
261 }
262}
263
264impl SerryAttrFields {
265 pub fn struct_def() -> Self {
266 Self {
267 version: SerryAttrVersionField::Init,
268 field_order: true,
269 ..Self::default()
270 }
271 }
272 pub fn field() -> Self {
273 Self {
274 version: SerryAttrVersionField::Range,
275 extrapolate: true,
276 ..Self::default()
277 }
278 }
279 pub fn enum_def() -> Self {
280 Self {
281 discriminate_by: true,
283 ..Self::default()
284 }
285 }
286 pub fn enum_variant() -> Self {
287 Self {
288 version: SerryAttrVersionField::Init,
289 discriminator: true,
290 ..Self::default()
291 }
292 }
293}
294
295#[derive(Copy, Clone, Eq, PartialEq)]
296enum SerryAttrVersionField {
297 None,
298 Init,
299 Range,
300}
301
302fn parse_serry_attr(attr: &Attribute, fields: SerryAttrFields) -> Result<SerryAttr, Error> {
303 let mut version_info = None;
304 let mut version_range = None;
305 let mut extrapolate = None;
306 let mut discriminant_type = None;
307 let mut discriminant_value = None;
308 let mut field_order = None;
309 attr.parse_nested_meta(|meta| match &meta.path {
310 path if fields.version != SerryAttrVersionField::None
311 && version_range.is_none()
312 && version_info.is_none()
313 && path.is_ident("version") =>
314 {
315 match fields.version {
316 SerryAttrVersionField::None => panic!("Logical impossibility has occurred"),
317 SerryAttrVersionField::Init => {
318 let version_meta;
319 parenthesized!(version_meta in meta.input);
320 let value: VersionRange = version_meta.parse()?;
321
322 let ty = if version_meta.peek(Token![as]) {
323 version_meta.parse::<Token![as]>()?;
324 version_meta.parse()?
325 } else {
326 parse_quote!(u8)
327 };
328
329 let current_version = value.until.unwrap_or(value.since);
330 let minimum_supported_version = value.since;
331
332 version_info = Some(RootVersionInfo {
333 minimum_supported_version,
334 current_version,
335 version_type: ty,
336 });
337
338 Ok(())
339 }
340 SerryAttrVersionField::Range => {
341 version_range = Some(if meta.input.peek(Token![=]) {
342 let value = meta.value()?;
343 value.parse()?
344 } else {
345 let version_meta;
346 parenthesized!(version_meta in meta.input);
347 version_meta.parse()?
348 });
349 Ok(())
350 }
351 }
352 }
353 path if fields.extrapolate && extrapolate.is_none() && path.is_ident("extrapolate") => {
354 let value = meta.value()?;
355 extrapolate = Some(Extrapolate::Function(value.parse()?));
356 Ok(())
357 }
358 path if fields.extrapolate && extrapolate.is_none() && path.is_ident("default") => {
359 extrapolate = Some(Extrapolate::Default);
360 Ok(())
361 }
362 path if fields.discriminate_by
363 && discriminant_type.is_none()
364 && (path.is_ident("discriminate_by") || path.is_ident("repr")) =>
365 {
366 let value;
367 parenthesized!(value in meta.input);
368
369 let path: TypePath = value.parse()?;
370 discriminant_type = Some(path);
371
372 Ok(())
373 }
374 path if fields.discriminator
375 && discriminant_value.is_none()
376 && (path.is_ident("discriminant") || path.is_ident("repr")) =>
377 {
378 let value = meta.value()?;
379
380 let type_path = value.parse()?;
381 discriminant_value = Some(type_path);
382
383 Ok(())
384 }
385 path if fields.field_order && field_order.is_none() && path.is_ident("field_order") => {
386 let value = meta.value()?;
387 field_order = Some(value.parse()?);
388
389 Ok(())
390 }
391 other => {
392 return Err(meta.error(format_args!(
393 "unexpected attribute '{}'",
394 other.to_token_stream()
395 )));
396 }
397 })?;
398 Ok(SerryAttr {
399 version_info,
400 version_range,
401 extrapolate,
402 discriminant_type,
403 discriminant_value,
404 field_order,
405 attr: Some(attr),
406 })
407}
408
409fn find_and_parse_serry_attr(
410 attrs: &Vec<Attribute>,
411 fields: SerryAttrFields,
412) -> Result<SerryAttr, Error> {
413 let serry_attr: Vec<_> = attrs
414 .iter()
415 .filter(|v| v.path().is_ident("serry"))
416 .collect();
417 if serry_attr.len() > 1 {
418 return Err(Error::new(
423 attrs.first().map_or_else(Span::call_site, Attribute::span),
424 "more than one serry attribute",
425 ));
426 }
427 let serry_attr = serry_attr.into_iter().nth(0);
428 serry_attr
429 .map(|v| parse_serry_attr(v, fields))
430 .unwrap_or(Ok(SerryAttr::default()))
431}
432
433fn find_and_parse_serry_attr_auto<'a>(
434 attrs: &'a Vec<Attribute>,
435 type_data: &'_ Data,
436) -> Result<SerryAttr<'a>, Error> {
437 find_and_parse_serry_attr(
438 attrs,
439 match type_data {
440 Data::Struct(_) => SerryAttrFields::struct_def(),
441 Data::Enum(_) => SerryAttrFields::enum_def(),
442 _ => {
443 return Err(Error::new(
444 Span::call_site(),
445 "cannot derive for types other than structs and enums",
446 ));
447 }
448 },
449 )
450}
451
452fn default_discriminant_type() -> TypePath {
453 parse_quote!(u16)
454}
455
456struct AnnotatedVariant<'a> {
457 pub variant: &'a Variant,
458 pub attr: SerryAttr<'a>,
459 pub discriminant: usize,
460}
461
462fn enumerate_variants<'a, I>(variants: I) -> Result<Vec<AnnotatedVariant<'a>>, Error>
463where
464 I: Iterator<Item = &'a Variant>,
465{
466 let mut reserved_nums = HashSet::new();
467
468 let mut preprocessed = Vec::new();
469 for variant in variants {
470 let attr = find_and_parse_serry_attr(&variant.attrs, SerryAttrFields::enum_variant())?;
471
472 let discriminant = match &attr.discriminant_value {
473 Some(value) => {
474 let parsed_value: usize = value.base10_parse()?;
475 if reserved_nums.contains(&parsed_value) {
476 return Err(Error::new_spanned(
477 value,
478 "multiple variants can not have the same value",
479 ));
480 }
481 reserved_nums.insert(parsed_value);
482 Some(parsed_value)
483 }
484 None => None,
485 };
486
487 preprocessed.push((variant, attr, discriminant))
488 }
489
490 let mut vec = Vec::new();
491 let mut next = 0usize;
492
493 for (variant, attr, discriminant) in preprocessed {
494 let discriminant = if let Some(discriminant) = discriminant {
495 discriminant
496 } else {
497 let value = loop {
498 if reserved_nums.contains(&next) {
499 next += 1;
500 continue;
501 }
502 break next;
503 };
504 next += 1;
505 value
506 };
507
508 vec.push(AnnotatedVariant {
509 variant,
510 attr,
511 discriminant,
512 })
513 }
514
515 vec.sort_by_key(|v| v.discriminant);
516
517 Ok(vec)
518}
519
520enum FieldName {
521 Ident(Ident),
522 Index(LitInt),
523}
524impl ToTokens for FieldName {
525 fn to_tokens(&self, tokens: &mut TokenStream) {
526 match &self {
527 Self::Ident(ident) => ident.to_tokens(tokens),
528 Self::Index(index) => index.to_tokens(tokens),
529 }
530 }
531}
532impl FieldName {
533 fn output_ident(&self) -> Ident {
534 let name = match &self {
535 Self::Ident(ident) => ident.to_string(),
536 Self::Index(int) => int.to_string(),
537 };
538 Ident::new(
539 ["__field_", name.as_str()].join("").as_str(),
540 Span::call_site(),
541 )
542 }
543}