sumtype_macro/
lib.rs

1use darling::ast::NestedMeta;
2use darling::FromMeta;
3use derive_syn_parse::Parse;
4use proc_macro::TokenStream as TokenStream1;
5use proc_macro2::Span;
6use proc_macro2::TokenStream;
7use proc_macro_error::{abort, proc_macro_error};
8use std::collections::{HashMap, HashSet};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::*;
12use template_quote::{quote, ToTokens};
13use type_leak::*;
14
15mod sumtrait_internal;
16
17fn random() -> u64 {
18    use std::hash::{BuildHasher, Hasher};
19    std::collections::hash_map::RandomState::new()
20        .build_hasher()
21        .finish()
22}
23
24fn generic_param_to_arg(i: GenericParam) -> GenericArgument {
25    match i {
26        GenericParam::Lifetime(LifetimeParam { lifetime, .. }) => {
27            GenericArgument::Lifetime(lifetime)
28        }
29        GenericParam::Type(TypeParam { ident, .. }) => GenericArgument::Type(parse_quote!(#ident)),
30        GenericParam::Const(ConstParam { ident, .. }) => {
31            GenericArgument::Const(parse_quote!(#ident))
32        }
33    }
34}
35
36fn merge_generic_params(
37    args1: impl IntoIterator<Item = GenericParam, IntoIter: Clone>,
38    args2: impl IntoIterator<Item = GenericParam, IntoIter: Clone>,
39) -> impl Iterator<Item = GenericParam> {
40    let it1 = args1.into_iter();
41    let it2 = args2.into_iter();
42    it1.clone()
43        .filter(|arg| matches!(arg, GenericParam::Lifetime(_)))
44        .chain(
45            it2.clone()
46                .filter(|arg| matches!(arg, GenericParam::Lifetime(_))),
47        )
48        .chain(
49            it1.clone()
50                .filter(|arg| matches!(arg, GenericParam::Const(_))),
51        )
52        .chain(
53            it2.clone()
54                .filter(|arg| matches!(arg, GenericParam::Const(_))),
55        )
56        .chain(
57            it1.clone()
58                .filter(|arg| matches!(arg, GenericParam::Type(_))),
59        )
60        .chain(
61            it2.clone()
62                .filter(|arg| matches!(arg, GenericParam::Type(_))),
63        )
64}
65
66fn merge_generic_args(
67    args1: impl IntoIterator<Item = GenericArgument, IntoIter: Clone>,
68    args2: impl IntoIterator<Item = GenericArgument, IntoIter: Clone>,
69) -> impl Iterator<Item = GenericArgument> {
70    let it1 = args1.into_iter();
71    let it2 = args2.into_iter();
72    it1.clone()
73        .filter(|arg| matches!(arg, GenericArgument::Lifetime(_)))
74        .chain(
75            it2.clone()
76                .filter(|arg| matches!(arg, GenericArgument::Lifetime(_))),
77        )
78        .chain(
79            it1.clone()
80                .filter(|arg| matches!(arg, GenericArgument::Const(_))),
81        )
82        .chain(
83            it2.clone()
84                .filter(|arg| matches!(arg, GenericArgument::Const(_))),
85        )
86        .chain(
87            it1.clone()
88                .filter(|arg| matches!(arg, GenericArgument::Type(_))),
89        )
90        .chain(
91            it2.clone()
92                .filter(|arg| matches!(arg, GenericArgument::Type(_))),
93        )
94        .chain(it1.filter(|arg| {
95            matches!(
96                arg,
97                GenericArgument::AssocType(_)
98                    | GenericArgument::AssocConst(_)
99                    | GenericArgument::Constraint(_)
100            )
101        }))
102        .chain(it2.filter(|arg| {
103            matches!(
104                arg,
105                GenericArgument::AssocType(_)
106                    | GenericArgument::AssocConst(_)
107                    | GenericArgument::Constraint(_)
108            )
109        }))
110}
111
112fn path_of_ident(ident: Ident, is_super: bool) -> Path {
113    let mut segments = vec![];
114    if is_super {
115        segments.push(PathSegment {
116            ident: Ident::new("super", Span::call_site()),
117            arguments: PathArguments::None,
118        });
119    }
120    segments.push(PathSegment {
121        ident,
122        arguments: PathArguments::None,
123    });
124    Path {
125        leading_colon: None,
126        segments: segments.into_iter().collect(),
127    }
128}
129
130fn split_for_impl(
131    generics: Option<&Generics>,
132) -> (Vec<GenericParam>, Vec<GenericArgument>, Vec<WherePredicate>) {
133    if let Some(generics) = generics {
134        let (_, ty_generics, where_clause) = generics.split_for_impl();
135        let ty_generics: std::result::Result<AngleBracketedGenericArguments, _> =
136            parse2(ty_generics.into_token_stream());
137        (
138            generics.params.iter().cloned().collect(),
139            ty_generics
140                .map(|g| g.args.into_iter().collect())
141                .unwrap_or(vec![]),
142            where_clause
143                .map(|w| w.predicates.iter().cloned().collect())
144                .unwrap_or(vec![]),
145        )
146    } else {
147        (vec![], vec![], vec![])
148    }
149}
150
151#[derive(Parse)]
152struct Arguments {
153    #[call(Punctuated::parse_terminated)]
154    bounds: Punctuated<Path, Token![+]>,
155}
156
157enum SumTypeImpl {
158    Trait(Path),
159}
160
161impl SumTypeImpl {
162    #[allow(clippy::too_many_arguments)]
163    fn gen(
164        &self,
165        enum_path: &Path,
166        unspecified_ty_params: &[Ident],
167        variants: &[(Ident, Type)],
168        impl_generics: Vec<GenericParam>,
169        ty_generics: Vec<GenericArgument>,
170        where_clause: Vec<WherePredicate>,
171        constraint_expr_trait_ident: &Ident,
172    ) -> TokenStream {
173        match self {
174            SumTypeImpl::Trait(trait_path) => {
175                quote! {
176                    #trait_path!(
177                        /* constraint_expr_trait_ident = */ #constraint_expr_trait_ident,
178                        /* trait_path = */ #trait_path,
179                        /* enum_path = */ #enum_path,
180                        /* unspecified_ty_params = */ [#(#unspecified_ty_params),*],
181                        /* variants = */ [#(for (id, ty) in variants),{#id:#ty}],
182                        /* impl_generics_base = */ [ #(#impl_generics),* ],
183                        /* ty_generics_base = */ [#(#ty_generics),*],
184                        /* where_clause_base = */ { #(#where_clause),* },
185                    );
186                }
187            }
188        }
189    }
190}
191
192struct ExprMacroInfo {
193    span: Span,
194    variant_ident: Ident,
195    reftype_ident: Option<Ident>,
196    analyzed_bounds: HashMap<Ident, HashSet<Lifetime>>,
197    generics: Generics,
198}
199
200struct TypeMacroInfo {
201    _span: Span,
202    generic_args: Punctuated<GenericArgument, Token![,]>,
203}
204
205// Factory methods to process on supported tree elements
206trait ProcessTree: Sized {
207    // Collect macros in both type context and expr context. Replace macros with code.
208    fn collect_inline_macro(
209        &mut self,
210        enum_path: &Path,
211        typeref_path: &Path,
212        constraint_expr_trait_path: &Path,
213        generics: Option<&Generics>,
214        is_module: bool,
215    ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>);
216
217    fn emit_items(
218        mut self,
219        args: &Arguments,
220        generics_env: Option<&Generics>,
221        is_module: bool,
222        vis: Visibility,
223    ) -> (TokenStream, Self) {
224        let r = random();
225        let enum_ident = Ident::new(&format!("__Sumtype_Enum_{}", r), Span::call_site());
226        let typeref_ident =
227            Ident::new(&format!("__Sumtype_TypeRef_Trait_{}", r), Span::call_site());
228        let constraint_expr_trait_ident = Ident::new(
229            &format!("__Sumtype_ConstraintExprTrait_{}", r),
230            Span::call_site(),
231        );
232        let (found_exprs, type_emitted) = self.collect_inline_macro(
233            &path_of_ident(enum_ident.clone(), is_module),
234            &path_of_ident(typeref_ident.clone(), is_module),
235            &path_of_ident(constraint_expr_trait_ident.clone(), is_module),
236            generics_env,
237            is_module,
238        );
239        let reftypes = found_exprs
240            .iter()
241            .filter_map(|info| info.reftype_ident.clone())
242            .collect::<Vec<_>>();
243        let (impl_generics_env, _, where_clause_env) = split_for_impl(generics_env);
244        if found_exprs.is_empty() {
245            abort!(Span::call_site(), "Cannot find any sumtype!() in expr");
246        }
247        let expr_generics_list = found_exprs.iter().fold(HashMap::new(), |mut acc, info| {
248            *acc.entry(info.generics.clone()).or_insert(0usize) += 1;
249            acc
250        });
251        if expr_generics_list.len() != 1 {
252            let mut expr_gparams = expr_generics_list.into_iter().collect::<Vec<_>>();
253            expr_gparams.sort_by_key(|item| item.1);
254            abort!(expr_gparams[0].0.span(), "Generic argument mismatch");
255        }
256        let expr_generics = expr_generics_list.into_iter().next().unwrap().0;
257        let mut analyzed = found_exprs.iter().fold(
258            HashMap::new(),
259            |mut acc: HashMap<Ident, HashSet<TypeParamBound>>, info| {
260                for (id, lts) in &info.analyzed_bounds {
261                    acc.entry(id.clone())
262                        .or_default()
263                        .extend(lts.iter().map(|lt| TypeParamBound::Lifetime(lt.clone())));
264                }
265                acc
266            },
267        );
268        if let Some(where_clause) = &expr_generics.where_clause {
269            for pred in &where_clause.predicates {
270                if let WherePredicate::Type(PredicateType {
271                    bounded_ty: Type::Path(path),
272                    bounds,
273                    ..
274                }) = pred
275                {
276                    if path.qself.is_none() {
277                        if let Some(id) = path.path.get_ident() {
278                            analyzed
279                                .entry(id.clone())
280                                .or_insert(HashSet::new())
281                                .extend(bounds.clone());
282                        }
283                    }
284                }
285            }
286        }
287        let expr_garg = expr_generics
288            .params
289            .iter()
290            .cloned()
291            .map(generic_param_to_arg)
292            .collect::<Vec<_>>();
293        for info in &type_emitted {
294            if info.generic_args.len() != expr_garg.len()
295                || !expr_garg.iter().zip(&info.generic_args).all(|two| {
296                    matches!(
297                        two,
298                        (GenericArgument::Lifetime(_), GenericArgument::Lifetime(_))
299                            | (GenericArgument::Const(_), GenericArgument::Const(_))
300                            | (GenericArgument::Type(_), GenericArgument::Type(_))
301                    )
302                })
303            {
304                abort!(
305                    info.generic_args.span(),
306                    "The generic arguments are incompatible with generic params in expression."
307                )
308            }
309        }
310        let mut impl_generics =
311            merge_generic_params(impl_generics_env, expr_generics.params).collect::<Vec<_>>();
312        for g in impl_generics.iter_mut() {
313            if let GenericParam::Type(TypeParam { ident, bounds, .. }) = g {
314                if let Some(bs) = analyzed.get(ident) {
315                    for b in bs {
316                        bounds.push(b.clone());
317                    }
318                }
319            }
320        }
321        let ty_generics = impl_generics
322            .iter()
323            .cloned()
324            .map(generic_param_to_arg)
325            .collect::<Vec<_>>();
326        let where_clause = expr_generics
327            .where_clause
328            .clone()
329            .map(|wc| wc.predicates)
330            .into_iter()
331            .flatten()
332            .chain(where_clause_env)
333            .collect::<Vec<_>>();
334        let (unspecified_ty_params, variants) = found_exprs.iter().enumerate().fold(
335            (vec![], vec![]),
336            |(mut ty_params, mut variants), (i, info)| {
337                if let Some(reft) = &info.reftype_ident {
338                    variants.push((
339                        info.variant_ident.clone(),
340                        parse_quote!(<#reft as #typeref_ident<#(#ty_generics),*>>::Type),
341                    ));
342                } else {
343                    let tp_ident =
344                        Ident::new(&format!("__Sumtype_TypeParam_{}", i), Span::call_site());
345                    variants.push((info.variant_ident.clone(), parse_quote!(#tp_ident)));
346                    ty_params.push(tp_ident);
347                }
348                (ty_params, variants)
349            },
350        );
351        if let (Some(info), true) = (
352            found_exprs.iter().find(|info| info.reftype_ident.is_none()),
353            !type_emitted.is_empty(),
354        ) {
355            abort!(
356                &info.span,
357                r#"
358To emit full type, you should specify the type.
359Example: sumtype!(std::iter::empty(), std::iter::Empty<T>)
360"#
361            )
362        } else {
363            let replaced_ty_generics: Vec<_> = ty_generics
364                .iter()
365                .map(|ga| match ga {
366                    GenericArgument::Lifetime(lt) => quote!(& #lt ()),
367                    GenericArgument::Const(_) => quote!(),
368                    o => quote!(#o),
369                })
370                .collect();
371            let constraint_traits = (0..args.bounds.len())
372                .map(|n| {
373                    Ident::new(
374                        &format!("__Sumtype_ConstraintExprTrait_{}_{}", n, random()),
375                        Span::call_site(),
376                    )
377                })
378                .collect::<Vec<_>>();
379            let out = quote! {
380                #(for reft in &reftypes) {
381                    #[doc(hidden)]
382                    #[allow(non_camel_case_types)]
383                    #[allow(non_camel_case_types)]
384                    struct #reft;
385                }
386                #[doc(hidden)]
387                #[allow(non_camel_case_types)]
388                trait #typeref_ident <#(#impl_generics),*> { type Type; }
389                #[doc(hidden)]
390                #[allow(non_camel_case_types)]
391                #vis enum #enum_ident <
392                    #(#impl_generics),*
393                    #(if !impl_generics.is_empty() && !unspecified_ty_params.is_empty()) { , }
394                    #(#unspecified_ty_params),*
395                > {
396                    #(for (ident, ty) in &variants) {
397                        #ident ( #ty ),
398                    }
399                    __Uninhabited(
400                        (
401                            ::core::convert::Infallible,
402                            #(::core::marker::PhantomData<#replaced_ty_generics>),*
403                        )
404                    ),
405                }
406                #[doc(hidden)]
407                #[allow(non_camel_case_types)]
408                trait #constraint_expr_trait_ident<#(#impl_generics),*> {}
409                impl<#(#impl_generics,)*__Sumtype_TypeParam> #constraint_expr_trait_ident<#(#ty_generics),*> for __Sumtype_TypeParam
410                where
411                    #(for t in &constraint_traits) {
412                        __Sumtype_TypeParam: #t<#(#ty_generics),*>,
413                    }
414                    #(#where_clause,)*
415                {}
416                #(for (trait_, constraint_trait) in args.bounds.iter().zip(&constraint_traits)) {
417                    #{ SumTypeImpl::Trait(trait_.clone()).gen(
418                        &path_of_ident(enum_ident.clone(), false),
419                        unspecified_ty_params.as_slice(),
420                        variants.as_slice(),
421                        impl_generics.clone(),
422                        ty_generics.clone(),
423                        where_clause.clone(),
424                        constraint_trait,
425                    ) }
426                }
427            };
428            (out, self)
429        }
430    }
431}
432
433const _: () = {
434    use syn::visit_mut::VisitMut;
435    struct Visitor<'a> {
436        enum_path: &'a Path,
437        typeref_path: &'a Path,
438        constraint_expr_trait_path: &'a Path,
439        found_exprs: Vec<ExprMacroInfo>,
440        emit_type: Vec<TypeMacroInfo>,
441        generics: Option<&'a Generics>,
442        is_module: bool,
443    }
444
445    impl ProcessTree for Block {
446        fn collect_inline_macro(
447            &mut self,
448            enum_path: &Path,
449            typeref_path: &Path,
450            constraint_expr_trait_path: &Path,
451            generics: Option<&Generics>,
452            is_module: bool,
453        ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
454            let mut visitor = Visitor::new(
455                enum_path,
456                typeref_path,
457                constraint_expr_trait_path,
458                generics,
459                is_module,
460            );
461            visitor.visit_block_mut(self);
462            (visitor.found_exprs, visitor.emit_type)
463        }
464    }
465
466    impl ProcessTree for Item {
467        fn collect_inline_macro(
468            &mut self,
469            enum_path: &Path,
470            typeref_path: &Path,
471            constraint_expr_trait_path: &Path,
472            generics: Option<&Generics>,
473            is_module: bool,
474        ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
475            let mut visitor = Visitor::new(
476                enum_path,
477                typeref_path,
478                constraint_expr_trait_path,
479                generics,
480                is_module,
481            );
482            visitor.visit_item_mut(self);
483            (visitor.found_exprs, visitor.emit_type)
484        }
485    }
486
487    impl ProcessTree for Stmt {
488        fn collect_inline_macro(
489            &mut self,
490            enum_path: &Path,
491            typeref_path: &Path,
492            constraint_expr_trait_path: &Path,
493            generics: Option<&Generics>,
494            is_module: bool,
495        ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
496            let mut visitor = Visitor::new(
497                enum_path,
498                typeref_path,
499                constraint_expr_trait_path,
500                generics,
501                is_module,
502            );
503            visitor.visit_stmt_mut(self);
504            (visitor.found_exprs, visitor.emit_type)
505        }
506    }
507
508    impl<'a> Visitor<'a> {
509        fn new(
510            enum_path: &'a Path,
511            typeref_path: &'a Path,
512            constraint_expr_trait_path: &'a Path,
513            generics: Option<&'a Generics>,
514            is_module: bool,
515        ) -> Self {
516            Self {
517                enum_path,
518                typeref_path,
519                constraint_expr_trait_path,
520                found_exprs: Vec::new(),
521                emit_type: Vec::new(),
522                generics,
523                is_module,
524            }
525        }
526        fn do_type_macro(&mut self, mac: &Macro) -> TokenStream {
527            #[derive(Parse)]
528            struct Arg {
529                #[call(Punctuated::parse_terminated)]
530                generic_args: Punctuated<GenericArgument, Token![,]>,
531            }
532            let arg: Arg = mac
533                .parse_body()
534                .unwrap_or_else(|e| abort!(e.span(), &format!("{}", &e)));
535            let ty_generics = merge_generic_args(
536                self.generics
537                    .iter()
538                    .flat_map(|g| g.params.iter().cloned().map(generic_param_to_arg)),
539                arg.generic_args.clone(),
540            )
541            .collect::<Vec<_>>();
542            self.emit_type.push(TypeMacroInfo {
543                _span: mac.span(),
544                generic_args: arg.generic_args,
545            });
546            quote! {
547                #{&self.enum_path}
548                #(if !ty_generics.is_empty()){
549                    <#(#ty_generics),*>
550                }
551            }
552        }
553
554        fn analyze_lifetime_bounds(
555            &self,
556            generics: &Generics,
557            ty: &Type,
558        ) -> HashMap<Ident, HashSet<Lifetime>> {
559            struct LifetimeVisitor {
560                generic_lifetimes: HashSet<Lifetime>,
561                generic_params: HashSet<Ident>,
562                lifetime_stack: Vec<Lifetime>,
563                result: HashMap<Ident, HashSet<Lifetime>>,
564            }
565            use syn::visit::Visit;
566            impl syn::visit::Visit<'_> for LifetimeVisitor {
567                fn visit_type_reference(&mut self, i: &TypeReference) {
568                    if let Some(lt) = &i.lifetime {
569                        if self.generic_lifetimes.contains(lt) {
570                            self.lifetime_stack.push(lt.clone());
571                            syn::visit::visit_type_reference(self, i);
572                            self.lifetime_stack.pop();
573                            return;
574                        }
575                    }
576                    syn::visit::visit_type_reference(self, i);
577                }
578                fn visit_type_path(&mut self, i: &TypePath) {
579                    if i.qself.is_none() {
580                        if let Some(id) = i.path.get_ident() {
581                            if self.generic_params.contains(id) {
582                                self.result
583                                    .entry(id.clone())
584                                    .or_default()
585                                    .extend(self.lifetime_stack.clone());
586                            }
587                            return;
588                        }
589                    }
590                    syn::visit::visit_type_path(self, i);
591                }
592            }
593            let mut visitor = LifetimeVisitor {
594                generic_lifetimes: generics
595                    .params
596                    .iter()
597                    .filter_map(|p| {
598                        if let GenericParam::Lifetime(LifetimeParam { lifetime, .. }) = p {
599                            Some(lifetime.clone())
600                        } else {
601                            None
602                        }
603                    })
604                    .collect(),
605                generic_params: generics
606                    .params
607                    .iter()
608                    .filter_map(|p| {
609                        if let GenericParam::Type(TypeParam { ident, .. }) = p {
610                            Some(ident.clone())
611                        } else {
612                            None
613                        }
614                    })
615                    .collect(),
616                lifetime_stack: Vec::new(),
617                result: HashMap::new(),
618            };
619            visitor.visit_type(ty);
620            visitor.result
621        }
622
623        fn do_expr_macro(&mut self, mac: &Macro) -> TokenStream {
624            #[derive(Parse)]
625            struct Arg {
626                expr: Expr,
627                _comma_token: Option<Token![,]>,
628                _for_token: Option<Token![for]>,
629                #[prefix(Option<Token![<]>)]
630                #[postfix(Option<Token![>]>)]
631                #[parse_if(_for_token.is_some())]
632                #[call(Punctuated::parse_separated_nonempty)]
633                for_generics: Option<Punctuated<GenericParam, Token![,]>>,
634                #[parse_if(_comma_token.is_some())]
635                ty: Option<Type>,
636                #[parse_if(_comma_token.is_some())]
637                where_clause: Option<Option<WhereClause>>,
638            }
639            let arg: Arg = mac
640                .parse_body()
641                .unwrap_or_else(|e| abort!(e.span(), &format!("{}", &e)));
642            let n = self.found_exprs.len();
643            let variant_ident = Ident::new(&format!("__SumType_Variant_{}", n), Span::call_site());
644            let reftype_ident = Ident::new(
645                &format!("__SumType_RefType_{}_{}", random(), n),
646                Span::call_site(),
647            );
648            let reftype_path = path_of_ident(reftype_ident.clone(), self.is_module);
649            let id_fn_ident =
650                Ident::new(&format!("__sum_type_id_fn_{}", random()), Span::call_site());
651            let (mut impl_generics, _, where_clause) = split_for_impl(self.generics);
652            let analyzed =
653                if let (Some(generics), Some(ty)) = (self.generics.as_ref(), arg.ty.as_ref()) {
654                    self.analyze_lifetime_bounds(generics, ty)
655                } else {
656                    HashMap::new()
657                };
658            let generics = Generics {
659                params: arg.for_generics.clone().unwrap_or_default(),
660                where_clause: arg.where_clause.unwrap_or(Some(WhereClause {
661                    predicates: Punctuated::new(),
662                    where_token: Default::default(),
663                })),
664                ..Default::default()
665            };
666            for g in impl_generics.iter_mut() {
667                if let GenericParam::Type(TypeParam { ident, bounds, .. }) = g {
668                    if let Some(lts) = analyzed.get(ident) {
669                        for lt in lts {
670                            bounds.push(TypeParamBound::Lifetime(lt.clone().clone()));
671                        }
672                    }
673                }
674            }
675            let impl_generics =
676                merge_generic_params(impl_generics, generics.params.clone()).collect::<Vec<_>>();
677            let ty_generics = impl_generics
678                .iter()
679                .cloned()
680                .map(generic_param_to_arg)
681                .collect::<Vec<_>>();
682            let where_clause = generics
683                .where_clause
684                .clone()
685                .map(|wc| wc.predicates)
686                .into_iter()
687                .flatten()
688                .chain(where_clause)
689                .collect::<Vec<_>>();
690            self.found_exprs.push(ExprMacroInfo {
691                span: mac.span(),
692                variant_ident: variant_ident.clone(),
693                reftype_ident: arg.ty.as_ref().map(|_| reftype_ident.clone()),
694                analyzed_bounds: analyzed.clone(),
695                generics,
696            });
697            quote! {
698                {
699                    #(if let Some(ty) = &arg.ty){
700                        impl<#(#impl_generics,)*> #{&self.typeref_path} <#(#ty_generics),*> for #reftype_path
701                            #(if !where_clause.is_empty()) {
702                                where #(#where_clause,)*
703                            }
704                        {
705                            type Type = #ty;
706                        }
707                    }
708                    fn #id_fn_ident<
709                        #(#impl_generics,)* __SumType_T: #{&self.constraint_expr_trait_path}<#(#ty_generics),*>
710                    >(t: __SumType_T) -> __SumType_T
711                    #(if !where_clause.is_empty()) {
712                        where #(#where_clause,)*
713                    }
714                    { t }
715                    #id_fn_ident::<#(#ty_generics,)*_>(#{&self.enum_path}::#variant_ident(#{&arg.expr}))
716                }
717            }
718        }
719    }
720
721    impl VisitMut for Visitor<'_> {
722        fn visit_type_mut(&mut self, ty: &mut Type) {
723            if let Type::Macro(tm) = &*ty {
724                if tm.mac.path.is_ident("sumtype") {
725                    let out = self.do_type_macro(&tm.mac);
726                    *ty = parse2(out).unwrap();
727                    return;
728                }
729            }
730            syn::visit_mut::visit_type_mut(self, ty);
731        }
732
733        fn visit_expr_mut(&mut self, expr: &mut Expr) {
734            if let Expr::Macro(em) = &*expr {
735                if em.mac.path.is_ident("sumtype") {
736                    let out = self.do_expr_macro(&em.mac);
737                    *expr = parse2(out).unwrap();
738                    return;
739                }
740            }
741            syn::visit_mut::visit_expr_mut(self, expr);
742        }
743
744        fn visit_stmt_mut(&mut self, stmt: &mut Stmt) {
745            if let Stmt::Macro(sm) = &*stmt {
746                if sm.mac.path.is_ident("sumtype") {
747                    let out = self.do_expr_macro(&sm.mac);
748                    *stmt = parse2(out).unwrap();
749                    return;
750                }
751            }
752            syn::visit_mut::visit_stmt_mut(self, stmt);
753        }
754    }
755};
756
757fn inner(args: &Arguments, input: TokenStream) -> TokenStream {
758    let public = Visibility::Public(Default::default());
759    if let Ok(block) = parse2::<Block>(input.clone()) {
760        let (out, block) = block.emit_items(args, None, false, public);
761        quote! { #out #[allow(non_local_definitions)] #block }
762    } else if let Ok(item_trait) = parse2::<ItemTrait>(input.clone()) {
763        let generics = item_trait.generics.clone();
764        let vis = item_trait.vis.clone();
765        let (out, item) = Item::Trait(item_trait).emit_items(args, Some(&generics), false, vis);
766        quote! { #out #[allow(non_local_definitions)] #item }
767    } else if let Ok(item_impl) = parse2::<ItemImpl>(input.clone()) {
768        let generics = item_impl.generics.clone();
769        let (out, item) = Item::Impl(item_impl).emit_items(args, Some(&generics), false, public);
770        quote! { #out #[allow(non_local_definitions)] #item }
771    } else if let Ok(item_fn) = parse2::<ItemFn>(input.clone()) {
772        let generics = item_fn.sig.generics.clone();
773        let vis = item_fn.vis.clone();
774        let (out, item) = Item::Fn(item_fn).emit_items(args, Some(&generics), false, vis);
775        quote! { #out #[allow(non_local_definitions)] #item }
776    } else if let Ok(item_mod) = parse2::<ItemMod>(input.clone()) {
777        let (out, item) = Item::Mod(item_mod).emit_items(args, None, true, public);
778        quote! { #out #[allow(non_local_definitions)] #item }
779    } else if let Ok(item) = parse2::<Item>(input.clone()) {
780        let (out, item) = item.emit_items(args, None, false, public);
781        quote! { #out #[allow(non_local_definitions)] #item }
782    } else if let Ok(stmt) = parse2::<Stmt>(input.clone()) {
783        let (out, stmt) = stmt.emit_items(args, None, false, public);
784        quote! { #out #[allow(non_local_definitions)] #stmt }
785    } else {
786        abort!(input.span(), "This element is not supported")
787    }
788}
789
790fn process_supported_supertraits<'a>(
791    traits: impl IntoIterator<Item = &'a TypeParamBound>,
792    krate: &Path,
793) -> (Vec<Path>, Vec<Path>) {
794    let mut supertraits = Vec::new();
795    let mut derive_traits = Vec::new();
796    for tpb in traits.into_iter() {
797        if let TypeParamBound::Trait(tb) = tpb {
798            if let Some(ident) = tb.path.get_ident() {
799                match ident.to_string().as_str() {
800                    "Copy" | "Clone" | "Hash" | "Eq" => {
801                        supertraits.push(parse_quote!(#krate::traits::#ident))
802                    }
803                    "PartialEq" => derive_traits.push(parse_quote!(PartialEq)),
804                    o if o.starts_with("__SumTrait_Sealed") => (),
805                    _ => (),
806                }
807            } else {
808                supertraits.push(tb.path.clone())
809            }
810        } else {
811            abort!(tpb.span(), "Only path is supported");
812        }
813    }
814    (supertraits, derive_traits)
815}
816
817fn collect_typeref_types(input: &ItemTrait) -> Vec<Type> {
818    let mut leaker = Leaker::from_trait(input)
819        .unwrap_or_else(|_| Leaker::with_generics(input.generics.clone()));
820    leaker.self_ty_can_be_interned = false;
821    leaker
822        .finish()
823        .iter()
824        .cloned()
825        .collect()
826}
827
828fn sumtrait_impl(
829    args: Option<Path>,
830    marker_path: &Path,
831    krate: &Path,
832    input: ItemTrait,
833) -> TokenStream {
834    let (supertraits, derive_traits) = process_supported_supertraits(&input.supertraits, krate);
835    for item in &input.items {
836        match item {
837            TraitItem::Const(_) => abort!(item.span(), "associated const is not supported"),
838            TraitItem::Fn(tfn) => {
839                if tfn.sig.inputs.is_empty() || !matches!(&tfn.sig.inputs[0], FnArg::Receiver(_)) {
840                    abort!(tfn.sig.span(), "requires receiver")
841                }
842            }
843            TraitItem::Type(tty) => {
844                if tty.default.is_some() {
845                    abort!(tty.span(), "associated type defaults is not supported")
846                }
847                if !tty.generics.params.is_empty() || tty.generics.where_clause.is_some() {
848                    abort!(
849                        tty.generics.span(),
850                        "generalized associated types is not supported"
851                    )
852                }
853            }
854            o => abort!(o.span(), "Not supported"),
855        }
856    }
857    let temporary_mac_name =
858        Ident::new(&format!("__sumtype_macro_{}", random()), Span::call_site());
859    let typeref_types = collect_typeref_types(&input);
860    let (_, _, where_clause) = input.generics.split_for_impl();
861    let typeref_id = random() as usize;
862    quote! {
863        #input
864        #(for (i, ty) in typeref_types.iter().enumerate()) {
865            impl<#(for p in &input.generics.params),{#p}> #krate::TypeRef<#typeref_id, #i> for #marker_path #where_clause {
866                type Type = #ty;
867            }
868        }
869
870        #[doc(hidden)]
871        #[macro_export]
872        macro_rules! #temporary_mac_name {
873            ($($t:tt)*) => {
874                #krate::_sumtrait_internal!(
875                    { $($t)* }
876                    /* typerefs= */  [#(#typeref_types),*],
877                    /* item_trait= */  {#input},
878                    /* typeref_id= */  #typeref_id,
879                    /* krate= */  #krate,
880                    /* marker_path= */ #marker_path,
881                    /* implementation= */  [#{args.map(|m| quote!(#m)).unwrap_or(quote!(_))}],
882                    /* supertraits= */ [#(#supertraits),*],
883                    /* derive_traits= */ [#(#derive_traits),*],
884                );
885            };
886        }
887        #[doc(hidden)]
888        #{&input.vis} use #temporary_mac_name as #{&input.ident};
889    }
890}
891
892#[doc(hidden)]
893#[proc_macro_error]
894#[proc_macro]
895pub fn _sumtrait_internal(input: TokenStream1) -> TokenStream1 {
896    sumtrait_internal::sumtrait_internal(input.into()).into()
897}
898
899#[proc_macro_error]
900#[proc_macro_attribute]
901pub fn sumtrait(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
902    #[derive(FromMeta, Debug)]
903    struct SumtraitArgs {
904        implement: Option<Path>,
905        krate: Option<Path>,
906        marker: Path,
907    }
908    let args = SumtraitArgs::from_list(&NestedMeta::parse_meta_list(attr.into()).unwrap()).unwrap();
909
910    let krate = args.krate.unwrap_or(parse_quote!(::sumtype));
911    sumtrait_impl(
912        args.implement,
913        &args.marker,
914        &krate,
915        parse(input).unwrap_or_else(|_| abort!(Span::call_site(), "Requires trait definition")),
916    )
917    .into()
918}
919
920/// Enables `sumtype!(..)` macro in the context.
921///
922/// For each context marked by `#[sumtype]`, sumtype creates a union type of several
923/// [`std::iter::Iterator`] types. To intern an expression of `Iterator` into the union type, you
924/// can use `sumtype!([expr])` syntax. This is an example of returning a unified `Iterator`:
925///
926/// ```
927/// # use sumtype::sumtype;
928/// # use std::iter::Iterator;
929/// #[sumtype]
930/// fn return_iter(a: bool) -> impl Iterator<Item = ()> {
931///     if a {
932///         sumtype!(std::iter::once(()))
933///     } else {
934///         sumtype!(vec![()].into_iter())
935///     }
936/// }
937/// ```
938///
939/// This function returns [`std::iter::Once`] or [`std::vec::IntoIter`] depending on the `a` value. The
940/// `#[sumtype]` system creates an anonymous union type that is also [`std::iter::Iterator`], and wraps
941/// each `sumtype!(..)` expression with the union type. The mechanism is zero-cost when `a` is fixed
942/// at compile time.
943///
944/// You can specify the exact (non-anonymous) type using `sumtype!()` macro in type context. In this
945/// case, you should specify the type using `sumtype!([expr], [type])` format like:
946///
947/// ```
948/// # use sumtype::sumtype;
949/// # use std::iter::Iterator;
950/// #[sumtype]
951/// fn return_iter_explicit(a: bool) -> sumtype!() {
952///     if a {
953///         sumtype!(std::iter::once(()), std::iter::Once<()>)
954///     } else {
955///         sumtype!(vec![()].into_iter(), std::vec::IntoIter<()>)
956///     }
957/// }
958/// ```
959///
960/// You may need to use additional generic parameters in `[ty]`. The following example
961/// demonstrates this:
962///
963/// ```
964/// # use sumtype::sumtype;
965/// # use std::iter::Iterator;
966/// struct S;
967///
968/// #[sumtype]
969/// impl S {
970///     fn f<'a, T>(t: &'a T, count: usize) -> sumtype!['a, T] {
971///         if count == 0 {
972///             sumtype!(std::iter::empty(), for<'a, T: 'a> std::iter::Empty<&'a T>)
973///         } else {
974///             sumtype!(
975///                 std::iter::repeat(t).take(count),
976///                 for<'a, T: 'a> std::iter::Take<std::iter::Repeat<&'a T>>
977///             )
978///         }
979///     }
980/// }
981/// ```
982#[proc_macro_error]
983#[proc_macro_attribute]
984pub fn sumtype(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
985    inner(&parse_macro_input!(attr as Arguments), input.into()).into()
986}