variants_struct/
lib.rs

1//! A derive macro to convert enums into a struct where the variants are members.
2//! Effectively, its like using a `HashMap<MyEnum, MyData>`, but it generates a hard-coded struct instead
3//! of a HashMap to reduce overhead.
4//!
5//! # Basic Example
6//!
7//! Applying the macro to a basic enum (i.e. one without tuple variants or struct variants) like this:
8//!
9//! ```
10//! use variants_struct::VariantsStruct;
11//!
12//! #[derive(VariantsStruct)]
13//! enum Hello {
14//!     World,
15//!     There
16//! }
17//! ```
18//!
19//! would produce the following code:
20//!
21//! ```
22//! # enum Hello {
23//! #     World,
24//! #     There
25//! # }
26//! struct HelloStruct<T> {
27//!     pub world: T,
28//!     pub there: T
29//! }
30//!
31//! impl<T> HelloStruct<T> {
32//!     pub fn new(world: T, there: T) -> HelloStruct<T> {
33//!         HelloStruct {
34//!             world,
35//!             there
36//!         }
37//!     }
38//!
39//!     pub fn get_unchecked(&self, var: &Hello) -> &T {
40//!         match var {
41//!             &Hello::World => &self.world,
42//!             &Hello::There => &self.there
43//!         }
44//!     }
45//!
46//!     pub fn get_mut_unchecked(&mut self, var: &Hello) -> &mut T {
47//!         match var {
48//!             &Hello::World => &mut self.world,
49//!             &Hello::There => &mut self.there
50//!         }
51//!     }
52//!
53//!     pub fn get(&self, var: &Hello) -> Option<&T> {
54//!         match var {
55//!             &Hello::World => Some(&self.world),
56//!             &Hello::There => Some(&self.there)
57//!         }
58//!     }
59//!
60//!     pub fn get_mut(&mut self, var: &Hello) -> Option<&mut T> {
61//!         match var {
62//!             &Hello::World => Some(&mut self.world),
63//!             &Hello::There => Some(&mut self.there)
64//!         }
65//!     }
66//! }
67//! ```
68//!
69//! The members can be accessed either directly (like `hello.world`) or by using the getter methods, like:
70//!
71//! ```
72//! # use variants_struct::VariantsStruct;
73//! # #[derive(VariantsStruct)]
74//! # enum Hello {
75//! #     World,
76//! #     There
77//! # }
78//! let mut hello = HelloStruct::new(2, 3);
79//! *hello.get_mut_unchecked(&Hello::World) = 5;
80//!
81//! assert_eq!(hello.world, 5);
82//! assert_eq!(hello.world, *hello.get_unchecked(&Hello::World));
83//! ```
84//!
85//! The getters can be particularly useful with the [enum-iterator](https://docs.rs/crate/enum-iterator/) crate. For basic enums,
86//! the checked-getters will always return `Some(...)`, so using `get_unchecked` is recommended, *but this is not the case when the enum contains tuple variants*.
87//!
88//! Keep in mind that the enum variants are renamed from CamelCase to snake_case, to be consistent with Rust's naming conventions.
89//!
90//! # Visibility
91//!
92//! The struct fields are always `pub`, and the struct shares the same visibility as the enum.
93//!
94//! # Customizing the struct
95//!
96//! ## Renaming
97//!
98//! By default, the struct's name is `<OriginalEnumName>Struct`. You can set it to something else with the `struct_name` attribute. For example, this:
99//!
100//! ```
101//! # use variants_struct::VariantsStruct;
102//! #[derive(VariantsStruct)]
103//! #[struct_name = "SomeOtherName"]
104//! enum NotThisName {
105//!     Variant
106//! }
107//! ```
108//!
109//! will produce a struct with name `SomeOtherName`.
110//!
111//! You can also rename the individual fields manually with the `field_name` attribute. For example, this:
112//!
113//! ```
114//! # use variants_struct::VariantsStruct;
115//! #[derive(VariantsStruct)]
116//! enum ChangeMyVariantName {
117//!     #[field_name = "this_name"] NotThisName
118//! }
119//! ```
120//!
121//! Will produce the following struct:
122//!
123//! ```
124//! struct ChangeMyVariantName<T> {
125//!     this_name: T
126//! }
127//! ```
128//!
129//! ## Derives
130//!
131//! By default no derives are applied to the generated struct. You can add derive macro invocations with the `struct_derive` attribute. For example, this:
132//!
133//! ```
134//! # use variants_struct::VariantsStruct;
135//! use serde::{Serialize, Deserialize};
136//!
137//! #[derive(VariantsStruct)]
138//! #[struct_derive(Debug, Default, Serialize, Deserialize)]
139//! enum Hello {
140//!     World,
141//!     There
142//! }
143//! ```
144//!
145//! would produce the following code:
146//!
147//! ```
148//! # use serde::{Serialize, Deserialize};
149//! #[derive(Debug, Default, Serialize, Deserialize)]
150//! struct HelloStruct<T> {
151//!     pub world: T,
152//!     pub there: T
153//! }
154//!
155//! // impl block omitted
156//! ```
157//!
158//! ## Trait Bounds
159//!
160//! By default the struct's type argument `T` has no trait bounds, but you can add them with the `struct_bounds` attribute. For example, this:
161//!
162//! ```
163//! # use variants_struct::VariantsStruct;
164//! #[derive(VariantsStruct)]
165//! #[struct_bounds(Clone)]
166//! enum Hello {
167//!     World,
168//!     There
169//! }
170//! ```
171//!
172//! would produce the following code:
173//!
174//! ```
175//! struct HelloStruct<T: Clone> {
176//!     # go_away: T,
177//!     // fields omitted
178//! }
179//!
180//! impl<T: Clone> HelloStruct<T> {
181//!     // methods omitted
182//! }
183//! ```
184//!
185//! ## Combinations
186//!
187//! Note that many derives don't require that the type argument `T` fulfills any trait bounds. For example, applying the `Clone`
188//! derive to the struct only makes the struct cloneable if `T` is cloneable, and still allows un-cloneable types to be used with the struct.
189//!
190//! So if you want the struct to *always* be cloneable, you have to use both the derive and the trait bound:
191//!
192//! ```
193//! # use variants_struct::VariantsStruct;
194//! #[derive(VariantsStruct)]
195//! #[struct_derive(Clone)]
196//! #[struct_bounds(Clone)]
197//! enum Hello {
198//!     // variants omitted
199//! }
200//! ```
201//!
202//! These two attributes, and the `struct_name` attribute, can be used in any order, or even multiple times (although that wouldn't be very readable).
203//!
204//! # Tuple and Struct Variants
205//!
206//! Tuple variants are turned into a `HashMap`, where the data stored in the tuple is the key (so the data must implement `Hash`).
207//! Unfortunately, variants with more than one field in them are not supported.
208//!
209//! Tuple variants are omitted from the struct's `new` function. For example, this:
210//!
211//! ```
212//! # use variants_struct::VariantsStruct;
213//! #[derive(VariantsStruct)]
214//! enum Hello {
215//!     World,
216//!     There(i32)
217//! }
218//! ```
219//!
220//! produces the following code:
221//!
222//! ```
223//! # enum Hello {
224//! #     World,
225//! #     There(i32)
226//! # }
227//! struct HelloStruct<T> {
228//!     pub world: T,
229//!     pub there: std::collections::HashMap<i32, T>
230//! }
231//!
232//! impl<T> HelloStruct<T> {
233//!     fn new(world: T) -> HelloStruct<T> {
234//!         HelloStruct {
235//!             world,
236//!             there: std::collections::HashMap::new()
237//!         }
238//!     }
239//!
240//!     pub fn get_unchecked(&self, var: &Hello) -> &T {
241//!         match var {
242//!             &Hello::World => &self.world,
243//!             &Hello::There(key) => self.there.get(&key)
244//!                 .expect("tuple variant key not found in hashmap")
245//!         }
246//!     }
247//!
248//!     pub fn get_mut_unchecked(&mut self, var: &Hello) -> &mut T {
249//!         match var {
250//!             &Hello::World => &mut self.world,
251//!             &Hello::There(key) => self.there.get_mut(&key)
252//!                 .expect("tuple variant key not found in hashmap")
253//!         }
254//!     }
255//!
256//!     pub fn get(&self, var: &Hello) -> Option<&T> {
257//!         match var {
258//!             &Hello::World => Some(&self.world),
259//!             &Hello::There(key) => self.there.get(&key)
260//!         }
261//!     }
262//!
263//!     pub fn get_mut(&mut self, var: &Hello) -> Option<&mut T> {
264//!         match var {
265//!             &Hello::World => Some(&mut self.world),
266//!             &Hello::There(key) => self.there.get_mut(&key)
267//!         }
268//!     }
269//! }
270//! ```
271//!
272//! Notice that the `new` function now only takes the `world` argument, and the unchecked getter methods query the hashmap and unwrap the result.
273//!
274//! The same can also be done in struct variants that have only one field.
275
276use proc_macro::TokenStream;
277use syn::{Ident, parse_macro_input, ItemEnum, Fields};
278use quote::{quote, format_ident};
279use inflector::Inflector;
280use proc_macro_error::{proc_macro_error, emit_error, abort};
281use check_keyword::CheckKeyword;
282
283/// Stores basic information about variants.
284struct VariantInfo {
285    normal: Ident,
286    snake: Ident,
287    fields: Fields
288}
289
290/// Derives the variants struct and impl.
291#[proc_macro_error]
292#[proc_macro_derive(VariantsStruct, attributes(struct_bounds, struct_derive, struct_name, field_name))]
293pub fn variants_struct(input: TokenStream) -> TokenStream {
294    let input = parse_macro_input!(input as ItemEnum);
295    let enum_ident = input.ident.clone();
296    let mut struct_ident = format_ident!("{}Struct", input.ident);
297    let visibility = input.vis.clone();
298
299    // read the `struct_bounds`, `struct_derive`, and `struct_name` attributes. (ignore any others)
300    let mut bounds = vec![];
301    let mut derives = vec![];
302    for attr in input.clone().attrs {
303        match attr.parse_meta() {
304            Ok(syn::Meta::List(syn::MetaList {path, nested, ..})) => {
305                if let Some(ident) = path.get_ident() {
306                    let attr_name = ident.to_string();
307                    if attr_name == "struct_bounds" || attr_name == "struct_derive" {
308                        let mut paths = vec![];
309                        for meta in nested {
310                            match meta {
311                                syn::NestedMeta::Meta(syn::Meta::Path(path)) => {
312                                    paths.push(path.clone());
313                                }
314                                _ => emit_error!(path, "only path arguments are accepted")
315                            }
316                        }
317                        if attr_name == "struct_bounds" {
318                            bounds.extend(paths);
319                        } else {
320                            derives.extend(paths);
321                        }
322                    }
323                }
324            }
325            Ok(syn::Meta::NameValue(syn::MetaNameValue {path, lit, ..})) => {
326                if let Some(ident) = path.get_ident() {
327                    let attr_name = ident.to_string();
328                    if attr_name == "struct_name" {
329                        if let syn::Lit::Str(lit_str) = lit {
330                            struct_ident = format_ident!("{}", lit_str.value());
331                        } else {
332                            emit_error!(lit, "must be a str literal");
333                        }
334                    }
335                }
336            }
337            _ => {}
338        }
339    }
340
341    if input.variants.len() == 0 {
342        return (quote! {
343            #[derive(#(#derives),*)]
344            #visibility struct #struct_ident;
345        }).into()
346    }
347
348    let vars: Vec<_> = input.clone().variants.iter().map(
349        |var| {
350            let snake = {
351                let names: Vec<_> = var.attrs.iter().filter_map(
352                    |attr| {
353                        match attr.parse_meta() {
354                            Ok(syn::Meta::NameValue(syn::MetaNameValue {path, lit, ..})) => {
355                                if let Some(ident) = path.get_ident() {
356                                    if ident.to_string() == "field_name" {
357                                        if let syn::Lit::Str(lit_str) = lit {
358                                            Some(lit_str.value())
359                                        } else {
360                                            abort!(lit, "must be a string literal");
361                                        }
362                                    } else {
363                                        None
364                                    }
365                                } else {
366                                    None
367                                }
368                            }
369                            _ => None
370                        }
371                    }
372                ).collect();
373                if names.is_empty() {
374                    let name = var.ident.to_string().to_snake_case();
375                    format_ident!("{}", name.into_safe())
376                } else {
377                    format_ident!("{}", names.first().unwrap().to_safe())
378                }
379            };
380            VariantInfo {
381                normal: var.ident.clone(),
382                snake,
383                fields: var.fields.clone()
384            }
385        }
386    ).collect();
387
388    // generate the fields and impl code
389    let mut field_idents = vec![];
390    let mut field_names = vec![];
391    let mut struct_fields = vec![];
392    let mut get_uncheckeds = vec![];
393    let mut get_mut_uncheckeds = vec![];
394    let mut gets = vec![];
395    let mut get_muts = vec![];
396    let mut new_args = vec![];
397    let mut new_fields = vec![];
398    for VariantInfo { normal, snake, fields } in &vars {
399        field_idents.push(snake.clone());
400        field_names.push(snake.to_string());
401        match fields {
402            Fields::Unit => {
403                struct_fields.push(quote! { pub #snake: T });
404                gets.push(quote! { &#enum_ident::#normal => Some(&self.#snake) });
405                get_muts.push(quote! { &#enum_ident::#normal => Some(&mut self.#snake) });
406                get_uncheckeds.push(quote! { &#enum_ident::#normal => &self.#snake });
407                get_mut_uncheckeds.push(quote! { &#enum_ident::#normal => &mut self.#snake });
408                new_args.push(quote! {#snake: T});
409                new_fields.push(quote! {#snake});
410            }
411            Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) => {
412                if unnamed.len() == 1 {
413                    let ty = unnamed.first().unwrap().clone().ty;
414                    struct_fields.push(quote! {
415                        pub #snake: std::collections::HashMap<#ty, T>
416                    });
417                    gets.push(quote! {
418                        &#enum_ident::#normal(key) => self.#snake.get(&key)
419                    });
420                    get_muts.push(quote! {
421                        &#enum_ident::#normal(key) => self.#snake.get_mut(&key)
422                    });
423                    get_uncheckeds.push(quote! {
424                        &#enum_ident::#normal(key) => self.#snake.get(&key)
425                            .expect("tuple variant key not found in hashmap")
426                    });
427                    get_mut_uncheckeds.push(quote! {
428                        &#enum_ident::#normal(key) => self.#snake.get_mut(&key)
429                            .expect("tuple variant key not found in hashmap")
430                    });
431                    new_fields.push(quote! {#snake: std::collections::HashMap::new()});
432                } else {
433                    emit_error!(unnamed, "only tuples with one value are allowed");
434                }
435            }
436            Fields::Named(syn::FieldsNamed { named, .. }) => {
437                if named.len() == 1 {
438                    let ty = named.first().unwrap().clone().ty;
439                    let ident = named.first().unwrap().ident.clone().unwrap();
440                    struct_fields.push(quote! {
441                        pub #snake: std::collections::HashMap<#ty, T>
442                    });
443                    gets.push(quote! {
444                        &#enum_ident::#normal {#ident}  => self.#snake.get(&#ident)
445                    });
446                    get_muts.push(quote! {
447                        &#enum_ident::#normal {#ident}  => self.#snake.get_mut(&#ident)
448                    });
449                    get_uncheckeds.push(quote! {
450                        &#enum_ident::#normal {#ident} => self.#snake.get(&#ident)
451                            .expect("tuple variant key not found in hashmap")
452                    });
453                    get_mut_uncheckeds.push(quote! {
454                        &#enum_ident::#normal {#ident} => self.#snake.get_mut(&#ident)
455                            .expect("tuple variant key not found in hashmap")
456                    });
457                    new_fields.push(quote! {#snake: std::collections::HashMap::new()});
458                } else {
459                    emit_error!(named, "only structs with one field are allowed");
460                }
461            }
462        }
463    }
464
465    // combine it all together
466    (quote! {
467        #[derive(#(#derives),*)]
468        #visibility struct #struct_ident<T: #(#bounds)+*> {
469            #(#struct_fields),*
470        }
471
472        impl<T: #(#bounds)+*> #struct_ident<T> {
473            pub fn new(#(#new_args),*) -> #struct_ident<T> {
474                #struct_ident {
475                    #(#new_fields),*
476                }
477            }
478
479            pub fn get_unchecked(&self, var: &#enum_ident) -> &T {
480                match var {
481                    #(#get_uncheckeds),*
482                }
483            }
484
485            pub fn get_mut_unchecked(&mut self, var: &#enum_ident) -> &mut T {
486                match var {
487                    #(#get_mut_uncheckeds),*
488                }
489            }
490
491            pub fn get(&self, var: &#enum_ident) -> Option<&T> {
492                match var {
493                    #(#gets),*
494                }
495            }
496
497            pub fn get_mut(&mut self, var: &#enum_ident) -> Option<&mut T> {
498                match var {
499                    #(#get_muts),*
500                }
501            }
502        }
503    }).into()
504}