1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
use proc_macro::TokenStream;

use quote::quote;
use syn::{Data, DeriveInput, Expr, Fields, parse_macro_input, punctuated::Punctuated, token::Comma, Variant};

#[proc_macro_derive(UnitEnum)]
pub fn unit_enum_derive(input: TokenStream) -> TokenStream {
    let ast = parse_macro_input!(input as DeriveInput);
    impl_unit_enum(&ast)
}

fn impl_unit_enum(ast: &DeriveInput) -> TokenStream {
    let name = &ast.ident;
    let data = match &ast.data {
        Data::Enum(data_enum) => data_enum,
        _ => panic!("UnitEnum is only defined for enums!"),
    };

    let variants = &data.variants;
    let num_variants = variants.len(); // Count the number of variants

    let name_match_arms = variants.iter().enumerate().map(|(index, variant)| {
        let variant_name = &variant.ident;
        match &variant.fields {
            Fields::Unit => quote! { #name::#variant_name => stringify!(#variant_name) },
            _ => panic!("UnitEnum only supports unit variants (no fields)"),
        }
    });

    let ordinal_match_arms = variants.iter().enumerate().map(|(index, variant)| {
        let variant_name = &variant.ident;
        match &variant.fields {
            Fields::Unit => quote! { #name::#variant_name => #index },
            _ => panic!("UnitEnum only supports unit variants (no fields)"),
        }
    });

    let from_ordinal_match_arms = variants.iter().enumerate().map(|(index, variant)| {
        let variant_name = &variant.ident;
        quote! { #index => Some(#name::#variant_name) }
    });

    let values_arms = (0..num_variants).map(|index| {
        quote! { #name::from_ordinal(#index).unwrap() }
    });

    let (discriminant_match_arms, from_discriminant_match_arms) = generate_discriminant_arms(variants, name);

    let gen = quote! {
        impl #name {
            /// Returns the name of the enum variant.
            pub fn name(&self) -> &str {
                match self {
                    #(#name_match_arms,)*
                }
            }

            /// Returns the zero-based ordinal of the enum variant.
            pub fn ordinal(&self) -> usize {
                match self {
                    #(#ordinal_match_arms,)*
                }
            }

            /// Converts a zero-based ordinal to an enum variant, if possible.
            ///
            /// Returns `None` if the ordinal is out of range.
            pub fn from_ordinal(ord: usize) -> Option<Self> {
                match ord {
                    #(#from_ordinal_match_arms,)*
                    _ => None,
                }
            }

            /// Returns the discriminant value of the enum variant.
            pub fn discriminant(&self) -> i32 {
                match self {
                    #(#discriminant_match_arms,)*
                }
            }

            /// Converts a discriminant value back to an enum variant, if possible.
            ///
            /// Returns `None` if the discriminant does not correspond to any variant.
            pub fn from_discriminant(discr: i32) -> Option<Self> {
                match discr {
                    #(#from_discriminant_match_arms,)*
                    _ => None,
                }
            }

            /// Returns the total number of variants in the enum.
            pub fn len() -> usize {
                #num_variants
            }

            /// Returns an iterator over all variants of the enum.
            ///
            /// This method creates an iterator that yields each variant in
            /// definition order, starting from the first variant.
            pub fn values() -> impl Iterator<Item = Self> {
                vec![#(#values_arms,)*].into_iter()

            }
        }
    };
    gen.into()
}

fn generate_discriminant_arms(variants: &Punctuated<Variant, Comma>, name: &syn::Ident) -> (Vec<proc_macro2::TokenStream>, Vec<proc_macro2::TokenStream>) {
    let mut last_discriminant: Option<Expr> = None;
    let discriminant_match_arms = variants.iter().map(|variant| {
        let variant_name = &variant.ident;
        let discriminant = variant.discriminant.as_ref().map(|(_, expr)| expr.clone())
            .or_else(|| last_discriminant.clone()).unwrap_or_else(|| syn::parse_quote! { 0 });
        last_discriminant = Some(syn::parse_quote! { #discriminant + 1 });

        quote! { #name::#variant_name => #discriminant as i32 }
    }).collect::<Vec<_>>();

    let from_discriminant_match_arms = variants.iter().map(|variant| {
        let variant_name = &variant.ident;
        let discriminant = variant.discriminant.as_ref().map(|(_, expr)| expr.clone())
            .or_else(|| last_discriminant.clone()).unwrap_or_else(|| syn::parse_quote! { 0 });

        quote! { x if x == #discriminant as i32 => Some(#name::#variant_name) }
    }).collect::<Vec<_>>();

    (discriminant_match_arms, from_discriminant_match_arms)
}