variants_data_struct/
lib.rs

1#[doc = include_str!("../README.md")]
2use convert_case::Casing as _;
3
4use convert_case::Case;
5use proc_macro::TokenStream;
6use syn::{Token, punctuated::Punctuated};
7
8struct VariantsDataStructAttrMeta {
9    struct_name: Option<syn::Ident>,
10    attrs: Vec<syn::Attribute>,
11    variants_tys_attrs: Vec<syn::Attribute>,
12}
13
14impl VariantsDataStructAttrMeta {
15    fn from_attrs(attrs: Vec<syn::Attribute>) -> syn::Result<Option<Self>> {
16        let variants_data_struct_attr: syn::Attribute = match attrs
17            .into_iter()
18            .find(|attr| attr.path().is_ident("variants_data_struct"))
19        {
20            Some(attr) => attr,
21            None => return Ok(None),
22        };
23
24        let variants_data_struct_attr_meta = variants_data_struct_attr.parse_args()?;
25        Ok(Some(variants_data_struct_attr_meta))
26    }
27}
28
29impl syn::parse::Parse for VariantsDataStructAttrMeta {
30    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
31        let mut struct_name: Option<syn::Ident> = None;
32        let mut attrs: Vec<syn::Attribute> = vec![];
33        let mut variants_tys_attrs: Vec<syn::Attribute> = vec![];
34
35        while !input.is_empty() {
36            let lookahead = input.lookahead1();
37            if !lookahead.peek(syn::Ident) {
38                return Err(lookahead.error());
39            }
40            let ident: syn::Ident = input.parse()?;
41
42            if ident == "name" {
43                let _: syn::Token![=] = input.parse()?;
44                let name: syn::Ident = input.parse()?;
45                struct_name = Some(name);
46            } else if ident == "attrs" {
47                let content;
48                let _paren_token = syn::parenthesized!(content in input);
49                attrs = content.call(syn::Attribute::parse_outer)?;
50            } else if ident == "variants_tys_attrs" {
51                let content;
52                let _paren_token = syn::parenthesized!(content in input);
53                variants_tys_attrs = content.call(syn::Attribute::parse_outer)?;
54            } else {
55                return Err(syn::Error::new_spanned(
56                    ident,
57                    "unexpected identifier in variants_data_struct attribute",
58                ));
59            }
60            let lookahead = input.lookahead1();
61            if lookahead.peek(syn::Token![,]) {
62                let _: syn::Token![,] = input.parse()?;
63            }
64        }
65
66        Ok(VariantsDataStructAttrMeta {
67            variants_tys_attrs,
68            struct_name,
69            attrs,
70        })
71    }
72}
73
74struct VariantsDataStructDefs {
75    derived_struct: syn::ItemStruct,
76    variant_type_structs: Vec<syn::ItemStruct>,
77}
78
79enum FieldType {
80    Unit(syn::TypeTuple),
81    TupleStructType {
82        def: syn::ItemStruct,
83        ty: syn::TypePath,
84    },
85    ProperStructType {
86        def: syn::ItemStruct,
87        ty: syn::TypePath,
88    },
89}
90
91fn unit_type() -> syn::TypeTuple {
92    let span: proc_macro2::extra::DelimSpan = {
93        let group = proc_macro2::Group::new(
94            proc_macro2::Delimiter::Brace,
95            proc_macro2::TokenStream::new(),
96        );
97        group.delim_span()
98    };
99    syn::TypeTuple {
100        paren_token: syn::token::Paren { span },
101        elems: Punctuated::new(),
102    }
103}
104
105fn field_type(
106    variant: syn::Variant,
107    _enum_generics: &syn::Generics,
108    // The visibility for the generated variant type struct and its fields, if they are going to be defined
109    vis: syn::Visibility,
110    variants_tys_attrs: &[syn::Attribute],
111) -> FieldType {
112    let syn::Variant {
113        attrs: _,
114        ident,
115        fields,
116        discriminant: _,
117    } = variant;
118    match fields {
119        syn::Fields::Unit => FieldType::Unit(unit_type()),
120        syn::Fields::Unnamed(fields_unnamed) => {
121            let syn::FieldsUnnamed {
122                paren_token,
123                mut unnamed,
124            } = fields_unnamed;
125
126            if unnamed.len() == 0 {
127                return FieldType::Unit(unit_type());
128            }
129
130            unnamed.pairs_mut().for_each(|pair| {
131                let field = pair.into_tuple().0;
132                field.vis = vis.clone();
133            });
134
135            let variant_struct_ident =
136                syn::Ident::new(format!("{ident}VariantType").as_str(), ident.span());
137
138            // TODO: handle generics properly
139            let variant_type_generics = syn::Generics::default();
140
141            let item_struct = syn::ItemStruct {
142                attrs: Vec::from(variants_tys_attrs),
143                vis,
144                struct_token: syn::token::Struct { span: ident.span() },
145                ident: variant_struct_ident.clone(),
146                generics: variant_type_generics,
147                fields: syn::Fields::Unnamed(syn::FieldsUnnamed {
148                    paren_token,
149                    unnamed,
150                }),
151                semi_token: None,
152            };
153
154            let type_path = syn::TypePath {
155                qself: None,
156                path: syn::Path::from(variant_struct_ident),
157            };
158
159            FieldType::TupleStructType {
160                def: item_struct,
161                ty: type_path,
162            }
163        }
164        syn::Fields::Named(field_named) => {
165            let syn::FieldsNamed {
166                brace_token,
167                mut named,
168            } = field_named;
169
170            if named.len() == 0 {
171                return FieldType::Unit(unit_type());
172            }
173
174            named.pairs_mut().for_each(|pair| {
175                let field = pair.into_tuple().0;
176                field.vis = vis.clone();
177            });
178
179            let variant_struct_ident =
180                syn::Ident::new(format!("{ident}VariantType").as_str(), ident.span());
181
182            // TODO: handle generics properly
183            let variant_type_generics = syn::Generics::default();
184
185            let item_struct = syn::ItemStruct {
186                attrs: Vec::from(variants_tys_attrs),
187                vis,
188                struct_token: syn::token::Struct { span: ident.span() },
189                ident: variant_struct_ident.clone(),
190                generics: variant_type_generics,
191                fields: syn::Fields::Named(syn::FieldsNamed { brace_token, named }),
192                semi_token: None,
193            };
194
195            let type_path = syn::TypePath {
196                qself: None,
197                path: syn::Path::from(variant_struct_ident),
198            };
199
200            FieldType::ProperStructType {
201                def: item_struct,
202                ty: type_path,
203            }
204        }
205    }
206}
207
208fn variants_data_struct_defs(
209    attrs: Vec<syn::Attribute>,
210    variants_tys_attrs: Vec<syn::Attribute>,
211    vis: syn::Visibility,
212    struct_name: syn::Ident,
213    generics: syn::Generics,
214    variants: Punctuated<syn::Variant, Token![,]>,
215) -> VariantsDataStructDefs {
216    let fields = variants.into_iter().map(|variant| {
217        let field_ident = syn::Ident::new(
218            variant
219                .ident
220                .to_string()
221                .from_case(Case::Pascal)
222                .to_case(Case::Snake)
223                .as_str(),
224            variant.ident.span(),
225        );
226        let field_type: FieldType =
227            field_type(variant, &generics, vis.clone(), &variants_tys_attrs);
228        (field_ident, field_type)
229    });
230
231    let mut variant_type_structs: Vec<syn::ItemStruct> = vec![];
232    let mut struct_fields: Vec<syn::Field> = vec![];
233
234    for (field_ident, field_type) in fields {
235        match field_type {
236            FieldType::Unit(unit_type) => {
237                struct_fields.push(syn::Field {
238                    attrs: vec![],
239                    mutability: syn::FieldMutability::None,
240                    vis: vis.clone(),
241                    ident: Some(field_ident),
242                    colon_token: Some(syn::token::Colon {
243                        spans: [proc_macro2::Span::call_site()],
244                    }),
245                    ty: syn::Type::Tuple(unit_type),
246                });
247            }
248            FieldType::TupleStructType { def, ty } => {
249                variant_type_structs.push(def);
250                struct_fields.push(syn::Field {
251                    attrs: vec![],
252                    mutability: syn::FieldMutability::None,
253                    vis: vis.clone(),
254                    ident: Some(field_ident),
255                    colon_token: Some(syn::token::Colon {
256                        spans: [proc_macro2::Span::call_site()],
257                    }),
258                    ty: syn::Type::Path(ty),
259                });
260            }
261            FieldType::ProperStructType { def, ty } => {
262                variant_type_structs.push(def);
263                struct_fields.push(syn::Field {
264                    attrs: vec![],
265                    mutability: syn::FieldMutability::None,
266                    vis: vis.clone(),
267                    ident: Some(field_ident),
268                    colon_token: Some(syn::token::Colon {
269                        spans: [proc_macro2::Span::call_site()],
270                    }),
271                    ty: syn::Type::Path(ty),
272                });
273            }
274        }
275    }
276
277    let delim_span: proc_macro2::extra::DelimSpan = {
278        let group = proc_macro2::Group::new(
279            proc_macro2::Delimiter::Brace,
280            proc_macro2::TokenStream::new(),
281        );
282        group.delim_span()
283    };
284
285    let derived_struct = syn::ItemStruct {
286        attrs,
287        vis,
288        struct_token: syn::token::Struct {
289            span: struct_name.span(),
290        },
291        ident: struct_name,
292        generics,
293        fields: syn::Fields::Named(syn::FieldsNamed {
294            brace_token: syn::token::Brace { span: delim_span },
295            named: Punctuated::from_iter(struct_fields),
296        }),
297        semi_token: None,
298    };
299
300    VariantsDataStructDefs {
301        derived_struct,
302        variant_type_structs,
303    }
304}
305
306/// Derive macro to generate a data struct containing fields for each variant of the enum.
307///
308/// ```rust
309/// use variants_data_struct::VariantsDataStruct;
310///
311/// #[derive(VariantsDataStruct)]
312/// pub enum MyEnum {
313///     UnitEnum,
314///     TupleEnum(i32, String),
315///     StructEnum { id: u32, name: String },
316/// }
317///
318/// // Equivalent to:
319/// // pub struct MyEnumVariantsData {
320/// //     pub unit_enum: (),
321/// //     pub tuple_enum: TupleEnumVariantType,
322/// //     pub struct_enum: StructEnumVariantType,
323/// // }
324/// //
325/// // pub struct TupleEnumVariantType(pub i32, pub String);
326/// //
327/// // pub struct StructEnumVariantType {
328/// //     pub id: u32,
329/// //     pub name: String,
330/// // }
331/// ```
332///
333/// ## Helper attributes
334///
335/// ### `#[variants_data_struct(<meta>)]` customizes the behavior of the derive macro.
336/// The `<meta>` is a comma-separated list that can contain the following items:
337///
338/// - `name = <CustomName>`: Specifies a custom name for the generated data struct.
339///  If not provided, the default name is `<EnumName>VariantsData`.
340/// - `attrs(#[derive(...)] ...)`: Adds the specified attributes to the generated data struct. Notably, you
341/// can use it to add derives like `Debug`, `Clone` to the generated struct.
342/// - `variants_tys_attrs(#[derive(...)] ...)`: Adds the specified attributes to each of the generated variant type structs.
343/// Notably, you can use it to add derives like `Debug`, `Clone` to the generated variant type structs.
344#[proc_macro_derive(VariantsDataStruct, attributes(variants_data_struct))]
345pub fn derive_variants_data_struct(item: TokenStream) -> TokenStream {
346    let input = syn::parse_macro_input!(item as syn::DeriveInput);
347    let syn::DeriveInput {
348        attrs,
349        vis,
350        ident,
351        generics,
352        data,
353    } = input;
354    let syn::Data::Enum(enum_data) = data else {
355        return syn::Error::new_spanned(
356            ident,
357            concat!(
358                stringify!(VariantsDataStruct),
359                " can only be derived for enums"
360            ),
361        )
362        .to_compile_error()
363        .into();
364    };
365
366    let variants_data_struct_attr_meta: VariantsDataStructAttrMeta =
367        match VariantsDataStructAttrMeta::from_attrs(attrs) {
368            Ok(Some(meta)) => meta,
369            Ok(None) => VariantsDataStructAttrMeta {
370                struct_name: None,
371                variants_tys_attrs: vec![],
372                attrs: vec![],
373            },
374            Err(err) => return err.to_compile_error().into(),
375        };
376
377    let struct_name: syn::Ident = match variants_data_struct_attr_meta.struct_name {
378        Some(name) => name,
379        None => proc_macro2::Ident::new(format!("{}VariantsData", ident).as_str(), ident.span()),
380    };
381
382    let VariantsDataStructDefs {
383        derived_struct,
384        variant_type_structs,
385    } = variants_data_struct_defs(
386        variants_data_struct_attr_meta.attrs,
387        variants_data_struct_attr_meta.variants_tys_attrs,
388        vis,
389        struct_name,
390        generics,
391        enum_data.variants,
392    );
393
394    quote::quote! {
395        #derived_struct
396
397        #(#variant_type_structs)*
398    }
399    .into()
400}