1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#![recursion_limit = "512"]

extern crate proc_macro;

use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
    bracketed,
    parse::{Parse, ParseBuffer},
    parse2, parse_macro_input, parse_quote,
    punctuated::Punctuated,
    spanned::Spanned,
    Error, ItemEnum, Path, Token,
};

#[derive(Debug)]
struct Args {
    derives: Vec<Path>,
    traits: Vec<Path>,
}

impl Parse for Args {
    fn parse(input: &ParseBuffer) -> Result<Self, Error> {
        let kvps = Punctuated::<_, Token![,]>::parse_terminated_with(input, |input| {
            let key: Path = input.parse()?;
            let key = key.segments.first().unwrap().value().ident.to_string();
            input.parse::<Token![=]>()?;
            let value;
            bracketed!(value in input);
            let value = Punctuated::<Path, Token![,]>::parse_terminated(&value)?;
            let value = value.into_iter().collect();
            Ok((key, value))
        })?;
        let mut args = Args {
            derives: Vec::new(),
            traits: Vec::new(),
        };
        for (key, value) in kvps.into_iter() {
            match &key[..] {
                "derive" => args.derives = value,
                "trait_obj" => args.traits = value,
                _ => unimplemented!("error"),
            }
        }
        Ok(args)
    }
}

#[doc(hidden)]
#[proc_macro_attribute]
pub fn tyenum(attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let Args {
        derives,
        traits,
    } = parse_macro_input!(attr as Args);
    let item: TokenStream = item.into();
    let mut tyenum: ItemEnum = parse2(item).unwrap();
    let name = &tyenum.ident;
    let mut impls = Vec::new();
    let mut tyenum_trait_object_impl = Vec::new();
    let mut tyenum_derive_variants = Vec::new();
    for v in tyenum.variants.iter_mut() {
        let ty: &dyn ToTokens;
        let ident;
        {
            let mut iter = v.fields.iter();
            ty = if let Some(f) = iter.next() {
                if let Some(f) = iter.next() {
                    return Error::new(f.span(), "maximum one field in variants allowed").to_compile_error().into();
                }
                &f.ty
            } else {
                &v.ident
            };
            ident = &v.ident;
            impls.push(quote! {
                impl From<#ty> for #name {
                    fn from(variant: #ty) -> Self {
                        #name::#ident(variant)
                    }
                }

                impl std::convert::TryFrom<#name> for #ty {
                    type Error = tyenum::TryFromTyenumError;

                    fn try_from(e: #name) -> Result<Self, tyenum::TryFromTyenumError> {
                        if let #name::#ident(variant) = e {
                            Ok(variant)
                        } else {
                            Err(tyenum::TryFromTyenumError)
                        }
                    }
                }

                impl tyenum::IsTypeOf<#name> for #ty {
                    fn is_type_of(e: &#name) -> bool {
                        if let #name::#ident(_) = e {
                            true
                        } else {
                            false
                        }
                    }
                }
            });
            tyenum_trait_object_impl.push(quote! {#name::#ident(ref mut v) => v.trait_obj()});
            tyenum_derive_variants.push(quote! {#name::#ident(v)});
        }
        *v = parse_quote!(#ident(#ty));
    }
    let tyenum_trait_object_impl_ref = &tyenum_trait_object_impl;
    let trait_name = Ident::new(&format!("{}ToTraitObject", name), Span::call_site());
    for trt in traits {
        impls.push(quote! {
            impl<'a> #name {
                fn trait_obj(&'a mut self) -> Option<&'a mut dyn #trt> {
                    match self {
                        #(#tyenum_trait_object_impl_ref),*
                    }
                }
            }

            trait #trait_name<'a, T> {
                fn trait_obj(&'a mut self) -> Option<T>;
            }

            impl<'a, I> #trait_name<'a, &'a mut dyn #trt> for I {
                default fn trait_obj(&'a mut self) -> Option<&'a mut dyn #trt> {
                    None
                }
            }

            impl<'a, I: #trt> #trait_name<'a, &'a mut dyn #trt> for I {
                fn trait_obj(&'a mut self) -> Option<&'a mut dyn #trt> {
                    Some(self)
                }
            }
        })
    }
    let tyenum_derive_variants_ref = &tyenum_derive_variants;
    for drv in derives {
        impls.push(quote! {
            impl std::ops::Deref for #name {
                type Target = dyn #drv;
                fn deref(&self) -> &Self::Target {
                    match self {
                        #(#tyenum_derive_variants_ref => v as & Self::Target),*
                    }
                }
            }

            impl std::ops::DerefMut for #name {
                fn deref_mut(&mut self) -> &mut Self::Target {
                    match self {
                        #(#tyenum_derive_variants_ref => v as &mut Self::Target),*
                    }
                }
            }
        });
    }
    quote!(
        impl #name {
            fn is<T: tyenum::IsTypeOf<#name>>(&self) -> bool {
                T::is_type_of(self)
            }
        }

        #tyenum

        #(#impls)*
    )
    .into()
}