variants_struct/
lib.rs

1//! A derive macro to convert enums into a struct where the variants are members.
2//! Effectively, its like using a `HashMap<MyEnum, MyData>`, but it generates a hard-coded struct instead
3//! of a HashMap to reduce overhead.
4//!
5//! # Basic Example
6//!
7//! Applying the macro to a basic enum (i.e. one without tuple variants or struct variants) like this:
8//!
9//! ```
10//! use variants_struct::VariantsStruct;
11//!
12//! #[derive(VariantsStruct)]
13//! enum Hello {
14//!     World,
15//!     There
16//! }
17//! ```
18//!
19//! would produce the following code:
20//!
21//! ```
22//! # enum Hello {
23//! #     World,
24//! #     There
25//! # }
26//! struct HelloStruct<T> {
27//!     pub world: T,
28//!     pub there: T
29//! }
30//!
31//! impl<T> HelloStruct<T> {
32//!     pub fn new(world: T, there: T) -> HelloStruct<T> {
33//!         HelloStruct {
34//!             world,
35//!             there
36//!         }
37//!     }
38//!
39//!     pub fn get_unchecked(&self, var: &Hello) -> &T {
40//!         match var {
41//!             &Hello::World => &self.world,
42//!             &Hello::There => &self.there
43//!         }
44//!     }
45//!
46//!     pub fn get_mut_unchecked(&mut self, var: &Hello) -> &mut T {
47//!         match var {
48//!             &Hello::World => &mut self.world,
49//!             &Hello::There => &mut self.there
50//!         }
51//!     }
52//!
53//!     pub fn get(&self, var: &Hello) -> Option<&T> {
54//!         match var {
55//!             &Hello::World => Some(&self.world),
56//!             &Hello::There => Some(&self.there)
57//!         }
58//!     }
59//!
60//!     pub fn get_mut(&mut self, var: &Hello) -> Option<&mut T> {
61//!         match var {
62//!             &Hello::World => Some(&mut self.world),
63//!             &Hello::There => Some(&mut self.there)
64//!         }
65//!     }
66//! }
67//! ```
68//!
69//! The members can be accessed either directly (like `hello.world`) or by using the getter methods, like:
70//!
71//! ```
72//! # use variants_struct::VariantsStruct;
73//! # #[derive(VariantsStruct)]
74//! # enum Hello {
75//! #     World,
76//! #     There
77//! # }
78//! let mut hello = HelloStruct::new(2, 3);
79//! *hello.get_mut_unchecked(&Hello::World) = 5;
80//!
81//! assert_eq!(hello.world, 5);
82//! assert_eq!(hello.world, *hello.get_unchecked(&Hello::World));
83//! ```
84//!
85//! The getters can be particularly useful with the [enum-iterator](https://docs.rs/crate/enum-iterator/) crate. For basic enums,
86//! the checked-getters will always return `Some(...)`, so using `get_unchecked` is recommended, *but this is not the case when the enum contains tuple variants*.
87//!
88//! Keep in mind that the enum variants are renamed from CamelCase to snake_case, to be consistent with Rust's naming conventions.
89//!
90//! # Visibility
91//!
92//! The struct fields are always `pub`, and the struct shares the same visibility as the enum.
93//!
94//! # Customizing the struct
95//!
96//! ## Renaming
97//!
98//! By default, the struct's name is `<OriginalEnumName>Struct`. You can set it to something else with the `struct_name` attribute. For example, this:
99//!
100//! ```
101//! # use variants_struct::VariantsStruct;
102//! #[derive(VariantsStruct)]
103//! #[struct_name = "SomeOtherName"]
104//! enum NotThisName {
105//!     Variant
106//! }
107//! ```
108//!
109//! will produce a struct with name `SomeOtherName`.
110//!
111//! You can also rename the individual fields manually with the `field_name` attribute. For example, this:
112//!
113//! ```
114//! # use variants_struct::VariantsStruct;
115//! #[derive(VariantsStruct)]
116//! enum ChangeMyVariantName {
117//!     #[field_name = "this_name"] NotThisName
118//! }
119//! ```
120//!
121//! Will produce the following struct:
122//!
123//! ```
124//! struct ChangeMyVariantName<T> {
125//!     this_name: T
126//! }
127//! ```
128//!
129//! ## Derives
130//!
131//! By default no derives are applied to the generated struct. You can add derive macro invocations with the `struct_derive` attribute. For example, this:
132//!
133//! ```
134//! # use variants_struct::VariantsStruct;
135//! use serde::{Serialize, Deserialize};
136//!
137//! #[derive(VariantsStruct)]
138//! #[struct_derive(Debug, Default, Serialize, Deserialize)]
139//! enum Hello {
140//!     World,
141//!     There
142//! }
143//! ```
144//!
145//! would produce the following code:
146//!
147//! ```
148//! # use serde::{Serialize, Deserialize};
149//! #[derive(Debug, Default, Serialize, Deserialize)]
150//! struct HelloStruct<T> {
151//!     pub world: T,
152//!     pub there: T
153//! }
154//!
155//! // impl block omitted
156//! ```
157//!
158//! ## Trait Bounds
159//!
160//! By default the struct's type argument `T` has no trait bounds, but you can add them with the `struct_bounds` attribute. For example, this:
161//!
162//! ```
163//! # use variants_struct::VariantsStruct;
164//! #[derive(VariantsStruct)]
165//! #[struct_bounds(Copy + Clone)]
166//! enum Hello {
167//!     World,
168//!     There
169//! }
170//! ```
171//!
172//! would produce the following code:
173//!
174//! ```
175//! struct HelloStruct<T: Copy + Clone> {
176//!     # go_away: T,
177//!     // fields omitted
178//! }
179//!
180//! impl<T: Copy + Clone> HelloStruct<T> {
181//!     // methods omitted
182//! }
183//! ```
184//!
185//! ## Arbitrary attributes
186//!
187//! To apply other arbitrary attributes to the struct, use `#[struct_attr(...)]`. For example, if you apply
188//! `serde::Serialize` to the struct, and your bounds already include a trait that requires `T: Serialize`,
189//! serde will give an error. Serde documentation tells you to add `#[serde(bound(serialize = ...))]`,
190//! and you can pass that along with `struct_attr`.
191//!
192//! ```
193//! # use variants_struct::VariantsStruct;
194//! # use serde::Serialize;
195//! trait MyTrait: Serialize {}
196//!
197//! #[derive(VariantsStruct)]
198//! #[struct_derive(Serialize)]
199//! #[struct_bounds(MyTrait)]
200//! #[struct_attr(serde(bound(serialize = "T: MyTrait")))]
201//! enum MyEnum {
202//!     MyVariant
203//! }
204//! ```
205//!
206//! ## Combinations
207//!
208//! Note that many derives don't require that the type argument `T` fulfills any trait bounds. For example, applying the `Clone`
209//! derive to the struct only makes the struct cloneable if `T` is cloneable, and still allows un-cloneable types to be used with the struct.
210//!
211//! So if you want the struct to *always* be cloneable, you have to use both the derive and the trait bound:
212//!
213//! ```
214//! # use variants_struct::VariantsStruct;
215//! #[derive(VariantsStruct)]
216//! #[struct_derive(Clone)]
217//! #[struct_bounds(Clone)]
218//! enum MyEnum {
219//!     MyVariant
220//! }
221//! ```
222//!
223//! These two attributes, and the `struct_name` attribute, can be used in any order, or even multiple times (although that wouldn't be very readable).
224//!
225//! # Tuple and Struct Variants
226//!
227//! Tuple variants are turned into a `HashMap`, where the data stored in the tuple is the key (so the data must implement `Hash`).
228//! Unfortunately, variants with more than one field in them are not supported.
229//!
230//! Tuple variants are omitted from the struct's `new` function. For example, this:
231//!
232//! ```
233//! # use variants_struct::VariantsStruct;
234//! #[derive(VariantsStruct)]
235//! enum Hello {
236//!     World,
237//!     There(i32)
238//! }
239//! ```
240//!
241//! produces the following code:
242//!
243//! ```
244//! # enum Hello {
245//! #     World,
246//! #     There(i32)
247//! # }
248//! struct HelloStruct<T> {
249//!     pub world: T,
250//!     pub there: std::collections::HashMap<i32, T>
251//! }
252//!
253//! impl<T> HelloStruct<T> {
254//!     fn new(world: T) -> HelloStruct<T> {
255//!         HelloStruct {
256//!             world,
257//!             there: std::collections::HashMap::new()
258//!         }
259//!     }
260//!
261//!     pub fn get_unchecked(&self, var: &Hello) -> &T {
262//!         match var {
263//!             &Hello::World => &self.world,
264//!             &Hello::There(key) => self.there.get(&key)
265//!                 .expect("tuple variant key not found in hashmap")
266//!         }
267//!     }
268//!
269//!     pub fn get_mut_unchecked(&mut self, var: &Hello) -> &mut T {
270//!         match var {
271//!             &Hello::World => &mut self.world,
272//!             &Hello::There(key) => self.there.get_mut(&key)
273//!                 .expect("tuple variant key not found in hashmap")
274//!         }
275//!     }
276//!
277//!     pub fn get(&self, var: &Hello) -> Option<&T> {
278//!         match var {
279//!             &Hello::World => Some(&self.world),
280//!             &Hello::There(key) => self.there.get(&key)
281//!         }
282//!     }
283//!
284//!     pub fn get_mut(&mut self, var: &Hello) -> Option<&mut T> {
285//!         match var {
286//!             &Hello::World => Some(&mut self.world),
287//!             &Hello::There(key) => self.there.get_mut(&key)
288//!         }
289//!     }
290//! }
291//! ```
292//!
293//! Notice that the `new` function now only takes the `world` argument, and the unchecked getter methods query the hashmap and unwrap the result.
294//!
295//! The same can also be done in struct variants that have only one field.
296
297use check_keyword::CheckKeyword;
298use heck::ToSnekCase;
299use proc_macro::TokenStream;
300use proc_macro_error2::{emit_error, proc_macro_error};
301use quote::{format_ident, quote};
302use syn::{Fields, Ident, ItemEnum, parse_macro_input};
303
304/// Stores basic information about variants.
305struct VariantInfo {
306    normal: Ident,
307    snake: Ident,
308    fields: Fields,
309}
310
311/// Derives the variants struct and impl.
312#[proc_macro_error]
313#[proc_macro_derive(
314    VariantsStruct,
315    attributes(struct_bounds, struct_derive, struct_name, field_name, struct_attr)
316)]
317pub fn variants_struct(input: TokenStream) -> TokenStream {
318    let input = parse_macro_input!(input as ItemEnum);
319    let enum_ident = input.ident.clone();
320    let mut struct_ident = format_ident!("{}Struct", input.ident);
321    let visibility = input.vis.clone();
322
323    // read the `struct_bounds`, `struct_derive`, and `struct_name` attributes. (ignore any others)
324    let mut bounds = quote! {};
325    let mut derives = vec![];
326    let mut attrs = vec![];
327    for attr in input.clone().attrs {
328        if attr.path().is_ident("struct_bounds") {
329            let syn::Meta::List(l) = attr.meta else {
330                emit_error!(
331                    attr,
332                    "struct_bounds must be of the form #[struct_bounds(Bound)]"
333                );
334                return quote! {}.into();
335            };
336            bounds = l.tokens;
337        } else if attr.path().is_ident("struct_derive") {
338            attr.parse_nested_meta(|meta| {
339                derives.push(meta.path);
340                Ok(())
341            })
342            .unwrap();
343        } else if attr.path().is_ident("struct_name") {
344            if let syn::Meta::NameValue(syn::MetaNameValue { value, .. }) = attr.meta {
345                if let syn::Expr::Lit(syn::ExprLit {
346                    lit: syn::Lit::Str(lit_str),
347                    ..
348                }) = value
349                {
350                    struct_ident = format_ident!("{}", lit_str.value());
351                } else {
352                    emit_error!(value, "must be a str literal");
353                }
354            }
355        } else if attr.path().is_ident("struct_attr") {
356            let syn::Meta::List(l) = attr.meta else {
357                emit_error!(attr, "struct_attr must be of the form #[struct_attr(attr)]");
358                return quote! {}.into();
359            };
360            attrs.push(l.tokens);
361        }
362    }
363
364    if input.variants.is_empty() {
365        return (quote! {
366            #[derive(#(#derives),*)]
367            #visibility struct #struct_ident;
368        })
369        .into();
370    }
371
372    let vars: Vec<_> = input
373        .clone()
374        .variants
375        .iter()
376        .map(|var| {
377            let mut names = vec![];
378            for attr in &var.attrs {
379                if attr.path().is_ident("field_name") {
380                    if let syn::Meta::NameValue(syn::MetaNameValue { value, .. }) = &attr.meta {
381                        if let syn::Expr::Lit(syn::ExprLit {
382                            lit: syn::Lit::Str(lit_str),
383                            ..
384                        }) = value
385                        {
386                            names.push(lit_str.value());
387                        } else {
388                            emit_error!(value, "must be a str literal");
389                        }
390                    }
391                }
392            }
393
394            let snake = if names.is_empty() {
395                format_ident!("{}", var.ident.to_string().to_snek_case().into_safe())
396            } else {
397                format_ident!("{}", names.first().unwrap().into_safe())
398            };
399            VariantInfo {
400                normal: var.ident.clone(),
401                snake,
402                fields: var.fields.clone(),
403            }
404        })
405        .collect();
406
407    // generate the fields and impl code
408    let mut field_idents = vec![];
409    let mut field_names = vec![];
410    let mut struct_fields = vec![];
411    let mut get_uncheckeds = vec![];
412    let mut get_mut_uncheckeds = vec![];
413    let mut gets = vec![];
414    let mut get_muts = vec![];
415    let mut new_args = vec![];
416    let mut new_fields = vec![];
417    for VariantInfo {
418        normal,
419        snake,
420        fields,
421    } in &vars
422    {
423        field_idents.push(snake.clone());
424        field_names.push(snake.to_string());
425        match fields {
426            Fields::Unit => {
427                struct_fields.push(quote! { pub #snake: T });
428                gets.push(quote! { &#enum_ident::#normal => Some(&self.#snake) });
429                get_muts.push(quote! { &#enum_ident::#normal => Some(&mut self.#snake) });
430                get_uncheckeds.push(quote! { &#enum_ident::#normal => &self.#snake });
431                get_mut_uncheckeds.push(quote! { &#enum_ident::#normal => &mut self.#snake });
432                new_args.push(quote! {#snake: T});
433                new_fields.push(quote! {#snake});
434            }
435            Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) => {
436                if unnamed.len() == 1 {
437                    let ty = unnamed.first().unwrap().clone().ty;
438                    struct_fields.push(quote! {
439                        pub #snake: std::collections::HashMap<#ty, T>
440                    });
441                    gets.push(quote! {
442                        &#enum_ident::#normal(key) => self.#snake.get(&key)
443                    });
444                    get_muts.push(quote! {
445                        &#enum_ident::#normal(key) => self.#snake.get_mut(&key)
446                    });
447                    get_uncheckeds.push(quote! {
448                        &#enum_ident::#normal(key) => self.#snake.get(&key)
449                            .expect("tuple variant key not found in hashmap")
450                    });
451                    get_mut_uncheckeds.push(quote! {
452                        &#enum_ident::#normal(key) => self.#snake.get_mut(&key)
453                            .expect("tuple variant key not found in hashmap")
454                    });
455                    new_fields.push(quote! {#snake: std::collections::HashMap::new()});
456                } else {
457                    emit_error!(unnamed, "only tuples with one value are allowed");
458                }
459            }
460            Fields::Named(syn::FieldsNamed { named, .. }) => {
461                if named.len() == 1 {
462                    let ty = named.first().unwrap().clone().ty;
463                    let ident = named.first().unwrap().ident.clone().unwrap();
464                    struct_fields.push(quote! {
465                        pub #snake: std::collections::HashMap<#ty, T>
466                    });
467                    gets.push(quote! {
468                        &#enum_ident::#normal {#ident}  => self.#snake.get(&#ident)
469                    });
470                    get_muts.push(quote! {
471                        &#enum_ident::#normal {#ident}  => self.#snake.get_mut(&#ident)
472                    });
473                    get_uncheckeds.push(quote! {
474                        &#enum_ident::#normal {#ident} => self.#snake.get(&#ident)
475                            .expect("tuple variant key not found in hashmap")
476                    });
477                    get_mut_uncheckeds.push(quote! {
478                        &#enum_ident::#normal {#ident} => self.#snake.get_mut(&#ident)
479                            .expect("tuple variant key not found in hashmap")
480                    });
481                    new_fields.push(quote! {#snake: std::collections::HashMap::new()});
482                } else {
483                    emit_error!(named, "only structs with one field are allowed");
484                }
485            }
486        }
487    }
488
489    // combine it all together
490    (quote! {
491        #[derive(#(#derives),*)]
492        #(#[#attrs])*
493        #visibility struct #struct_ident<T: #bounds> {
494            #(#struct_fields),*
495        }
496
497        impl<T: #bounds> #struct_ident<T> {
498            pub fn new(#(#new_args),*) -> #struct_ident<T> {
499                #struct_ident {
500                    #(#new_fields),*
501                }
502            }
503
504            pub fn get_unchecked(&self, var: &#enum_ident) -> &T {
505                match var {
506                    #(#get_uncheckeds),*
507                }
508            }
509
510            pub fn get_mut_unchecked(&mut self, var: &#enum_ident) -> &mut T {
511                match var {
512                    #(#get_mut_uncheckeds),*
513                }
514            }
515
516            pub fn get(&self, var: &#enum_ident) -> Option<&T> {
517                match var {
518                    #(#gets),*
519                }
520            }
521
522            pub fn get_mut(&mut self, var: &#enum_ident) -> Option<&mut T> {
523                match var {
524                    #(#get_muts),*
525                }
526            }
527        }
528
529        impl<T: #bounds> std::ops::Index<#enum_ident> for #struct_ident<T> {
530            type Output = T;
531            fn index(&self, var: #enum_ident) -> &T {
532                self.get_unchecked(&var)
533            }
534        }
535
536        impl<T: #bounds> std::ops::IndexMut<#enum_ident> for #struct_ident<T> {
537            fn index_mut(&mut self, var: #enum_ident) -> &mut T {
538                self.get_mut_unchecked(&var)
539            }
540        }
541
542        impl<T: #bounds> std::ops::Index<&#enum_ident> for #struct_ident<T> {
543            type Output = T;
544            fn index(&self, var: &#enum_ident) -> &T {
545                self.get_unchecked(var)
546            }
547        }
548
549        impl<T: #bounds> std::ops::IndexMut<&#enum_ident> for #struct_ident<T> {
550            fn index_mut(&mut self, var: &#enum_ident) -> &mut T {
551                self.get_mut_unchecked(var)
552            }
553        }
554    })
555    .into()
556}