thisenum_impl/
lib.rs

1#![doc = include_str!("../README.md")]
2// --------------------------------------------------
3// external
4// --------------------------------------------------
5use quote::{
6    quote,
7    ToTokens,
8};
9use syn::{
10    Meta,
11    Data,
12    Type,
13    DataEnum,
14    Attribute,
15    DeriveInput,
16    MetaNameValue,
17    parse_macro_input,
18};
19use unzip_n::unzip_n;
20use thiserror::Error;
21use proc_macro::TokenStream;
22
23// --------------------------------------------------
24// local
25// --------------------------------------------------
26mod prelude;
27use prelude::*;
28unzip_n!(3);
29
30#[derive(Error, Debug)]
31/// All errors that can occur while deriving [`Const`]
32/// or [`ConstEach`]
33enum Error {
34    #[error("`{0}` can only be derived for enums")]
35    DeriveForNonEnum(String),
36    #[error("Missing #[armtype = ...] attribute {0}, required for `{1}`-derived enum")]
37    MissingArmType(String, String),
38    #[error("Missing #[value = ...] attribute, expected for `{0}`-derived enum")]
39    MissingValue(String),
40    #[error("Attemping to parse non-literal attribute for `value`: not yet supported")]
41    NonLiteralValue,
42}
43
44#[proc_macro_derive(Const, attributes(value, armtype))]
45/// Add's constants to each arm of an enum
46/// 
47/// * To get the value as a reference, call the function [`<enum_name>::value`]
48/// * However, direct comparison to non-reference values are possible with
49///   [`PartialEq`]
50/// 
51/// The `#[armtype = ...]` attribute is required for this macro to function, 
52/// and must be applied to **the enum**, since all values share the same type.
53/// 
54/// All values set will return a [`&'static T`] reference. To the input type,
55/// of [`T`] AND [`&T`]. If multiple references are used (e.g. `&&T`), then
56/// the return type will be [`&'static &T`].
57/// 
58/// # Example
59/// 
60/// ```
61/// use thisenum::Const;
62/// 
63/// #[derive(Const, Debug)]
64/// #[armtype(i32)]
65/// enum MyEnum {
66///     #[value = 0]
67///     A,
68///     #[value = 1]
69///     B,
70/// }
71/// 
72/// #[derive(Const, Debug)]
73/// #[armtype(&[u8])]
74/// enum Tags {
75///     #[value = b"\x00\x01\x7f"]
76///     Key,
77///     #[value = b"\xba\x5e"]
78///     Length,
79///     #[value = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"]
80///     Data,
81/// }
82/// 
83/// fn main() {
84///     // it's prefered to use the function call to `value` 
85///     // to get a [`&'static T`] reference to the value
86///     assert_eq!(MyEnum::A.value(), &0);
87///     assert_eq!(MyEnum::B.value(), &1);
88///     assert_eq!(Tags::Key.value(), b"\x00\x01\x7f");
89///     assert_eq!(Tags::Length.value(), b"\xba\x5e");
90///     assert_eq!(Tags::Data.value(), b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f");
91/// 
92///     // can also check equality without the function call. This must compare the input 
93///     // type defined in `#[armtype = ...]`
94///     //
95///     // to use this, use the `eq` feature in `Cargo.toml`: thisenum = { version = "x", features = ["eq"] }
96///     #[cfg(feature = "eq")]
97///     assert_eq!(Tags::Length, b"\xba\x5e");
98/// }
99/// ```
100pub fn thisenum_const(input: TokenStream) -> TokenStream {
101    let name = "Const";
102    let input = parse_macro_input!(input as DeriveInput);
103    // --------------------------------------------------
104    // extract the name, variants, and values
105    // --------------------------------------------------
106    let enum_name = &input.ident;
107    let variants = match input.data {
108        Data::Enum(DataEnum { variants, .. }) => variants,
109        _ => panic!("{}", Error::DeriveForNonEnum(name.into())),
110    };
111    // --------------------------------------------------
112    // extract the type
113    // --------------------------------------------------
114    let (type_name, deref) = match get_deref_type(&input.attrs) {
115        Some((type_name, deref)) => (type_name, deref),
116        None => panic!("{}", Error::MissingArmType("applied to enum".into(), name.into())),
117    };
118    let type_name_raw = match get_type(&input.attrs) {
119        Some(type_name_raw) => type_name_raw,
120        None => panic!("{}", Error::MissingArmType("applied to enum".into(), name.into())),
121    };
122    // --------------------------------------------------
123    // get unique assigned values
124    // --------------------------------------------------
125    let values = variants
126        .iter()
127        .map(|variant| get_val(name.into(), &variant.attrs))
128        .collect::<Result<Vec<_>, _>>()
129        .unwrap();
130    let values_string = values.iter().map(|v| v.to_string()).collect::<Vec<_>>();
131    let repeated_values_string = values_string.clone().into_iter().repeated();
132    // --------------------------------------------------
133    // generate the output tokens
134    // --------------------------------------------------
135    let (
136        // #[cfg(feature = "debug")]
137        _debug_arms,
138        variant_match_arms,
139        mut variant_inv_match_arms
140    ) = variants
141        .iter()
142        .map(|variant| {
143            let variant_name = &variant.ident;
144            // ------------------------------------------------
145            // number of args in the variant
146            // ------------------------------------------------
147            // e.g.: enum Test { VariantA(i23), VariantB(String, String) }
148            // will have 1 (i23) and 2 (String, String)
149            // ------------------------------------------------
150            let num_args = match variant.fields {
151                syn::Fields::Named(syn::FieldsNamed { ref named, .. }) => named.len(),
152                syn::Fields::Unnamed(syn::FieldsUnnamed { ref unnamed, .. }) => unnamed.len(),
153                syn::Fields::Unit => 0,
154            };
155            let value = match get_val(name.into(), &variant.attrs) {
156                Ok(value) => value,
157                Err(e) => panic!("{}", e),
158            };
159            // ------------------------------------------------
160            // check if the value is unique
161            // this is used to prevent unreachable arms
162            // ------------------------------------------------
163            let val_repeated = repeated_values_string.contains(&value.to_string());
164            // ------------------------------------------------
165            // if the type input is a reference (e.g. &[u8] or &str)
166            // then the return type will be 
167            // * `&'static [u8]` or
168            // * `&'static str`
169            //
170            // otherwise, if the input is not a reference (e.g. u8 or f32)
171            // then the return type will be
172            // * `&'static u8` or
173            // * `&'static f32`
174            //
175            // as a result, need to ensure we are removing / adding
176            // the `&` symbol wherever necessary
177            // ------------------------------------------------
178            let args_tokens = match num_args {
179                0 => quote! {},
180                _ => {
181                    let args = (0..num_args).map(|_| quote! { _ });
182                    quote! { ( #(#args),* ) }
183                },
184            };
185            // ------------------------------------------------
186            // debug arms implementation
187            // ------------------------------------------------
188            let debug_arm = match get_val(name.into(), &variant.attrs) {
189                Ok(_) => quote! { #enum_name::#variant_name #args_tokens => write!(f, concat!(stringify!(#enum_name), "::", stringify!(#variant_name), ": {:?}"), self.value()), },
190                Err(e) => panic!("{}", e),
191            };
192            // ------------------------------------------------
193            // variant -> value
194            // ------------------------------------------------
195            let vma = match deref {
196                true => quote! { #enum_name::#variant_name #args_tokens => #value, },
197                false => quote! { #enum_name::#variant_name #args_tokens => &#value, },
198            };
199            // ------------------------------------------------
200            // value -> variant
201            // ------------------------------------------------
202            match (num_args, val_repeated) {
203                (0, false) => (debug_arm, vma, Some(quote! { #value => Ok(#enum_name::#variant_name), })),
204                (_, _) => (debug_arm, vma, None),
205            }
206        })
207        .into_iter()
208        .unzip_n_vec();
209    // --------------------------------------------------
210    // get the vima for repeated values
211    // --------------------------------------------------
212    let mut repeated_indices = values_string
213        .clone()
214        .into_iter()
215        .repeated_idx();
216    repeated_indices.sort_by(|a, b| b.cmp(a));
217    repeated_indices
218        .iter()
219        .for_each(|i| { variant_inv_match_arms.remove(*i); } );
220    let variant_inv_match_arms_repeated = values_string
221        .clone()
222        .into_iter()
223        .positions()
224        .iter()
225        .map(|(_, pos)| match pos.len() {
226            ..=1 => quote! {},
227            _ => {
228                let val = values[pos[0]].clone();
229                quote! { #val => Err(::thisenum::Error::UnreachableValue(format!("{:?}", #val))), }
230            }
231        })
232        .collect::<Vec<_>>();
233    // --------------------------------------------------
234    // get all the indices of variants which have nested args
235    // --------------------------------------------------
236    let arg_indices = variant_inv_match_arms
237        .iter()
238        .enumerate()
239        .filter(|(i, v)| v.is_none() && !repeated_indices.contains(&i))
240        .map(|(i, _)| i)
241        .collect::<Vec<_>>();
242    let variant_inv_match_arms_args = values
243        .clone()
244        .into_iter()
245        .zip(variants)
246        .enumerate()
247        .filter(|(i, _)| arg_indices.contains(i))
248        .map(|(_, (value, variant))| {
249            let variant_name = &variant.ident;
250            quote! { #value => Err(::thisenum::Error::UnableToReturnVariant(stringify!(#variant_name).into())), }
251        })
252        .collect::<Vec<_>>();
253    // --------------------------------------------------
254    // see deref comment above
255    // --------------------------------------------------
256    let into_impl = match deref {
257        false => quote! {
258            #[automatically_derived]
259            #[doc = concat!(" [`Into`] implementation for [`", stringify!(#enum_name), "`]")]
260            impl ::std::convert::Into<#type_name_raw> for #enum_name {
261                #[inline]
262                fn into(self) -> #type_name_raw {
263                    *self.value()
264                }
265            }
266        },
267        true => quote! { },
268    };
269    let mut expanded = quote! {
270        #[automatically_derived]
271        impl #enum_name {
272            #[inline]
273            /// Returns the value of the enum variant
274            /// defined by [`Const`]
275            /// 
276            /// # Returns
277            /// 
278            #[doc = concat!(" * [`&'static ", stringify!(#type_name), "`]")]
279            pub fn value(&self) -> &'static #type_name {
280                match self {
281                    #( #variant_match_arms )*
282                }
283            }
284        }
285        #into_impl
286    };
287
288    if cfg!(feature = "debug") {
289        expanded = quote! {
290            #expanded
291            #[automatically_derived]
292            #[doc = concat!(" [`Debug`] implementation for [`", stringify!(#enum_name), "`]")]
293            impl ::std::fmt::Debug for #enum_name {
294                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
295                    match self {
296                        #( #_debug_arms )*
297                    }
298                }
299            }
300        };
301    }
302
303    if cfg!(feature = "eq") {
304        let variant_par_eq_lhs = match deref {
305            true => quote! { &self.value() == other },
306            false => quote! { self.value() == other },
307        };
308        let variant_par_eq_rhs = match deref {
309            true => quote! { &other.value() == self },
310            false => quote! { other.value() == self },
311        };
312        expanded = quote! {
313            #expanded
314            #[automatically_derived]
315            #[doc = concat!(" [`PartialEq<", stringify!(#type_name_raw) ,">`] implementation for [`", stringify!(#enum_name), "`]")]
316            ///
317            #[doc = concat!(" This is the LHS of the [`PartialEq`] implementation between [`", stringify!(#enum_name), "`] and [`", stringify!(#type_name_raw), "`]")]
318            /// 
319            /// # Returns
320            /// 
321            /// * [`true`] if the type and the enum are equal
322            /// * [`false`] if the type and the enum are not equal
323            impl ::std::cmp::PartialEq<#type_name_raw> for #enum_name {
324                #[inline]
325                fn eq(&self, other: &#type_name_raw) -> bool {
326                    #variant_par_eq_lhs
327                }
328            }
329            #[automatically_derived]
330            #[doc = concat!(" [`PartialEq<", stringify!(#enum_name) ,">`] implementation for [`", stringify!(#type_name_raw), "`]")]
331            /// 
332            #[doc = concat!(" This is the RHS of the [`PartialEq`] implementation between [`", stringify!(#enum_name), "`] and [`", stringify!(#type_name_raw), "`]")]
333            /// 
334            /// # Returns
335            /// 
336            /// * [`true`] if the enum and the type are equal
337            /// * [`false`] if the enum and the type are not equal
338            impl ::std::cmp::PartialEq<#enum_name> for #type_name_raw {
339                #[inline]
340                fn eq(&self, other: &#enum_name) -> bool {
341                    #variant_par_eq_rhs
342                }
343            }
344        };
345    }
346
347    let variant_inv_match_arms = variant_inv_match_arms.into_iter().filter(|v| v.is_some()).map(|v| v.unwrap());
348    expanded = quote! {
349        #expanded
350        #[automatically_derived]
351        #[doc = concat!(" [`TryFrom`] implementation for [`", stringify!(#enum_name), "`]")]
352        ///
353        /// This is able to be derived since none of the Arms of the Enum had
354        /// any arguments. If that is the case, this implementation is 
355        /// non-existent.
356        /// 
357        /// # Returns
358        /// 
359        /// * [`Ok(T)`] where `T` is the enum variant
360        /// * [`Err(Error)`] if the conversion fails
361        impl ::std::convert::TryFrom<#type_name_raw> for #enum_name {
362            type Error = ::thisenum::Error;
363            #[inline]
364            fn try_from(value: #type_name_raw) -> Result<Self, Self::Error> {
365                match value {
366                    #( #variant_inv_match_arms )*
367                    #( #variant_inv_match_arms_repeated )*
368                    #( #variant_inv_match_arms_args )*
369                    _ => Err(::thisenum::Error::InvalidValue(format!("{:?}", value), stringify!(#enum_name).into())),
370                }
371            }
372        }
373    };
374    // --------------------------------------------------
375    // return
376    // --------------------------------------------------
377    TokenStream::from(expanded)
378}
379
380#[proc_macro_derive(ConstEach, attributes(value, armtype))]
381/// Add's constants of any type to each arm of an enum
382/// 
383/// To get the value, the type must be explicitly passed
384/// as a generic to [`<enum_name>::value`]. This will automatically
385/// try to convert constant to the expected type using [`std::any::Any`] 
386/// and [`downcast_ref`]. Currently [`TryFrom`] is not supported, so typing
387/// is fairly strict. Upon failure, it will return [`None`].
388/// 
389/// * To get the value as a reference, call the function [`<enum_name>::value`]
390/// * Unlike [`Const`], this macro does not enable direct comparison
391///   using [`PartialEq`] when imported using the `eq` feature.
392/// 
393/// The `#[armtype = ...]` attribute is **NOT*** required for this macro to function, 
394/// but ***CAN** be applied to ***each individual arm*** of the enum, since values
395/// are not expected to share a type. If no type is given, then the type is
396/// inferred from the literal value in the `#[value = ...]` attribute.
397/// 
398/// All values set will return a [`Option<&'static T>`] reference. To the input type,
399/// of [`T`] AND [`&T`]. If multiple references are used (e.g. `&&T`), then
400/// the return type will be [`Option<&'static &T>`].
401/// 
402/// # Example
403/// 
404/// ```
405/// use thisenum::ConstEach;
406/// 
407/// #[derive(ConstEach, Debug)]
408/// enum MyEnum {
409///     #[armtype(u8)]
410///     #[value = 0xAA]
411///     A,
412///     #[value = "test3"]
413///     B,
414/// }
415/// 
416/// #[derive(ConstEach, Debug)]
417/// enum Tags {
418///     #[value = b"\x00\x01"]
419///     Key,
420///     #[armtype(u16)]
421///     #[value = 24250]
422///     Length,
423///     #[armtype(&[u8])]
424///     #[value = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"]
425///     Data,
426/// }
427/// 
428/// fn main() {
429///     // [`ConstEach`] examples
430///     assert!(MyEnum::A.value::<u8>().is_some());
431///     assert!(MyEnum::A.value::<Vec<f32>>().is_none());
432///     assert!(MyEnum::B.value::<u8>().is_none());
433///     assert!(MyEnum::B.value::<&str>().is_some());
434///     assert!(Tags::Data.value::<&[u8]>().is_some());
435/// 
436///     // An infered type. This will be as strict as possible,
437///     // therefore [`&[u8]`] will fail but [`&[u8; 2]`] will succeed
438///     assert!(Tags::Key.value::<&[u8; 2]>().is_some());
439///     assert!(Tags::Key.value::<&[u8; 5]>().is_none());
440///     assert!(Tags::Key.value::<&[u8]>().is_none());
441///     assert!(u16::from_le_bytes(**Tags::Key.value::<&[u8; 2]>().unwrap()) == 0x0100);
442/// 
443///     // casting as anything other than the defined / inferred type will
444///     // fail, since this uses [`downcast_ref`] from [`std::any::Any`]
445///     assert!(Tags::Length.value::<u16>().is_some());
446///     assert!(Tags::Length.value::<u32>().is_none());
447///     assert!(Tags::Length.value::<u64>().is_none());
448/// 
449///     // however, can always convert to a different type
450///     // after value is successfully acquired
451///     assert!(*Tags::Length.value::<u16>().unwrap() as u32 == 24250);
452/// }
453/// ```
454pub fn thisenum_const_each(input: TokenStream) -> TokenStream {
455    let name = "ConstEach";
456    let input = parse_macro_input!(input as DeriveInput);
457    // --------------------------------------------------
458    // extract the name, variants, and values
459    // --------------------------------------------------
460    let enum_name = &input.ident;
461    let variants = match input.data {
462        Data::Enum(DataEnum { variants, .. }) => variants,
463        _ => panic!("{}", Error::DeriveForNonEnum(name.into())),
464    };
465    // --------------------------------------------------
466    // generate the output tokens
467    // --------------------------------------------------
468    let variant_code = variants.iter().map(|variant| {
469        let variant_name = &variant.ident;
470        match (get_type(&variant.attrs), get_val(name.into(), &variant.attrs)) {
471            // ------------------------------------------------
472            // if type is specified, use it
473            // ------------------------------------------------
474            (Some(typ), Ok(value)) => quote! {
475                #enum_name::#variant_name => {
476                    let val: &dyn ::std::any::Any = &(#value as #typ);
477                    val.downcast_ref::<T>()
478                },
479
480            },
481            // ------------------------------------------------
482            // no type specified, try to infer
483            // ------------------------------------------------
484            (None, Ok(value)) => quote! {
485                #enum_name::#variant_name => {
486                    let val: &dyn ::std::any::Any = &#value;
487                    val.downcast_ref::<T>()
488                },
489            },
490            // ------------------------------------------------
491            // unable to infer type
492            // ------------------------------------------------
493            (_, Err(_)) => quote! { #enum_name::#variant_name => None, },
494        }
495    });
496    // ------------------------------------------------
497    // return
498    // ------------------------------------------------
499    let expanded = quote! {
500        #[automatically_derived]
501        #[doc = concat!(" [`ConstEach`] implementation for [`", stringify!(#enum_name), "`]")]
502        impl #enum_name {
503            pub fn value<T: 'static>(&self) -> Option<&'static T> {
504                match self {
505                    #( #variant_code )*
506                    _ => None,
507                }
508            }
509        }
510    };
511    TokenStream::from(expanded)
512}
513
514/// Helper function to extract the value from a [`MetaNameValue`], aka `#[value = <value>]`
515///
516/// # Input
517///
518/// ```text
519/// #[value = <value>]
520/// ```
521///
522/// # Output
523///
524/// [`TokenStream`] containing the value `<value>`, or [`Err`] if the attribute is not present / invalid
525fn get_val(name: String, attrs: &[Attribute]) -> Result<proc_macro2::TokenStream, Error> {
526    for attr in attrs {
527        if !attr.path.is_ident("value") { continue; }
528        match attr.parse_meta() {
529            Ok(meta) => match meta {
530                Meta::NameValue(MetaNameValue { lit, .. }) => return Ok(lit.into_token_stream()),
531                Meta::List(list) => {
532                    let tokens = list.nested.iter().map(|nested_meta| {
533                        match nested_meta {
534                            syn::NestedMeta::Lit(lit) => lit.to_token_stream(),
535                            syn::NestedMeta::Meta(meta) => meta.to_token_stream(),
536                        }
537                    });
538                    return Ok(quote! { #( #tokens )* });
539                }
540                Meta::Path(_) => return Ok(meta.into_token_stream())
541            },
542            Err(_) => {
543                return Err(Error::NonLiteralValue);
544                /*
545                // Maybe for future:
546                // --------------------------------------------------
547                let elems = attr
548                    .to_token_stream()
549                    .to_string();
550                // println!("elems: {}", elems);
551                let mut elems = elems
552                    .trim()
553                    .trim_start_matches("#[")
554                    .rsplit_once("]")
555                    .unwrap()
556                    .0
557                    .split("=")
558                    .collect::<Vec<_>>();
559                // println!("elems: {:?}", elems);
560                elems.remove(0);
561                // println!("elems: {:?}", elems);
562                return Ok(elems
563                    .join("=")
564                    .trim()
565                    .parse::<proc_macro2::TokenStream>()?);
566                // --------------------------------------------------
567                */
568            },
569        }
570    }
571    Err(Error::MissingValue(name))
572}
573
574/// Helper function to extract the type from the [`Attribute`], aka `#[armtype(<type>)]`
575/// 
576/// Will indicate whether or not the type should be dereferenced or not. Useful
577/// for the [`Const`] macro
578///
579/// # Input
580///
581/// ```text
582/// #[armtype(<type>)]
583/// ```
584///
585/// # Output
586///
587/// [`None`] if the attribute is not present / invalid
588/// 
589/// Otherwise a tuple:
590/// 
591/// * 0 - [`Type`] containing the type `<type>` (already de-referenced)
592/// * 1 - An additional flag that indicates if the type has been de-referenced
593fn get_deref_type(attrs: &[Attribute]) -> Option<(Type, bool)> {
594    for attr in attrs {
595        if !attr.path.is_ident("armtype") { continue; }
596        let tokens = match attr.parse_args::<proc_macro2::TokenStream>() {
597            Ok(tokens) => tokens,
598            Err(_) => return None,
599        };
600        let deref = tokens
601            .to_string()
602            .trim()
603            .starts_with('&');
604        let tokens = match deref {
605            true => {
606                let mut tokens = tokens.into_iter();
607                let _ = tokens.next();
608                tokens.collect::<proc_macro2::TokenStream>()
609            }
610            false => tokens,
611        };
612        return match syn::parse2::<Type>(tokens).ok() {
613            Some(type_name) => Some((type_name, deref)),
614            None => None
615        }
616    }
617    None
618}
619
620/// Helper function to extract the type from the [`Attribute`], aka `#[armtype(<type>)]`
621/// 
622/// Will return the raw [`Type`]. Useful for the [`Const`] and the [`ConstEach`]
623/// macros
624///
625/// # Input
626///
627/// ```text
628/// #[armtype(<type>)]
629/// ```
630///
631/// # Output
632///
633/// [`None`] if the attribute is not present / invalid
634/// 
635/// Otherwise [`Some<Type>`] containing the type `<type>`
636fn get_type(attrs: &[Attribute]) -> Option<Type> {
637    for attr in attrs {
638        if !attr.path.is_ident("armtype") { continue; }
639        let tokens = match attr.parse_args::<proc_macro2::TokenStream>() {
640            Ok(tokens) => tokens,
641            Err(_) => return None,
642        };
643        return syn::parse2::<Type>(
644            tokens
645            .into_iter()
646            .collect::<proc_macro2::TokenStream>()
647        ).ok()
648    }
649    None
650}