serde_enum/
lib.rs

1use quote::quote;
2use syn::*;
3
4use lazy_static::lazy_static;
5use proc_macro::TokenStream;
6use proc_macro2::TokenStream as TokenStream2;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10enum NamingStyle {
11    SnakeCase,
12    CamelCase,
13    ScreamingSnakeCase,
14    None,
15}
16
17lazy_static! {
18    static ref NAME_MAP: HashMap<NamingStyle, fn(&str) -> String> = {
19        let mut m = HashMap::new();
20
21        // I have no actual idea why this is working like that, and why I need this cast
22        m.insert(NamingStyle::SnakeCase, to_snake_case as fn(&str) -> String);
23        m.insert(NamingStyle::CamelCase, to_camel_case);
24        m.insert(NamingStyle::ScreamingSnakeCase, to_screaming_snake_case);
25        m
26    };
27}
28
29#[proc_macro_derive(ToString)]
30pub fn to_string_enum(item: TokenStream) -> TokenStream {
31    let target = parse_macro_input!(item as DeriveInput);
32    let data = get_enum_from_input(&target);
33
34    let ident = &target.ident;
35
36    let style = get_naming_style(target.attrs.iter());
37
38    let to_str_arms = create_to_str_arms(&data, style);
39
40    let out = quote! {
41        impl std::convert::From<&#ident> for &'static str {
42            fn from(v: &#ident) -> &'static str {
43                match v {
44                    #(#ident::#to_str_arms),*
45                }
46            }
47        }
48
49        impl std::string::ToString for #ident {
50            fn to_string(&self) -> String {
51                <&#ident as std::convert::Into<&'static str>>::into(self).to_string()
52            }
53        }
54    };
55    out.into()
56}
57
58#[proc_macro_derive(Serialize_enum, attributes(serde))]
59pub fn serialize_enum(item: TokenStream) -> TokenStream {
60    let target = parse_macro_input!(item as DeriveInput);
61    let data = get_enum_from_input(&target);
62
63    let style = get_naming_style(target.attrs.iter());
64
65    let target_ident = &target.ident;
66    let ser_arms = create_ser_arms(&data, style);
67    let out = quote! {
68        impl serde::Serialize for #target_ident {
69            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70            where
71                S: serde::Serializer
72            {
73                match self {
74                    #(#ser_arms),*
75                }
76            }
77        }
78    };
79    out.into()
80}
81
82#[proc_macro_derive(Deserialize_enum, attributes(serde))]
83pub fn deserialize_enum(item: TokenStream) -> TokenStream {
84    let target = parse_macro_input!(item as DeriveInput);
85    let data = get_enum_from_input(&target);
86
87    let style = get_naming_style(target.attrs.iter());
88
89    let target_ident = &target.ident;
90    let de_arms = create_de_arms(&data, style);
91    let out = quote! {
92        impl<'de> serde::Deserialize<'de> for #target_ident {
93            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
94            where
95                D: serde::Deserializer<'de>
96            {
97                Ok(
98                    match <&str>::deserialize(deserializer)? {
99                        #(#de_arms),*,
100                        _ => { unimplemented!() }
101                    }
102                )
103            }
104        }
105    };
106    out.into()
107}
108
109fn get_naming_style<'a>(target: impl Iterator<Item = &'a Attribute>) -> NamingStyle {
110    for a in target {
111        if let Some(i) = a.path.get_ident() {
112            if i == "serde" {
113                if let Ok(ExprParen { expr, .. }) = parse2::<ExprParen>(a.tokens.clone()) {
114                    if let Expr::Assign(ea) = expr.as_ref() {
115                        if let Expr::Path(ep) = ea.left.as_ref() {
116                            if let Some(i) = ep.path.get_ident() {
117                                if i == "rename" || i == "rename_all" {
118                                    if let Expr::Lit(ExprLit {
119                                        lit: Lit::Str(s), ..
120                                    }) = ea.right.as_ref()
121                                    {
122                                        return match s.value().as_str() {
123                                            "snake_case" => NamingStyle::SnakeCase,
124                                            "camelCase" => NamingStyle::CamelCase,
125                                            "SCREAMING_SNAKE_CASE" => {
126                                                NamingStyle::ScreamingSnakeCase
127                                            }
128                                            _ => {
129                                                panic!(
130                                                    "Unsupported style. \
131                                                    Available: `snake_case`, `camelCase`"
132                                                )
133                                            }
134                                        };
135                                    }
136                                }
137                            }
138                        }
139                    }
140                }
141            }
142        }
143    }
144    NamingStyle::None
145}
146
147fn get_variant_alias(v: &Variant) -> Option<String> {
148    for a in v.attrs.iter() {
149        if let Some(i) = a.path.get_ident() {
150            if i == "serde" {
151                if let Ok(ExprParen { expr, .. }) = parse2::<ExprParen>(a.tokens.clone()) {
152                    if let Expr::Assign(ea) = expr.as_ref() {
153                        if let Expr::Path(ep) = ea.left.as_ref() {
154                            if let Some(i) = ep.path.get_ident() {
155                                if i == "name" {
156                                    if let Expr::Lit(ExprLit {
157                                        lit: Lit::Str(s), ..
158                                    }) = ea.right.as_ref()
159                                    {
160                                        return Some(s.value());
161                                    }
162                                }
163                            }
164                        }
165                    }
166                }
167            }
168        }
169    }
170    None
171}
172
173fn get_enum_from_input(target: &DeriveInput) -> DataEnum {
174    if !target.generics.params.is_empty() {
175        panic!("`Serialize_enum` target cannot have any generics parameters!");
176    }
177
178    if let Data::Enum(ref e) = target.data {
179        e.clone()
180    } else {
181        panic!("`Serialize_enum` can only be applied to enums!");
182    }
183}
184
185fn create_ser_arms(target: &DataEnum, n: NamingStyle) -> impl Iterator<Item = TokenStream2> {
186    target.variants.clone().into_iter().map(move |v| {
187        assert!(matches!(v.fields, Fields::Unit));
188        let ident = &v.ident;
189        let value = format_variant(&v, n);
190
191        quote! {
192            Self::#ident => { serializer.serialize_str(#value) }
193        }
194    })
195}
196
197fn create_to_str_arms(target: &DataEnum, n: NamingStyle) -> impl Iterator<Item = TokenStream2> {
198    target.variants.clone().into_iter().map(move |v| {
199        let ident = &v.ident;
200        let value = format_variant(&v, n);
201
202        quote! {
203            #ident => #value
204        }
205    })
206}
207
208fn create_de_arms(target: &DataEnum, n: NamingStyle) -> impl Iterator<Item = TokenStream2> {
209    target.variants.clone().into_iter().map(move |v| {
210        assert!(matches!(v.fields, Fields::Unit));
211
212        let ident = &v.ident;
213        let value = format_variant(&v, n);
214
215        quote! {
216            #value => Self::#ident
217        }
218    })
219}
220
221fn format_variant(v: &Variant, parent_style: NamingStyle) -> String {
222    if let Some(s) = get_variant_alias(v) {
223        return s;
224    }
225
226    let own_style = get_naming_style(v.attrs.iter());
227
228    match own_style {
229        NamingStyle::None => match parent_style {
230            NamingStyle::None => v.ident.to_string(),
231            ps => NAME_MAP.get(&ps).unwrap()(&v.ident.to_string()),
232        },
233        os => NAME_MAP.get(&os).unwrap()(&v.ident.to_string()),
234    }
235}
236
237fn to_snake_case(v: &str) -> String {
238    let mut out = String::with_capacity(v.len());
239    if v.is_empty() {
240        out.push(v.chars().next().unwrap().to_ascii_lowercase());
241    }
242
243    for c in v.chars().skip(1) {
244        if c.is_uppercase() {
245            out.push('_');
246            out.push(c.to_ascii_lowercase());
247        } else {
248            out.push(c);
249        }
250    }
251
252    out
253}
254
255fn to_camel_case(v: &str) -> String {
256    v.to_string()
257        .char_indices()
258        .map(|(i, c)| if i == 0 { c.to_ascii_lowercase() } else { c })
259        .collect()
260}
261
262fn to_screaming_snake_case(v: &str) -> String {
263    v.char_indices()
264        .fold(String::with_capacity(v.len()), |mut s, (i, c)| {
265            if c.is_uppercase() && i != 0 {
266                s.push('_');
267            }
268            s.push(c.to_ascii_uppercase());
269            s
270        })
271}