tagged_union_macro/
lib.rs

1extern crate proc_macro;
2
3use inflector::Inflector;
4use proc_macro2::Span;
5use quote::{quote, ToTokens};
6use syn::{
7    parse,
8    parse::Parse,
9    parse2, parse_quote,
10    punctuated::{Pair, Punctuated},
11    spanned::Spanned,
12    Data, DataEnum, DeriveInput, Expr, ExprLit, Field, FieldMutability, Fields, FieldsNamed,
13    FieldsUnnamed, Generics, Ident, ImplItem, Item, ItemImpl, Lifetime, Lit, Meta, MetaNameValue,
14    Token, Type, TypeReference, TypeTuple, Variant, Visibility, WhereClause,
15};
16
17/// A proc macro to generate methods like is_variant / expect_variant.
18///
19///
20/// # Example
21///
22/// ```rust
23/// 
24/// use is_macro::Is;
25/// #[derive(Debug, Is)]
26/// pub enum Enum<T> {
27///     A,
28///     B(T),
29///     C(Option<T>),
30/// }
31///
32/// // Rust's type inference cannot handle this.
33/// assert!(Enum::<()>::A.is_a());
34///
35/// assert_eq!(Enum::B(String::from("foo")).b(), Some(String::from("foo")));
36///
37/// assert_eq!(Enum::B(String::from("foo")).expect_b(), String::from("foo"));
38/// ```
39///
40/// # Renaming
41///
42/// ```rust
43/// 
44/// use tagged_union::TaggedUnion;
45/// #[derive(Debug, TaggedUnion)]
46/// pub enum Enum {
47///     #[tagged_union(name = "video_mp4")]
48///     VideoMp4,
49/// }
50///
51/// assert!(Enum::VideoMp4.is_video_mp4());
52/// ```
53#[proc_macro_derive(TaggedUnion, attributes(tagged_union))]
54pub fn derive_tagged_union(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
55    let input: DeriveInput = syn::parse(input).expect("failed to parse derive input");
56    let generics: Generics = input.generics.clone();
57
58    let items = match input.data {
59        Data::Enum(e) => expand(&input.ident, &input.vis, &input.generics, e),
60        _ => panic!("`Is` can be applied only on enums"),
61    };
62
63    quote!(
64        #(#items)*
65    )
66    .into()
67}
68
69#[derive(Debug)]
70struct Input {
71    name: String,
72}
73
74impl Parse for Input {
75    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
76        let _: Ident = input.parse()?;
77        let _: Token![=] = input.parse()?;
78
79        let name = input.parse::<ExprLit>()?;
80
81        Ok(Input {
82            name: match name.lit {
83                Lit::Str(s) => s.value(),
84                _ => panic!("is(name = ...) expects a string literal"),
85            },
86        })
87    }
88}
89
90fn make_impl_item_for_enum(
91    enum_name: &Ident,
92    vis: &Visibility,
93    generics: &Generics,
94    input: &DataEnum,
95) -> Item {
96    let mut items = create_cast_methods_from_orig_enum(input);
97
98    parse_quote!(
99        /// Generated by `#[derive(tagged_union::TaggedUnion)]``.
100        #[automatically_derived]
101        impl #enum_name {}
102    )
103}
104
105fn make_ref_enum(
106    enum_name: &Ident,
107    vis: &Visibility,
108    generics: &Generics,
109    input: &DataEnum,
110    mutable: bool,
111) -> Item {
112    let new_type_name = ref_enum_name(enum_name, mutable);
113
114    let docs = format!(
115        "A {mutable} reference to the enum [`{name}`]. This is different from `&{name}` because \
116         this type supports creation from a subset of [`{name}`]",
117        name = enum_name,
118        mutable = if mutable { "mutable" } else { "immutable" },
119    );
120
121    let mut variants: Punctuated<Variant, Token![,]> = Default::default();
122
123    for v in &input.variants {
124        let variant = &v.ident;
125        let mut fields: Punctuated<Field, Token![,]> = Default::default();
126        let fields = match &v.fields {
127            Fields::Unnamed(fields_) => {
128                for f in fields_.unnamed.iter() {
129                    let ty = add_ref(f.ty.clone(), mutable);
130                    fields.push(Field {
131                        attrs: Default::default(),
132                        vis: Visibility::Inherited,
133                        mutability: FieldMutability::None,
134                        ident: None,
135                        colon_token: None,
136                        ty,
137                    });
138                }
139
140                Fields::Unnamed(FieldsUnnamed {
141                    paren_token: Default::default(),
142                    unnamed: fields,
143                })
144            }
145            Fields::Named(fields_) => {
146                for f in fields_.named.iter() {
147                    let ty = add_ref(f.ty.clone(), mutable);
148                    let name = f.ident.clone().unwrap();
149
150                    fields.push(Field {
151                        attrs: Default::default(),
152                        vis: Visibility::Inherited,
153                        mutability: FieldMutability::None,
154                        ident: Some(name),
155                        colon_token: None,
156                        ty,
157                    });
158                }
159
160                Fields::Named(FieldsNamed {
161                    brace_token: Default::default(),
162                    named: fields,
163                })
164            }
165            _ => todo!("ref enum for unit variant"),
166        };
167
168        let variant = parse_quote!(#variant #fields);
169
170        variants.push(variant);
171    }
172
173    parse_quote!(
174        #[doc = #docs]
175        ///
176        /// Generated by `#[derive(tagged_union::TaggedUnion)]``.
177        pub enum #new_type_name<'tu> {
178            #variants
179        }
180    )
181}
182
183fn make_impl_item_for_ref_enum(
184    enum_name: &Ident,
185    generics: &Generics,
186    input: &DataEnum,
187    mutable: bool,
188) -> Item {
189    let new_type_name = ref_enum_name(enum_name, mutable);
190
191    parse_quote!(
192        /// Generated by `#[derive(tagged_union::TaggedUnion)]``.
193        #[automatically_derived]
194        impl<'tu> #new_type_name<'tu> {}
195    )
196}
197
198fn ref_enum_name(enum_name: &Ident, mutable: bool) -> Ident {
199    let mut name = enum_name.to_string();
200    if mutable {
201        name.push_str("Ref");
202    } else {
203        name.push_str("MutRef");
204    }
205    Ident::new(&name, enum_name.span())
206}
207
208fn expand(enum_name: &Ident, vis: &Visibility, generics: &Generics, input: DataEnum) -> Vec<Item> {
209    vec![
210        make_impl_item_for_enum(enum_name, vis, generics, &input),
211        make_ref_enum(enum_name, vis, generics, &input, false),
212        make_impl_item_for_ref_enum(enum_name, generics, &input, false),
213        make_ref_enum(enum_name, vis, generics, &input, true),
214        make_impl_item_for_ref_enum(enum_name, generics, &input, true),
215    ]
216}
217
218fn create_cast_methods_from_orig_enum(input: &DataEnum) -> Vec<ImplItem> {
219    let mut items = vec![];
220
221    for v in &input.variants {
222        let attrs = v
223            .attrs
224            .iter()
225            .filter(|attr| attr.path().is_ident("is"))
226            .collect::<Vec<_>>();
227        if attrs.len() >= 2 {
228            panic!("derive(Is) expects no attribute or one attribute")
229        }
230        let i = match attrs.into_iter().next() {
231            None => Input {
232                name: {
233                    v.ident.to_string().to_snake_case()
234                    //
235                },
236            },
237            Some(attr) => {
238                //
239
240                let mut input = Input {
241                    name: Default::default(),
242                };
243
244                let mut apply = |v: &MetaNameValue| {
245                    assert!(
246                        v.path.is_ident("name"),
247                        "Currently, is() only supports `is(name = 'foo')`"
248                    );
249
250                    input.name = match &v.value {
251                        Expr::Lit(ExprLit {
252                            lit: Lit::Str(s), ..
253                        }) => s.value(),
254                        _ => unimplemented!(
255                            "is(): name must be a string literal but {:?} is provided",
256                            v.value
257                        ),
258                    };
259                };
260
261                match &attr.meta {
262                    Meta::NameValue(v) => {
263                        //
264                        apply(v)
265                    }
266                    Meta::List(l) => {
267                        // Handle is(name = "foo")
268                        input = parse2(l.tokens.clone()).expect("failed to parse input");
269                    }
270                    _ => unimplemented!("is({:?})", attr.meta),
271                }
272
273                input
274            }
275        };
276
277        let name = &*i.name;
278        {
279            let name_of_is = Ident::new(&format!("is_{name}"), v.ident.span());
280            let docs_of_is = format!(
281                "Returns `true` if `self` is of variant [`{variant}`].\n\n[`{variant}`]: \
282                 #variant.{variant}",
283                variant = v.ident,
284            );
285
286            let variant = &v.ident;
287
288            let item_impl: ItemImpl = parse_quote!(
289                impl Type {
290                    #[doc = #docs_of_is]
291                    #[inline]
292                    pub const fn #name_of_is(&self) -> bool {
293                        match *self {
294                            Self::#variant { .. } => true,
295                            _ => false,
296                        }
297                    }
298                }
299            );
300
301            items.extend(item_impl.items);
302        }
303
304        {
305            let name_of_cast = Ident::new(&format!("as_{name}"), v.ident.span());
306            let name_of_cast_mut = Ident::new(&format!("as_mut_{name}"), v.ident.span());
307            let name_of_expect = Ident::new(&format!("expect_{name}"), v.ident.span());
308            let name_of_take = Ident::new(name, v.ident.span());
309
310            let docs_of_cast = format!(
311                "Returns `Some` if `self` is a reference of variant [`{variant}`], and `None` \
312                 otherwise.\n\n[`{variant}`]: #variant.{variant}",
313                variant = v.ident,
314            );
315            let docs_of_cast_mut = format!(
316                "Returns `Some` if `self` is a mutable reference of variant [`{variant}`], and \
317                 `None` otherwise.\n\n[`{variant}`]: #variant.{variant}",
318                variant = v.ident,
319            );
320            let docs_of_expect = format!(
321                "Unwraps the value, yielding the content of [`{variant}`].\n\n# Panics\n\nPanics \
322                 if the value is not [`{variant}`], with a panic message including the content of \
323                 `self`.\n\n[`{variant}`]: #variant.{variant}",
324                variant = v.ident,
325            );
326            let docs_of_take = format!(
327                "Returns `Some` if `self` is of variant [`{variant}`], and `None` \
328                 otherwise.\n\n[`{variant}`]: #variant.{variant}",
329                variant = v.ident,
330            );
331
332            if let Fields::Unnamed(fields) = &v.fields {
333                let types = fields.unnamed.iter().map(|f| f.ty.clone());
334                let cast_ty = types_to_type(types.clone().map(|ty| add_ref(ty, false)));
335                let cast_ty_mut = types_to_type(types.clone().map(|ty| add_ref(ty, true)));
336                let ty = types_to_type(types);
337
338                let mut fields: Punctuated<Ident, Token![,]> = fields
339                    .unnamed
340                    .clone()
341                    .into_pairs()
342                    .enumerate()
343                    .map(|(i, pair)| {
344                        let handle = |f: Field| {
345                            //
346                            Ident::new(&format!("v{i}"), f.span())
347                        };
348                        match pair {
349                            Pair::Punctuated(v, p) => Pair::Punctuated(handle(v), p),
350                            Pair::End(v) => Pair::End(handle(v)),
351                        }
352                    })
353                    .collect();
354
355                // Make sure that we don't have any trailing punctuation
356                // This ensure that if we have a single unnamed field,
357                // we will produce a value of the form `(v)`,
358                // not a single-element tuple `(v,)`
359                if let Some(mut pair) = fields.pop() {
360                    if let Pair::Punctuated(v, _) = pair {
361                        pair = Pair::End(v);
362                    }
363                    fields.extend(std::iter::once(pair));
364                }
365
366                let variant = &v.ident;
367
368                let item_impl: ItemImpl = parse_quote!(
369                    impl #ty {
370                        #[doc = #docs_of_cast]
371                        #[inline]
372                        pub fn #name_of_cast(&self) -> Option<#cast_ty> {
373                            match self {
374                                Self::#variant(#fields) => Some((#fields)),
375                                _ => None,
376                            }
377                        }
378
379                        #[doc = #docs_of_cast_mut]
380                        #[inline]
381                        pub fn #name_of_cast_mut(&mut self) -> Option<#cast_ty_mut> {
382                            match self {
383                                Self::#variant(#fields) => Some((#fields)),
384                                _ => None,
385                            }
386                        }
387
388                        #[doc = #docs_of_expect]
389                        #[inline]
390                        pub fn #name_of_expect(self) -> #ty
391                        where
392                            Self: ::std::fmt::Debug,
393                        {
394                            match self {
395                                Self::#variant(#fields) => (#fields),
396                                _ => panic!("called expect on {:?}", self),
397                            }
398                        }
399
400                        #[doc = #docs_of_take]
401                        #[inline]
402                        pub fn #name_of_take(self) -> Option<#ty> {
403                            match self {
404                                Self::#variant(#fields) => Some((#fields)),
405                                _ => None,
406                            }
407                        }
408                    }
409                );
410
411                items.extend(item_impl.items);
412            }
413        }
414    }
415
416    items
417}
418
419fn types_to_type(types: impl Iterator<Item = Type>) -> Type {
420    let mut types: Punctuated<_, _> = types.collect();
421    if types.len() == 1 {
422        types.pop().expect("len is 1").into_value()
423    } else {
424        TypeTuple {
425            paren_token: Default::default(),
426            elems: types,
427        }
428        .into()
429    }
430}
431
432fn add_ref(ty: Type, mutable: bool) -> Type {
433    Type::Reference(TypeReference {
434        and_token: Default::default(),
435        lifetime: Some(Lifetime::new("'tu", Span::call_site())),
436        mutability: if mutable {
437            Some(Default::default())
438        } else {
439            None
440        },
441        elem: Box::new(ty),
442    })
443}
444
445/// Extension trait for `ItemImpl` (impl block).
446trait ItemImplExt {
447    /// Instead of
448    ///
449    /// ```rust,ignore
450    /// let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
451    ///
452    /// let item: Item = Quote::new(def_site::<Span>())
453    ///     .quote_with(smart_quote!(
454    /// Vars {
455    /// Type: type_name,
456    /// impl_generics,
457    /// ty_generics,
458    /// where_clause,
459    /// },
460    /// {
461    /// impl impl_generics ::swc_common::AstNode for Type ty_generics
462    /// where_clause {}
463    /// }
464    /// )).parse();
465    /// ```
466    ///
467    /// You can use this like
468    ///
469    /// ```rust,ignore
470    // let item = Quote::new(def_site::<Span>())
471    ///     .quote_with(smart_quote!(Vars { Type: type_name }, {
472    ///         impl ::swc_common::AstNode for Type {}
473    ///     }))
474    ///     .parse::<ItemImpl>()
475    ///     .with_generics(input.generics);
476    /// ```
477    fn with_generics(self, generics: Generics) -> Self;
478}
479
480impl ItemImplExt for ItemImpl {
481    fn with_generics(mut self, mut generics: Generics) -> Self {
482        // TODO: Check conflicting name
483
484        let need_new_punct = !generics.params.empty_or_trailing();
485        if need_new_punct {
486            generics
487                .params
488                .push_punct(syn::token::Comma(Span::call_site()));
489        }
490
491        // Respan
492        if let Some(t) = generics.lt_token {
493            self.generics.lt_token = Some(t)
494        }
495        if let Some(t) = generics.gt_token {
496            self.generics.gt_token = Some(t)
497        }
498
499        let ty = self.self_ty;
500
501        // Handle generics defined on struct, enum, or union.
502        let mut item: ItemImpl = {
503            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
504            let item = if let Some((ref polarity, ref path, ref for_token)) = self.trait_ {
505                quote! {
506                    impl #impl_generics #polarity #path #for_token #ty #ty_generics #where_clause {}
507                }
508            } else {
509                quote! {
510                    impl #impl_generics #ty #ty_generics #where_clause {}
511
512                }
513            };
514            parse2(item.into_token_stream())
515                .unwrap_or_else(|err| panic!("with_generics failed: {}", err))
516        };
517
518        // Handle generics added by proc-macro.
519        item.generics
520            .params
521            .extend(self.generics.params.into_pairs());
522        match self.generics.where_clause {
523            Some(WhereClause {
524                ref mut predicates, ..
525            }) => predicates.extend(
526                generics
527                    .where_clause
528                    .into_iter()
529                    .flat_map(|wc| wc.predicates.into_pairs()),
530            ),
531            ref mut opt @ None => *opt = generics.where_clause,
532        }
533
534        ItemImpl {
535            attrs: self.attrs,
536            defaultness: self.defaultness,
537            unsafety: self.unsafety,
538            impl_token: self.impl_token,
539            brace_token: self.brace_token,
540            items: self.items,
541            ..item
542        }
543    }
544}