set_enum_fields/
lib.rs

1// Copyright (C) 2023 Tristan Gerritsen <tristan@thewoosh.org>
2// All Rights Reserved.
3
4//! # enum-fields
5//! Grabbed from enum-fields and updated.
6//! Mutable swap is used to replace without option.
7//! ```
8
9use std::collections::HashMap;
10
11use proc_macro::TokenStream;
12use proc_macro2::{Ident, Span};
13use quote::{quote, ToTokens};
14use syn;
15
16#[proc_macro_derive(SetEnumFields)]
17pub fn enum_fields_macro_derive(input: TokenStream) -> TokenStream {
18    let ast = syn::parse(input).unwrap();
19    self::impl_for_input(&ast)
20}
21
22fn collect_available_fields<'input>(enum_data: &'input syn::DataEnum) -> HashMap<String, Vec<&'input syn::Field>> {
23    let mut fields = HashMap::new();
24
25    for variant in &enum_data.variants {
26        for field in &variant.fields {
27            if let Some(field_ident) = &field.ident {
28                let ident = field_ident.to_string();
29                fields.entry(ident)
30                    .or_insert(Vec::new())
31                    .push(field);
32            }
33        }
34    }
35
36    fields
37}
38
39fn impl_for_input(ast: &syn::DeriveInput) -> TokenStream {
40    let fail_message = "`EnumFields` is only applicable to `enum`s";
41    match &ast.data {
42        syn::Data::Enum(data_enum) => impl_for_enum(ast, &data_enum),
43        syn::Data::Union(data_union) => syn::Error::new(data_union.union_token.span, fail_message).to_compile_error().into(),
44        syn::Data::Struct(data_struct) => syn::Error::new(data_struct.struct_token.span, fail_message).to_compile_error().into(),
45    }
46}
47
48fn impl_for_enum(ast: &syn::DeriveInput, enum_data: &syn::DataEnum) -> TokenStream {
49    let name = &ast.ident;
50
51    // Collect available fields
52    let fields = collect_available_fields(enum_data);
53
54    let mut data = proc_macro2::TokenStream::new();
55
56    let mut field_idents: Vec<Ident> = vec![];
57
58    for (field_name, fields) in fields {
59        let field_present_everywhere = fields.len() == enum_data.variants.len();
60
61        let generics = &ast.generics;
62        let field_type = &fields[0].ty;
63        let field_name_ident = Ident::new(&field_name, Span::call_site());
64
65        let mut variants = proc_macro2::TokenStream::new();
66        let mut mut_set_variances = proc_macro2::TokenStream::new();
67
68
69        for variant in &enum_data.variants {
70            let name = &variant.ident;
71
72            let variant_field_ident = variant.fields.iter()
73                .find(|variant_field| {
74                    if let Some(variant_field_ident) = &variant_field.ident {
75                        if variant_field_ident.to_string() == field_name {
76                            true
77                        } else {
78                            false
79                        }
80                    } else {
81                        false
82                    }
83                })
84                .map(|field| {
85                    field.ident.as_ref().unwrap()
86                });
87
88            match variant_field_ident {
89                Some(variant_field_ident) => {
90                    if field_present_everywhere {
91                        variants.extend(quote! {
92                            Self::#name{ #variant_field_ident, .. } =>  {
93                                std::mem::swap(#variant_field_ident, to_set);
94                            }
95                        });
96                    } else {
97                        variants.extend(quote! {
98                            Self::#name{ #variant_field_ident, .. } =>  {
99                                std::mem::swap(#variant_field_ident, to_set);
100                            }
101                        });
102                    }
103
104                    if field_present_everywhere {
105                        mut_set_variances.extend(quote! {
106                        Self::#name{ #variant_field_ident, .. } => #variant_field_ident,
107                    });
108                    } else {
109                        mut_set_variances.extend(quote! {
110                        Self::#name{ #variant_field_ident, .. } => Some(#variant_field_ident),
111                    });
112                    }
113
114                }
115                None => {
116                    // Field not present in field list.
117                    if let Some(first_field) = variant.fields.iter().next() {
118                        if first_field.ident.is_some() {
119                            mut_set_variances.extend(quote! {
120                                Self::#name{ .. } => None,
121                            });
122                        } else {
123                            mut_set_variances.extend(quote! {
124                                Self::#name(..) => None,
125                            });
126                        }
127                    } else {
128                        mut_set_variances.extend(quote! {
129                            Self::#name => None,
130                        });
131                    }
132                }
133            }
134        }
135
136        let variant_field_ident = fields[0].ident.as_ref();
137        if variant_field_ident.is_some() {
138            let set_value = Ident::new(format!("set_{}", variant_field_ident.as_ref().unwrap().to_string()).as_str(), Span::call_site());
139            data.extend(quote! {
140                impl #generics #name #generics {
141                    pub fn #set_value(&mut self, to_set: &mut #field_type) {
142                        //! Get the property of this enum discriminant if it's available
143                        match self {
144                            #variants
145                            _ => {}
146                        };
147                    }
148                }
149            });
150        }
151
152        let ty = if field_present_everywhere {
153            quote! {
154                &mut #field_type
155            }
156        } else {
157            quote! {
158                Option<&mut #field_type>
159            }
160        };
161
162        let field_name_mut = Ident::new(format!("{}_mut", variant_field_ident.unwrap()).as_str(), Span::call_site());
163        data.extend(quote! {
164            impl #generics #name #generics {
165                pub fn #field_name_mut(&mut self) -> #ty {
166                    //! Get the property of this enum discriminant if it's available
167                    match self {
168                        #mut_set_variances
169                    }
170                }
171            }
172        });
173
174
175    }
176
177    data.into()
178}