1use std::collections::HashMap;
10
11use proc_macro::TokenStream;
12use proc_macro2::{Ident, Span};
13use quote::{quote, ToTokens};
14use syn;
15
16#[proc_macro_derive(SetEnumFields)]
17pub fn enum_fields_macro_derive(input: TokenStream) -> TokenStream {
18 let ast = syn::parse(input).unwrap();
19 self::impl_for_input(&ast)
20}
21
22fn collect_available_fields<'input>(enum_data: &'input syn::DataEnum) -> HashMap<String, Vec<&'input syn::Field>> {
23 let mut fields = HashMap::new();
24
25 for variant in &enum_data.variants {
26 for field in &variant.fields {
27 if let Some(field_ident) = &field.ident {
28 let ident = field_ident.to_string();
29 fields.entry(ident)
30 .or_insert(Vec::new())
31 .push(field);
32 }
33 }
34 }
35
36 fields
37}
38
39fn impl_for_input(ast: &syn::DeriveInput) -> TokenStream {
40 let fail_message = "`EnumFields` is only applicable to `enum`s";
41 match &ast.data {
42 syn::Data::Enum(data_enum) => impl_for_enum(ast, &data_enum),
43 syn::Data::Union(data_union) => syn::Error::new(data_union.union_token.span, fail_message).to_compile_error().into(),
44 syn::Data::Struct(data_struct) => syn::Error::new(data_struct.struct_token.span, fail_message).to_compile_error().into(),
45 }
46}
47
48fn impl_for_enum(ast: &syn::DeriveInput, enum_data: &syn::DataEnum) -> TokenStream {
49 let name = &ast.ident;
50
51 let fields = collect_available_fields(enum_data);
53
54 let mut data = proc_macro2::TokenStream::new();
55
56 let mut field_idents: Vec<Ident> = vec![];
57
58 for (field_name, fields) in fields {
59 let field_present_everywhere = fields.len() == enum_data.variants.len();
60
61 let generics = &ast.generics;
62 let field_type = &fields[0].ty;
63 let field_name_ident = Ident::new(&field_name, Span::call_site());
64
65 let mut variants = proc_macro2::TokenStream::new();
66 let mut mut_set_variances = proc_macro2::TokenStream::new();
67
68
69 for variant in &enum_data.variants {
70 let name = &variant.ident;
71
72 let variant_field_ident = variant.fields.iter()
73 .find(|variant_field| {
74 if let Some(variant_field_ident) = &variant_field.ident {
75 if variant_field_ident.to_string() == field_name {
76 true
77 } else {
78 false
79 }
80 } else {
81 false
82 }
83 })
84 .map(|field| {
85 field.ident.as_ref().unwrap()
86 });
87
88 match variant_field_ident {
89 Some(variant_field_ident) => {
90 if field_present_everywhere {
91 variants.extend(quote! {
92 Self::#name{ #variant_field_ident, .. } => {
93 std::mem::swap(#variant_field_ident, to_set);
94 }
95 });
96 } else {
97 variants.extend(quote! {
98 Self::#name{ #variant_field_ident, .. } => {
99 std::mem::swap(#variant_field_ident, to_set);
100 }
101 });
102 }
103
104 if field_present_everywhere {
105 mut_set_variances.extend(quote! {
106 Self::#name{ #variant_field_ident, .. } => #variant_field_ident,
107 });
108 } else {
109 mut_set_variances.extend(quote! {
110 Self::#name{ #variant_field_ident, .. } => Some(#variant_field_ident),
111 });
112 }
113
114 }
115 None => {
116 if let Some(first_field) = variant.fields.iter().next() {
118 if first_field.ident.is_some() {
119 mut_set_variances.extend(quote! {
120 Self::#name{ .. } => None,
121 });
122 } else {
123 mut_set_variances.extend(quote! {
124 Self::#name(..) => None,
125 });
126 }
127 } else {
128 mut_set_variances.extend(quote! {
129 Self::#name => None,
130 });
131 }
132 }
133 }
134 }
135
136 let variant_field_ident = fields[0].ident.as_ref();
137 if variant_field_ident.is_some() {
138 let set_value = Ident::new(format!("set_{}", variant_field_ident.as_ref().unwrap().to_string()).as_str(), Span::call_site());
139 data.extend(quote! {
140 impl #generics #name #generics {
141 pub fn #set_value(&mut self, to_set: &mut #field_type) {
142 match self {
144 #variants
145 _ => {}
146 };
147 }
148 }
149 });
150 }
151
152 let ty = if field_present_everywhere {
153 quote! {
154 &mut #field_type
155 }
156 } else {
157 quote! {
158 Option<&mut #field_type>
159 }
160 };
161
162 let field_name_mut = Ident::new(format!("{}_mut", variant_field_ident.unwrap()).as_str(), Span::call_site());
163 data.extend(quote! {
164 impl #generics #name #generics {
165 pub fn #field_name_mut(&mut self) -> #ty {
166 match self {
168 #mut_set_variances
169 }
170 }
171 }
172 });
173
174
175 }
176
177 data.into()
178}