unit_enum/
lib.rs

1#![doc = include_str!("lib.md")]
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Error, Expr, Fields
6          , Type, Variant};
7
8/// Derives the `UnitEnum` trait for an enum.
9///
10/// This macro can be used on enums with unit variants (no fields) and optionally one "other" variant
11/// that can hold arbitrary discriminant values.
12///
13/// # Attributes
14/// - `#[repr(type)]`: Optional for regular enums, defaults to i32. Required when using an "other" variant.
15/// - `#[unit_enum(other)]`: Marks a variant as the catch-all for undefined discriminant values.
16///   The type of this variant must match the repr type.
17///
18/// # Requirements
19/// - The enum must contain only unit variants, except for one optional "other" variant
20/// - The "other" variant, if present, must:
21///   - Be marked with `#[unit_enum(other)]`
22///   - Have exactly one unnamed field matching the repr type
23///   - Be the only variant with the "other" attribute
24///   - Have a matching `#[repr(type)]` attribute
25///
26/// # Examples
27///
28/// Basic usage with unit variants (repr is optional):
29/// ```rust
30/// # use unit_enum::UnitEnum;
31/// #[derive(UnitEnum)]
32/// enum Example {
33///     A,
34///     B = 10,
35///     C,
36/// }
37/// ```
38///
39/// Usage with explicit repr:
40/// ```rust
41/// # use unit_enum::UnitEnum;
42/// #[derive(UnitEnum)]
43/// #[repr(u16)]
44/// enum Color {
45///     Red = 10,
46///     Green,
47///     Blue = 45654,
48/// }
49/// ```
50///
51/// Usage with an "other" variant (repr required):
52/// ```rust
53/// # use unit_enum::UnitEnum;
54/// #[derive(UnitEnum)]
55/// #[repr(u16)]
56/// enum Status {
57///     Active = 1,
58///     Inactive = 2,
59///     #[unit_enum(other)]
60///     Unknown(u16),  // type must match repr
61/// }
62/// ```
63#[proc_macro_derive(UnitEnum, attributes(unit_enum))]
64pub fn unit_enum_derive(input: TokenStream) -> TokenStream {
65    let ast = parse_macro_input!(input as DeriveInput);
66
67    match validate_and_process(&ast) {
68        Ok((discriminant_type, unit_variants, other_variant)) => {
69            impl_unit_enum(&ast, &discriminant_type, &unit_variants, other_variant)
70        }
71        Err(e) => e.to_compile_error().into(),
72    }
73}
74
75struct ValidationResult<'a> {
76    unit_variants: Vec<&'a Variant>,
77    other_variant: Option<(&'a Variant, Type)>,
78}
79
80fn validate_and_process(ast: &DeriveInput) -> Result<(Type, Vec<&Variant>, Option<(&Variant, Type)>), Error> {
81    // Get discriminant type from #[repr] attribute
82    let discriminant_type = get_discriminant_type(ast)?;
83
84    let data_enum = match &ast.data {
85        Data::Enum(data_enum) => data_enum,
86        _ => return Err(Error::new_spanned(ast, "UnitEnum can only be derived for enums")),
87    };
88
89    let mut validation = ValidationResult {
90        unit_variants: Vec::new(),
91        other_variant: None,
92    };
93
94    // Validate each variant
95    for variant in &data_enum.variants {
96        match &variant.fields {
97            Fields::Unit => {
98                if has_unit_enum_attr(variant) {
99                    return Err(Error::new_spanned(variant,
100                                                  "Unit variants cannot have #[unit_enum] attributes"));
101                }
102                validation.unit_variants.push(variant);
103            }
104            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
105                if has_unit_enum_other_attr(variant) {
106                    if validation.other_variant.is_some() {
107                        return Err(Error::new_spanned(variant,
108                                                      "Multiple #[unit_enum(other)] variants found. Only one is allowed"));
109                    }
110                    validation.other_variant = Some((variant, fields.unnamed[0].ty.clone()));
111                } else {
112                    return Err(Error::new_spanned(variant,
113                                                  "Non-unit variant must be marked with #[unit_enum(other)] to be used as the catch-all variant"));
114                }
115            }
116            _ => return Err(Error::new_spanned(variant,
117                                               "Invalid variant. UnitEnum only supports unit variants and a single tuple variant marked with #[unit_enum(other)]")),
118        }
119    }
120
121    Ok((discriminant_type, validation.unit_variants, validation.other_variant))
122}
123
124fn get_discriminant_type(ast: &DeriveInput) -> Result<Type, Error> {
125    ast.attrs.iter()
126        .find(|attr| attr.path().is_ident("repr"))
127        .map_or(Ok(syn::parse_quote!(i32)), |attr| {
128            attr.parse_args::<Type>()
129                .map_err(|_| Error::new_spanned(attr, "Invalid repr attribute"))
130        })
131}
132
133fn has_unit_enum_attr(variant: &Variant) -> bool {
134    variant.attrs.iter().any(|attr| attr.path().is_ident("unit_enum"))
135}
136
137fn has_unit_enum_other_attr(variant: &Variant) -> bool {
138    variant.attrs.iter().any(|attr| {
139        attr.path().is_ident("unit_enum") &&
140            attr.parse_nested_meta(|meta| {
141                if meta.path.is_ident("other") {
142                    Ok(())
143                } else {
144                    Err(meta.error("Invalid unit_enum attribute"))
145                }
146            }).is_ok()
147    })
148}
149
150fn compute_discriminants(variants: &[&Variant]) -> Vec<Expr> {
151    let mut discriminants = Vec::with_capacity(variants.len());
152    let mut last_discriminant: Option<Expr> = None;
153
154    for variant in variants {
155        let discriminant = variant.discriminant.as_ref().map(|(_, expr)| expr.clone())
156            .or_else(|| {
157                last_discriminant.clone().map(|expr| syn::parse_quote! { #expr + 1 })
158            })
159            .unwrap_or_else(|| syn::parse_quote! { 0 });
160
161        discriminants.push(discriminant.clone());
162        last_discriminant = Some(discriminant);
163    }
164
165    discriminants
166}
167
168fn impl_unit_enum(
169    ast: &DeriveInput,
170    discriminant_type: &Type,
171    unit_variants: &[&Variant],
172    other_variant: Option<(&Variant, Type)>,
173) -> TokenStream {
174    let name = &ast.ident;
175    let num_variants = unit_variants.len();
176    let discriminants = compute_discriminants(unit_variants);
177
178    let name_impl = generate_name_impl(name, unit_variants, &other_variant);
179    let ordinal_impl = generate_ordinal_impl(name, unit_variants, &other_variant, num_variants);
180    let from_ordinal_impl = generate_from_ordinal_impl(name, unit_variants);
181    let discriminant_impl = generate_discriminant_impl(name, unit_variants, &other_variant, discriminant_type, &discriminants);
182    let from_discriminant_impl = generate_from_discriminant_impl(name, unit_variants, &other_variant, discriminant_type, &discriminants);
183    let values_impl = generate_values_impl(name, unit_variants, &discriminants, &other_variant);
184
185    quote! {
186        impl #name {
187            #name_impl
188
189            #ordinal_impl
190
191            #from_ordinal_impl
192
193            #discriminant_impl
194
195            #from_discriminant_impl
196
197            /// Returns the total number of unit variants in the enum (excluding the "other" variant if present).
198            ///
199            /// # Examples
200            ///
201            /// ```ignore
202            /// # use unit_enum::UnitEnum;
203            /// #[derive(UnitEnum)]
204            /// enum Example {
205            ///     A,
206            ///     B,
207            ///     #[unit_enum(other)]
208            ///     Other(i32),
209            /// }
210            ///
211            /// assert_eq!(Example::len(), 2);
212            /// ```
213            pub fn len() -> usize {
214                #num_variants
215            }
216
217            #values_impl
218        }
219    }.into()
220}
221
222fn generate_name_impl(
223    name: &syn::Ident,
224    unit_variants: &[&Variant],
225    other_variant: &Option<(&Variant, Type)>,
226) -> proc_macro2::TokenStream {
227    let unit_match_arms = unit_variants.iter().map(|variant| {
228        let variant_name = &variant.ident;
229        quote! { #name::#variant_name => stringify!(#variant_name) }
230    });
231
232    let other_arm = other_variant.as_ref().map(|(variant, _)| {
233        let variant_name = &variant.ident;
234        quote! { #name::#variant_name(_) => stringify!(#variant_name) }
235    });
236
237    quote! {
238        /// Returns the name of the enum variant as a string.
239        ///
240        /// # Examples
241        ///
242        /// ```ignore
243        /// # use unit_enum::UnitEnum;
244        /// #[derive(UnitEnum)]
245        /// enum Example {
246        ///     A,
247        ///     B = 10,
248        ///     C,
249        /// }
250        ///
251        /// assert_eq!(Example::A.name(), "A");
252        /// assert_eq!(Example::B.name(), "B");
253        /// assert_eq!(Example::C.name(), "C");
254        /// ```
255        pub fn name(&self) -> &str {
256            match self {
257                #(#unit_match_arms,)*
258                #other_arm
259            }
260        }
261    }
262}
263
264fn generate_ordinal_impl(
265    name: &syn::Ident,
266    unit_variants: &[&Variant],
267    other_variant: &Option<(&Variant, Type)>,
268    num_variants: usize,
269) -> proc_macro2::TokenStream {
270    let unit_match_arms = unit_variants.iter().enumerate().map(|(index, variant)| {
271        let variant_name = &variant.ident;
272        quote! { #name::#variant_name => #index }
273    });
274
275    let other_arm = other_variant.as_ref().map(|(variant, _)| {
276        let variant_name = &variant.ident;
277        quote! { #name::#variant_name(_) => #num_variants }
278    });
279
280    quote! {
281        /// Returns the zero-based ordinal of the enum variant.
282        ///
283        /// For enums with an "other" variant, it returns the position after all unit variants.
284        ///
285        /// # Examples
286        ///
287        /// ```ignore
288        /// # use unit_enum::UnitEnum;
289        /// #[derive(UnitEnum)]
290        /// enum Example {
291        ///     A,      // ordinal: 0
292        ///     B = 10, // ordinal: 1
293        ///     C,      // ordinal: 2
294        /// }
295        ///
296        /// assert_eq!(Example::A.ordinal(), 0);
297        /// assert_eq!(Example::B.ordinal(), 1);
298        /// assert_eq!(Example::C.ordinal(), 2);
299        /// ```
300        pub fn ordinal(&self) -> usize {
301            match self {
302                #(#unit_match_arms,)*
303                #other_arm
304            }
305        }
306    }
307}
308fn generate_from_ordinal_impl(
309    name: &syn::Ident,
310    unit_variants: &[&Variant],
311) -> proc_macro2::TokenStream {
312    let match_arms = unit_variants.iter().enumerate().map(|(index, variant)| {
313        let variant_name = &variant.ident;
314        quote! { #index => Some(#name::#variant_name) }
315    });
316
317    quote! {
318        /// Converts a zero-based ordinal to an enum variant, if possible.
319        ///
320        /// Returns `Some(variant)` if the ordinal corresponds to a unit variant,
321        /// or `None` if the ordinal is out of range or would correspond to the "other" variant.
322        ///
323        /// # Examples
324        ///
325        /// ```ignore
326        /// # use unit_enum::UnitEnum;
327        /// # #[derive(Debug, PartialEq)]
328        /// #[derive(UnitEnum)]
329        /// enum Example {
330        ///     A,
331        ///     B,
332        ///     #[unit_enum(other)]
333        ///     Other(i32),
334        /// }
335        ///
336        /// assert_eq!(Example::from_ordinal(0), Some(Example::A));
337        /// assert_eq!(Example::from_ordinal(2), None); // Other variant
338        /// assert_eq!(Example::from_ordinal(99), None); // Out of range
339        /// ```
340        pub fn from_ordinal(ord: usize) -> Option<Self> {
341            match ord {
342                #(#match_arms,)*
343                _ => None
344            }
345        }
346    }
347}
348
349fn generate_discriminant_impl(
350    name: &syn::Ident,
351    unit_variants: &[&Variant],
352    other_variant: &Option<(&Variant, Type)>,
353    discriminant_type: &Type,
354    discriminants: &[Expr],
355) -> proc_macro2::TokenStream {
356    let unit_match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
357        let variant_name = &variant.ident;
358        quote! { #name::#variant_name => #discriminant as #discriminant_type }
359    });
360
361    let other_arm = other_variant.as_ref().map(|(variant, _)| {
362        let variant_name = &variant.ident;
363        quote! { #name::#variant_name(val) => *val }
364    });
365
366    quote! {
367        /// Returns the discriminant value of the enum variant.
368        ///
369        /// For "other" variants, returns the contained value.
370        ///
371        /// # Examples
372        ///
373        /// ```ignore
374        /// # use unit_enum::UnitEnum;
375        /// #[derive(UnitEnum)]
376        /// enum Example {
377        ///     A,      // 0
378        ///     B = 10, // 10
379        ///     C,      // 11
380        /// }
381        ///
382        /// assert_eq!(Example::A.discriminant(), 0);
383        /// assert_eq!(Example::B.discriminant(), 10);
384        /// assert_eq!(Example::C.discriminant(), 11);
385        /// ```
386         pub fn discriminant(&self) -> #discriminant_type {
387            match self {
388                #(#unit_match_arms,)*
389                #other_arm
390            }
391        }
392    }
393}
394
395fn generate_from_discriminant_impl(
396    name: &syn::Ident,
397    unit_variants: &[&Variant],
398    other_variant: &Option<(&Variant, Type)>,
399    discriminant_type: &Type,
400    discriminants: &[Expr],
401) -> proc_macro2::TokenStream {
402    if let Some((other_variant, _)) = other_variant {
403        let match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
404            let variant_name = &variant.ident;
405            quote! { x if x == (#discriminant as #discriminant_type) => #name::#variant_name }
406        });
407
408        let other_name = &other_variant.ident;
409        quote! {
410            /// Converts a discriminant value to an enum variant.
411            ///
412            /// For enums with an "other" variant, this will always return a value,
413            /// using the "other" variant for undefined discriminants.
414            ///
415            /// # Examples
416            ///
417            /// ```ignore
418            /// # use unit_enum::UnitEnum;
419            /// #[derive(UnitEnum, PartialEq, Debug)]
420            /// #[repr(u8)]
421            /// enum Example {
422            ///     A,      // 0
423            ///     B = 10, // 10
424            ///     #[unit_enum(other)]
425            ///     Other(u8),
426            /// }
427            ///
428            /// assert_eq!(Example::from_discriminant(0), Example::A);
429            /// assert_eq!(Example::from_discriminant(10), Example::B);
430            /// assert_eq!(Example::from_discriminant(42), Example::Other(42));
431            /// ```
432            pub fn from_discriminant(discr: #discriminant_type) -> Self {
433                match discr {
434                    #(#match_arms,)*
435                    other => #name::#other_name(other)
436                }
437            }
438        }
439    } else {
440        let match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
441            let variant_name = &variant.ident;
442            quote! { x if x == (#discriminant as #discriminant_type) => Some(#name::#variant_name) }
443        });
444
445        quote! {
446            /// Converts a discriminant value to an enum variant, if possible.
447            ///
448            /// Returns `Some(variant)` if the discriminant corresponds to a defined variant,
449            /// or `None` if the discriminant is undefined.
450            ///
451            /// # Examples
452            ///
453            /// ```ignore
454            /// # use unit_enum::UnitEnum;
455            /// #[derive(UnitEnum, PartialEq, Debug)]
456            /// #[repr(u8)]
457            /// enum Example {
458            ///     A,      // 0
459            ///     B = 10, // 10
460            ///     C,      // 11
461            /// }
462            ///
463            /// assert_eq!(Example::from_discriminant(0), Some(Example::A));
464            /// assert_eq!(Example::from_discriminant(10), Some(Example::B));
465            /// assert_eq!(Example::from_discriminant(42), None);
466            /// ```
467            pub fn from_discriminant(discr: #discriminant_type) -> Option<Self> {
468                match discr {
469                    #(#match_arms,)*
470                    _ => None
471                }
472            }
473        }
474    }
475}
476
477fn generate_values_impl(
478    name: &syn::Ident,
479    unit_variants: &[&Variant],
480    discriminants: &[Expr],
481    _other_variant: &Option<(&Variant, Type)>,
482) -> proc_macro2::TokenStream {
483    // Create a vector of variant expressions paired with their discriminants
484    let variant_exprs = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
485        let variant_name = &variant.ident;
486        quote! {
487            #name::#variant_name // The variant
488        }
489    });
490
491    // Collect variants into a Vec to ensure consistent ordering
492    quote! {
493        /// Returns an iterator over all unit variants of the enum.
494        ///
495        /// Note: This does not include values from the "other" variant, if present.
496        ///
497        /// # Examples
498        ///
499        /// ```ignore
500        /// # use unit_enum::UnitEnum;
501        /// #[derive(UnitEnum, PartialEq, Debug)]
502        /// enum Example {
503        ///     A,
504        ///     B,
505        ///     #[unit_enum(other)]
506        ///     Other(i32),
507        /// }
508        ///
509        /// let values: Vec<_> = Example::values().collect();
510        /// assert_eq!(values, vec![Example::A, Example::B]);
511        /// ```
512        pub fn values() -> impl Iterator<Item = Self> {
513            vec![
514                #(#variant_exprs),*
515            ].into_iter()
516        }
517    }
518}