1use quote::quote;
2use syn::*;
3
4use lazy_static::lazy_static;
5use proc_macro::TokenStream;
6use proc_macro2::TokenStream as TokenStream2;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10enum NamingStyle {
11 SnakeCase,
12 CamelCase,
13 ScreamingSnakeCase,
14 None,
15}
16
17lazy_static! {
18 static ref NAME_MAP: HashMap<NamingStyle, fn(&str) -> String> = {
19 let mut m = HashMap::new();
20
21 m.insert(NamingStyle::SnakeCase, to_snake_case as fn(&str) -> String);
23 m.insert(NamingStyle::CamelCase, to_camel_case);
24 m.insert(NamingStyle::ScreamingSnakeCase, to_screaming_snake_case);
25 m
26 };
27}
28
29#[proc_macro_derive(ToString)]
30pub fn to_string_enum(item: TokenStream) -> TokenStream {
31 let target = parse_macro_input!(item as DeriveInput);
32 let data = get_enum_from_input(&target);
33
34 let ident = &target.ident;
35
36 let style = get_naming_style(target.attrs.iter());
37
38 let to_str_arms = create_to_str_arms(&data, style);
39
40 let out = quote! {
41 impl std::convert::From<&#ident> for &'static str {
42 fn from(v: &#ident) -> &'static str {
43 match v {
44 #(#ident::#to_str_arms),*
45 }
46 }
47 }
48
49 impl std::string::ToString for #ident {
50 fn to_string(&self) -> String {
51 <&#ident as std::convert::Into<&'static str>>::into(self).to_string()
52 }
53 }
54 };
55 out.into()
56}
57
58#[proc_macro_derive(Serialize_enum, attributes(serde))]
59pub fn serialize_enum(item: TokenStream) -> TokenStream {
60 let target = parse_macro_input!(item as DeriveInput);
61 let data = get_enum_from_input(&target);
62
63 let style = get_naming_style(target.attrs.iter());
64
65 let target_ident = &target.ident;
66 let ser_arms = create_ser_arms(&data, style);
67 let out = quote! {
68 impl serde::Serialize for #target_ident {
69 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
70 where
71 S: serde::Serializer
72 {
73 match self {
74 #(#ser_arms),*
75 }
76 }
77 }
78 };
79 out.into()
80}
81
82#[proc_macro_derive(Deserialize_enum, attributes(serde))]
83pub fn deserialize_enum(item: TokenStream) -> TokenStream {
84 let target = parse_macro_input!(item as DeriveInput);
85 let data = get_enum_from_input(&target);
86
87 let style = get_naming_style(target.attrs.iter());
88
89 let target_ident = &target.ident;
90 let de_arms = create_de_arms(&data, style);
91 let out = quote! {
92 impl<'de> serde::Deserialize<'de> for #target_ident {
93 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
94 where
95 D: serde::Deserializer<'de>
96 {
97 Ok(
98 match <&str>::deserialize(deserializer)? {
99 #(#de_arms),*,
100 _ => { unimplemented!() }
101 }
102 )
103 }
104 }
105 };
106 out.into()
107}
108
109fn get_naming_style<'a>(target: impl Iterator<Item = &'a Attribute>) -> NamingStyle {
110 for a in target {
111 if let Some(i) = a.path.get_ident() {
112 if i == "serde" {
113 if let Ok(ExprParen { expr, .. }) = parse2::<ExprParen>(a.tokens.clone()) {
114 if let Expr::Assign(ea) = expr.as_ref() {
115 if let Expr::Path(ep) = ea.left.as_ref() {
116 if let Some(i) = ep.path.get_ident() {
117 if i == "rename" || i == "rename_all" {
118 if let Expr::Lit(ExprLit {
119 lit: Lit::Str(s), ..
120 }) = ea.right.as_ref()
121 {
122 return match s.value().as_str() {
123 "snake_case" => NamingStyle::SnakeCase,
124 "camelCase" => NamingStyle::CamelCase,
125 "SCREAMING_SNAKE_CASE" => {
126 NamingStyle::ScreamingSnakeCase
127 }
128 _ => {
129 panic!(
130 "Unsupported style. \
131 Available: `snake_case`, `camelCase`"
132 )
133 }
134 };
135 }
136 }
137 }
138 }
139 }
140 }
141 }
142 }
143 }
144 NamingStyle::None
145}
146
147fn get_variant_alias(v: &Variant) -> Option<String> {
148 for a in v.attrs.iter() {
149 if let Some(i) = a.path.get_ident() {
150 if i == "serde" {
151 if let Ok(ExprParen { expr, .. }) = parse2::<ExprParen>(a.tokens.clone()) {
152 if let Expr::Assign(ea) = expr.as_ref() {
153 if let Expr::Path(ep) = ea.left.as_ref() {
154 if let Some(i) = ep.path.get_ident() {
155 if i == "name" {
156 if let Expr::Lit(ExprLit {
157 lit: Lit::Str(s), ..
158 }) = ea.right.as_ref()
159 {
160 return Some(s.value());
161 }
162 }
163 }
164 }
165 }
166 }
167 }
168 }
169 }
170 None
171}
172
173fn get_enum_from_input(target: &DeriveInput) -> DataEnum {
174 if !target.generics.params.is_empty() {
175 panic!("`Serialize_enum` target cannot have any generics parameters!");
176 }
177
178 if let Data::Enum(ref e) = target.data {
179 e.clone()
180 } else {
181 panic!("`Serialize_enum` can only be applied to enums!");
182 }
183}
184
185fn create_ser_arms(target: &DataEnum, n: NamingStyle) -> impl Iterator<Item = TokenStream2> {
186 target.variants.clone().into_iter().map(move |v| {
187 assert!(matches!(v.fields, Fields::Unit));
188 let ident = &v.ident;
189 let value = format_variant(&v, n);
190
191 quote! {
192 Self::#ident => { serializer.serialize_str(#value) }
193 }
194 })
195}
196
197fn create_to_str_arms(target: &DataEnum, n: NamingStyle) -> impl Iterator<Item = TokenStream2> {
198 target.variants.clone().into_iter().map(move |v| {
199 let ident = &v.ident;
200 let value = format_variant(&v, n);
201
202 quote! {
203 #ident => #value
204 }
205 })
206}
207
208fn create_de_arms(target: &DataEnum, n: NamingStyle) -> impl Iterator<Item = TokenStream2> {
209 target.variants.clone().into_iter().map(move |v| {
210 assert!(matches!(v.fields, Fields::Unit));
211
212 let ident = &v.ident;
213 let value = format_variant(&v, n);
214
215 quote! {
216 #value => Self::#ident
217 }
218 })
219}
220
221fn format_variant(v: &Variant, parent_style: NamingStyle) -> String {
222 if let Some(s) = get_variant_alias(v) {
223 return s;
224 }
225
226 let own_style = get_naming_style(v.attrs.iter());
227
228 match own_style {
229 NamingStyle::None => match parent_style {
230 NamingStyle::None => v.ident.to_string(),
231 ps => NAME_MAP.get(&ps).unwrap()(&v.ident.to_string()),
232 },
233 os => NAME_MAP.get(&os).unwrap()(&v.ident.to_string()),
234 }
235}
236
237fn to_snake_case(v: &str) -> String {
238 let mut out = String::with_capacity(v.len());
239 if v.is_empty() {
240 out.push(v.chars().next().unwrap().to_ascii_lowercase());
241 }
242
243 for c in v.chars().skip(1) {
244 if c.is_uppercase() {
245 out.push('_');
246 out.push(c.to_ascii_lowercase());
247 } else {
248 out.push(c);
249 }
250 }
251
252 out
253}
254
255fn to_camel_case(v: &str) -> String {
256 v.to_string()
257 .char_indices()
258 .map(|(i, c)| if i == 0 { c.to_ascii_lowercase() } else { c })
259 .collect()
260}
261
262fn to_screaming_snake_case(v: &str) -> String {
263 v.char_indices()
264 .fold(String::with_capacity(v.len()), |mut s, (i, c)| {
265 if c.is_uppercase() && i != 0 {
266 s.push('_');
267 }
268 s.push(c.to_ascii_uppercase());
269 s
270 })
271}