serry_derive/
lib.rs

1use std::cmp::Ordering;
2use std::collections::HashSet;
3
4use proc_macro2::{Ident, Span, TokenStream};
5use quote::ToTokens;
6use syn::parse::{Parse, ParseStream};
7use syn::{
8    parenthesized, parse_macro_input, parse_quote, spanned::Spanned, Attribute, Data, DeriveInput,
9    Error, LitInt, Path, Token, Type, TypePath,
10};
11use syn::{Field, Fields, LitStr, Variant};
12
13#[macro_use]
14extern crate quote;
15
16mod read;
17mod sized;
18mod write;
19
20#[proc_macro_derive(SerryWrite, attributes(serry))]
21pub fn derive_write(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
22    let item = parse_macro_input!(item as DeriveInput);
23    match write::derive_write_impl(item) {
24        Ok(output) => output,
25        Err(e) => e.to_compile_error(),
26    }
27    .into()
28}
29
30#[proc_macro_derive(SerryRead, attributes(serry))]
31pub fn derive_read(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
32    let item = parse_macro_input!(item as DeriveInput);
33    match read::derive_read_impl(item) {
34        Ok(output) => output,
35        Err(e) => e.to_compile_error(),
36    }
37    .into()
38}
39
40#[proc_macro_derive(SerrySized, attributes(serry))]
41pub fn derive_sized(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
42    let item = parse_macro_input!(item as DeriveInput);
43    match sized::derive_sized_impl(item) {
44        Ok(output) => output,
45        Err(e) => e.to_compile_error(),
46    }
47    .into()
48}
49
50type ProcessedFields<'a> = Vec<(FieldName, &'a Field)>;
51fn process_fields(fields: &Fields, field_order: FieldOrder) -> Option<ProcessedFields> {
52    let fields: Vec<_> = match fields {
53        Fields::Unit => return None,
54        Fields::Named(named) => named.named.iter().collect(),
55        Fields::Unnamed(unnamed) => unnamed.unnamed.iter().collect(),
56    };
57
58    let mut vec: Vec<(FieldName, &Field)> = vec![];
59    for (i, field) in fields.into_iter().enumerate() {
60        vec.push((
61            match &field.ident {
62                Some(ident) => FieldName::Ident(ident.clone()),
63                None => FieldName::Index(LitInt::new(i.to_string().as_str(), Span::call_site())),
64            },
65            field,
66        ));
67    }
68
69    if field_order.do_sort() {
70        vec.sort_by(|a, b| field_order.cmp(a.1, b.1));
71    }
72
73    Some(vec)
74}
75
76fn create_pattern_match<'a, I>(iter: I, unnamed: bool) -> TokenStream
77where
78    I: Iterator<Item = &'a FieldName>,
79{
80    if unnamed {
81        let names = iter.map(FieldName::output_ident);
82        quote!((#(#names),*))
83    } else {
84        let names = iter.map(|name| {
85            let output = name.output_ident();
86            quote!(#name: #output)
87        });
88        quote!({ #(#names),* })
89    }
90}
91
92struct RootVersionInfo {
93    minimum_supported_version: usize,
94    current_version: usize,
95    version_type: Type,
96}
97
98#[derive(Default, Copy, Clone, Debug)]
99struct VersionRange {
100    since: usize,
101    until: Option<usize>,
102}
103
104impl Parse for VersionRange {
105    fn parse(input: ParseStream) -> syn::Result<Self> {
106        let since: LitInt = input.parse()?;
107        let since = since.base10_parse()?;
108
109        let until = if input.peek(Token![..]) {
110            let _ = input.parse::<Token![..]>();
111            let until: LitInt = input.parse()?;
112            Some(until.base10_parse()?)
113        } else {
114            None
115        };
116
117        Ok(Self { since, until })
118    }
119}
120
121struct SerryAttr<'a> {
122    version_info: Option<RootVersionInfo>,
123    version_range: Option<VersionRange>,
124    extrapolate: Option<Extrapolate>,
125    discriminant_value: Option<LitInt>,
126    discriminant_type: Option<TypePath>,
127    field_order: Option<FieldOrder>,
128    attr: Option<&'a Attribute>,
129}
130
131impl<'a> SerryAttr<'a> {
132    fn version_with_range_of<'all>(
133        &'all self,
134        field_attr: &'all SerryAttr,
135    ) -> syn::Result<Option<(&'all RootVersionInfo, &'all VersionRange)>> {
136        Ok(match (&self.version_info, &field_attr.version_range) {
137            (Some(info), Some(range)) => {
138                if let Some(until) = range.until {
139                    if info.current_version > until && field_attr.extrapolate.is_none() {
140                        return Err(Error::new(
141                            field_attr.span(),
142                            "extrapolate is required if version has upper limit",
143                        ));
144                    }
145                }
146                Some((info, range))
147            }
148            (None, Some(_)) => {
149                return Err(Error::new(
150                    field_attr.span(),
151                    "field has version range, but structure does not",
152                ));
153            }
154            (Some(_), None) => {
155                return Err(Error::new(
156                    field_attr.span(),
157                    "structure has versioning, but field does not",
158                ));
159            }
160            (None, None) => None,
161        })
162    }
163}
164
165#[derive(Copy, Clone)]
166enum FieldOrder {
167    Alphabetical,
168    AsSpecified,
169}
170
171impl Default for FieldOrder {
172    fn default() -> Self {
173        FieldOrder::Alphabetical
174    }
175}
176
177impl FieldOrder {
178    fn do_sort(&self) -> bool {
179        match self {
180            Self::AsSpecified => false,
181            _ => true,
182        }
183    }
184
185    fn cmp(&self, a: &Field, b: &Field) -> Ordering {
186        match self {
187            Self::AsSpecified => Ordering::Equal,
188            Self::Alphabetical => a.ident.cmp(&b.ident),
189        }
190    }
191}
192
193impl Parse for FieldOrder {
194    fn parse(input: ParseStream) -> syn::Result<Self> {
195        let str: LitStr = input.parse()?;
196        Ok(match str.value().to_lowercase().as_str() {
197            "alphabetical" => Self::Alphabetical,
198            "as_specified" => Self::AsSpecified,
199            _ => {
200                return Err(Error::new_spanned(
201                    str,
202                    "invalid field order - must be either 'alphabetical' or 'as_specified'",
203                ))
204            }
205        })
206    }
207}
208
209enum Extrapolate {
210    Default,
211    Function(Path),
212}
213
214impl<'a> ToTokens for SerryAttr<'a> {
215    fn to_tokens(&self, tokens: &mut TokenStream) {
216        self.attr.to_tokens(tokens)
217    }
218    fn to_token_stream(&self) -> TokenStream {
219        self.attr.to_token_stream()
220    }
221    fn into_token_stream(self) -> TokenStream
222    where
223        Self: Sized,
224    {
225        self.attr.into_token_stream()
226    }
227}
228
229impl<'a> Default for SerryAttr<'a> {
230    fn default() -> Self {
231        Self {
232            version_info: None,
233            version_range: None,
234            extrapolate: None,
235            discriminant_value: None,
236            discriminant_type: None,
237            field_order: None,
238            attr: None,
239        }
240    }
241}
242
243#[derive(Copy, Clone)]
244struct SerryAttrFields {
245    version: SerryAttrVersionField,
246    extrapolate: bool,
247    discriminate_by: bool,
248    discriminator: bool,
249    field_order: bool,
250}
251
252impl Default for SerryAttrFields {
253    fn default() -> Self {
254        Self {
255            version: SerryAttrVersionField::None,
256            extrapolate: false,
257            discriminate_by: false,
258            discriminator: false,
259            field_order: false,
260        }
261    }
262}
263
264impl SerryAttrFields {
265    pub fn struct_def() -> Self {
266        Self {
267            version: SerryAttrVersionField::Init,
268            field_order: true,
269            ..Self::default()
270        }
271    }
272    pub fn field() -> Self {
273        Self {
274            version: SerryAttrVersionField::Range,
275            extrapolate: true,
276            ..Self::default()
277        }
278    }
279    pub fn enum_def() -> Self {
280        Self {
281            // TODO: Right now enums don't support versioning in favour of being able to version each variant separately.
282            discriminate_by: true,
283            ..Self::default()
284        }
285    }
286    pub fn enum_variant() -> Self {
287        Self {
288            version: SerryAttrVersionField::Init,
289            discriminator: true,
290            ..Self::default()
291        }
292    }
293}
294
295#[derive(Copy, Clone, Eq, PartialEq)]
296enum SerryAttrVersionField {
297    None,
298    Init,
299    Range,
300}
301
302fn parse_serry_attr(attr: &Attribute, fields: SerryAttrFields) -> Result<SerryAttr, Error> {
303    let mut version_info = None;
304    let mut version_range = None;
305    let mut extrapolate = None;
306    let mut discriminant_type = None;
307    let mut discriminant_value = None;
308    let mut field_order = None;
309    attr.parse_nested_meta(|meta| match &meta.path {
310        path if fields.version != SerryAttrVersionField::None
311            && version_range.is_none()
312            && version_info.is_none()
313            && path.is_ident("version") =>
314        {
315            match fields.version {
316                SerryAttrVersionField::None => panic!("Logical impossibility has occurred"),
317                SerryAttrVersionField::Init => {
318                    let version_meta;
319                    parenthesized!(version_meta in meta.input);
320                    let value: VersionRange = version_meta.parse()?;
321
322                    let ty = if version_meta.peek(Token![as]) {
323                        version_meta.parse::<Token![as]>()?;
324                        version_meta.parse()?
325                    } else {
326                        parse_quote!(u8)
327                    };
328
329                    let current_version = value.until.unwrap_or(value.since);
330                    let minimum_supported_version = value.since;
331
332                    version_info = Some(RootVersionInfo {
333                        minimum_supported_version,
334                        current_version,
335                        version_type: ty,
336                    });
337
338                    Ok(())
339                }
340                SerryAttrVersionField::Range => {
341                    version_range = Some(if meta.input.peek(Token![=]) {
342                        let value = meta.value()?;
343                        value.parse()?
344                    } else {
345                        let version_meta;
346                        parenthesized!(version_meta in meta.input);
347                        version_meta.parse()?
348                    });
349                    Ok(())
350                }
351            }
352        }
353        path if fields.extrapolate && extrapolate.is_none() && path.is_ident("extrapolate") => {
354            let value = meta.value()?;
355            extrapolate = Some(Extrapolate::Function(value.parse()?));
356            Ok(())
357        }
358        path if fields.extrapolate && extrapolate.is_none() && path.is_ident("default") => {
359            extrapolate = Some(Extrapolate::Default);
360            Ok(())
361        }
362        path if fields.discriminate_by
363            && discriminant_type.is_none()
364            && (path.is_ident("discriminate_by") || path.is_ident("repr")) =>
365        {
366            let value;
367            parenthesized!(value in meta.input);
368
369            let path: TypePath = value.parse()?;
370            discriminant_type = Some(path);
371
372            Ok(())
373        }
374        path if fields.discriminator
375            && discriminant_value.is_none()
376            && (path.is_ident("discriminant") || path.is_ident("repr")) =>
377        {
378            let value = meta.value()?;
379
380            let type_path = value.parse()?;
381            discriminant_value = Some(type_path);
382
383            Ok(())
384        }
385        path if fields.field_order && field_order.is_none() && path.is_ident("field_order") => {
386            let value = meta.value()?;
387            field_order = Some(value.parse()?);
388
389            Ok(())
390        }
391        other => {
392            return Err(meta.error(format_args!(
393                "unexpected attribute '{}'",
394                other.to_token_stream()
395            )));
396        }
397    })?;
398    Ok(SerryAttr {
399        version_info,
400        version_range,
401        extrapolate,
402        discriminant_type,
403        discriminant_value,
404        field_order,
405        attr: Some(attr),
406    })
407}
408
409fn find_and_parse_serry_attr(
410    attrs: &Vec<Attribute>,
411    fields: SerryAttrFields,
412) -> Result<SerryAttr, Error> {
413    let serry_attr: Vec<_> = attrs
414        .iter()
415        .filter(|v| v.path().is_ident("serry"))
416        .collect();
417    if serry_attr.len() > 1 {
418        /*for i in 1..serry_attr.len() {
419            let attr = serry_attr[i];
420            errors.extend(quote_spanned!(attr.span() => compile_error!("Only one Serry attribute per item")));
421        }*/
422        return Err(Error::new(
423            attrs.first().map_or_else(Span::call_site, Attribute::span),
424            "more than one serry attribute",
425        ));
426    }
427    let serry_attr = serry_attr.into_iter().nth(0);
428    serry_attr
429        .map(|v| parse_serry_attr(v, fields))
430        .unwrap_or(Ok(SerryAttr::default()))
431}
432
433fn find_and_parse_serry_attr_auto<'a>(
434    attrs: &'a Vec<Attribute>,
435    type_data: &'_ Data,
436) -> Result<SerryAttr<'a>, Error> {
437    find_and_parse_serry_attr(
438        attrs,
439        match type_data {
440            Data::Struct(_) => SerryAttrFields::struct_def(),
441            Data::Enum(_) => SerryAttrFields::enum_def(),
442            _ => {
443                return Err(Error::new(
444                    Span::call_site(),
445                    "cannot derive for types other than structs and enums",
446                ));
447            }
448        },
449    )
450}
451
452fn default_discriminant_type() -> TypePath {
453    parse_quote!(u16)
454}
455
456struct AnnotatedVariant<'a> {
457    pub variant: &'a Variant,
458    pub attr: SerryAttr<'a>,
459    pub discriminant: usize,
460}
461
462fn enumerate_variants<'a, I>(variants: I) -> Result<Vec<AnnotatedVariant<'a>>, Error>
463where
464    I: Iterator<Item = &'a Variant>,
465{
466    let mut reserved_nums = HashSet::new();
467
468    let mut preprocessed = Vec::new();
469    for variant in variants {
470        let attr = find_and_parse_serry_attr(&variant.attrs, SerryAttrFields::enum_variant())?;
471
472        let discriminant = match &attr.discriminant_value {
473            Some(value) => {
474                let parsed_value: usize = value.base10_parse()?;
475                if reserved_nums.contains(&parsed_value) {
476                    return Err(Error::new_spanned(
477                        value,
478                        "multiple variants can not have the same value",
479                    ));
480                }
481                reserved_nums.insert(parsed_value);
482                Some(parsed_value)
483            }
484            None => None,
485        };
486
487        preprocessed.push((variant, attr, discriminant))
488    }
489
490    let mut vec = Vec::new();
491    let mut next = 0usize;
492
493    for (variant, attr, discriminant) in preprocessed {
494        let discriminant = if let Some(discriminant) = discriminant {
495            discriminant
496        } else {
497            let value = loop {
498                if reserved_nums.contains(&next) {
499                    next += 1;
500                    continue;
501                }
502                break next;
503            };
504            next += 1;
505            value
506        };
507
508        vec.push(AnnotatedVariant {
509            variant,
510            attr,
511            discriminant,
512        })
513    }
514
515    vec.sort_by_key(|v| v.discriminant);
516
517    Ok(vec)
518}
519
520enum FieldName {
521    Ident(Ident),
522    Index(LitInt),
523}
524impl ToTokens for FieldName {
525    fn to_tokens(&self, tokens: &mut TokenStream) {
526        match &self {
527            Self::Ident(ident) => ident.to_tokens(tokens),
528            Self::Index(index) => index.to_tokens(tokens),
529        }
530    }
531}
532impl FieldName {
533    fn output_ident(&self) -> Ident {
534        let name = match &self {
535            Self::Ident(ident) => ident.to_string(),
536            Self::Index(int) => int.to_string(),
537        };
538        Ident::new(
539            ["__field_", name.as_str()].join("").as_str(),
540            Span::call_site(),
541        )
542    }
543}