variants_struct/lib.rs
1//! A derive macro to convert enums into a struct where the variants are members.
2//! Effectively, its like using a `HashMap<MyEnum, MyData>`, but it generates a hard-coded struct instead
3//! of a HashMap to reduce overhead.
4//!
5//! # Basic Example
6//!
7//! Applying the macro to a basic enum (i.e. one without tuple variants or struct variants) like this:
8//!
9//! ```
10//! use variants_struct::VariantsStruct;
11//!
12//! #[derive(VariantsStruct)]
13//! enum Hello {
14//! World,
15//! There
16//! }
17//! ```
18//!
19//! would produce the following code:
20//!
21//! ```
22//! # enum Hello {
23//! # World,
24//! # There
25//! # }
26//! struct HelloStruct<T> {
27//! pub world: T,
28//! pub there: T
29//! }
30//!
31//! impl<T> HelloStruct<T> {
32//! pub fn new(world: T, there: T) -> HelloStruct<T> {
33//! HelloStruct {
34//! world,
35//! there
36//! }
37//! }
38//!
39//! pub fn get_unchecked(&self, var: &Hello) -> &T {
40//! match var {
41//! &Hello::World => &self.world,
42//! &Hello::There => &self.there
43//! }
44//! }
45//!
46//! pub fn get_mut_unchecked(&mut self, var: &Hello) -> &mut T {
47//! match var {
48//! &Hello::World => &mut self.world,
49//! &Hello::There => &mut self.there
50//! }
51//! }
52//!
53//! pub fn get(&self, var: &Hello) -> Option<&T> {
54//! match var {
55//! &Hello::World => Some(&self.world),
56//! &Hello::There => Some(&self.there)
57//! }
58//! }
59//!
60//! pub fn get_mut(&mut self, var: &Hello) -> Option<&mut T> {
61//! match var {
62//! &Hello::World => Some(&mut self.world),
63//! &Hello::There => Some(&mut self.there)
64//! }
65//! }
66//! }
67//! ```
68//!
69//! The members can be accessed either directly (like `hello.world`) or by using the getter methods, like:
70//!
71//! ```
72//! # use variants_struct::VariantsStruct;
73//! # #[derive(VariantsStruct)]
74//! # enum Hello {
75//! # World,
76//! # There
77//! # }
78//! let mut hello = HelloStruct::new(2, 3);
79//! *hello.get_mut_unchecked(&Hello::World) = 5;
80//!
81//! assert_eq!(hello.world, 5);
82//! assert_eq!(hello.world, *hello.get_unchecked(&Hello::World));
83//! ```
84//!
85//! The getters can be particularly useful with the [enum-iterator](https://docs.rs/crate/enum-iterator/) crate. For basic enums,
86//! the checked-getters will always return `Some(...)`, so using `get_unchecked` is recommended, *but this is not the case when the enum contains tuple variants*.
87//!
88//! Keep in mind that the enum variants are renamed from CamelCase to snake_case, to be consistent with Rust's naming conventions.
89//!
90//! # Visibility
91//!
92//! The struct fields are always `pub`, and the struct shares the same visibility as the enum.
93//!
94//! # Customizing the struct
95//!
96//! ## Renaming
97//!
98//! By default, the struct's name is `<OriginalEnumName>Struct`. You can set it to something else with the `struct_name` attribute. For example, this:
99//!
100//! ```
101//! # use variants_struct::VariantsStruct;
102//! #[derive(VariantsStruct)]
103//! #[struct_name = "SomeOtherName"]
104//! enum NotThisName {
105//! Variant
106//! }
107//! ```
108//!
109//! will produce a struct with name `SomeOtherName`.
110//!
111//! You can also rename the individual fields manually with the `field_name` attribute. For example, this:
112//!
113//! ```
114//! # use variants_struct::VariantsStruct;
115//! #[derive(VariantsStruct)]
116//! enum ChangeMyVariantName {
117//! #[field_name = "this_name"] NotThisName
118//! }
119//! ```
120//!
121//! Will produce the following struct:
122//!
123//! ```
124//! struct ChangeMyVariantName<T> {
125//! this_name: T
126//! }
127//! ```
128//!
129//! ## Derives
130//!
131//! By default no derives are applied to the generated struct. You can add derive macro invocations with the `struct_derive` attribute. For example, this:
132//!
133//! ```
134//! # use variants_struct::VariantsStruct;
135//! use serde::{Serialize, Deserialize};
136//!
137//! #[derive(VariantsStruct)]
138//! #[struct_derive(Debug, Default, Serialize, Deserialize)]
139//! enum Hello {
140//! World,
141//! There
142//! }
143//! ```
144//!
145//! would produce the following code:
146//!
147//! ```
148//! # use serde::{Serialize, Deserialize};
149//! #[derive(Debug, Default, Serialize, Deserialize)]
150//! struct HelloStruct<T> {
151//! pub world: T,
152//! pub there: T
153//! }
154//!
155//! // impl block omitted
156//! ```
157//!
158//! ## Trait Bounds
159//!
160//! By default the struct's type argument `T` has no trait bounds, but you can add them with the `struct_bounds` attribute. For example, this:
161//!
162//! ```
163//! # use variants_struct::VariantsStruct;
164//! #[derive(VariantsStruct)]
165//! #[struct_bounds(Clone)]
166//! enum Hello {
167//! World,
168//! There
169//! }
170//! ```
171//!
172//! would produce the following code:
173//!
174//! ```
175//! struct HelloStruct<T: Clone> {
176//! # go_away: T,
177//! // fields omitted
178//! }
179//!
180//! impl<T: Clone> HelloStruct<T> {
181//! // methods omitted
182//! }
183//! ```
184//!
185//! ## Combinations
186//!
187//! Note that many derives don't require that the type argument `T` fulfills any trait bounds. For example, applying the `Clone`
188//! derive to the struct only makes the struct cloneable if `T` is cloneable, and still allows un-cloneable types to be used with the struct.
189//!
190//! So if you want the struct to *always* be cloneable, you have to use both the derive and the trait bound:
191//!
192//! ```
193//! # use variants_struct::VariantsStruct;
194//! #[derive(VariantsStruct)]
195//! #[struct_derive(Clone)]
196//! #[struct_bounds(Clone)]
197//! enum Hello {
198//! // variants omitted
199//! }
200//! ```
201//!
202//! These two attributes, and the `struct_name` attribute, can be used in any order, or even multiple times (although that wouldn't be very readable).
203//!
204//! # Tuple and Struct Variants
205//!
206//! Tuple variants are turned into a `HashMap`, where the data stored in the tuple is the key (so the data must implement `Hash`).
207//! Unfortunately, variants with more than one field in them are not supported.
208//!
209//! Tuple variants are omitted from the struct's `new` function. For example, this:
210//!
211//! ```
212//! # use variants_struct::VariantsStruct;
213//! #[derive(VariantsStruct)]
214//! enum Hello {
215//! World,
216//! There(i32)
217//! }
218//! ```
219//!
220//! produces the following code:
221//!
222//! ```
223//! # enum Hello {
224//! # World,
225//! # There(i32)
226//! # }
227//! struct HelloStruct<T> {
228//! pub world: T,
229//! pub there: std::collections::HashMap<i32, T>
230//! }
231//!
232//! impl<T> HelloStruct<T> {
233//! fn new(world: T) -> HelloStruct<T> {
234//! HelloStruct {
235//! world,
236//! there: std::collections::HashMap::new()
237//! }
238//! }
239//!
240//! pub fn get_unchecked(&self, var: &Hello) -> &T {
241//! match var {
242//! &Hello::World => &self.world,
243//! &Hello::There(key) => self.there.get(&key)
244//! .expect("tuple variant key not found in hashmap")
245//! }
246//! }
247//!
248//! pub fn get_mut_unchecked(&mut self, var: &Hello) -> &mut T {
249//! match var {
250//! &Hello::World => &mut self.world,
251//! &Hello::There(key) => self.there.get_mut(&key)
252//! .expect("tuple variant key not found in hashmap")
253//! }
254//! }
255//!
256//! pub fn get(&self, var: &Hello) -> Option<&T> {
257//! match var {
258//! &Hello::World => Some(&self.world),
259//! &Hello::There(key) => self.there.get(&key)
260//! }
261//! }
262//!
263//! pub fn get_mut(&mut self, var: &Hello) -> Option<&mut T> {
264//! match var {
265//! &Hello::World => Some(&mut self.world),
266//! &Hello::There(key) => self.there.get_mut(&key)
267//! }
268//! }
269//! }
270//! ```
271//!
272//! Notice that the `new` function now only takes the `world` argument, and the unchecked getter methods query the hashmap and unwrap the result.
273//!
274//! The same can also be done in struct variants that have only one field.
275
276use proc_macro::TokenStream;
277use syn::{Ident, parse_macro_input, ItemEnum, Fields};
278use quote::{quote, format_ident};
279use inflector::Inflector;
280use proc_macro_error::{proc_macro_error, emit_error, abort};
281use check_keyword::CheckKeyword;
282
283/// Stores basic information about variants.
284struct VariantInfo {
285 normal: Ident,
286 snake: Ident,
287 fields: Fields
288}
289
290/// Derives the variants struct and impl.
291#[proc_macro_error]
292#[proc_macro_derive(VariantsStruct, attributes(struct_bounds, struct_derive, struct_name, field_name))]
293pub fn variants_struct(input: TokenStream) -> TokenStream {
294 let input = parse_macro_input!(input as ItemEnum);
295 let enum_ident = input.ident.clone();
296 let mut struct_ident = format_ident!("{}Struct", input.ident);
297 let visibility = input.vis.clone();
298
299 // read the `struct_bounds`, `struct_derive`, and `struct_name` attributes. (ignore any others)
300 let mut bounds = vec![];
301 let mut derives = vec![];
302 for attr in input.clone().attrs {
303 match attr.parse_meta() {
304 Ok(syn::Meta::List(syn::MetaList {path, nested, ..})) => {
305 if let Some(ident) = path.get_ident() {
306 let attr_name = ident.to_string();
307 if attr_name == "struct_bounds" || attr_name == "struct_derive" {
308 let mut paths = vec![];
309 for meta in nested {
310 match meta {
311 syn::NestedMeta::Meta(syn::Meta::Path(path)) => {
312 paths.push(path.clone());
313 }
314 _ => emit_error!(path, "only path arguments are accepted")
315 }
316 }
317 if attr_name == "struct_bounds" {
318 bounds.extend(paths);
319 } else {
320 derives.extend(paths);
321 }
322 }
323 }
324 }
325 Ok(syn::Meta::NameValue(syn::MetaNameValue {path, lit, ..})) => {
326 if let Some(ident) = path.get_ident() {
327 let attr_name = ident.to_string();
328 if attr_name == "struct_name" {
329 if let syn::Lit::Str(lit_str) = lit {
330 struct_ident = format_ident!("{}", lit_str.value());
331 } else {
332 emit_error!(lit, "must be a str literal");
333 }
334 }
335 }
336 }
337 _ => {}
338 }
339 }
340
341 if input.variants.len() == 0 {
342 return (quote! {
343 #[derive(#(#derives),*)]
344 #visibility struct #struct_ident;
345 }).into()
346 }
347
348 let vars: Vec<_> = input.clone().variants.iter().map(
349 |var| {
350 let snake = {
351 let names: Vec<_> = var.attrs.iter().filter_map(
352 |attr| {
353 match attr.parse_meta() {
354 Ok(syn::Meta::NameValue(syn::MetaNameValue {path, lit, ..})) => {
355 if let Some(ident) = path.get_ident() {
356 if ident.to_string() == "field_name" {
357 if let syn::Lit::Str(lit_str) = lit {
358 Some(lit_str.value())
359 } else {
360 abort!(lit, "must be a string literal");
361 }
362 } else {
363 None
364 }
365 } else {
366 None
367 }
368 }
369 _ => None
370 }
371 }
372 ).collect();
373 if names.is_empty() {
374 let name = var.ident.to_string().to_snake_case();
375 format_ident!("{}", name.into_safe())
376 } else {
377 format_ident!("{}", names.first().unwrap().to_safe())
378 }
379 };
380 VariantInfo {
381 normal: var.ident.clone(),
382 snake,
383 fields: var.fields.clone()
384 }
385 }
386 ).collect();
387
388 // generate the fields and impl code
389 let mut field_idents = vec![];
390 let mut field_names = vec![];
391 let mut struct_fields = vec![];
392 let mut get_uncheckeds = vec![];
393 let mut get_mut_uncheckeds = vec![];
394 let mut gets = vec![];
395 let mut get_muts = vec![];
396 let mut new_args = vec![];
397 let mut new_fields = vec![];
398 for VariantInfo { normal, snake, fields } in &vars {
399 field_idents.push(snake.clone());
400 field_names.push(snake.to_string());
401 match fields {
402 Fields::Unit => {
403 struct_fields.push(quote! { pub #snake: T });
404 gets.push(quote! { &#enum_ident::#normal => Some(&self.#snake) });
405 get_muts.push(quote! { &#enum_ident::#normal => Some(&mut self.#snake) });
406 get_uncheckeds.push(quote! { &#enum_ident::#normal => &self.#snake });
407 get_mut_uncheckeds.push(quote! { &#enum_ident::#normal => &mut self.#snake });
408 new_args.push(quote! {#snake: T});
409 new_fields.push(quote! {#snake});
410 }
411 Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) => {
412 if unnamed.len() == 1 {
413 let ty = unnamed.first().unwrap().clone().ty;
414 struct_fields.push(quote! {
415 pub #snake: std::collections::HashMap<#ty, T>
416 });
417 gets.push(quote! {
418 &#enum_ident::#normal(key) => self.#snake.get(&key)
419 });
420 get_muts.push(quote! {
421 &#enum_ident::#normal(key) => self.#snake.get_mut(&key)
422 });
423 get_uncheckeds.push(quote! {
424 &#enum_ident::#normal(key) => self.#snake.get(&key)
425 .expect("tuple variant key not found in hashmap")
426 });
427 get_mut_uncheckeds.push(quote! {
428 &#enum_ident::#normal(key) => self.#snake.get_mut(&key)
429 .expect("tuple variant key not found in hashmap")
430 });
431 new_fields.push(quote! {#snake: std::collections::HashMap::new()});
432 } else {
433 emit_error!(unnamed, "only tuples with one value are allowed");
434 }
435 }
436 Fields::Named(syn::FieldsNamed { named, .. }) => {
437 if named.len() == 1 {
438 let ty = named.first().unwrap().clone().ty;
439 let ident = named.first().unwrap().ident.clone().unwrap();
440 struct_fields.push(quote! {
441 pub #snake: std::collections::HashMap<#ty, T>
442 });
443 gets.push(quote! {
444 &#enum_ident::#normal {#ident} => self.#snake.get(&#ident)
445 });
446 get_muts.push(quote! {
447 &#enum_ident::#normal {#ident} => self.#snake.get_mut(&#ident)
448 });
449 get_uncheckeds.push(quote! {
450 &#enum_ident::#normal {#ident} => self.#snake.get(&#ident)
451 .expect("tuple variant key not found in hashmap")
452 });
453 get_mut_uncheckeds.push(quote! {
454 &#enum_ident::#normal {#ident} => self.#snake.get_mut(&#ident)
455 .expect("tuple variant key not found in hashmap")
456 });
457 new_fields.push(quote! {#snake: std::collections::HashMap::new()});
458 } else {
459 emit_error!(named, "only structs with one field are allowed");
460 }
461 }
462 }
463 }
464
465 // combine it all together
466 (quote! {
467 #[derive(#(#derives),*)]
468 #visibility struct #struct_ident<T: #(#bounds)+*> {
469 #(#struct_fields),*
470 }
471
472 impl<T: #(#bounds)+*> #struct_ident<T> {
473 pub fn new(#(#new_args),*) -> #struct_ident<T> {
474 #struct_ident {
475 #(#new_fields),*
476 }
477 }
478
479 pub fn get_unchecked(&self, var: &#enum_ident) -> &T {
480 match var {
481 #(#get_uncheckeds),*
482 }
483 }
484
485 pub fn get_mut_unchecked(&mut self, var: &#enum_ident) -> &mut T {
486 match var {
487 #(#get_mut_uncheckeds),*
488 }
489 }
490
491 pub fn get(&self, var: &#enum_ident) -> Option<&T> {
492 match var {
493 #(#gets),*
494 }
495 }
496
497 pub fn get_mut(&mut self, var: &#enum_ident) -> Option<&mut T> {
498 match var {
499 #(#get_muts),*
500 }
501 }
502 }
503 }).into()
504}