1#![recursion_limit="256"]
2
3extern crate proc_macro;
4
5use darling::*;
6use proc_macro::TokenStream;
7use proc_macro2::{TokenStream as SynTokenStream, Literal, Span};
8use std::collections::HashSet;
9use syn::{*, Result, Error};
10use syn::spanned::Spanned;
11use quote::*;
12
13fn error<T>(span: Span, message: &str) -> Result<T> {
15 Err(Error::new(span, message))
16}
17
18#[derive(FromDeriveInput, Default)]
20#[darling(attributes(enumset), default)]
21struct EnumsetAttrs {
22 no_ops: bool,
23 serialize_as_list: bool,
24 serialize_deny_unknown: bool,
25 #[darling(default)]
26 serialize_repr: Option<String>,
27 #[darling(default)]
28 crate_name: Option<String>,
29}
30
31struct EnumSetValue {
33 name: Ident,
35 variant_repr: u32,
37}
38
39#[allow(dead_code)]
41struct EnumSetInfo {
42 name: Ident,
44 crate_name: Option<Ident>,
46 explicit_serde_repr: Option<Ident>,
48 has_signed_repr: bool,
50 has_large_repr: bool,
52 variants: Vec<EnumSetValue>,
54
55 max_discrim: u32,
57 cur_discrim: u32,
59 used_variant_names: HashSet<String>,
61 used_discriminants: HashSet<u32>,
63
64 no_ops: bool,
66 serialize_as_list: bool,
68 serialize_deny_unknown: bool,
70}
71impl EnumSetInfo {
72 fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
73 EnumSetInfo {
74 name: input.ident.clone(),
75 crate_name: attrs.crate_name.map(|x| Ident::new(&x, Span::call_site())),
76 explicit_serde_repr: attrs.serialize_repr.map(|x| Ident::new(&x, Span::call_site())),
77 has_signed_repr: false,
78 has_large_repr: false,
79 variants: Vec::new(),
80 max_discrim: 0,
81 cur_discrim: 0,
82 used_variant_names: HashSet::new(),
83 used_discriminants: HashSet::new(),
84 no_ops: attrs.no_ops,
85 serialize_as_list: attrs.serialize_as_list,
86 serialize_deny_unknown: attrs.serialize_deny_unknown
87 }
88 }
89
90 fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
92 match repr {
95 "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
96 "usize" | "u64" | "u128" => {
97 self.has_large_repr = true;
98 Ok(())
99 }
100 "i8" | "i16" | "i32" => {
101 self.has_signed_repr = true;
102 Ok(())
103 }
104 "isize" | "i64" | "i128" => {
105 self.has_signed_repr = true;
106 self.has_large_repr = true;
107 Ok(())
108 }
109 _ => error(attr_span, "Unsupported repr.")
110 }
111 }
112 fn push_variant(&mut self, variant: &Variant) -> Result<()> {
114 if self.used_variant_names.contains(&variant.ident.to_string()) {
115 error(variant.span(), "Duplicated variant name.")
116 } else if let Fields::Unit = variant.fields {
117 if let Some((_, expr)) = &variant.discriminant {
119 let discriminant_fail_message = format!(
120 "Enum set discriminants must be `u32`s.{}",
121 if self.has_signed_repr || self.has_large_repr {
122 format!(
123 " ({} discrimiants are still unsupported with reprs that allow them.)",
124 if self.has_large_repr {
125 "larger"
126 } else if self.has_signed_repr {
127 "negative"
128 } else {
129 "larger or negative"
130 }
131 )
132 } else {
133 String::new()
134 },
135 );
136 if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
137 match i.base10_parse() {
138 Ok(val) => self.cur_discrim = val,
139 Err(_) => error(expr.span(), &discriminant_fail_message)?,
140 }
141 } else {
142 error(variant.span(), &discriminant_fail_message)?;
143 }
144 }
145
146 let discriminant = self.cur_discrim;
148 if discriminant >= 128 {
149 let message = if self.variants.len() <= 127 {
150 "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
151 } else {
152 "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
153 };
154 error(variant.span(), message)?;
155 }
156 if self.used_discriminants.contains(&discriminant) {
157 error(variant.span(), "Duplicated enum discriminant.")?;
158 }
159
160 self.cur_discrim += 1;
162 if discriminant > self.max_discrim {
163 self.max_discrim = discriminant;
164 }
165 self.variants.push(EnumSetValue {
166 name: variant.ident.clone(),
167 variant_repr: discriminant,
168 });
169 self.used_variant_names.insert(variant.ident.to_string());
170 self.used_discriminants.insert(discriminant);
171
172 Ok(())
173 } else {
174 error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
175 }
176 }
177 fn validate(&self) -> Result<()> {
179 if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
181 let is_overflowed = match explicit_serde_repr.to_string().as_str() {
182 "u8" => self.max_discrim >= 8,
183 "u16" => self.max_discrim >= 16,
184 "u32" => self.max_discrim >= 32,
185 "u64" => self.max_discrim >= 64,
186 "u128" => self.max_discrim >= 128,
187 _ => error(
188 Span::call_site(),
189 "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr."
190 )?,
191 };
192 if is_overflowed {
193 error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
194 }
195 }
196 Ok(())
197 }
198
199 fn enumset_repr(&self) -> SynTokenStream {
201 if self.max_discrim <= 7 {
202 quote! { u8 }
203 } else if self.max_discrim <= 15 {
204 quote! { u16 }
205 } else if self.max_discrim <= 31 {
206 quote! { u32 }
207 } else if self.max_discrim <= 63 {
208 quote! { u64 }
209 } else if self.max_discrim <= 127 {
210 quote! { u128 }
211 } else {
212 panic!("max_variant > 127?")
213 }
214 }
215 #[cfg(feature = "serde")]
217 fn serde_repr(&self) -> SynTokenStream {
218 if let Some(serde_repr) = &self.explicit_serde_repr {
219 quote! { #serde_repr }
220 } else {
221 self.enumset_repr()
222 }
223 }
224
225 fn all_variants(&self) -> u128 {
227 let mut accum = 0u128;
228 for variant in &self.variants {
229 assert!(variant.variant_repr <= 127);
230 accum |= 1u128 << variant.variant_repr as u128;
231 }
232 accum
233 }
234}
235
236fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
238 let name = &info.name;
239 let enumset = match &info.crate_name {
240 Some(crate_name) => quote!(::#crate_name),
241 None => quote!(::wasmer_enumset),
242 };
243 let typed_enumset = quote!(#enumset::EnumSet<#name>);
244 let core = quote!(#enumset::__internal::core_export);
245
246 let repr = info.enumset_repr();
247 let all_variants = Literal::u128_unsuffixed(info.all_variants());
248
249 let ops = if info.no_ops {
250 quote! {}
251 } else {
252 quote! {
253 impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
254 type Output = #typed_enumset;
255 fn sub(self, other: O) -> Self::Output {
256 #enumset::EnumSet::only(self) - other.into()
257 }
258 }
259 impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
260 type Output = #typed_enumset;
261 fn bitand(self, other: O) -> Self::Output {
262 #enumset::EnumSet::only(self) & other.into()
263 }
264 }
265 impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
266 type Output = #typed_enumset;
267 fn bitor(self, other: O) -> Self::Output {
268 #enumset::EnumSet::only(self) | other.into()
269 }
270 }
271 impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
272 type Output = #typed_enumset;
273 fn bitxor(self, other: O) -> Self::Output {
274 #enumset::EnumSet::only(self) ^ other.into()
275 }
276 }
277 impl #core::ops::Not for #name {
278 type Output = #typed_enumset;
279 fn not(self) -> Self::Output {
280 !#enumset::EnumSet::only(self)
281 }
282 }
283 impl #core::cmp::PartialEq<#typed_enumset> for #name {
284 fn eq(&self, other: &#typed_enumset) -> bool {
285 #enumset::EnumSet::only(*self) == *other
286 }
287 }
288 }
289 };
290
291
292 #[cfg(feature = "serde")]
293 let serde = quote!(#enumset::__internal::serde);
294
295 #[cfg(feature = "serde")]
296 let serde_ops = if info.serialize_as_list {
297 let expecting_str = format!("a list of {}", name);
298 quote! {
299 fn serialize<S: #serde::Serializer>(
300 set: #enumset::EnumSet<#name>, ser: S,
301 ) -> #core::result::Result<S::Ok, S::Error> {
302 use #serde::ser::SerializeSeq;
303 let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
304 for bit in set {
305 seq.serialize_element(&bit)?;
306 }
307 seq.end()
308 }
309 fn deserialize<'de, D: #serde::Deserializer<'de>>(
310 de: D,
311 ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
312 struct Visitor;
313 impl <'de> #serde::de::Visitor<'de> for Visitor {
314 type Value = #enumset::EnumSet<#name>;
315 fn expecting(
316 &self, formatter: &mut #core::fmt::Formatter,
317 ) -> #core::fmt::Result {
318 write!(formatter, #expecting_str)
319 }
320 fn visit_seq<A>(
321 mut self, mut seq: A,
322 ) -> #core::result::Result<Self::Value, A::Error> where
323 A: #serde::de::SeqAccess<'de>
324 {
325 let mut accum = #enumset::EnumSet::<#name>::new();
326 while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
327 accum |= val;
328 }
329 #core::prelude::v1::Ok(accum)
330 }
331 }
332 de.deserialize_seq(Visitor)
333 }
334 }
335 } else {
336 let serialize_repr = info.serde_repr();
337 let check_unknown = if info.serialize_deny_unknown {
338 quote! {
339 if value & !#all_variants != 0 {
340 use #serde::de::Error;
341 return #core::prelude::v1::Err(
342 D::Error::custom("enumset contains unknown bits")
343 )
344 }
345 }
346 } else {
347 quote! { }
348 };
349 quote! {
350 fn serialize<S: #serde::Serializer>(
351 set: #enumset::EnumSet<#name>, ser: S,
352 ) -> #core::result::Result<S::Ok, S::Error> {
353 #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
354 }
355 fn deserialize<'de, D: #serde::Deserializer<'de>>(
356 de: D,
357 ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
358 let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
359 #check_unknown
360 #core::prelude::v1::Ok(#enumset::EnumSet {
361 __enumset_underlying: (value & #all_variants) as #repr,
362 })
363 }
364 }
365 };
366
367 #[cfg(not(feature = "serde"))]
368 let serde_ops = quote! { };
369
370 let is_uninhabited = info.variants.is_empty();
371 let is_zst = info.variants.len() == 1;
372 let into_impl = if is_uninhabited {
373 quote! {
374 fn enum_into_u32(self) -> u32 {
375 panic!(concat!(stringify!(#name), " is uninhabited."))
376 }
377 unsafe fn enum_from_u32(val: u32) -> Self {
378 panic!(concat!(stringify!(#name), " is uninhabited."))
379 }
380 }
381 } else if is_zst {
382 let variant = &info.variants[0].name;
383 quote! {
384 fn enum_into_u32(self) -> u32 {
385 self as u32
386 }
387 unsafe fn enum_from_u32(val: u32) -> Self {
388 #name::#variant
389 }
390 }
391 } else {
392 let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
393 let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
394
395 let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
396 .iter().map(|x| Ident::new(x, Span::call_site())).collect();
397 let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
398 .iter().map(|x| Ident::new(x, Span::call_site())).collect();
399
400 quote! {
401 fn enum_into_u32(self) -> u32 {
402 self as u32
403 }
404 unsafe fn enum_from_u32(val: u32) -> Self {
405 #(const #const_field: bool =
408 #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
409 match val {
410 #(#variant_value => #name::#variant_name,)*
414 #(x if #const_field => {
417 let x = x as #int_type;
418 *(&x as *const _ as *const #name)
419 })*
420 _ => #core::hint::unreachable_unchecked(),
423 }
424 }
425 }
426 };
427
428 let eq_impl = if is_uninhabited {
429 quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
430 } else {
431 quote!((*self as u32) == (*other as u32))
432 };
433
434 quote! {
435 unsafe impl #enumset::__internal::EnumSetTypePrivate for #name {
436 type Repr = #repr;
437 const ALL_BITS: Self::Repr = #all_variants;
438 #into_impl
439 #serde_ops
440 }
441
442 unsafe impl #enumset::EnumSetType for #name { }
443
444 impl #core::cmp::PartialEq for #name {
445 fn eq(&self, other: &Self) -> bool {
446 #eq_impl
447 }
448 }
449 impl #core::cmp::Eq for #name { }
450 impl #core::clone::Clone for #name {
451 fn clone(&self) -> Self {
452 *self
453 }
454 }
455 impl #core::marker::Copy for #name { }
456
457 #ops
458 }
459}
460
461#[proc_macro_derive(EnumSetType, attributes(enumset))]
463pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
464 let input: DeriveInput = parse_macro_input!(input);
465 let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
466 Ok(attrs) => attrs,
467 Err(e) => return e.write_errors().into(),
468 };
469 match derive_enum_set_type_0(input, attrs) {
470 Ok(v) => v,
471 Err(e) => e.to_compile_error().into(),
472 }
473}
474fn derive_enum_set_type_0(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
475 if !input.generics.params.is_empty() {
476 error(
477 input.generics.span(),
478 "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
479 )
480 } else if let Data::Enum(data) = &input.data {
481 let mut info = EnumSetInfo::new(&input, attrs);
482 for attr in &input.attrs {
483 if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
484 let meta: Ident = attr.parse_args()?;
485 info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
486 }
487 }
488 for variant in &data.variants {
489 info.push_variant(variant)?;
490 }
491 info.validate()?;
492 Ok(enum_set_type_impl(info).into())
493 } else {
494 error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
495 }
496}