unit_enum/lib.rs
1#![doc = include_str!("lib.md")]
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Error, Expr, Fields
6 , Type, Variant};
7
8/// Derives the `UnitEnum` trait for an enum.
9///
10/// This macro can be used on enums with unit variants (no fields) and optionally one "other" variant
11/// that can hold arbitrary discriminant values.
12///
13/// # Attributes
14/// - `#[repr(type)]`: Optional for regular enums, defaults to i32. Required when using an "other" variant.
15/// - `#[unit_enum(other)]`: Marks a variant as the catch-all for undefined discriminant values.
16/// The type of this variant must match the repr type.
17///
18/// # Requirements
19/// - The enum must contain only unit variants, except for one optional "other" variant
20/// - The "other" variant, if present, must:
21/// - Be marked with `#[unit_enum(other)]`
22/// - Have exactly one unnamed field matching the repr type
23/// - Be the only variant with the "other" attribute
24/// - Have a matching `#[repr(type)]` attribute
25///
26/// # Examples
27///
28/// Basic usage with unit variants (repr is optional):
29/// ```rust
30/// # use unit_enum::UnitEnum;
31/// #[derive(UnitEnum)]
32/// enum Example {
33/// A,
34/// B = 10,
35/// C,
36/// }
37/// ```
38///
39/// Usage with explicit repr:
40/// ```rust
41/// # use unit_enum::UnitEnum;
42/// #[derive(UnitEnum)]
43/// #[repr(u16)]
44/// enum Color {
45/// Red = 10,
46/// Green,
47/// Blue = 45654,
48/// }
49/// ```
50///
51/// Usage with an "other" variant (repr required):
52/// ```rust
53/// # use unit_enum::UnitEnum;
54/// #[derive(UnitEnum)]
55/// #[repr(u16)]
56/// enum Status {
57/// Active = 1,
58/// Inactive = 2,
59/// #[unit_enum(other)]
60/// Unknown(u16), // type must match repr
61/// }
62/// ```
63#[proc_macro_derive(UnitEnum, attributes(unit_enum))]
64pub fn unit_enum_derive(input: TokenStream) -> TokenStream {
65 let ast = parse_macro_input!(input as DeriveInput);
66
67 match validate_and_process(&ast) {
68 Ok((discriminant_type, unit_variants, other_variant)) => {
69 impl_unit_enum(&ast, &discriminant_type, &unit_variants, other_variant)
70 }
71 Err(e) => e.to_compile_error().into(),
72 }
73}
74
75struct ValidationResult<'a> {
76 unit_variants: Vec<&'a Variant>,
77 other_variant: Option<(&'a Variant, Type)>,
78}
79
80fn validate_and_process(ast: &DeriveInput) -> Result<(Type, Vec<&Variant>, Option<(&Variant, Type)>), Error> {
81 // Get discriminant type from #[repr] attribute
82 let discriminant_type = get_discriminant_type(ast)?;
83
84 let data_enum = match &ast.data {
85 Data::Enum(data_enum) => data_enum,
86 _ => return Err(Error::new_spanned(ast, "UnitEnum can only be derived for enums")),
87 };
88
89 let mut validation = ValidationResult {
90 unit_variants: Vec::new(),
91 other_variant: None,
92 };
93
94 // Validate each variant
95 for variant in &data_enum.variants {
96 match &variant.fields {
97 Fields::Unit => {
98 if has_unit_enum_attr(variant) {
99 return Err(Error::new_spanned(variant,
100 "Unit variants cannot have #[unit_enum] attributes"));
101 }
102 validation.unit_variants.push(variant);
103 }
104 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
105 if has_unit_enum_other_attr(variant) {
106 if validation.other_variant.is_some() {
107 return Err(Error::new_spanned(variant,
108 "Multiple #[unit_enum(other)] variants found. Only one is allowed"));
109 }
110 validation.other_variant = Some((variant, fields.unnamed[0].ty.clone()));
111 } else {
112 return Err(Error::new_spanned(variant,
113 "Non-unit variant must be marked with #[unit_enum(other)] to be used as the catch-all variant"));
114 }
115 }
116 _ => return Err(Error::new_spanned(variant,
117 "Invalid variant. UnitEnum only supports unit variants and a single tuple variant marked with #[unit_enum(other)]")),
118 }
119 }
120
121 Ok((discriminant_type, validation.unit_variants, validation.other_variant))
122}
123
124fn get_discriminant_type(ast: &DeriveInput) -> Result<Type, Error> {
125 ast.attrs.iter()
126 .find(|attr| attr.path().is_ident("repr"))
127 .map_or(Ok(syn::parse_quote!(i32)), |attr| {
128 attr.parse_args::<Type>()
129 .map_err(|_| Error::new_spanned(attr, "Invalid repr attribute"))
130 })
131}
132
133fn has_unit_enum_attr(variant: &Variant) -> bool {
134 variant.attrs.iter().any(|attr| attr.path().is_ident("unit_enum"))
135}
136
137fn has_unit_enum_other_attr(variant: &Variant) -> bool {
138 variant.attrs.iter().any(|attr| {
139 attr.path().is_ident("unit_enum") &&
140 attr.parse_nested_meta(|meta| {
141 if meta.path.is_ident("other") {
142 Ok(())
143 } else {
144 Err(meta.error("Invalid unit_enum attribute"))
145 }
146 }).is_ok()
147 })
148}
149
150fn compute_discriminants(variants: &[&Variant]) -> Vec<Expr> {
151 let mut discriminants = Vec::with_capacity(variants.len());
152 let mut last_discriminant: Option<Expr> = None;
153
154 for variant in variants {
155 let discriminant = variant.discriminant.as_ref().map(|(_, expr)| expr.clone())
156 .or_else(|| {
157 last_discriminant.clone().map(|expr| syn::parse_quote! { #expr + 1 })
158 })
159 .unwrap_or_else(|| syn::parse_quote! { 0 });
160
161 discriminants.push(discriminant.clone());
162 last_discriminant = Some(discriminant);
163 }
164
165 discriminants
166}
167
168fn impl_unit_enum(
169 ast: &DeriveInput,
170 discriminant_type: &Type,
171 unit_variants: &[&Variant],
172 other_variant: Option<(&Variant, Type)>,
173) -> TokenStream {
174 let name = &ast.ident;
175 let num_variants = unit_variants.len();
176 let discriminants = compute_discriminants(unit_variants);
177
178 let name_impl = generate_name_impl(name, unit_variants, &other_variant);
179 let ordinal_impl = generate_ordinal_impl(name, unit_variants, &other_variant, num_variants);
180 let from_ordinal_impl = generate_from_ordinal_impl(name, unit_variants);
181 let discriminant_impl = generate_discriminant_impl(name, unit_variants, &other_variant, discriminant_type, &discriminants);
182 let from_discriminant_impl = generate_from_discriminant_impl(name, unit_variants, &other_variant, discriminant_type, &discriminants);
183 let values_impl = generate_values_impl(name, unit_variants, &discriminants, &other_variant);
184
185 quote! {
186 impl #name {
187 #name_impl
188
189 #ordinal_impl
190
191 #from_ordinal_impl
192
193 #discriminant_impl
194
195 #from_discriminant_impl
196
197 /// Returns the total number of unit variants in the enum (excluding the "other" variant if present).
198 ///
199 /// # Examples
200 ///
201 /// ```ignore
202 /// # use unit_enum::UnitEnum;
203 /// #[derive(UnitEnum)]
204 /// enum Example {
205 /// A,
206 /// B,
207 /// #[unit_enum(other)]
208 /// Other(i32),
209 /// }
210 ///
211 /// assert_eq!(Example::len(), 2);
212 /// ```
213 pub fn len() -> usize {
214 #num_variants
215 }
216
217 #values_impl
218 }
219 }.into()
220}
221
222fn generate_name_impl(
223 name: &syn::Ident,
224 unit_variants: &[&Variant],
225 other_variant: &Option<(&Variant, Type)>,
226) -> proc_macro2::TokenStream {
227 let unit_match_arms = unit_variants.iter().map(|variant| {
228 let variant_name = &variant.ident;
229 quote! { #name::#variant_name => stringify!(#variant_name) }
230 });
231
232 let other_arm = other_variant.as_ref().map(|(variant, _)| {
233 let variant_name = &variant.ident;
234 quote! { #name::#variant_name(_) => stringify!(#variant_name) }
235 });
236
237 quote! {
238 /// Returns the name of the enum variant as a string.
239 ///
240 /// # Examples
241 ///
242 /// ```ignore
243 /// # use unit_enum::UnitEnum;
244 /// #[derive(UnitEnum)]
245 /// enum Example {
246 /// A,
247 /// B = 10,
248 /// C,
249 /// }
250 ///
251 /// assert_eq!(Example::A.name(), "A");
252 /// assert_eq!(Example::B.name(), "B");
253 /// assert_eq!(Example::C.name(), "C");
254 /// ```
255 pub fn name(&self) -> &str {
256 match self {
257 #(#unit_match_arms,)*
258 #other_arm
259 }
260 }
261 }
262}
263
264fn generate_ordinal_impl(
265 name: &syn::Ident,
266 unit_variants: &[&Variant],
267 other_variant: &Option<(&Variant, Type)>,
268 num_variants: usize,
269) -> proc_macro2::TokenStream {
270 let unit_match_arms = unit_variants.iter().enumerate().map(|(index, variant)| {
271 let variant_name = &variant.ident;
272 quote! { #name::#variant_name => #index }
273 });
274
275 let other_arm = other_variant.as_ref().map(|(variant, _)| {
276 let variant_name = &variant.ident;
277 quote! { #name::#variant_name(_) => #num_variants }
278 });
279
280 quote! {
281 /// Returns the zero-based ordinal of the enum variant.
282 ///
283 /// For enums with an "other" variant, it returns the position after all unit variants.
284 ///
285 /// # Examples
286 ///
287 /// ```ignore
288 /// # use unit_enum::UnitEnum;
289 /// #[derive(UnitEnum)]
290 /// enum Example {
291 /// A, // ordinal: 0
292 /// B = 10, // ordinal: 1
293 /// C, // ordinal: 2
294 /// }
295 ///
296 /// assert_eq!(Example::A.ordinal(), 0);
297 /// assert_eq!(Example::B.ordinal(), 1);
298 /// assert_eq!(Example::C.ordinal(), 2);
299 /// ```
300 pub fn ordinal(&self) -> usize {
301 match self {
302 #(#unit_match_arms,)*
303 #other_arm
304 }
305 }
306 }
307}
308fn generate_from_ordinal_impl(
309 name: &syn::Ident,
310 unit_variants: &[&Variant],
311) -> proc_macro2::TokenStream {
312 let match_arms = unit_variants.iter().enumerate().map(|(index, variant)| {
313 let variant_name = &variant.ident;
314 quote! { #index => Some(#name::#variant_name) }
315 });
316
317 quote! {
318 /// Converts a zero-based ordinal to an enum variant, if possible.
319 ///
320 /// Returns `Some(variant)` if the ordinal corresponds to a unit variant,
321 /// or `None` if the ordinal is out of range or would correspond to the "other" variant.
322 ///
323 /// # Examples
324 ///
325 /// ```ignore
326 /// # use unit_enum::UnitEnum;
327 /// # #[derive(Debug, PartialEq)]
328 /// #[derive(UnitEnum)]
329 /// enum Example {
330 /// A,
331 /// B,
332 /// #[unit_enum(other)]
333 /// Other(i32),
334 /// }
335 ///
336 /// assert_eq!(Example::from_ordinal(0), Some(Example::A));
337 /// assert_eq!(Example::from_ordinal(2), None); // Other variant
338 /// assert_eq!(Example::from_ordinal(99), None); // Out of range
339 /// ```
340 pub fn from_ordinal(ord: usize) -> Option<Self> {
341 match ord {
342 #(#match_arms,)*
343 _ => None
344 }
345 }
346 }
347}
348
349fn generate_discriminant_impl(
350 name: &syn::Ident,
351 unit_variants: &[&Variant],
352 other_variant: &Option<(&Variant, Type)>,
353 discriminant_type: &Type,
354 discriminants: &[Expr],
355) -> proc_macro2::TokenStream {
356 let unit_match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
357 let variant_name = &variant.ident;
358 quote! { #name::#variant_name => #discriminant as #discriminant_type }
359 });
360
361 let other_arm = other_variant.as_ref().map(|(variant, _)| {
362 let variant_name = &variant.ident;
363 quote! { #name::#variant_name(val) => *val }
364 });
365
366 quote! {
367 /// Returns the discriminant value of the enum variant.
368 ///
369 /// For "other" variants, returns the contained value.
370 ///
371 /// # Examples
372 ///
373 /// ```ignore
374 /// # use unit_enum::UnitEnum;
375 /// #[derive(UnitEnum)]
376 /// enum Example {
377 /// A, // 0
378 /// B = 10, // 10
379 /// C, // 11
380 /// }
381 ///
382 /// assert_eq!(Example::A.discriminant(), 0);
383 /// assert_eq!(Example::B.discriminant(), 10);
384 /// assert_eq!(Example::C.discriminant(), 11);
385 /// ```
386 pub fn discriminant(&self) -> #discriminant_type {
387 match self {
388 #(#unit_match_arms,)*
389 #other_arm
390 }
391 }
392 }
393}
394
395fn generate_from_discriminant_impl(
396 name: &syn::Ident,
397 unit_variants: &[&Variant],
398 other_variant: &Option<(&Variant, Type)>,
399 discriminant_type: &Type,
400 discriminants: &[Expr],
401) -> proc_macro2::TokenStream {
402 if let Some((other_variant, _)) = other_variant {
403 let match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
404 let variant_name = &variant.ident;
405 quote! { x if x == (#discriminant as #discriminant_type) => #name::#variant_name }
406 });
407
408 let other_name = &other_variant.ident;
409 quote! {
410 /// Converts a discriminant value to an enum variant.
411 ///
412 /// For enums with an "other" variant, this will always return a value,
413 /// using the "other" variant for undefined discriminants.
414 ///
415 /// # Examples
416 ///
417 /// ```ignore
418 /// # use unit_enum::UnitEnum;
419 /// #[derive(UnitEnum, PartialEq, Debug)]
420 /// #[repr(u8)]
421 /// enum Example {
422 /// A, // 0
423 /// B = 10, // 10
424 /// #[unit_enum(other)]
425 /// Other(u8),
426 /// }
427 ///
428 /// assert_eq!(Example::from_discriminant(0), Example::A);
429 /// assert_eq!(Example::from_discriminant(10), Example::B);
430 /// assert_eq!(Example::from_discriminant(42), Example::Other(42));
431 /// ```
432 pub fn from_discriminant(discr: #discriminant_type) -> Self {
433 match discr {
434 #(#match_arms,)*
435 other => #name::#other_name(other)
436 }
437 }
438 }
439 } else {
440 let match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
441 let variant_name = &variant.ident;
442 quote! { x if x == (#discriminant as #discriminant_type) => Some(#name::#variant_name) }
443 });
444
445 quote! {
446 /// Converts a discriminant value to an enum variant, if possible.
447 ///
448 /// Returns `Some(variant)` if the discriminant corresponds to a defined variant,
449 /// or `None` if the discriminant is undefined.
450 ///
451 /// # Examples
452 ///
453 /// ```ignore
454 /// # use unit_enum::UnitEnum;
455 /// #[derive(UnitEnum, PartialEq, Debug)]
456 /// #[repr(u8)]
457 /// enum Example {
458 /// A, // 0
459 /// B = 10, // 10
460 /// C, // 11
461 /// }
462 ///
463 /// assert_eq!(Example::from_discriminant(0), Some(Example::A));
464 /// assert_eq!(Example::from_discriminant(10), Some(Example::B));
465 /// assert_eq!(Example::from_discriminant(42), None);
466 /// ```
467 pub fn from_discriminant(discr: #discriminant_type) -> Option<Self> {
468 match discr {
469 #(#match_arms,)*
470 _ => None
471 }
472 }
473 }
474 }
475}
476
477fn generate_values_impl(
478 name: &syn::Ident,
479 unit_variants: &[&Variant],
480 discriminants: &[Expr],
481 _other_variant: &Option<(&Variant, Type)>,
482) -> proc_macro2::TokenStream {
483 // Create a vector of variant expressions paired with their discriminants
484 let variant_exprs = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
485 let variant_name = &variant.ident;
486 quote! {
487 #name::#variant_name // The variant
488 }
489 });
490
491 // Collect variants into a Vec to ensure consistent ordering
492 quote! {
493 /// Returns an iterator over all unit variants of the enum.
494 ///
495 /// Note: This does not include values from the "other" variant, if present.
496 ///
497 /// # Examples
498 ///
499 /// ```ignore
500 /// # use unit_enum::UnitEnum;
501 /// #[derive(UnitEnum, PartialEq, Debug)]
502 /// enum Example {
503 /// A,
504 /// B,
505 /// #[unit_enum(other)]
506 /// Other(i32),
507 /// }
508 ///
509 /// let values: Vec<_> = Example::values().collect();
510 /// assert_eq!(values, vec![Example::A, Example::B]);
511 /// ```
512 pub fn values() -> impl Iterator<Item = Self> {
513 vec![
514 #(#variant_exprs),*
515 ].into_iter()
516 }
517 }
518}