variants_struct/lib.rs
1//! A derive macro to convert enums into a struct where the variants are members.
2//! Effectively, its like using a `HashMap<MyEnum, MyData>`, but it generates a hard-coded struct instead
3//! of a HashMap to reduce overhead.
4//!
5//! # Basic Example
6//!
7//! Applying the macro to a basic enum (i.e. one without tuple variants or struct variants) like this:
8//!
9//! ```
10//! use variants_struct::VariantsStruct;
11//!
12//! #[derive(VariantsStruct)]
13//! enum Hello {
14//! World,
15//! There
16//! }
17//! ```
18//!
19//! would produce the following code:
20//!
21//! ```
22//! # enum Hello {
23//! # World,
24//! # There
25//! # }
26//! struct HelloStruct<T> {
27//! pub world: T,
28//! pub there: T
29//! }
30//!
31//! impl<T> HelloStruct<T> {
32//! pub fn new(world: T, there: T) -> HelloStruct<T> {
33//! HelloStruct {
34//! world,
35//! there
36//! }
37//! }
38//!
39//! pub fn get_unchecked(&self, var: &Hello) -> &T {
40//! match var {
41//! &Hello::World => &self.world,
42//! &Hello::There => &self.there
43//! }
44//! }
45//!
46//! pub fn get_mut_unchecked(&mut self, var: &Hello) -> &mut T {
47//! match var {
48//! &Hello::World => &mut self.world,
49//! &Hello::There => &mut self.there
50//! }
51//! }
52//!
53//! pub fn get(&self, var: &Hello) -> Option<&T> {
54//! match var {
55//! &Hello::World => Some(&self.world),
56//! &Hello::There => Some(&self.there)
57//! }
58//! }
59//!
60//! pub fn get_mut(&mut self, var: &Hello) -> Option<&mut T> {
61//! match var {
62//! &Hello::World => Some(&mut self.world),
63//! &Hello::There => Some(&mut self.there)
64//! }
65//! }
66//! }
67//! ```
68//!
69//! The members can be accessed either directly (like `hello.world`) or by using the getter methods, like:
70//!
71//! ```
72//! # use variants_struct::VariantsStruct;
73//! # #[derive(VariantsStruct)]
74//! # enum Hello {
75//! # World,
76//! # There
77//! # }
78//! let mut hello = HelloStruct::new(2, 3);
79//! *hello.get_mut_unchecked(&Hello::World) = 5;
80//!
81//! assert_eq!(hello.world, 5);
82//! assert_eq!(hello.world, *hello.get_unchecked(&Hello::World));
83//! ```
84//!
85//! The getters can be particularly useful with the [enum-iterator](https://docs.rs/crate/enum-iterator/) crate. For basic enums,
86//! the checked-getters will always return `Some(...)`, so using `get_unchecked` is recommended, *but this is not the case when the enum contains tuple variants*.
87//!
88//! Keep in mind that the enum variants are renamed from CamelCase to snake_case, to be consistent with Rust's naming conventions.
89//!
90//! # Visibility
91//!
92//! The struct fields are always `pub`, and the struct shares the same visibility as the enum.
93//!
94//! # Customizing the struct
95//!
96//! ## Renaming
97//!
98//! By default, the struct's name is `<OriginalEnumName>Struct`. You can set it to something else with the `struct_name` attribute. For example, this:
99//!
100//! ```
101//! # use variants_struct::VariantsStruct;
102//! #[derive(VariantsStruct)]
103//! #[struct_name = "SomeOtherName"]
104//! enum NotThisName {
105//! Variant
106//! }
107//! ```
108//!
109//! will produce a struct with name `SomeOtherName`.
110//!
111//! You can also rename the individual fields manually with the `field_name` attribute. For example, this:
112//!
113//! ```
114//! # use variants_struct::VariantsStruct;
115//! #[derive(VariantsStruct)]
116//! enum ChangeMyVariantName {
117//! #[field_name = "this_name"] NotThisName
118//! }
119//! ```
120//!
121//! Will produce the following struct:
122//!
123//! ```
124//! struct ChangeMyVariantName<T> {
125//! this_name: T
126//! }
127//! ```
128//!
129//! ## Derives
130//!
131//! By default no derives are applied to the generated struct. You can add derive macro invocations with the `struct_derive` attribute. For example, this:
132//!
133//! ```
134//! # use variants_struct::VariantsStruct;
135//! use serde::{Serialize, Deserialize};
136//!
137//! #[derive(VariantsStruct)]
138//! #[struct_derive(Debug, Default, Serialize, Deserialize)]
139//! enum Hello {
140//! World,
141//! There
142//! }
143//! ```
144//!
145//! would produce the following code:
146//!
147//! ```
148//! # use serde::{Serialize, Deserialize};
149//! #[derive(Debug, Default, Serialize, Deserialize)]
150//! struct HelloStruct<T> {
151//! pub world: T,
152//! pub there: T
153//! }
154//!
155//! // impl block omitted
156//! ```
157//!
158//! ## Trait Bounds
159//!
160//! By default the struct's type argument `T` has no trait bounds, but you can add them with the `struct_bounds` attribute. For example, this:
161//!
162//! ```
163//! # use variants_struct::VariantsStruct;
164//! #[derive(VariantsStruct)]
165//! #[struct_bounds(Copy + Clone)]
166//! enum Hello {
167//! World,
168//! There
169//! }
170//! ```
171//!
172//! would produce the following code:
173//!
174//! ```
175//! struct HelloStruct<T: Copy + Clone> {
176//! # go_away: T,
177//! // fields omitted
178//! }
179//!
180//! impl<T: Copy + Clone> HelloStruct<T> {
181//! // methods omitted
182//! }
183//! ```
184//!
185//! ## Arbitrary attributes
186//!
187//! To apply other arbitrary attributes to the struct, use `#[struct_attr(...)]`. For example, if you apply
188//! `serde::Serialize` to the struct, and your bounds already include a trait that requires `T: Serialize`,
189//! serde will give an error. Serde documentation tells you to add `#[serde(bound(serialize = ...))]`,
190//! and you can pass that along with `struct_attr`.
191//!
192//! ```
193//! # use variants_struct::VariantsStruct;
194//! # use serde::Serialize;
195//! trait MyTrait: Serialize {}
196//!
197//! #[derive(VariantsStruct)]
198//! #[struct_derive(Serialize)]
199//! #[struct_bounds(MyTrait)]
200//! #[struct_attr(serde(bound(serialize = "T: MyTrait")))]
201//! enum MyEnum {
202//! MyVariant
203//! }
204//! ```
205//!
206//! ## Combinations
207//!
208//! Note that many derives don't require that the type argument `T` fulfills any trait bounds. For example, applying the `Clone`
209//! derive to the struct only makes the struct cloneable if `T` is cloneable, and still allows un-cloneable types to be used with the struct.
210//!
211//! So if you want the struct to *always* be cloneable, you have to use both the derive and the trait bound:
212//!
213//! ```
214//! # use variants_struct::VariantsStruct;
215//! #[derive(VariantsStruct)]
216//! #[struct_derive(Clone)]
217//! #[struct_bounds(Clone)]
218//! enum MyEnum {
219//! MyVariant
220//! }
221//! ```
222//!
223//! These two attributes, and the `struct_name` attribute, can be used in any order, or even multiple times (although that wouldn't be very readable).
224//!
225//! # Tuple and Struct Variants
226//!
227//! Tuple variants are turned into a `HashMap`, where the data stored in the tuple is the key (so the data must implement `Hash`).
228//! Unfortunately, variants with more than one field in them are not supported.
229//!
230//! Tuple variants are omitted from the struct's `new` function. For example, this:
231//!
232//! ```
233//! # use variants_struct::VariantsStruct;
234//! #[derive(VariantsStruct)]
235//! enum Hello {
236//! World,
237//! There(i32)
238//! }
239//! ```
240//!
241//! produces the following code:
242//!
243//! ```
244//! # enum Hello {
245//! # World,
246//! # There(i32)
247//! # }
248//! struct HelloStruct<T> {
249//! pub world: T,
250//! pub there: std::collections::HashMap<i32, T>
251//! }
252//!
253//! impl<T> HelloStruct<T> {
254//! fn new(world: T) -> HelloStruct<T> {
255//! HelloStruct {
256//! world,
257//! there: std::collections::HashMap::new()
258//! }
259//! }
260//!
261//! pub fn get_unchecked(&self, var: &Hello) -> &T {
262//! match var {
263//! &Hello::World => &self.world,
264//! &Hello::There(key) => self.there.get(&key)
265//! .expect("tuple variant key not found in hashmap")
266//! }
267//! }
268//!
269//! pub fn get_mut_unchecked(&mut self, var: &Hello) -> &mut T {
270//! match var {
271//! &Hello::World => &mut self.world,
272//! &Hello::There(key) => self.there.get_mut(&key)
273//! .expect("tuple variant key not found in hashmap")
274//! }
275//! }
276//!
277//! pub fn get(&self, var: &Hello) -> Option<&T> {
278//! match var {
279//! &Hello::World => Some(&self.world),
280//! &Hello::There(key) => self.there.get(&key)
281//! }
282//! }
283//!
284//! pub fn get_mut(&mut self, var: &Hello) -> Option<&mut T> {
285//! match var {
286//! &Hello::World => Some(&mut self.world),
287//! &Hello::There(key) => self.there.get_mut(&key)
288//! }
289//! }
290//! }
291//! ```
292//!
293//! Notice that the `new` function now only takes the `world` argument, and the unchecked getter methods query the hashmap and unwrap the result.
294//!
295//! The same can also be done in struct variants that have only one field.
296
297use check_keyword::CheckKeyword;
298use heck::ToSnekCase;
299use proc_macro::TokenStream;
300use proc_macro_error2::{emit_error, proc_macro_error};
301use quote::{format_ident, quote};
302use syn::{Fields, Ident, ItemEnum, parse_macro_input};
303
304/// Stores basic information about variants.
305struct VariantInfo {
306 normal: Ident,
307 snake: Ident,
308 fields: Fields,
309}
310
311/// Derives the variants struct and impl.
312#[proc_macro_error]
313#[proc_macro_derive(
314 VariantsStruct,
315 attributes(struct_bounds, struct_derive, struct_name, field_name, struct_attr)
316)]
317pub fn variants_struct(input: TokenStream) -> TokenStream {
318 let input = parse_macro_input!(input as ItemEnum);
319 let enum_ident = input.ident.clone();
320 let mut struct_ident = format_ident!("{}Struct", input.ident);
321 let visibility = input.vis.clone();
322
323 // read the `struct_bounds`, `struct_derive`, and `struct_name` attributes. (ignore any others)
324 let mut bounds = quote! {};
325 let mut derives = vec![];
326 let mut attrs = vec![];
327 for attr in input.clone().attrs {
328 if attr.path().is_ident("struct_bounds") {
329 let syn::Meta::List(l) = attr.meta else {
330 emit_error!(
331 attr,
332 "struct_bounds must be of the form #[struct_bounds(Bound)]"
333 );
334 return quote! {}.into();
335 };
336 bounds = l.tokens;
337 } else if attr.path().is_ident("struct_derive") {
338 attr.parse_nested_meta(|meta| {
339 derives.push(meta.path);
340 Ok(())
341 })
342 .unwrap();
343 } else if attr.path().is_ident("struct_name") {
344 if let syn::Meta::NameValue(syn::MetaNameValue { value, .. }) = attr.meta {
345 if let syn::Expr::Lit(syn::ExprLit {
346 lit: syn::Lit::Str(lit_str),
347 ..
348 }) = value
349 {
350 struct_ident = format_ident!("{}", lit_str.value());
351 } else {
352 emit_error!(value, "must be a str literal");
353 }
354 }
355 } else if attr.path().is_ident("struct_attr") {
356 let syn::Meta::List(l) = attr.meta else {
357 emit_error!(attr, "struct_attr must be of the form #[struct_attr(attr)]");
358 return quote! {}.into();
359 };
360 attrs.push(l.tokens);
361 }
362 }
363
364 if input.variants.is_empty() {
365 return (quote! {
366 #[derive(#(#derives),*)]
367 #visibility struct #struct_ident;
368 })
369 .into();
370 }
371
372 let vars: Vec<_> = input
373 .clone()
374 .variants
375 .iter()
376 .map(|var| {
377 let mut names = vec![];
378 for attr in &var.attrs {
379 if attr.path().is_ident("field_name") {
380 if let syn::Meta::NameValue(syn::MetaNameValue { value, .. }) = &attr.meta {
381 if let syn::Expr::Lit(syn::ExprLit {
382 lit: syn::Lit::Str(lit_str),
383 ..
384 }) = value
385 {
386 names.push(lit_str.value());
387 } else {
388 emit_error!(value, "must be a str literal");
389 }
390 }
391 }
392 }
393
394 let snake = if names.is_empty() {
395 format_ident!("{}", var.ident.to_string().to_snek_case().into_safe())
396 } else {
397 format_ident!("{}", names.first().unwrap().into_safe())
398 };
399 VariantInfo {
400 normal: var.ident.clone(),
401 snake,
402 fields: var.fields.clone(),
403 }
404 })
405 .collect();
406
407 // generate the fields and impl code
408 let mut field_idents = vec![];
409 let mut field_names = vec![];
410 let mut struct_fields = vec![];
411 let mut get_uncheckeds = vec![];
412 let mut get_mut_uncheckeds = vec![];
413 let mut gets = vec![];
414 let mut get_muts = vec![];
415 let mut new_args = vec![];
416 let mut new_fields = vec![];
417 for VariantInfo {
418 normal,
419 snake,
420 fields,
421 } in &vars
422 {
423 field_idents.push(snake.clone());
424 field_names.push(snake.to_string());
425 match fields {
426 Fields::Unit => {
427 struct_fields.push(quote! { pub #snake: T });
428 gets.push(quote! { &#enum_ident::#normal => Some(&self.#snake) });
429 get_muts.push(quote! { &#enum_ident::#normal => Some(&mut self.#snake) });
430 get_uncheckeds.push(quote! { &#enum_ident::#normal => &self.#snake });
431 get_mut_uncheckeds.push(quote! { &#enum_ident::#normal => &mut self.#snake });
432 new_args.push(quote! {#snake: T});
433 new_fields.push(quote! {#snake});
434 }
435 Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) => {
436 if unnamed.len() == 1 {
437 let ty = unnamed.first().unwrap().clone().ty;
438 struct_fields.push(quote! {
439 pub #snake: std::collections::HashMap<#ty, T>
440 });
441 gets.push(quote! {
442 &#enum_ident::#normal(key) => self.#snake.get(&key)
443 });
444 get_muts.push(quote! {
445 &#enum_ident::#normal(key) => self.#snake.get_mut(&key)
446 });
447 get_uncheckeds.push(quote! {
448 &#enum_ident::#normal(key) => self.#snake.get(&key)
449 .expect("tuple variant key not found in hashmap")
450 });
451 get_mut_uncheckeds.push(quote! {
452 &#enum_ident::#normal(key) => self.#snake.get_mut(&key)
453 .expect("tuple variant key not found in hashmap")
454 });
455 new_fields.push(quote! {#snake: std::collections::HashMap::new()});
456 } else {
457 emit_error!(unnamed, "only tuples with one value are allowed");
458 }
459 }
460 Fields::Named(syn::FieldsNamed { named, .. }) => {
461 if named.len() == 1 {
462 let ty = named.first().unwrap().clone().ty;
463 let ident = named.first().unwrap().ident.clone().unwrap();
464 struct_fields.push(quote! {
465 pub #snake: std::collections::HashMap<#ty, T>
466 });
467 gets.push(quote! {
468 &#enum_ident::#normal {#ident} => self.#snake.get(&#ident)
469 });
470 get_muts.push(quote! {
471 &#enum_ident::#normal {#ident} => self.#snake.get_mut(&#ident)
472 });
473 get_uncheckeds.push(quote! {
474 &#enum_ident::#normal {#ident} => self.#snake.get(&#ident)
475 .expect("tuple variant key not found in hashmap")
476 });
477 get_mut_uncheckeds.push(quote! {
478 &#enum_ident::#normal {#ident} => self.#snake.get_mut(&#ident)
479 .expect("tuple variant key not found in hashmap")
480 });
481 new_fields.push(quote! {#snake: std::collections::HashMap::new()});
482 } else {
483 emit_error!(named, "only structs with one field are allowed");
484 }
485 }
486 }
487 }
488
489 // combine it all together
490 (quote! {
491 #[derive(#(#derives),*)]
492 #(#[#attrs])*
493 #visibility struct #struct_ident<T: #bounds> {
494 #(#struct_fields),*
495 }
496
497 impl<T: #bounds> #struct_ident<T> {
498 pub fn new(#(#new_args),*) -> #struct_ident<T> {
499 #struct_ident {
500 #(#new_fields),*
501 }
502 }
503
504 pub fn get_unchecked(&self, var: &#enum_ident) -> &T {
505 match var {
506 #(#get_uncheckeds),*
507 }
508 }
509
510 pub fn get_mut_unchecked(&mut self, var: &#enum_ident) -> &mut T {
511 match var {
512 #(#get_mut_uncheckeds),*
513 }
514 }
515
516 pub fn get(&self, var: &#enum_ident) -> Option<&T> {
517 match var {
518 #(#gets),*
519 }
520 }
521
522 pub fn get_mut(&mut self, var: &#enum_ident) -> Option<&mut T> {
523 match var {
524 #(#get_muts),*
525 }
526 }
527 }
528
529 impl<T: #bounds> std::ops::Index<#enum_ident> for #struct_ident<T> {
530 type Output = T;
531 fn index(&self, var: #enum_ident) -> &T {
532 self.get_unchecked(&var)
533 }
534 }
535
536 impl<T: #bounds> std::ops::IndexMut<#enum_ident> for #struct_ident<T> {
537 fn index_mut(&mut self, var: #enum_ident) -> &mut T {
538 self.get_mut_unchecked(&var)
539 }
540 }
541
542 impl<T: #bounds> std::ops::Index<&#enum_ident> for #struct_ident<T> {
543 type Output = T;
544 fn index(&self, var: &#enum_ident) -> &T {
545 self.get_unchecked(var)
546 }
547 }
548
549 impl<T: #bounds> std::ops::IndexMut<&#enum_ident> for #struct_ident<T> {
550 fn index_mut(&mut self, var: &#enum_ident) -> &mut T {
551 self.get_mut_unchecked(var)
552 }
553 }
554 })
555 .into()
556}