thisenum_impl/lib.rs
1#![doc = include_str!("../README.md")]
2// --------------------------------------------------
3// external
4// --------------------------------------------------
5use quote::{
6 quote,
7 ToTokens,
8};
9use syn::{
10 Meta,
11 Data,
12 Type,
13 DataEnum,
14 Attribute,
15 DeriveInput,
16 MetaNameValue,
17 parse_macro_input,
18};
19use unzip_n::unzip_n;
20use thiserror::Error;
21use proc_macro::TokenStream;
22
23// --------------------------------------------------
24// local
25// --------------------------------------------------
26mod prelude;
27use prelude::*;
28unzip_n!(3);
29
30#[derive(Error, Debug)]
31/// All errors that can occur while deriving [`Const`]
32/// or [`ConstEach`]
33enum Error {
34 #[error("`{0}` can only be derived for enums")]
35 DeriveForNonEnum(String),
36 #[error("Missing #[armtype = ...] attribute {0}, required for `{1}`-derived enum")]
37 MissingArmType(String, String),
38 #[error("Missing #[value = ...] attribute, expected for `{0}`-derived enum")]
39 MissingValue(String),
40 #[error("Attemping to parse non-literal attribute for `value`: not yet supported")]
41 NonLiteralValue,
42}
43
44#[proc_macro_derive(Const, attributes(value, armtype))]
45/// Add's constants to each arm of an enum
46///
47/// * To get the value as a reference, call the function [`<enum_name>::value`]
48/// * However, direct comparison to non-reference values are possible with
49/// [`PartialEq`]
50///
51/// The `#[armtype = ...]` attribute is required for this macro to function,
52/// and must be applied to **the enum**, since all values share the same type.
53///
54/// All values set will return a [`&'static T`] reference. To the input type,
55/// of [`T`] AND [`&T`]. If multiple references are used (e.g. `&&T`), then
56/// the return type will be [`&'static &T`].
57///
58/// # Example
59///
60/// ```
61/// use thisenum::Const;
62///
63/// #[derive(Const, Debug)]
64/// #[armtype(i32)]
65/// enum MyEnum {
66/// #[value = 0]
67/// A,
68/// #[value = 1]
69/// B,
70/// }
71///
72/// #[derive(Const, Debug)]
73/// #[armtype(&[u8])]
74/// enum Tags {
75/// #[value = b"\x00\x01\x7f"]
76/// Key,
77/// #[value = b"\xba\x5e"]
78/// Length,
79/// #[value = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"]
80/// Data,
81/// }
82///
83/// fn main() {
84/// // it's prefered to use the function call to `value`
85/// // to get a [`&'static T`] reference to the value
86/// assert_eq!(MyEnum::A.value(), &0);
87/// assert_eq!(MyEnum::B.value(), &1);
88/// assert_eq!(Tags::Key.value(), b"\x00\x01\x7f");
89/// assert_eq!(Tags::Length.value(), b"\xba\x5e");
90/// assert_eq!(Tags::Data.value(), b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f");
91///
92/// // can also check equality without the function call. This must compare the input
93/// // type defined in `#[armtype = ...]`
94/// //
95/// // to use this, use the `eq` feature in `Cargo.toml`: thisenum = { version = "x", features = ["eq"] }
96/// #[cfg(feature = "eq")]
97/// assert_eq!(Tags::Length, b"\xba\x5e");
98/// }
99/// ```
100pub fn thisenum_const(input: TokenStream) -> TokenStream {
101 let name = "Const";
102 let input = parse_macro_input!(input as DeriveInput);
103 // --------------------------------------------------
104 // extract the name, variants, and values
105 // --------------------------------------------------
106 let enum_name = &input.ident;
107 let variants = match input.data {
108 Data::Enum(DataEnum { variants, .. }) => variants,
109 _ => panic!("{}", Error::DeriveForNonEnum(name.into())),
110 };
111 // --------------------------------------------------
112 // extract the type
113 // --------------------------------------------------
114 let (type_name, deref) = match get_deref_type(&input.attrs) {
115 Some((type_name, deref)) => (type_name, deref),
116 None => panic!("{}", Error::MissingArmType("applied to enum".into(), name.into())),
117 };
118 let type_name_raw = match get_type(&input.attrs) {
119 Some(type_name_raw) => type_name_raw,
120 None => panic!("{}", Error::MissingArmType("applied to enum".into(), name.into())),
121 };
122 // --------------------------------------------------
123 // get unique assigned values
124 // --------------------------------------------------
125 let values = variants
126 .iter()
127 .map(|variant| get_val(name.into(), &variant.attrs))
128 .collect::<Result<Vec<_>, _>>()
129 .unwrap();
130 let values_string = values.iter().map(|v| v.to_string()).collect::<Vec<_>>();
131 let repeated_values_string = values_string.clone().into_iter().repeated();
132 // --------------------------------------------------
133 // generate the output tokens
134 // --------------------------------------------------
135 let (
136 // #[cfg(feature = "debug")]
137 _debug_arms,
138 variant_match_arms,
139 mut variant_inv_match_arms
140 ) = variants
141 .iter()
142 .map(|variant| {
143 let variant_name = &variant.ident;
144 // ------------------------------------------------
145 // number of args in the variant
146 // ------------------------------------------------
147 // e.g.: enum Test { VariantA(i23), VariantB(String, String) }
148 // will have 1 (i23) and 2 (String, String)
149 // ------------------------------------------------
150 let num_args = match variant.fields {
151 syn::Fields::Named(syn::FieldsNamed { ref named, .. }) => named.len(),
152 syn::Fields::Unnamed(syn::FieldsUnnamed { ref unnamed, .. }) => unnamed.len(),
153 syn::Fields::Unit => 0,
154 };
155 let value = match get_val(name.into(), &variant.attrs) {
156 Ok(value) => value,
157 Err(e) => panic!("{}", e),
158 };
159 // ------------------------------------------------
160 // check if the value is unique
161 // this is used to prevent unreachable arms
162 // ------------------------------------------------
163 let val_repeated = repeated_values_string.contains(&value.to_string());
164 // ------------------------------------------------
165 // if the type input is a reference (e.g. &[u8] or &str)
166 // then the return type will be
167 // * `&'static [u8]` or
168 // * `&'static str`
169 //
170 // otherwise, if the input is not a reference (e.g. u8 or f32)
171 // then the return type will be
172 // * `&'static u8` or
173 // * `&'static f32`
174 //
175 // as a result, need to ensure we are removing / adding
176 // the `&` symbol wherever necessary
177 // ------------------------------------------------
178 let args_tokens = match num_args {
179 0 => quote! {},
180 _ => {
181 let args = (0..num_args).map(|_| quote! { _ });
182 quote! { ( #(#args),* ) }
183 },
184 };
185 // ------------------------------------------------
186 // debug arms implementation
187 // ------------------------------------------------
188 let debug_arm = match get_val(name.into(), &variant.attrs) {
189 Ok(_) => quote! { #enum_name::#variant_name #args_tokens => write!(f, concat!(stringify!(#enum_name), "::", stringify!(#variant_name), ": {:?}"), self.value()), },
190 Err(e) => panic!("{}", e),
191 };
192 // ------------------------------------------------
193 // variant -> value
194 // ------------------------------------------------
195 let vma = match deref {
196 true => quote! { #enum_name::#variant_name #args_tokens => #value, },
197 false => quote! { #enum_name::#variant_name #args_tokens => &#value, },
198 };
199 // ------------------------------------------------
200 // value -> variant
201 // ------------------------------------------------
202 match (num_args, val_repeated) {
203 (0, false) => (debug_arm, vma, Some(quote! { #value => Ok(#enum_name::#variant_name), })),
204 (_, _) => (debug_arm, vma, None),
205 }
206 })
207 .into_iter()
208 .unzip_n_vec();
209 // --------------------------------------------------
210 // get the vima for repeated values
211 // --------------------------------------------------
212 let mut repeated_indices = values_string
213 .clone()
214 .into_iter()
215 .repeated_idx();
216 repeated_indices.sort_by(|a, b| b.cmp(a));
217 repeated_indices
218 .iter()
219 .for_each(|i| { variant_inv_match_arms.remove(*i); } );
220 let variant_inv_match_arms_repeated = values_string
221 .clone()
222 .into_iter()
223 .positions()
224 .iter()
225 .map(|(_, pos)| match pos.len() {
226 ..=1 => quote! {},
227 _ => {
228 let val = values[pos[0]].clone();
229 quote! { #val => Err(::thisenum::Error::UnreachableValue(format!("{:?}", #val))), }
230 }
231 })
232 .collect::<Vec<_>>();
233 // --------------------------------------------------
234 // get all the indices of variants which have nested args
235 // --------------------------------------------------
236 let arg_indices = variant_inv_match_arms
237 .iter()
238 .enumerate()
239 .filter(|(i, v)| v.is_none() && !repeated_indices.contains(&i))
240 .map(|(i, _)| i)
241 .collect::<Vec<_>>();
242 let variant_inv_match_arms_args = values
243 .clone()
244 .into_iter()
245 .zip(variants)
246 .enumerate()
247 .filter(|(i, _)| arg_indices.contains(i))
248 .map(|(_, (value, variant))| {
249 let variant_name = &variant.ident;
250 quote! { #value => Err(::thisenum::Error::UnableToReturnVariant(stringify!(#variant_name).into())), }
251 })
252 .collect::<Vec<_>>();
253 // --------------------------------------------------
254 // see deref comment above
255 // --------------------------------------------------
256 let into_impl = match deref {
257 false => quote! {
258 #[automatically_derived]
259 #[doc = concat!(" [`Into`] implementation for [`", stringify!(#enum_name), "`]")]
260 impl ::std::convert::Into<#type_name_raw> for #enum_name {
261 #[inline]
262 fn into(self) -> #type_name_raw {
263 *self.value()
264 }
265 }
266 },
267 true => quote! { },
268 };
269 let mut expanded = quote! {
270 #[automatically_derived]
271 impl #enum_name {
272 #[inline]
273 /// Returns the value of the enum variant
274 /// defined by [`Const`]
275 ///
276 /// # Returns
277 ///
278 #[doc = concat!(" * [`&'static ", stringify!(#type_name), "`]")]
279 pub fn value(&self) -> &'static #type_name {
280 match self {
281 #( #variant_match_arms )*
282 }
283 }
284 }
285 #into_impl
286 };
287
288 if cfg!(feature = "debug") {
289 expanded = quote! {
290 #expanded
291 #[automatically_derived]
292 #[doc = concat!(" [`Debug`] implementation for [`", stringify!(#enum_name), "`]")]
293 impl ::std::fmt::Debug for #enum_name {
294 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
295 match self {
296 #( #_debug_arms )*
297 }
298 }
299 }
300 };
301 }
302
303 if cfg!(feature = "eq") {
304 let variant_par_eq_lhs = match deref {
305 true => quote! { &self.value() == other },
306 false => quote! { self.value() == other },
307 };
308 let variant_par_eq_rhs = match deref {
309 true => quote! { &other.value() == self },
310 false => quote! { other.value() == self },
311 };
312 expanded = quote! {
313 #expanded
314 #[automatically_derived]
315 #[doc = concat!(" [`PartialEq<", stringify!(#type_name_raw) ,">`] implementation for [`", stringify!(#enum_name), "`]")]
316 ///
317 #[doc = concat!(" This is the LHS of the [`PartialEq`] implementation between [`", stringify!(#enum_name), "`] and [`", stringify!(#type_name_raw), "`]")]
318 ///
319 /// # Returns
320 ///
321 /// * [`true`] if the type and the enum are equal
322 /// * [`false`] if the type and the enum are not equal
323 impl ::std::cmp::PartialEq<#type_name_raw> for #enum_name {
324 #[inline]
325 fn eq(&self, other: &#type_name_raw) -> bool {
326 #variant_par_eq_lhs
327 }
328 }
329 #[automatically_derived]
330 #[doc = concat!(" [`PartialEq<", stringify!(#enum_name) ,">`] implementation for [`", stringify!(#type_name_raw), "`]")]
331 ///
332 #[doc = concat!(" This is the RHS of the [`PartialEq`] implementation between [`", stringify!(#enum_name), "`] and [`", stringify!(#type_name_raw), "`]")]
333 ///
334 /// # Returns
335 ///
336 /// * [`true`] if the enum and the type are equal
337 /// * [`false`] if the enum and the type are not equal
338 impl ::std::cmp::PartialEq<#enum_name> for #type_name_raw {
339 #[inline]
340 fn eq(&self, other: &#enum_name) -> bool {
341 #variant_par_eq_rhs
342 }
343 }
344 };
345 }
346
347 let variant_inv_match_arms = variant_inv_match_arms.into_iter().filter(|v| v.is_some()).map(|v| v.unwrap());
348 expanded = quote! {
349 #expanded
350 #[automatically_derived]
351 #[doc = concat!(" [`TryFrom`] implementation for [`", stringify!(#enum_name), "`]")]
352 ///
353 /// This is able to be derived since none of the Arms of the Enum had
354 /// any arguments. If that is the case, this implementation is
355 /// non-existent.
356 ///
357 /// # Returns
358 ///
359 /// * [`Ok(T)`] where `T` is the enum variant
360 /// * [`Err(Error)`] if the conversion fails
361 impl ::std::convert::TryFrom<#type_name_raw> for #enum_name {
362 type Error = ::thisenum::Error;
363 #[inline]
364 fn try_from(value: #type_name_raw) -> Result<Self, Self::Error> {
365 match value {
366 #( #variant_inv_match_arms )*
367 #( #variant_inv_match_arms_repeated )*
368 #( #variant_inv_match_arms_args )*
369 _ => Err(::thisenum::Error::InvalidValue(format!("{:?}", value), stringify!(#enum_name).into())),
370 }
371 }
372 }
373 };
374 // --------------------------------------------------
375 // return
376 // --------------------------------------------------
377 TokenStream::from(expanded)
378}
379
380#[proc_macro_derive(ConstEach, attributes(value, armtype))]
381/// Add's constants of any type to each arm of an enum
382///
383/// To get the value, the type must be explicitly passed
384/// as a generic to [`<enum_name>::value`]. This will automatically
385/// try to convert constant to the expected type using [`std::any::Any`]
386/// and [`downcast_ref`]. Currently [`TryFrom`] is not supported, so typing
387/// is fairly strict. Upon failure, it will return [`None`].
388///
389/// * To get the value as a reference, call the function [`<enum_name>::value`]
390/// * Unlike [`Const`], this macro does not enable direct comparison
391/// using [`PartialEq`] when imported using the `eq` feature.
392///
393/// The `#[armtype = ...]` attribute is **NOT*** required for this macro to function,
394/// but ***CAN** be applied to ***each individual arm*** of the enum, since values
395/// are not expected to share a type. If no type is given, then the type is
396/// inferred from the literal value in the `#[value = ...]` attribute.
397///
398/// All values set will return a [`Option<&'static T>`] reference. To the input type,
399/// of [`T`] AND [`&T`]. If multiple references are used (e.g. `&&T`), then
400/// the return type will be [`Option<&'static &T>`].
401///
402/// # Example
403///
404/// ```
405/// use thisenum::ConstEach;
406///
407/// #[derive(ConstEach, Debug)]
408/// enum MyEnum {
409/// #[armtype(u8)]
410/// #[value = 0xAA]
411/// A,
412/// #[value = "test3"]
413/// B,
414/// }
415///
416/// #[derive(ConstEach, Debug)]
417/// enum Tags {
418/// #[value = b"\x00\x01"]
419/// Key,
420/// #[armtype(u16)]
421/// #[value = 24250]
422/// Length,
423/// #[armtype(&[u8])]
424/// #[value = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"]
425/// Data,
426/// }
427///
428/// fn main() {
429/// // [`ConstEach`] examples
430/// assert!(MyEnum::A.value::<u8>().is_some());
431/// assert!(MyEnum::A.value::<Vec<f32>>().is_none());
432/// assert!(MyEnum::B.value::<u8>().is_none());
433/// assert!(MyEnum::B.value::<&str>().is_some());
434/// assert!(Tags::Data.value::<&[u8]>().is_some());
435///
436/// // An infered type. This will be as strict as possible,
437/// // therefore [`&[u8]`] will fail but [`&[u8; 2]`] will succeed
438/// assert!(Tags::Key.value::<&[u8; 2]>().is_some());
439/// assert!(Tags::Key.value::<&[u8; 5]>().is_none());
440/// assert!(Tags::Key.value::<&[u8]>().is_none());
441/// assert!(u16::from_le_bytes(**Tags::Key.value::<&[u8; 2]>().unwrap()) == 0x0100);
442///
443/// // casting as anything other than the defined / inferred type will
444/// // fail, since this uses [`downcast_ref`] from [`std::any::Any`]
445/// assert!(Tags::Length.value::<u16>().is_some());
446/// assert!(Tags::Length.value::<u32>().is_none());
447/// assert!(Tags::Length.value::<u64>().is_none());
448///
449/// // however, can always convert to a different type
450/// // after value is successfully acquired
451/// assert!(*Tags::Length.value::<u16>().unwrap() as u32 == 24250);
452/// }
453/// ```
454pub fn thisenum_const_each(input: TokenStream) -> TokenStream {
455 let name = "ConstEach";
456 let input = parse_macro_input!(input as DeriveInput);
457 // --------------------------------------------------
458 // extract the name, variants, and values
459 // --------------------------------------------------
460 let enum_name = &input.ident;
461 let variants = match input.data {
462 Data::Enum(DataEnum { variants, .. }) => variants,
463 _ => panic!("{}", Error::DeriveForNonEnum(name.into())),
464 };
465 // --------------------------------------------------
466 // generate the output tokens
467 // --------------------------------------------------
468 let variant_code = variants.iter().map(|variant| {
469 let variant_name = &variant.ident;
470 match (get_type(&variant.attrs), get_val(name.into(), &variant.attrs)) {
471 // ------------------------------------------------
472 // if type is specified, use it
473 // ------------------------------------------------
474 (Some(typ), Ok(value)) => quote! {
475 #enum_name::#variant_name => {
476 let val: &dyn ::std::any::Any = &(#value as #typ);
477 val.downcast_ref::<T>()
478 },
479
480 },
481 // ------------------------------------------------
482 // no type specified, try to infer
483 // ------------------------------------------------
484 (None, Ok(value)) => quote! {
485 #enum_name::#variant_name => {
486 let val: &dyn ::std::any::Any = &#value;
487 val.downcast_ref::<T>()
488 },
489 },
490 // ------------------------------------------------
491 // unable to infer type
492 // ------------------------------------------------
493 (_, Err(_)) => quote! { #enum_name::#variant_name => None, },
494 }
495 });
496 // ------------------------------------------------
497 // return
498 // ------------------------------------------------
499 let expanded = quote! {
500 #[automatically_derived]
501 #[doc = concat!(" [`ConstEach`] implementation for [`", stringify!(#enum_name), "`]")]
502 impl #enum_name {
503 pub fn value<T: 'static>(&self) -> Option<&'static T> {
504 match self {
505 #( #variant_code )*
506 _ => None,
507 }
508 }
509 }
510 };
511 TokenStream::from(expanded)
512}
513
514/// Helper function to extract the value from a [`MetaNameValue`], aka `#[value = <value>]`
515///
516/// # Input
517///
518/// ```text
519/// #[value = <value>]
520/// ```
521///
522/// # Output
523///
524/// [`TokenStream`] containing the value `<value>`, or [`Err`] if the attribute is not present / invalid
525fn get_val(name: String, attrs: &[Attribute]) -> Result<proc_macro2::TokenStream, Error> {
526 for attr in attrs {
527 if !attr.path.is_ident("value") { continue; }
528 match attr.parse_meta() {
529 Ok(meta) => match meta {
530 Meta::NameValue(MetaNameValue { lit, .. }) => return Ok(lit.into_token_stream()),
531 Meta::List(list) => {
532 let tokens = list.nested.iter().map(|nested_meta| {
533 match nested_meta {
534 syn::NestedMeta::Lit(lit) => lit.to_token_stream(),
535 syn::NestedMeta::Meta(meta) => meta.to_token_stream(),
536 }
537 });
538 return Ok(quote! { #( #tokens )* });
539 }
540 Meta::Path(_) => return Ok(meta.into_token_stream())
541 },
542 Err(_) => {
543 return Err(Error::NonLiteralValue);
544 /*
545 // Maybe for future:
546 // --------------------------------------------------
547 let elems = attr
548 .to_token_stream()
549 .to_string();
550 // println!("elems: {}", elems);
551 let mut elems = elems
552 .trim()
553 .trim_start_matches("#[")
554 .rsplit_once("]")
555 .unwrap()
556 .0
557 .split("=")
558 .collect::<Vec<_>>();
559 // println!("elems: {:?}", elems);
560 elems.remove(0);
561 // println!("elems: {:?}", elems);
562 return Ok(elems
563 .join("=")
564 .trim()
565 .parse::<proc_macro2::TokenStream>()?);
566 // --------------------------------------------------
567 */
568 },
569 }
570 }
571 Err(Error::MissingValue(name))
572}
573
574/// Helper function to extract the type from the [`Attribute`], aka `#[armtype(<type>)]`
575///
576/// Will indicate whether or not the type should be dereferenced or not. Useful
577/// for the [`Const`] macro
578///
579/// # Input
580///
581/// ```text
582/// #[armtype(<type>)]
583/// ```
584///
585/// # Output
586///
587/// [`None`] if the attribute is not present / invalid
588///
589/// Otherwise a tuple:
590///
591/// * 0 - [`Type`] containing the type `<type>` (already de-referenced)
592/// * 1 - An additional flag that indicates if the type has been de-referenced
593fn get_deref_type(attrs: &[Attribute]) -> Option<(Type, bool)> {
594 for attr in attrs {
595 if !attr.path.is_ident("armtype") { continue; }
596 let tokens = match attr.parse_args::<proc_macro2::TokenStream>() {
597 Ok(tokens) => tokens,
598 Err(_) => return None,
599 };
600 let deref = tokens
601 .to_string()
602 .trim()
603 .starts_with('&');
604 let tokens = match deref {
605 true => {
606 let mut tokens = tokens.into_iter();
607 let _ = tokens.next();
608 tokens.collect::<proc_macro2::TokenStream>()
609 }
610 false => tokens,
611 };
612 return match syn::parse2::<Type>(tokens).ok() {
613 Some(type_name) => Some((type_name, deref)),
614 None => None
615 }
616 }
617 None
618}
619
620/// Helper function to extract the type from the [`Attribute`], aka `#[armtype(<type>)]`
621///
622/// Will return the raw [`Type`]. Useful for the [`Const`] and the [`ConstEach`]
623/// macros
624///
625/// # Input
626///
627/// ```text
628/// #[armtype(<type>)]
629/// ```
630///
631/// # Output
632///
633/// [`None`] if the attribute is not present / invalid
634///
635/// Otherwise [`Some<Type>`] containing the type `<type>`
636fn get_type(attrs: &[Attribute]) -> Option<Type> {
637 for attr in attrs {
638 if !attr.path.is_ident("armtype") { continue; }
639 let tokens = match attr.parse_args::<proc_macro2::TokenStream>() {
640 Ok(tokens) => tokens,
641 Err(_) => return None,
642 };
643 return syn::parse2::<Type>(
644 tokens
645 .into_iter()
646 .collect::<proc_macro2::TokenStream>()
647 ).ok()
648 }
649 None
650}