sumtype_macro/
lib.rs

1use darling::ast::NestedMeta;
2use darling::FromMeta;
3use derive_syn_parse::Parse;
4use proc_macro::TokenStream as TokenStream1;
5use proc_macro2::Span;
6use proc_macro2::TokenStream;
7use proc_macro_error::{abort, proc_macro_error};
8use std::collections::{HashMap, HashSet};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::*;
12use template_quote::{quote, ToTokens};
13
14mod sumtrait_internal;
15
16fn random() -> u64 {
17    use std::hash::{BuildHasher, Hasher};
18    std::collections::hash_map::RandomState::new()
19        .build_hasher()
20        .finish()
21}
22
23fn generic_param_to_arg(i: GenericParam) -> GenericArgument {
24    match i {
25        GenericParam::Lifetime(LifetimeParam { lifetime, .. }) => {
26            GenericArgument::Lifetime(lifetime)
27        }
28        GenericParam::Type(TypeParam { ident, .. }) => GenericArgument::Type(parse_quote!(#ident)),
29        GenericParam::Const(ConstParam { ident, .. }) => {
30            GenericArgument::Const(parse_quote!(#ident))
31        }
32    }
33}
34
35fn merge_generic_params(
36    args1: impl IntoIterator<Item = GenericParam, IntoIter: Clone>,
37    args2: impl IntoIterator<Item = GenericParam, IntoIter: Clone>,
38) -> impl Iterator<Item = GenericParam> {
39    let it1 = args1.into_iter();
40    let it2 = args2.into_iter();
41    it1.clone()
42        .filter(|arg| {
43            if let GenericParam::Lifetime(_) = arg {
44                true
45            } else {
46                false
47            }
48        })
49        .chain(it2.clone().filter(|arg| {
50            if let GenericParam::Lifetime(_) = arg {
51                true
52            } else {
53                false
54            }
55        }))
56        .chain(it1.clone().filter(|arg| {
57            if let GenericParam::Const(_) = arg {
58                true
59            } else {
60                false
61            }
62        }))
63        .chain(it2.clone().filter(|arg| {
64            if let GenericParam::Const(_) = arg {
65                true
66            } else {
67                false
68            }
69        }))
70        .chain(it1.clone().filter(|arg| {
71            if let GenericParam::Type(_) = arg {
72                true
73            } else {
74                false
75            }
76        }))
77        .chain(it2.clone().filter(|arg| {
78            if let GenericParam::Type(_) = arg {
79                true
80            } else {
81                false
82            }
83        }))
84}
85
86fn merge_generic_args(
87    args1: impl IntoIterator<Item = GenericArgument, IntoIter: Clone>,
88    args2: impl IntoIterator<Item = GenericArgument, IntoIter: Clone>,
89) -> impl Iterator<Item = GenericArgument> {
90    let it1 = args1.into_iter();
91    let it2 = args2.into_iter();
92    it1.clone()
93        .filter(|arg| {
94            if let GenericArgument::Lifetime(_) = arg {
95                true
96            } else {
97                false
98            }
99        })
100        .chain(it2.clone().filter(|arg| {
101            if let GenericArgument::Lifetime(_) = arg {
102                true
103            } else {
104                false
105            }
106        }))
107        .chain(it1.clone().filter(|arg| {
108            if let GenericArgument::Const(_) = arg {
109                true
110            } else {
111                false
112            }
113        }))
114        .chain(it2.clone().filter(|arg| {
115            if let GenericArgument::Const(_) = arg {
116                true
117            } else {
118                false
119            }
120        }))
121        .chain(it1.clone().filter(|arg| {
122            if let GenericArgument::Type(_) = arg {
123                true
124            } else {
125                false
126            }
127        }))
128        .chain(it2.clone().filter(|arg| {
129            if let GenericArgument::Type(_) = arg {
130                true
131            } else {
132                false
133            }
134        }))
135        .chain(it1.filter(|arg| match arg {
136            GenericArgument::AssocType(_)
137            | GenericArgument::AssocConst(_)
138            | GenericArgument::Constraint(_) => true,
139            _ => false,
140        }))
141        .chain(it2.filter(|arg| match arg {
142            GenericArgument::AssocType(_)
143            | GenericArgument::AssocConst(_)
144            | GenericArgument::Constraint(_) => true,
145            _ => false,
146        }))
147}
148
149fn path_of_ident(ident: Ident, is_super: bool) -> Path {
150    let mut segments = vec![];
151    if is_super {
152        segments.push(PathSegment {
153            ident: Ident::new("super", Span::call_site()),
154            arguments: PathArguments::None,
155        });
156    }
157    segments.push(PathSegment {
158        ident,
159        arguments: PathArguments::None,
160    });
161    Path {
162        leading_colon: None,
163        segments: segments.into_iter().collect(),
164    }
165}
166
167fn split_for_impl(
168    generics: Option<&Generics>,
169) -> (Vec<GenericParam>, Vec<GenericArgument>, Vec<WherePredicate>) {
170    if let Some(generics) = generics {
171        let (_, ty_generics, where_clause) = generics.split_for_impl();
172        let ty_generics: std::result::Result<AngleBracketedGenericArguments, _> =
173            parse2(ty_generics.into_token_stream());
174        (
175            generics.params.iter().cloned().collect(),
176            ty_generics
177                .map(|g| g.args.into_iter().collect())
178                .unwrap_or(vec![]),
179            where_clause
180                .map(|w| w.predicates.iter().cloned().collect())
181                .unwrap_or(vec![]),
182        )
183    } else {
184        (vec![], vec![], vec![])
185    }
186}
187
188#[derive(Parse)]
189struct Arguments {
190    #[call(Punctuated::parse_terminated)]
191    bounds: Punctuated<Path, Token![+]>,
192}
193
194enum SumTypeImpl {
195    Trait(Path),
196}
197
198impl SumTypeImpl {
199    fn gen(
200        &self,
201        enum_path: &Path,
202        unspecified_ty_params: &[Ident],
203        variants: &[(Ident, Type)],
204        impl_generics: Vec<GenericParam>,
205        ty_generics: Vec<GenericArgument>,
206        where_clause: Vec<WherePredicate>,
207        constraint_expr_trait_ident: &Ident,
208    ) -> TokenStream {
209        match self {
210            SumTypeImpl::Trait(trait_path) => {
211                quote! {
212                    #trait_path!(
213                        /* constraint_expr_trait_ident = */ #constraint_expr_trait_ident,
214                        /* trait_path = */ #trait_path,
215                        /* enum_path = */ #enum_path,
216                        /* unspecified_ty_params = */ [#(#unspecified_ty_params),*],
217                        /* variants = */ [#(for (id, ty) in variants),{#id:#ty}],
218                        /* impl_generics_base = */ [ #(#impl_generics),* ],
219                        /* ty_generics_base = */ [#(#ty_generics),*],
220                        /* where_clause_base = */ { #(#where_clause),* },
221                    );
222                }
223            }
224        }
225    }
226}
227
228struct ExprMacroInfo {
229    span: Span,
230    variant_ident: Ident,
231    reftype_ident: Option<Ident>,
232    analyzed_bounds: HashMap<Ident, HashSet<Lifetime>>,
233    generics: Generics,
234}
235
236struct TypeMacroInfo {
237    _span: Span,
238    generic_args: Punctuated<GenericArgument, Token![,]>,
239}
240
241// Factory methods to process on supported tree elements
242trait ProcessTree: Sized {
243    // Collect macros in both type context and expr context. Replace macros with code.
244    fn collect_inline_macro(
245        &mut self,
246        enum_path: &Path,
247        typeref_path: &Path,
248        constraint_expr_trait_path: &Path,
249        generics: Option<&Generics>,
250        is_module: bool,
251    ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>);
252
253    fn emit_items(
254        mut self,
255        args: &Arguments,
256        generics_env: Option<&Generics>,
257        is_module: bool,
258        vis: Visibility,
259    ) -> (TokenStream, Self) {
260        let r = random();
261        let enum_ident = Ident::new(&format!("__Sumtype_Enum_{}", r), Span::call_site());
262        let typeref_ident =
263            Ident::new(&format!("__Sumtype_TypeRef_Trait_{}", r), Span::call_site());
264        let constraint_expr_trait_ident = Ident::new(
265            &format!("__Sumtype_ConstraintExprTrait_{}", r),
266            Span::call_site(),
267        );
268        let (found_exprs, type_emitted) = self.collect_inline_macro(
269            &path_of_ident(enum_ident.clone(), is_module),
270            &path_of_ident(typeref_ident.clone(), is_module),
271            &path_of_ident(constraint_expr_trait_ident.clone(), is_module),
272            generics_env,
273            is_module,
274        );
275        let reftypes = found_exprs
276            .iter()
277            .filter_map(|info| info.reftype_ident.clone())
278            .collect::<Vec<_>>();
279        let (impl_generics_env, _, where_clause_env) = split_for_impl(generics_env);
280        if found_exprs.len() == 0 {
281            abort!(Span::call_site(), "Cannot find any sumtype!() in expr");
282        }
283        let expr_generics_list = found_exprs.iter().fold(HashMap::new(), |mut acc, info| {
284            *acc.entry(info.generics.clone()).or_insert(0usize) += 1;
285            acc
286        });
287        if expr_generics_list.len() != 1 {
288            let mut expr_gparams = expr_generics_list.into_iter().collect::<Vec<_>>();
289            expr_gparams.sort_by_key(|item| item.1);
290            abort!(expr_gparams[0].0.span(), "Generic argument mismatch");
291        }
292        let expr_generics = expr_generics_list.into_iter().next().unwrap().0;
293        let mut analyzed = found_exprs.iter().fold(HashMap::new(), |mut acc, info| {
294            for (id, lts) in &info.analyzed_bounds {
295                acc.entry(id.clone())
296                    .or_insert(HashSet::new())
297                    .extend(lts.iter().map(|lt| TypeParamBound::Lifetime(lt.clone())));
298            }
299            acc
300        });
301        if let Some(where_clause) = &expr_generics.where_clause {
302            for pred in &where_clause.predicates {
303                match pred {
304                    WherePredicate::Type(PredicateType {
305                        bounded_ty, bounds, ..
306                    }) => {
307                        if let Type::Path(path) = bounded_ty {
308                            if path.qself.is_none() {
309                                if let Some(id) = path.path.get_ident() {
310                                    analyzed
311                                        .entry(id.clone())
312                                        .or_insert(HashSet::new())
313                                        .extend(bounds.clone());
314                                }
315                            }
316                        }
317                    }
318                    _ => (),
319                }
320            }
321        }
322        let expr_garg = expr_generics
323            .params
324            .iter()
325            .cloned()
326            .map(generic_param_to_arg)
327            .collect::<Vec<_>>();
328        for info in &type_emitted {
329            if info.generic_args.len() != expr_garg.len()
330                || !expr_garg
331                    .iter()
332                    .zip(&info.generic_args)
333                    .all(|two| match two {
334                        (GenericArgument::Lifetime(_), GenericArgument::Lifetime(_))
335                        | (GenericArgument::Const(_), GenericArgument::Const(_))
336                        | (GenericArgument::Type(_), GenericArgument::Type(_)) => true,
337                        _ => false,
338                    })
339            {
340                abort!(
341                    info.generic_args.span(),
342                    "The generic arguments are incompatible with generic params in expression."
343                )
344            }
345        }
346        let mut impl_generics =
347            merge_generic_params(impl_generics_env, expr_generics.params).collect::<Vec<_>>();
348        for g in impl_generics.iter_mut() {
349            if let GenericParam::Type(TypeParam { ident, bounds, .. }) = g {
350                if let Some(bs) = analyzed.get(ident) {
351                    for b in bs {
352                        bounds.push(b.clone());
353                    }
354                }
355            }
356        }
357        let ty_generics = impl_generics
358            .iter()
359            .cloned()
360            .map(generic_param_to_arg)
361            .collect::<Vec<_>>();
362        let where_clause = expr_generics
363            .where_clause
364            .clone()
365            .map(|wc| wc.predicates)
366            .into_iter()
367            .flatten()
368            .chain(where_clause_env)
369            .collect::<Vec<_>>();
370        let (unspecified_ty_params, variants) = found_exprs.iter().enumerate().fold(
371            (vec![], vec![]),
372            |(mut ty_params, mut variants), (i, info)| {
373                if let Some(reft) = &info.reftype_ident {
374                    variants.push((
375                        info.variant_ident.clone(),
376                        parse_quote!(<#reft as #typeref_ident<#(#ty_generics),*>>::Type),
377                    ));
378                } else {
379                    let tp_ident =
380                        Ident::new(&format!("__Sumtype_TypeParam_{}", i), Span::call_site());
381                    variants.push((info.variant_ident.clone(), parse_quote!(#tp_ident)));
382                    ty_params.push(tp_ident);
383                }
384                (ty_params, variants)
385            },
386        );
387        if let (Some(info), true) = (
388            found_exprs
389                .iter()
390                .filter(|info| info.reftype_ident.is_none())
391                .next(),
392            type_emitted.len() > 0,
393        ) {
394            abort!(
395                &info.span,
396                r#"
397To emit full type, you should specify the type.
398Example: sumtype!(std::iter::empty(), std::iter::Empty<T>)
399"#
400            )
401        } else {
402            let replaced_ty_generics: Vec<_> = ty_generics
403                .iter()
404                .map(|ga| match ga {
405                    GenericArgument::Lifetime(lt) => quote!(& #lt ()),
406                    GenericArgument::Const(_) => quote!(),
407                    o => quote!(#o),
408                })
409                .collect();
410            let constraint_traits = (0..args.bounds.len())
411                .map(|n| {
412                    Ident::new(
413                        &format!("__Sumtype_ConstraintExprTrait_{}_{}", n, random()),
414                        Span::call_site(),
415                    )
416                })
417                .collect::<Vec<_>>();
418            let out = quote! {
419                #(for reft in &reftypes) {
420                    #[doc(hidden)]
421                    struct #reft;
422                }
423                #[doc(hidden)]
424                trait #typeref_ident <#(#impl_generics),*> { type Type; }
425                #[doc(hidden)]
426                #vis enum #enum_ident <
427                    #(#impl_generics),*
428                    #(if impl_generics.len() > 0 && unspecified_ty_params.len() > 0) { , }
429                    #(#unspecified_ty_params),*
430                > {
431                    #(for (ident, ty) in &variants) {
432                        #ident ( #ty ),
433                    }
434                    __Uninhabited(
435                        (
436                            ::core::convert::Infallible,
437                            #(::core::marker::PhantomData<#replaced_ty_generics>),*
438                        )
439                    ),
440                }
441                #[doc(hidden)]
442                trait #constraint_expr_trait_ident<#(#impl_generics),*> {}
443                impl<#(#impl_generics,)*__Sumtype_TypeParam> #constraint_expr_trait_ident<#(#ty_generics),*> for __Sumtype_TypeParam
444                where
445                    #(for t in &constraint_traits) {
446                        __Sumtype_TypeParam: #t<#(#ty_generics),*>,
447                    }
448                    #(#where_clause,)*
449                {}
450                #(for (trait_, constraint_trait) in args.bounds.iter().zip(&constraint_traits)) {
451                    #{ SumTypeImpl::Trait(trait_.clone()).gen(
452                        &path_of_ident(enum_ident.clone(), false),
453                        unspecified_ty_params.as_slice(),
454                        variants.as_slice(),
455                        impl_generics.clone(),
456                        ty_generics.clone(),
457                        where_clause.clone(),
458                        constraint_trait,
459                    ) }
460                }
461            };
462            (out, self)
463        }
464    }
465}
466
467const _: () = {
468    use syn::visit_mut::VisitMut;
469    struct Visitor<'a> {
470        enum_path: &'a Path,
471        typeref_path: &'a Path,
472        constraint_expr_trait_path: &'a Path,
473        found_exprs: Vec<ExprMacroInfo>,
474        emit_type: Vec<TypeMacroInfo>,
475        generics: Option<&'a Generics>,
476        is_module: bool,
477    }
478
479    impl ProcessTree for Block {
480        fn collect_inline_macro(
481            &mut self,
482            enum_path: &Path,
483            typeref_path: &Path,
484            constraint_expr_trait_path: &Path,
485            generics: Option<&Generics>,
486            is_module: bool,
487        ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
488            let mut visitor = Visitor::new(
489                enum_path,
490                typeref_path,
491                constraint_expr_trait_path,
492                generics,
493                is_module,
494            );
495            visitor.visit_block_mut(self);
496            (visitor.found_exprs, visitor.emit_type)
497        }
498    }
499
500    impl ProcessTree for Item {
501        fn collect_inline_macro(
502            &mut self,
503            enum_path: &Path,
504            typeref_path: &Path,
505            constraint_expr_trait_path: &Path,
506            generics: Option<&Generics>,
507            is_module: bool,
508        ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
509            let mut visitor = Visitor::new(
510                enum_path,
511                typeref_path,
512                constraint_expr_trait_path,
513                generics,
514                is_module,
515            );
516            visitor.visit_item_mut(self);
517            (visitor.found_exprs, visitor.emit_type)
518        }
519    }
520
521    impl ProcessTree for Stmt {
522        fn collect_inline_macro(
523            &mut self,
524            enum_path: &Path,
525            typeref_path: &Path,
526            constraint_expr_trait_path: &Path,
527            generics: Option<&Generics>,
528            is_module: bool,
529        ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
530            let mut visitor = Visitor::new(
531                enum_path,
532                typeref_path,
533                constraint_expr_trait_path,
534                generics,
535                is_module,
536            );
537            visitor.visit_stmt_mut(self);
538            (visitor.found_exprs, visitor.emit_type)
539        }
540    }
541
542    impl<'a> Visitor<'a> {
543        fn new(
544            enum_path: &'a Path,
545            typeref_path: &'a Path,
546            constraint_expr_trait_path: &'a Path,
547            generics: Option<&'a Generics>,
548            is_module: bool,
549        ) -> Self {
550            Self {
551                enum_path,
552                typeref_path,
553                constraint_expr_trait_path,
554                found_exprs: Vec::new(),
555                emit_type: Vec::new(),
556                generics,
557                is_module,
558            }
559        }
560        fn do_type_macro(&mut self, mac: &Macro) -> TokenStream {
561            #[derive(Parse)]
562            struct Arg {
563                #[call(Punctuated::parse_terminated)]
564                generic_args: Punctuated<GenericArgument, Token![,]>,
565            }
566            let arg: Arg = mac
567                .parse_body()
568                .unwrap_or_else(|e| abort!(e.span(), &format!("{}", &e)));
569            let ty_generics = merge_generic_args(
570                self.generics
571                    .iter()
572                    .map(|g| g.params.iter().cloned().map(generic_param_to_arg))
573                    .flatten(),
574                arg.generic_args.clone(),
575            )
576            .collect::<Vec<_>>();
577            self.emit_type.push(TypeMacroInfo {
578                _span: mac.span(),
579                generic_args: arg.generic_args,
580            });
581            quote! {
582                #{&self.enum_path}
583                #(if ty_generics.len() > 0){
584                    <#(#ty_generics),*>
585                }
586            }
587        }
588
589        fn analyze_lifetime_bounds(
590            &self,
591            generics: &Generics,
592            ty: &Type,
593        ) -> HashMap<Ident, HashSet<Lifetime>> {
594            struct LifetimeVisitor {
595                generic_lifetimes: HashSet<Lifetime>,
596                generic_params: HashSet<Ident>,
597                lifetime_stack: Vec<Lifetime>,
598                result: HashMap<Ident, HashSet<Lifetime>>,
599            }
600            use syn::visit::Visit;
601            impl<'ast> syn::visit::Visit<'ast> for LifetimeVisitor {
602                fn visit_type_reference(&mut self, i: &TypeReference) {
603                    if let Some(lt) = &i.lifetime {
604                        if self.generic_lifetimes.contains(lt) {
605                            self.lifetime_stack.push(lt.clone());
606                            syn::visit::visit_type_reference(self, i);
607                            self.lifetime_stack.pop();
608                            return;
609                        }
610                    }
611                    syn::visit::visit_type_reference(self, i);
612                }
613                fn visit_type_path(&mut self, i: &TypePath) {
614                    if i.qself.is_none() {
615                        if let Some(id) = i.path.get_ident() {
616                            if self.generic_params.contains(id) {
617                                self.result
618                                    .entry(id.clone())
619                                    .or_insert(HashSet::new())
620                                    .extend(self.lifetime_stack.clone());
621                            }
622                            return;
623                        }
624                    }
625                    syn::visit::visit_type_path(self, i);
626                }
627            }
628            let mut visitor = LifetimeVisitor {
629                generic_lifetimes: generics
630                    .params
631                    .iter()
632                    .filter_map(|p| {
633                        if let GenericParam::Lifetime(LifetimeParam { lifetime, .. }) = p {
634                            Some(lifetime.clone())
635                        } else {
636                            None
637                        }
638                    })
639                    .collect(),
640                generic_params: generics
641                    .params
642                    .iter()
643                    .filter_map(|p| {
644                        if let GenericParam::Type(TypeParam { ident, .. }) = p {
645                            Some(ident.clone())
646                        } else {
647                            None
648                        }
649                    })
650                    .collect(),
651                lifetime_stack: Vec::new(),
652                result: HashMap::new(),
653            };
654            visitor.visit_type(ty);
655            visitor.result
656        }
657
658        fn do_expr_macro(&mut self, mac: &Macro) -> TokenStream {
659            #[derive(Parse)]
660            struct Arg {
661                expr: Expr,
662                _comma_token: Option<Token![,]>,
663                _for_token: Option<Token![for]>,
664                #[prefix(Option<Token![<]>)]
665                #[postfix(Option<Token![>]>)]
666                #[parse_if(_for_token.is_some())]
667                #[call(Punctuated::parse_separated_nonempty)]
668                for_generics: Option<Punctuated<GenericParam, Token![,]>>,
669                #[parse_if(_comma_token.is_some())]
670                ty: Option<Type>,
671                #[parse_if(_comma_token.is_some())]
672                where_clause: Option<Option<WhereClause>>,
673            }
674            let arg: Arg = mac
675                .parse_body()
676                .unwrap_or_else(|e| abort!(e.span(), &format!("{}", &e)));
677            let n = self.found_exprs.len();
678            let variant_ident = Ident::new(&format!("__SumType_Variant_{}", n), Span::call_site());
679            let reftype_ident = Ident::new(
680                &format!("__SumType_RefType_{}_{}", random(), n),
681                Span::call_site(),
682            );
683            let reftype_path = path_of_ident(reftype_ident.clone(), self.is_module);
684            let id_fn_ident =
685                Ident::new(&format!("__sum_type_id_fn_{}", random()), Span::call_site());
686            let (mut impl_generics, _, where_clause) = split_for_impl(self.generics);
687            let analyzed =
688                if let (Some(generics), Some(ty)) = (self.generics.as_ref(), arg.ty.as_ref()) {
689                    self.analyze_lifetime_bounds(*generics, ty)
690                } else {
691                    HashMap::new()
692                };
693            let generics = Generics {
694                params: arg.for_generics.clone().unwrap_or(Default::default()),
695                where_clause: arg.where_clause.unwrap_or(Some(WhereClause {
696                    predicates: Punctuated::new(),
697                    where_token: Default::default(),
698                })),
699                ..Default::default()
700            };
701            for g in impl_generics.iter_mut() {
702                if let GenericParam::Type(TypeParam { ident, bounds, .. }) = g {
703                    if let Some(lts) = analyzed.get(ident) {
704                        for lt in lts {
705                            bounds.push(TypeParamBound::Lifetime(lt.clone().clone()));
706                        }
707                    }
708                }
709            }
710            let impl_generics =
711                merge_generic_params(impl_generics, generics.params.clone()).collect::<Vec<_>>();
712            let ty_generics = impl_generics
713                .iter()
714                .cloned()
715                .map(generic_param_to_arg)
716                .collect::<Vec<_>>();
717            let where_clause = generics
718                .where_clause
719                .clone()
720                .map(|wc| wc.predicates)
721                .into_iter()
722                .flatten()
723                .chain(where_clause)
724                .collect::<Vec<_>>();
725            self.found_exprs.push(ExprMacroInfo {
726                span: mac.span(),
727                variant_ident: variant_ident.clone(),
728                reftype_ident: arg.ty.as_ref().map(|_| reftype_ident.clone()),
729                analyzed_bounds: analyzed.clone(),
730                generics,
731            });
732            quote! {
733                {
734                    #(if let Some(ty) = &arg.ty){
735                        impl<#(#impl_generics,)*> #{&self.typeref_path} <#(#ty_generics),*> for #reftype_path
736                            #(if where_clause.len() > 0) {
737                                where #(#where_clause,)*
738                            }
739                        {
740                            type Type = #ty;
741                        }
742                    }
743                    fn #id_fn_ident<
744                        #(#impl_generics,)* __SumType_T: #{&self.constraint_expr_trait_path}<#(#ty_generics),*>
745                    >(t: __SumType_T) -> __SumType_T
746                    #(if where_clause.len() > 0) {
747                        where #(#where_clause,)*
748                    }
749                    { t }
750                    #id_fn_ident::<#(#ty_generics,)*_>(#{&self.enum_path}::#variant_ident(#{&arg.expr}))
751                }
752            }
753        }
754    }
755
756    impl<'a> VisitMut for Visitor<'a> {
757        fn visit_type_mut(&mut self, ty: &mut Type) {
758            if let Type::Macro(tm) = &*ty {
759                if tm.mac.path.is_ident("sumtype") {
760                    let out = self.do_type_macro(&tm.mac);
761                    *ty = parse2(out).unwrap();
762                    return;
763                }
764            }
765            syn::visit_mut::visit_type_mut(self, ty);
766        }
767
768        fn visit_expr_mut(&mut self, expr: &mut Expr) {
769            if let Expr::Macro(em) = &*expr {
770                if em.mac.path.is_ident("sumtype") {
771                    let out = self.do_expr_macro(&em.mac);
772                    *expr = parse2(out).unwrap();
773                    return;
774                }
775            }
776            syn::visit_mut::visit_expr_mut(self, expr);
777        }
778
779        fn visit_stmt_mut(&mut self, stmt: &mut Stmt) {
780            if let Stmt::Macro(sm) = &*stmt {
781                if sm.mac.path.is_ident("sumtype") {
782                    let out = self.do_expr_macro(&sm.mac);
783                    *stmt = parse2(out).unwrap();
784                    return;
785                }
786            }
787            syn::visit_mut::visit_stmt_mut(self, stmt);
788        }
789    }
790};
791
792fn inner(args: &Arguments, input: TokenStream) -> TokenStream {
793    let public = Visibility::Public(Default::default());
794    if let Ok(block) = parse2::<Block>(input.clone()) {
795        let (out, block) = block.emit_items(args, None, false, public);
796        quote! { #out #[allow(non_local_definitions)] #block }
797    } else if let Ok(item_trait) = parse2::<ItemTrait>(input.clone()) {
798        let generics = item_trait.generics.clone();
799        let vis = item_trait.vis.clone();
800        let (out, item) = Item::Trait(item_trait).emit_items(args, Some(&generics), false, vis);
801        quote! { #out #[allow(non_local_definitions)] #item }
802    } else if let Ok(item_impl) = parse2::<ItemImpl>(input.clone()) {
803        let generics = item_impl.generics.clone();
804        let (out, item) = Item::Impl(item_impl).emit_items(args, Some(&generics), false, public);
805        quote! { #out #[allow(non_local_definitions)] #item }
806    } else if let Ok(item_fn) = parse2::<ItemFn>(input.clone()) {
807        let generics = item_fn.sig.generics.clone();
808        let vis = item_fn.vis.clone();
809        let (out, item) = Item::Fn(item_fn).emit_items(args, Some(&generics), false, vis);
810        quote! { #out #[allow(non_local_definitions)] #item }
811    } else if let Ok(item_mod) = parse2::<ItemMod>(input.clone()) {
812        let (out, item) = Item::Mod(item_mod).emit_items(args, None, true, public);
813        quote! { #out #[allow(non_local_definitions)] #item }
814    } else if let Ok(item) = parse2::<Item>(input.clone()) {
815        let (out, item) = item.emit_items(args, None, false, public);
816        quote! { #out #[allow(non_local_definitions)] #item }
817    } else if let Ok(stmt) = parse2::<Stmt>(input.clone()) {
818        let (out, stmt) = stmt.emit_items(args, None, false, public);
819        quote! { #out #[allow(non_local_definitions)] #stmt }
820    } else {
821        abort!(input.span(), "This element is not supported")
822    }
823}
824
825fn process_supported_supertraits<'a>(
826    traits: impl IntoIterator<Item = &'a TypeParamBound>,
827    krate: &Path,
828) -> (Vec<Path>, Vec<Path>) {
829    let mut supertraits = Vec::new();
830    let mut derive_traits = Vec::new();
831    for tpb in traits.into_iter() {
832        if let TypeParamBound::Trait(tb) = tpb {
833            if let Some(ident) = tb.path.get_ident() {
834                match ident.to_string().as_str() {
835                    "Copy" | "Clone" | "Hash" | "Eq" => {
836                        supertraits.push(parse_quote!(#krate::traits::#ident))
837                    }
838                    "PartialEq" => derive_traits.push(parse_quote!(PartialEq)),
839                    o if o.starts_with("__SumTrait_Sealed") => (),
840                    _ => (),
841                }
842            } else {
843                supertraits.push(tb.path.clone())
844            }
845        } else {
846            abort!(tpb.span(), "Only path is supported");
847        }
848    }
849    (supertraits, derive_traits)
850}
851
852fn collect_typeref_types(input: &ItemTrait) -> Vec<Type> {
853    fn can_make_typeref_type(ty: &Type, generics: &Generics) -> bool {
854        use syn::visit::Visit;
855        struct Visitor(bool, Vec<Ident>, Vec<Ident>);
856        let generics_param_tys = generics
857            .params
858            .iter()
859            .filter_map(|p| {
860                if let GenericParam::Type(TypeParam { ident, .. }) = p {
861                    Some(ident.clone())
862                } else {
863                    None
864                }
865            })
866            .collect::<Vec<_>>();
867        let generics_param_vals = generics
868            .params
869            .iter()
870            .filter_map(|p| {
871                if let GenericParam::Const(ConstParam { ident, .. }) = p {
872                    Some(ident.clone())
873                } else {
874                    None
875                }
876            })
877            .collect::<Vec<_>>();
878        impl<'a> syn::visit::Visit<'a> for Visitor {
879            fn visit_type(&mut self, i: &Type) {
880                match i {
881                    Type::ImplTrait(_) | Type::Verbatim(_) | Type::Infer(_) | Type::Macro(_) => {
882                        self.0 = false;
883                    }
884                    Type::Reference(TypeReference { lifetime, .. }) => {
885                        if let Some(lifetime) = lifetime {
886                            if &lifetime.ident != "static" {
887                                self.0 = false;
888                            }
889                        } else {
890                            self.0 = false;
891                        }
892                    }
893                    Type::Path(tp) => {
894                        if tp.qself.is_none() && self.2.iter().any(|ident| tp.path.is_ident(ident))
895                            || (tp.path.segments.len() >= 1 && &tp.path.segments[0].ident == "Self")
896                        {
897                            self.0 = false;
898                        }
899                    }
900                    _ => (),
901                }
902                syn::visit::visit_type(self, i)
903            }
904            fn visit_expr(&mut self, i: &Expr) {
905                if let Expr::Path(ExprPath { qself, path, .. }) = i {
906                    if qself.is_none() && self.1.iter().any(|ident| path.is_ident(ident)) {
907                        self.0 = false;
908                    }
909                }
910            }
911        }
912        let mut visitor = Visitor(true, generics_param_tys, generics_param_vals);
913        visitor.visit_type(ty);
914        visitor.0
915    }
916    use syn::visit::Visit;
917    struct Visitor(Vec<Type>, Generics);
918    impl<'a> syn::visit::Visit<'a> for Visitor {
919        fn visit_type(&mut self, i: &Type) {
920            if can_make_typeref_type(i, &self.1) {
921                self.0.push(i.clone());
922            } else {
923                syn::visit::visit_type(self, i)
924            }
925        }
926    }
927    let mut visitor = Visitor(Vec::new(), input.generics.clone());
928    visitor.visit_item_trait(input);
929    visitor.0
930}
931
932fn sumtrait_impl(
933    args: Option<Path>,
934    marker_path: &Path,
935    krate: &Path,
936    input: ItemTrait,
937) -> TokenStream {
938    let (supertraits, derive_traits) = process_supported_supertraits(&input.supertraits, krate);
939    for item in &input.items {
940        match item {
941            TraitItem::Const(_) => abort!(item.span(), "associated const is not supported"),
942            TraitItem::Fn(tfn) => {
943                if tfn.sig.inputs.len() == 0 || !matches!(&tfn.sig.inputs[0], FnArg::Receiver(_)) {
944                    abort!(tfn.sig.span(), "requires receiver")
945                }
946            }
947            TraitItem::Type(tty) => {
948                if tty.default.is_some() {
949                    abort!(tty.span(), "associated type defaults is not supported")
950                }
951                if tty.generics.params.len() > 0 || tty.generics.where_clause.is_some() {
952                    abort!(
953                        tty.generics.span(),
954                        "generalized associated types is not supported"
955                    )
956                }
957            }
958            o => abort!(o.span(), "Not supported"),
959        }
960    }
961    let temporary_mac_name =
962        Ident::new(&format!("__sumtype_macro_{}", random()), Span::call_site());
963    let typeref_types = collect_typeref_types(&input);
964    let (_, _, where_clause) = input.generics.split_for_impl();
965    let typeref_id = random() as usize;
966    quote! {
967        #input
968        #(for (i, ty) in typeref_types.iter().enumerate()) {
969            impl<#(for p in &input.generics.params),{#p}> #krate::TypeRef<#typeref_id, #i> for #marker_path #where_clause {
970                type Type = #ty;
971            }
972        }
973
974        #[doc(hidden)]
975        #[macro_export]
976        macro_rules! #temporary_mac_name {
977            ($($t:tt)*) => {
978                #krate::_sumtrait_internal!(
979                    { $($t)* }
980                    /* typerefs= */  [#(#typeref_types),*],
981                    /* item_trait= */  {#input},
982                    /* typeref_id= */  #typeref_id,
983                    /* krate= */  #krate,
984                    /* marker_path= */ #marker_path,
985                    /* implementation= */  [#{args.map(|m| quote!(#m)).unwrap_or(quote!(_))}],
986                    /* supertraits= */ [#(#supertraits),*],
987                    /* derive_traits= */ [#(#derive_traits),*],
988                );
989            };
990        }
991        #[doc(hidden)]
992        #{&input.vis} use #temporary_mac_name as #{&input.ident};
993    }
994}
995
996#[doc(hidden)]
997#[proc_macro_error]
998#[proc_macro]
999#[proc_debug::proc_debug]
1000pub fn _sumtrait_internal(input: TokenStream1) -> TokenStream1 {
1001    sumtrait_internal::sumtrait_internal(input.into()).into()
1002}
1003
1004#[proc_macro_error]
1005#[proc_macro_attribute]
1006#[proc_debug::proc_debug]
1007pub fn sumtrait(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
1008    #[derive(FromMeta, Debug)]
1009    struct SumtraitArgs {
1010        implement: Option<Path>,
1011        krate: Option<Path>,
1012        marker: Path,
1013    }
1014    let args = SumtraitArgs::from_list(&NestedMeta::parse_meta_list(attr.into()).unwrap()).unwrap();
1015
1016    let krate = args.krate.unwrap_or(parse_quote!(::sumtype));
1017    sumtrait_impl(
1018        args.implement,
1019        &args.marker,
1020        &krate,
1021        parse(input).unwrap_or_else(|_| abort!(Span::call_site(), "Requires trait definition")),
1022    )
1023    .into()
1024}
1025
1026/// Enabling `sumtype!(..)` macro in the context.
1027///
1028/// For each context marked by `#[sumtype]`, sumtype makes an union type of several
1029/// [`std::iter::Iterator`] types. To intern an expression of `Iterator` into the union type, you
1030/// can use `sumtype!([expr])` syntax. This is an example of returning unified `Iterator`:
1031///
1032/// ```
1033/// # use sumtype::sumtype;
1034/// # use std::iter::Iterator;
1035/// #[sumtype]
1036/// fn return_iter(a: bool) -> impl Iterator<Item = ()> {
1037///     if a {
1038///         sumtype!(std::iter::once(()))
1039///     } else {
1040///         sumtype!(vec![()].into_iter())
1041///     }
1042/// }
1043/// ```
1044///
1045/// This function returns [`std::iter::Once`] or [`std::vec::IntoIter`] depends on `a` value. The
1046/// `#[sumtype]` system make annonymous union type that is also [`std::iter::Iterator`], and wrap
1047/// each `sumtype!(..)` expression with the union type. The mechanism is zerocost when `a` is fixed
1048/// in the compile process.
1049///
1050/// You can place exact (non-annonymous) type using `sumtype!()` macro in type context. In this
1051/// way, you should specify type using `sumtype!([expr], [type])` format like:
1052///
1053/// ```
1054/// # use sumtype::sumtype;
1055/// # use std::iter::Iterator;
1056/// #[sumtype]
1057/// fn return_iter_explicit(a: bool) -> sumtype!() {
1058///     if a {
1059///         sumtype!(std::iter::once(()), std::iter::Once<()>)
1060///     } else {
1061///         sumtype!(vec![()].into_iter(), std::vec::IntoIter<()>)
1062///     }
1063/// }
1064/// ```
1065///
1066/// You may need to use additional generic parameters in `[ty]`, the following example will be
1067/// useful:
1068///
1069/// ```
1070/// # use sumtype::sumtype;
1071/// # use std::iter::Iterator;
1072/// struct S;
1073///
1074/// #[sumtype]
1075/// impl S {
1076///     fn f<'a, T>(t: &'a T, count: usize) -> sumtype!['a, T] {
1077///         if count == 0 {
1078///             sumtype!(std::iter::empty(), for<'a, T: 'a> std::iter::Empty<&'a T>)
1079///         } else {
1080///             sumtype!(
1081///                 std::iter::repeat(t).take(count),
1082///                 for<'a, T: 'a> std::iter::Take<std::iter::Repeat<&'a T>>
1083///             )
1084///         }
1085///     }
1086/// }
1087/// ```
1088#[proc_macro_error]
1089#[proc_macro_attribute]
1090#[proc_debug::proc_debug]
1091pub fn sumtype(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
1092    inner(&parse_macro_input!(attr as Arguments), input.into()).into()
1093}