1extern crate proc_macro;
2
3use inflector::Inflector;
4use proc_macro2::Span;
5use quote::{quote, ToTokens};
6use syn::{
7 parse,
8 parse::Parse,
9 parse2, parse_quote,
10 punctuated::{Pair, Punctuated},
11 spanned::Spanned,
12 Data, DataEnum, DeriveInput, Expr, ExprLit, Field, FieldMutability, Fields, FieldsNamed,
13 FieldsUnnamed, Generics, Ident, ImplItem, Item, ItemImpl, Lifetime, Lit, Meta, MetaNameValue,
14 Token, Type, TypeReference, TypeTuple, Variant, Visibility, WhereClause,
15};
16
17#[proc_macro_derive(TaggedUnion, attributes(tagged_union))]
54pub fn derive_tagged_union(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
55 let input: DeriveInput = syn::parse(input).expect("failed to parse derive input");
56 let generics: Generics = input.generics.clone();
57
58 let items = match input.data {
59 Data::Enum(e) => expand(&input.ident, &input.vis, &input.generics, e),
60 _ => panic!("`Is` can be applied only on enums"),
61 };
62
63 quote!(
64 #(#items)*
65 )
66 .into()
67}
68
69#[derive(Debug)]
70struct Input {
71 name: String,
72}
73
74impl Parse for Input {
75 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
76 let _: Ident = input.parse()?;
77 let _: Token![=] = input.parse()?;
78
79 let name = input.parse::<ExprLit>()?;
80
81 Ok(Input {
82 name: match name.lit {
83 Lit::Str(s) => s.value(),
84 _ => panic!("is(name = ...) expects a string literal"),
85 },
86 })
87 }
88}
89
90fn make_impl_item_for_enum(
91 enum_name: &Ident,
92 vis: &Visibility,
93 generics: &Generics,
94 input: &DataEnum,
95) -> Item {
96 let mut items = create_cast_methods_from_orig_enum(input);
97
98 parse_quote!(
99 #[automatically_derived]
101 impl #enum_name {}
102 )
103}
104
105fn make_ref_enum(
106 enum_name: &Ident,
107 vis: &Visibility,
108 generics: &Generics,
109 input: &DataEnum,
110 mutable: bool,
111) -> Item {
112 let new_type_name = ref_enum_name(enum_name, mutable);
113
114 let docs = format!(
115 "A {mutable} reference to the enum [`{name}`]. This is different from `&{name}` because \
116 this type supports creation from a subset of [`{name}`]",
117 name = enum_name,
118 mutable = if mutable { "mutable" } else { "immutable" },
119 );
120
121 let mut variants: Punctuated<Variant, Token![,]> = Default::default();
122
123 for v in &input.variants {
124 let variant = &v.ident;
125 let mut fields: Punctuated<Field, Token![,]> = Default::default();
126 let fields = match &v.fields {
127 Fields::Unnamed(fields_) => {
128 for f in fields_.unnamed.iter() {
129 let ty = add_ref(f.ty.clone(), mutable);
130 fields.push(Field {
131 attrs: Default::default(),
132 vis: Visibility::Inherited,
133 mutability: FieldMutability::None,
134 ident: None,
135 colon_token: None,
136 ty,
137 });
138 }
139
140 Fields::Unnamed(FieldsUnnamed {
141 paren_token: Default::default(),
142 unnamed: fields,
143 })
144 }
145 Fields::Named(fields_) => {
146 for f in fields_.named.iter() {
147 let ty = add_ref(f.ty.clone(), mutable);
148 let name = f.ident.clone().unwrap();
149
150 fields.push(Field {
151 attrs: Default::default(),
152 vis: Visibility::Inherited,
153 mutability: FieldMutability::None,
154 ident: Some(name),
155 colon_token: None,
156 ty,
157 });
158 }
159
160 Fields::Named(FieldsNamed {
161 brace_token: Default::default(),
162 named: fields,
163 })
164 }
165 _ => todo!("ref enum for unit variant"),
166 };
167
168 let variant = parse_quote!(#variant #fields);
169
170 variants.push(variant);
171 }
172
173 parse_quote!(
174 #[doc = #docs]
175 pub enum #new_type_name<'tu> {
178 #variants
179 }
180 )
181}
182
183fn make_impl_item_for_ref_enum(
184 enum_name: &Ident,
185 generics: &Generics,
186 input: &DataEnum,
187 mutable: bool,
188) -> Item {
189 let new_type_name = ref_enum_name(enum_name, mutable);
190
191 parse_quote!(
192 #[automatically_derived]
194 impl<'tu> #new_type_name<'tu> {}
195 )
196}
197
198fn ref_enum_name(enum_name: &Ident, mutable: bool) -> Ident {
199 let mut name = enum_name.to_string();
200 if mutable {
201 name.push_str("Ref");
202 } else {
203 name.push_str("MutRef");
204 }
205 Ident::new(&name, enum_name.span())
206}
207
208fn expand(enum_name: &Ident, vis: &Visibility, generics: &Generics, input: DataEnum) -> Vec<Item> {
209 vec![
210 make_impl_item_for_enum(enum_name, vis, generics, &input),
211 make_ref_enum(enum_name, vis, generics, &input, false),
212 make_impl_item_for_ref_enum(enum_name, generics, &input, false),
213 make_ref_enum(enum_name, vis, generics, &input, true),
214 make_impl_item_for_ref_enum(enum_name, generics, &input, true),
215 ]
216}
217
218fn create_cast_methods_from_orig_enum(input: &DataEnum) -> Vec<ImplItem> {
219 let mut items = vec![];
220
221 for v in &input.variants {
222 let attrs = v
223 .attrs
224 .iter()
225 .filter(|attr| attr.path().is_ident("is"))
226 .collect::<Vec<_>>();
227 if attrs.len() >= 2 {
228 panic!("derive(Is) expects no attribute or one attribute")
229 }
230 let i = match attrs.into_iter().next() {
231 None => Input {
232 name: {
233 v.ident.to_string().to_snake_case()
234 },
236 },
237 Some(attr) => {
238 let mut input = Input {
241 name: Default::default(),
242 };
243
244 let mut apply = |v: &MetaNameValue| {
245 assert!(
246 v.path.is_ident("name"),
247 "Currently, is() only supports `is(name = 'foo')`"
248 );
249
250 input.name = match &v.value {
251 Expr::Lit(ExprLit {
252 lit: Lit::Str(s), ..
253 }) => s.value(),
254 _ => unimplemented!(
255 "is(): name must be a string literal but {:?} is provided",
256 v.value
257 ),
258 };
259 };
260
261 match &attr.meta {
262 Meta::NameValue(v) => {
263 apply(v)
265 }
266 Meta::List(l) => {
267 input = parse2(l.tokens.clone()).expect("failed to parse input");
269 }
270 _ => unimplemented!("is({:?})", attr.meta),
271 }
272
273 input
274 }
275 };
276
277 let name = &*i.name;
278 {
279 let name_of_is = Ident::new(&format!("is_{name}"), v.ident.span());
280 let docs_of_is = format!(
281 "Returns `true` if `self` is of variant [`{variant}`].\n\n[`{variant}`]: \
282 #variant.{variant}",
283 variant = v.ident,
284 );
285
286 let variant = &v.ident;
287
288 let item_impl: ItemImpl = parse_quote!(
289 impl Type {
290 #[doc = #docs_of_is]
291 #[inline]
292 pub const fn #name_of_is(&self) -> bool {
293 match *self {
294 Self::#variant { .. } => true,
295 _ => false,
296 }
297 }
298 }
299 );
300
301 items.extend(item_impl.items);
302 }
303
304 {
305 let name_of_cast = Ident::new(&format!("as_{name}"), v.ident.span());
306 let name_of_cast_mut = Ident::new(&format!("as_mut_{name}"), v.ident.span());
307 let name_of_expect = Ident::new(&format!("expect_{name}"), v.ident.span());
308 let name_of_take = Ident::new(name, v.ident.span());
309
310 let docs_of_cast = format!(
311 "Returns `Some` if `self` is a reference of variant [`{variant}`], and `None` \
312 otherwise.\n\n[`{variant}`]: #variant.{variant}",
313 variant = v.ident,
314 );
315 let docs_of_cast_mut = format!(
316 "Returns `Some` if `self` is a mutable reference of variant [`{variant}`], and \
317 `None` otherwise.\n\n[`{variant}`]: #variant.{variant}",
318 variant = v.ident,
319 );
320 let docs_of_expect = format!(
321 "Unwraps the value, yielding the content of [`{variant}`].\n\n# Panics\n\nPanics \
322 if the value is not [`{variant}`], with a panic message including the content of \
323 `self`.\n\n[`{variant}`]: #variant.{variant}",
324 variant = v.ident,
325 );
326 let docs_of_take = format!(
327 "Returns `Some` if `self` is of variant [`{variant}`], and `None` \
328 otherwise.\n\n[`{variant}`]: #variant.{variant}",
329 variant = v.ident,
330 );
331
332 if let Fields::Unnamed(fields) = &v.fields {
333 let types = fields.unnamed.iter().map(|f| f.ty.clone());
334 let cast_ty = types_to_type(types.clone().map(|ty| add_ref(ty, false)));
335 let cast_ty_mut = types_to_type(types.clone().map(|ty| add_ref(ty, true)));
336 let ty = types_to_type(types);
337
338 let mut fields: Punctuated<Ident, Token![,]> = fields
339 .unnamed
340 .clone()
341 .into_pairs()
342 .enumerate()
343 .map(|(i, pair)| {
344 let handle = |f: Field| {
345 Ident::new(&format!("v{i}"), f.span())
347 };
348 match pair {
349 Pair::Punctuated(v, p) => Pair::Punctuated(handle(v), p),
350 Pair::End(v) => Pair::End(handle(v)),
351 }
352 })
353 .collect();
354
355 if let Some(mut pair) = fields.pop() {
360 if let Pair::Punctuated(v, _) = pair {
361 pair = Pair::End(v);
362 }
363 fields.extend(std::iter::once(pair));
364 }
365
366 let variant = &v.ident;
367
368 let item_impl: ItemImpl = parse_quote!(
369 impl #ty {
370 #[doc = #docs_of_cast]
371 #[inline]
372 pub fn #name_of_cast(&self) -> Option<#cast_ty> {
373 match self {
374 Self::#variant(#fields) => Some((#fields)),
375 _ => None,
376 }
377 }
378
379 #[doc = #docs_of_cast_mut]
380 #[inline]
381 pub fn #name_of_cast_mut(&mut self) -> Option<#cast_ty_mut> {
382 match self {
383 Self::#variant(#fields) => Some((#fields)),
384 _ => None,
385 }
386 }
387
388 #[doc = #docs_of_expect]
389 #[inline]
390 pub fn #name_of_expect(self) -> #ty
391 where
392 Self: ::std::fmt::Debug,
393 {
394 match self {
395 Self::#variant(#fields) => (#fields),
396 _ => panic!("called expect on {:?}", self),
397 }
398 }
399
400 #[doc = #docs_of_take]
401 #[inline]
402 pub fn #name_of_take(self) -> Option<#ty> {
403 match self {
404 Self::#variant(#fields) => Some((#fields)),
405 _ => None,
406 }
407 }
408 }
409 );
410
411 items.extend(item_impl.items);
412 }
413 }
414 }
415
416 items
417}
418
419fn types_to_type(types: impl Iterator<Item = Type>) -> Type {
420 let mut types: Punctuated<_, _> = types.collect();
421 if types.len() == 1 {
422 types.pop().expect("len is 1").into_value()
423 } else {
424 TypeTuple {
425 paren_token: Default::default(),
426 elems: types,
427 }
428 .into()
429 }
430}
431
432fn add_ref(ty: Type, mutable: bool) -> Type {
433 Type::Reference(TypeReference {
434 and_token: Default::default(),
435 lifetime: Some(Lifetime::new("'tu", Span::call_site())),
436 mutability: if mutable {
437 Some(Default::default())
438 } else {
439 None
440 },
441 elem: Box::new(ty),
442 })
443}
444
445trait ItemImplExt {
447 fn with_generics(self, generics: Generics) -> Self;
478}
479
480impl ItemImplExt for ItemImpl {
481 fn with_generics(mut self, mut generics: Generics) -> Self {
482 let need_new_punct = !generics.params.empty_or_trailing();
485 if need_new_punct {
486 generics
487 .params
488 .push_punct(syn::token::Comma(Span::call_site()));
489 }
490
491 if let Some(t) = generics.lt_token {
493 self.generics.lt_token = Some(t)
494 }
495 if let Some(t) = generics.gt_token {
496 self.generics.gt_token = Some(t)
497 }
498
499 let ty = self.self_ty;
500
501 let mut item: ItemImpl = {
503 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
504 let item = if let Some((ref polarity, ref path, ref for_token)) = self.trait_ {
505 quote! {
506 impl #impl_generics #polarity #path #for_token #ty #ty_generics #where_clause {}
507 }
508 } else {
509 quote! {
510 impl #impl_generics #ty #ty_generics #where_clause {}
511
512 }
513 };
514 parse2(item.into_token_stream())
515 .unwrap_or_else(|err| panic!("with_generics failed: {}", err))
516 };
517
518 item.generics
520 .params
521 .extend(self.generics.params.into_pairs());
522 match self.generics.where_clause {
523 Some(WhereClause {
524 ref mut predicates, ..
525 }) => predicates.extend(
526 generics
527 .where_clause
528 .into_iter()
529 .flat_map(|wc| wc.predicates.into_pairs()),
530 ),
531 ref mut opt @ None => *opt = generics.where_clause,
532 }
533
534 ItemImpl {
535 attrs: self.attrs,
536 defaultness: self.defaultness,
537 unsafety: self.unsafety,
538 impl_token: self.impl_token,
539 brace_token: self.brace_token,
540 items: self.items,
541 ..item
542 }
543 }
544}