Skip to main content

wasmer_enumset_derive/
lib.rs

1#![recursion_limit="256"]
2
3extern crate proc_macro;
4
5use darling::*;
6use proc_macro::TokenStream;
7use proc_macro2::{TokenStream as SynTokenStream, Literal, Span};
8use std::collections::HashSet;
9use syn::{*, Result, Error};
10use syn::spanned::Spanned;
11use quote::*;
12
13/// Helper function for emitting compile errors.
14fn error<T>(span: Span, message: &str) -> Result<T> {
15    Err(Error::new(span, message))
16}
17
18/// Decodes the custom attributes for our custom derive.
19#[derive(FromDeriveInput, Default)]
20#[darling(attributes(enumset), default)]
21struct EnumsetAttrs {
22    no_ops: bool,
23    serialize_as_list: bool,
24    serialize_deny_unknown: bool,
25    #[darling(default)]
26    serialize_repr: Option<String>,
27    #[darling(default)]
28    crate_name: Option<String>,
29}
30
31/// An variant in the enum set type.
32struct EnumSetValue {
33    /// The name of the variant.
34    name: Ident,
35    /// The discriminant of the variant.
36    variant_repr: u32,
37}
38
39/// Stores information about the enum set type.
40#[allow(dead_code)]
41struct EnumSetInfo {
42    /// The name of the enum.
43    name: Ident,
44    /// The crate name to use.
45    crate_name: Option<Ident>,
46    /// The numeric type to serialize the enum as.
47    explicit_serde_repr: Option<Ident>,
48    /// Whether the underlying repr of the enum supports negative values.
49    has_signed_repr: bool,
50    /// Whether the underlying repr of the enum supports values higher than 2^32.
51    has_large_repr: bool,
52    /// A list of variants in the enum.
53    variants: Vec<EnumSetValue>,
54
55    /// The highest encountered variant discriminant.
56    max_discrim: u32,
57    /// The current variant discriminant. Used to track, e.g. `A=10,B,C`.
58    cur_discrim: u32,
59    /// A list of variant names that are already in use.
60    used_variant_names: HashSet<String>,
61    /// A list of variant discriminants that are already in use.
62    used_discriminants: HashSet<u32>,
63
64    /// Avoid generating operator overloads on the enum type.
65    no_ops: bool,
66    /// Serialize the enum as a list.
67    serialize_as_list: bool,
68    /// Disallow unknown bits while deserializing the enum.
69    serialize_deny_unknown: bool,
70}
71impl EnumSetInfo {
72    fn new(input: &DeriveInput, attrs: EnumsetAttrs) -> EnumSetInfo {
73        EnumSetInfo {
74            name: input.ident.clone(),
75            crate_name: attrs.crate_name.map(|x| Ident::new(&x, Span::call_site())),
76            explicit_serde_repr: attrs.serialize_repr.map(|x| Ident::new(&x, Span::call_site())),
77            has_signed_repr: false,
78            has_large_repr: false,
79            variants: Vec::new(),
80            max_discrim: 0,
81            cur_discrim: 0,
82            used_variant_names: HashSet::new(),
83            used_discriminants: HashSet::new(),
84            no_ops: attrs.no_ops,
85            serialize_as_list: attrs.serialize_as_list,
86            serialize_deny_unknown: attrs.serialize_deny_unknown
87        }
88    }
89
90    /// Sets an explicit repr for the enumset.
91    fn push_explicit_repr(&mut self, attr_span: Span, repr: &str) -> Result<()> {
92        // Check whether the repr is supported, and if so, set some flags for better error
93        // messages later on.
94        match repr {
95            "Rust" | "C" | "u8" | "u16" | "u32" => Ok(()),
96            "usize" | "u64" | "u128" => {
97                self.has_large_repr = true;
98                Ok(())
99            }
100            "i8" | "i16" | "i32" => {
101                self.has_signed_repr = true;
102                Ok(())
103            }
104            "isize" | "i64" | "i128" => {
105                self.has_signed_repr = true;
106                self.has_large_repr = true;
107                Ok(())
108            }
109            _ => error(attr_span, "Unsupported repr.")
110        }
111    }
112    /// Adds a variant to the enumset.
113    fn push_variant(&mut self, variant: &Variant) -> Result<()> {
114        if self.used_variant_names.contains(&variant.ident.to_string()) {
115            error(variant.span(), "Duplicated variant name.")
116        } else if let Fields::Unit = variant.fields {
117            // Parse the discriminant.
118            if let Some((_, expr)) = &variant.discriminant {
119                let discriminant_fail_message = format!(
120                    "Enum set discriminants must be `u32`s.{}",
121                    if self.has_signed_repr || self.has_large_repr {
122                        format!(
123                            " ({} discrimiants are still unsupported with reprs that allow them.)",
124                            if self.has_large_repr {
125                                "larger"
126                            } else if self.has_signed_repr {
127                                "negative"
128                            } else {
129                                "larger or negative"
130                            }
131                        )
132                    } else {
133                        String::new()
134                    },
135                );
136                if let Expr::Lit(ExprLit { lit: Lit::Int(i), .. }) = expr {
137                    match i.base10_parse() {
138                        Ok(val) => self.cur_discrim = val,
139                        Err(_) => error(expr.span(), &discriminant_fail_message)?,
140                    }
141                } else {
142                    error(variant.span(), &discriminant_fail_message)?;
143                }
144            }
145
146            // Validate the discriminant.
147            let discriminant = self.cur_discrim;
148            if discriminant >= 128 {
149                let message = if self.variants.len() <= 127 {
150                    "`#[derive(EnumSetType)]` currently only supports discriminants up to 127."
151                } else {
152                    "`#[derive(EnumSetType)]` currently only supports enums up to 128 variants."
153                };
154                error(variant.span(), message)?;
155            }
156            if self.used_discriminants.contains(&discriminant) {
157                error(variant.span(), "Duplicated enum discriminant.")?;
158            }
159
160            // Add the variant to the info.
161            self.cur_discrim += 1;
162            if discriminant > self.max_discrim {
163                self.max_discrim = discriminant;
164            }
165            self.variants.push(EnumSetValue {
166                name: variant.ident.clone(),
167                variant_repr: discriminant,
168            });
169            self.used_variant_names.insert(variant.ident.to_string());
170            self.used_discriminants.insert(discriminant);
171
172            Ok(())
173        } else {
174            error(variant.span(), "`#[derive(EnumSetType)]` can only be used on fieldless enums.")
175        }
176    }
177    /// Validate the enumset type.
178    fn validate(&self) -> Result<()> {
179        // Check if all bits of the bitset can fit in the serialization representation.
180        if let Some(explicit_serde_repr) = &self.explicit_serde_repr {
181            let is_overflowed = match explicit_serde_repr.to_string().as_str() {
182                "u8" => self.max_discrim >= 8,
183                "u16" => self.max_discrim >= 16,
184                "u32" => self.max_discrim >= 32,
185                "u64" => self.max_discrim >= 64,
186                "u128" => self.max_discrim >= 128,
187                _ => error(
188                    Span::call_site(),
189                    "Only `u8`, `u16`, `u32`, `u64` and `u128` are supported for serde_repr."
190                )?,
191            };
192            if is_overflowed {
193                error(Span::call_site(), "serialize_repr cannot be smaller than bitset.")?;
194            }
195        }
196        Ok(())
197    }
198
199    /// Computes the underlying type used to store the enumset.
200    fn enumset_repr(&self) -> SynTokenStream {
201        if self.max_discrim <= 7 {
202            quote! { u8 }
203        } else if self.max_discrim <= 15 {
204            quote! { u16 }
205        } else if self.max_discrim <= 31 {
206            quote! { u32 }
207        } else if self.max_discrim <= 63 {
208            quote! { u64 }
209        } else if self.max_discrim <= 127 {
210            quote! { u128 }
211        } else {
212            panic!("max_variant > 127?")
213        }
214    }
215    /// Computes the underlying type used to serialize the enumset.
216    #[cfg(feature = "serde")]
217    fn serde_repr(&self) -> SynTokenStream {
218        if let Some(serde_repr) = &self.explicit_serde_repr {
219            quote! { #serde_repr }
220        } else {
221            self.enumset_repr()
222        }
223    }
224
225    /// Returns a bitmask of all variants in the set.
226    fn all_variants(&self) -> u128 {
227        let mut accum = 0u128;
228        for variant in &self.variants {
229            assert!(variant.variant_repr <= 127);
230            accum |= 1u128 << variant.variant_repr as u128;
231        }
232        accum
233    }
234}
235
236/// Generates the actual `EnumSetType` impl.
237fn enum_set_type_impl(info: EnumSetInfo) -> SynTokenStream {
238    let name = &info.name;
239    let enumset = match &info.crate_name {
240        Some(crate_name) => quote!(::#crate_name),
241        None => quote!(::wasmer_enumset),
242    };
243    let typed_enumset = quote!(#enumset::EnumSet<#name>);
244    let core = quote!(#enumset::__internal::core_export);
245
246    let repr = info.enumset_repr();
247    let all_variants = Literal::u128_unsuffixed(info.all_variants());
248
249    let ops = if info.no_ops {
250        quote! {}
251    } else {
252        quote! {
253            impl <O : Into<#typed_enumset>> #core::ops::Sub<O> for #name {
254                type Output = #typed_enumset;
255                fn sub(self, other: O) -> Self::Output {
256                    #enumset::EnumSet::only(self) - other.into()
257                }
258            }
259            impl <O : Into<#typed_enumset>> #core::ops::BitAnd<O> for #name {
260                type Output = #typed_enumset;
261                fn bitand(self, other: O) -> Self::Output {
262                    #enumset::EnumSet::only(self) & other.into()
263                }
264            }
265            impl <O : Into<#typed_enumset>> #core::ops::BitOr<O> for #name {
266                type Output = #typed_enumset;
267                fn bitor(self, other: O) -> Self::Output {
268                    #enumset::EnumSet::only(self) | other.into()
269                }
270            }
271            impl <O : Into<#typed_enumset>> #core::ops::BitXor<O> for #name {
272                type Output = #typed_enumset;
273                fn bitxor(self, other: O) -> Self::Output {
274                    #enumset::EnumSet::only(self) ^ other.into()
275                }
276            }
277            impl #core::ops::Not for #name {
278                type Output = #typed_enumset;
279                fn not(self) -> Self::Output {
280                    !#enumset::EnumSet::only(self)
281                }
282            }
283            impl #core::cmp::PartialEq<#typed_enumset> for #name {
284                fn eq(&self, other: &#typed_enumset) -> bool {
285                    #enumset::EnumSet::only(*self) == *other
286                }
287            }
288        }
289    };
290
291
292    #[cfg(feature = "serde")]
293    let serde = quote!(#enumset::__internal::serde);
294
295    #[cfg(feature = "serde")]
296    let serde_ops = if info.serialize_as_list {
297        let expecting_str = format!("a list of {}", name);
298        quote! {
299            fn serialize<S: #serde::Serializer>(
300                set: #enumset::EnumSet<#name>, ser: S,
301            ) -> #core::result::Result<S::Ok, S::Error> {
302                use #serde::ser::SerializeSeq;
303                let mut seq = ser.serialize_seq(#core::prelude::v1::Some(set.len()))?;
304                for bit in set {
305                    seq.serialize_element(&bit)?;
306                }
307                seq.end()
308            }
309            fn deserialize<'de, D: #serde::Deserializer<'de>>(
310                de: D,
311            ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
312                struct Visitor;
313                impl <'de> #serde::de::Visitor<'de> for Visitor {
314                    type Value = #enumset::EnumSet<#name>;
315                    fn expecting(
316                        &self, formatter: &mut #core::fmt::Formatter,
317                    ) -> #core::fmt::Result {
318                        write!(formatter, #expecting_str)
319                    }
320                    fn visit_seq<A>(
321                        mut self, mut seq: A,
322                    ) -> #core::result::Result<Self::Value, A::Error> where
323                        A: #serde::de::SeqAccess<'de>
324                    {
325                        let mut accum = #enumset::EnumSet::<#name>::new();
326                        while let #core::prelude::v1::Some(val) = seq.next_element::<#name>()? {
327                            accum |= val;
328                        }
329                        #core::prelude::v1::Ok(accum)
330                    }
331                }
332                de.deserialize_seq(Visitor)
333            }
334        }
335    } else {
336        let serialize_repr = info.serde_repr();
337        let check_unknown = if info.serialize_deny_unknown {
338            quote! {
339                if value & !#all_variants != 0 {
340                    use #serde::de::Error;
341                    return #core::prelude::v1::Err(
342                        D::Error::custom("enumset contains unknown bits")
343                    )
344                }
345            }
346        } else {
347            quote! { }
348        };
349        quote! {
350            fn serialize<S: #serde::Serializer>(
351                set: #enumset::EnumSet<#name>, ser: S,
352            ) -> #core::result::Result<S::Ok, S::Error> {
353                #serde::Serialize::serialize(&(set.__enumset_underlying as #serialize_repr), ser)
354            }
355            fn deserialize<'de, D: #serde::Deserializer<'de>>(
356                de: D,
357            ) -> #core::result::Result<#enumset::EnumSet<#name>, D::Error> {
358                let value = <#serialize_repr as #serde::Deserialize>::deserialize(de)?;
359                #check_unknown
360                #core::prelude::v1::Ok(#enumset::EnumSet {
361                    __enumset_underlying: (value & #all_variants) as #repr,
362                })
363            }
364        }
365    };
366
367    #[cfg(not(feature = "serde"))]
368    let serde_ops = quote! { };
369
370    let is_uninhabited = info.variants.is_empty();
371    let is_zst = info.variants.len() == 1;
372    let into_impl = if is_uninhabited {
373        quote! {
374            fn enum_into_u32(self) -> u32 {
375                panic!(concat!(stringify!(#name), " is uninhabited."))
376            }
377            unsafe fn enum_from_u32(val: u32) -> Self {
378                panic!(concat!(stringify!(#name), " is uninhabited."))
379            }
380        }
381    } else if is_zst {
382        let variant = &info.variants[0].name;
383        quote! {
384            fn enum_into_u32(self) -> u32 {
385                self as u32
386            }
387            unsafe fn enum_from_u32(val: u32) -> Self {
388                #name::#variant
389            }
390        }
391    } else {
392        let variant_name: Vec<_> = info.variants.iter().map(|x| &x.name).collect();
393        let variant_value: Vec<_> = info.variants.iter().map(|x| x.variant_repr).collect();
394
395        let const_field: Vec<_> = ["IS_U8", "IS_U16", "IS_U32", "IS_U64", "IS_U128"]
396            .iter().map(|x| Ident::new(x, Span::call_site())).collect();
397        let int_type: Vec<_> = ["u8", "u16", "u32", "u64", "u128"]
398            .iter().map(|x| Ident::new(x, Span::call_site())).collect();
399
400        quote! {
401            fn enum_into_u32(self) -> u32 {
402                self as u32
403            }
404            unsafe fn enum_from_u32(val: u32) -> Self {
405                // We put these in const fields so the branches they guard aren't generated even
406                // on -O0
407                #(const #const_field: bool =
408                    #core::mem::size_of::<#name>() == #core::mem::size_of::<#int_type>();)*
409                match val {
410                    // Every valid variant value has an explicit branch. If they get optimized out,
411                    // great. If the representation has changed somehow, and they don't, oh well,
412                    // there's still no UB.
413                    #(#variant_value => #name::#variant_name,)*
414                    // Helps hint to the LLVM that this is a transmute. Note that this branch is
415                    // still unreachable.
416                    #(x if #const_field => {
417                        let x = x as #int_type;
418                        *(&x as *const _ as *const #name)
419                    })*
420                    // Default case. Sometimes causes LLVM to generate a table instead of a simple
421                    // transmute, but, oh well.
422                    _ => #core::hint::unreachable_unchecked(),
423                }
424            }
425        }
426    };
427
428    let eq_impl = if is_uninhabited {
429        quote!(panic!(concat!(stringify!(#name), " is uninhabited.")))
430    } else {
431        quote!((*self as u32) == (*other as u32))
432    };
433
434    quote! {
435        unsafe impl #enumset::__internal::EnumSetTypePrivate for #name {
436            type Repr = #repr;
437            const ALL_BITS: Self::Repr = #all_variants;
438            #into_impl
439            #serde_ops
440        }
441
442        unsafe impl #enumset::EnumSetType for #name { }
443
444        impl #core::cmp::PartialEq for #name {
445            fn eq(&self, other: &Self) -> bool {
446                #eq_impl
447            }
448        }
449        impl #core::cmp::Eq for #name { }
450        impl #core::clone::Clone for #name {
451            fn clone(&self) -> Self {
452                *self
453            }
454        }
455        impl #core::marker::Copy for #name { }
456
457        #ops
458    }
459}
460
461/// A wrapper that parses the input enum.
462#[proc_macro_derive(EnumSetType, attributes(enumset))]
463pub fn derive_enum_set_type(input: TokenStream) -> TokenStream {
464    let input: DeriveInput = parse_macro_input!(input);
465    let attrs: EnumsetAttrs = match EnumsetAttrs::from_derive_input(&input) {
466        Ok(attrs) => attrs,
467        Err(e) => return e.write_errors().into(),
468    };
469    match derive_enum_set_type_0(input, attrs) {
470        Ok(v) => v,
471        Err(e) => e.to_compile_error().into(),
472    }
473}
474fn derive_enum_set_type_0(input: DeriveInput, attrs: EnumsetAttrs) -> Result<TokenStream> {
475    if !input.generics.params.is_empty() {
476        error(
477            input.generics.span(),
478            "`#[derive(EnumSetType)]` cannot be used on enums with type parameters.",
479        )
480    } else if let Data::Enum(data) = &input.data {
481        let mut info = EnumSetInfo::new(&input, attrs);
482        for attr in &input.attrs {
483            if attr.path.is_ident(&Ident::new("repr", Span::call_site())) {
484                let meta: Ident = attr.parse_args()?;
485                info.push_explicit_repr(attr.span(), meta.to_string().as_str())?;
486            }
487        }
488        for variant in &data.variants {
489            info.push_variant(variant)?;
490        }
491        info.validate()?;
492        Ok(enum_set_type_impl(info).into())
493    } else {
494        error(input.span(), "`#[derive(EnumSetType)]` may only be used on enums")
495    }
496}