1use darling::ast::NestedMeta;
2use darling::FromMeta;
3use derive_syn_parse::Parse;
4use proc_macro::TokenStream as TokenStream1;
5use proc_macro2::Span;
6use proc_macro2::TokenStream;
7use proc_macro_error::{abort, proc_macro_error};
8use std::collections::{HashMap, HashSet};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::*;
12use template_quote::{quote, ToTokens};
13
14mod sumtrait_internal;
15
16fn random() -> u64 {
17 use std::hash::{BuildHasher, Hasher};
18 std::collections::hash_map::RandomState::new()
19 .build_hasher()
20 .finish()
21}
22
23fn generic_param_to_arg(i: GenericParam) -> GenericArgument {
24 match i {
25 GenericParam::Lifetime(LifetimeParam { lifetime, .. }) => {
26 GenericArgument::Lifetime(lifetime)
27 }
28 GenericParam::Type(TypeParam { ident, .. }) => GenericArgument::Type(parse_quote!(#ident)),
29 GenericParam::Const(ConstParam { ident, .. }) => {
30 GenericArgument::Const(parse_quote!(#ident))
31 }
32 }
33}
34
35fn merge_generic_params(
36 args1: impl IntoIterator<Item = GenericParam, IntoIter: Clone>,
37 args2: impl IntoIterator<Item = GenericParam, IntoIter: Clone>,
38) -> impl Iterator<Item = GenericParam> {
39 let it1 = args1.into_iter();
40 let it2 = args2.into_iter();
41 it1.clone()
42 .filter(|arg| {
43 if let GenericParam::Lifetime(_) = arg {
44 true
45 } else {
46 false
47 }
48 })
49 .chain(it2.clone().filter(|arg| {
50 if let GenericParam::Lifetime(_) = arg {
51 true
52 } else {
53 false
54 }
55 }))
56 .chain(it1.clone().filter(|arg| {
57 if let GenericParam::Const(_) = arg {
58 true
59 } else {
60 false
61 }
62 }))
63 .chain(it2.clone().filter(|arg| {
64 if let GenericParam::Const(_) = arg {
65 true
66 } else {
67 false
68 }
69 }))
70 .chain(it1.clone().filter(|arg| {
71 if let GenericParam::Type(_) = arg {
72 true
73 } else {
74 false
75 }
76 }))
77 .chain(it2.clone().filter(|arg| {
78 if let GenericParam::Type(_) = arg {
79 true
80 } else {
81 false
82 }
83 }))
84}
85
86fn merge_generic_args(
87 args1: impl IntoIterator<Item = GenericArgument, IntoIter: Clone>,
88 args2: impl IntoIterator<Item = GenericArgument, IntoIter: Clone>,
89) -> impl Iterator<Item = GenericArgument> {
90 let it1 = args1.into_iter();
91 let it2 = args2.into_iter();
92 it1.clone()
93 .filter(|arg| {
94 if let GenericArgument::Lifetime(_) = arg {
95 true
96 } else {
97 false
98 }
99 })
100 .chain(it2.clone().filter(|arg| {
101 if let GenericArgument::Lifetime(_) = arg {
102 true
103 } else {
104 false
105 }
106 }))
107 .chain(it1.clone().filter(|arg| {
108 if let GenericArgument::Const(_) = arg {
109 true
110 } else {
111 false
112 }
113 }))
114 .chain(it2.clone().filter(|arg| {
115 if let GenericArgument::Const(_) = arg {
116 true
117 } else {
118 false
119 }
120 }))
121 .chain(it1.clone().filter(|arg| {
122 if let GenericArgument::Type(_) = arg {
123 true
124 } else {
125 false
126 }
127 }))
128 .chain(it2.clone().filter(|arg| {
129 if let GenericArgument::Type(_) = arg {
130 true
131 } else {
132 false
133 }
134 }))
135 .chain(it1.filter(|arg| match arg {
136 GenericArgument::AssocType(_)
137 | GenericArgument::AssocConst(_)
138 | GenericArgument::Constraint(_) => true,
139 _ => false,
140 }))
141 .chain(it2.filter(|arg| match arg {
142 GenericArgument::AssocType(_)
143 | GenericArgument::AssocConst(_)
144 | GenericArgument::Constraint(_) => true,
145 _ => false,
146 }))
147}
148
149fn path_of_ident(ident: Ident, is_super: bool) -> Path {
150 let mut segments = vec![];
151 if is_super {
152 segments.push(PathSegment {
153 ident: Ident::new("super", Span::call_site()),
154 arguments: PathArguments::None,
155 });
156 }
157 segments.push(PathSegment {
158 ident,
159 arguments: PathArguments::None,
160 });
161 Path {
162 leading_colon: None,
163 segments: segments.into_iter().collect(),
164 }
165}
166
167fn split_for_impl(
168 generics: Option<&Generics>,
169) -> (Vec<GenericParam>, Vec<GenericArgument>, Vec<WherePredicate>) {
170 if let Some(generics) = generics {
171 let (_, ty_generics, where_clause) = generics.split_for_impl();
172 let ty_generics: std::result::Result<AngleBracketedGenericArguments, _> =
173 parse2(ty_generics.into_token_stream());
174 (
175 generics.params.iter().cloned().collect(),
176 ty_generics
177 .map(|g| g.args.into_iter().collect())
178 .unwrap_or(vec![]),
179 where_clause
180 .map(|w| w.predicates.iter().cloned().collect())
181 .unwrap_or(vec![]),
182 )
183 } else {
184 (vec![], vec![], vec![])
185 }
186}
187
188#[derive(Parse)]
189struct Arguments {
190 #[call(Punctuated::parse_terminated)]
191 bounds: Punctuated<Path, Token![+]>,
192}
193
194enum SumTypeImpl {
195 Trait(Path),
196}
197
198impl SumTypeImpl {
199 fn gen(
200 &self,
201 enum_path: &Path,
202 unspecified_ty_params: &[Ident],
203 variants: &[(Ident, Type)],
204 impl_generics: Vec<GenericParam>,
205 ty_generics: Vec<GenericArgument>,
206 where_clause: Vec<WherePredicate>,
207 constraint_expr_trait_ident: &Ident,
208 ) -> TokenStream {
209 match self {
210 SumTypeImpl::Trait(trait_path) => {
211 quote! {
212 #trait_path!(
213 #constraint_expr_trait_ident,
214 #trait_path,
215 #enum_path,
216 [#(#unspecified_ty_params),*],
217 [#(for (id, ty) in variants),{#id:#ty}],
218 [ #(#impl_generics),* ],
219 [#(#ty_generics),*],
220 { #(#where_clause),* },
221 );
222 }
223 }
224 }
225 }
226}
227
228struct ExprMacroInfo {
229 span: Span,
230 variant_ident: Ident,
231 reftype_ident: Option<Ident>,
232 analyzed_bounds: HashMap<Ident, HashSet<Lifetime>>,
233 generics: Generics,
234}
235
236struct TypeMacroInfo {
237 _span: Span,
238 generic_args: Punctuated<GenericArgument, Token![,]>,
239}
240
241trait ProcessTree: Sized {
243 fn collect_inline_macro(
245 &mut self,
246 enum_path: &Path,
247 typeref_path: &Path,
248 constraint_expr_trait_path: &Path,
249 generics: Option<&Generics>,
250 is_module: bool,
251 ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>);
252
253 fn emit_items(
254 mut self,
255 args: &Arguments,
256 generics_env: Option<&Generics>,
257 is_module: bool,
258 vis: Visibility,
259 ) -> (TokenStream, Self) {
260 let r = random();
261 let enum_ident = Ident::new(&format!("__Sumtype_Enum_{}", r), Span::call_site());
262 let typeref_ident =
263 Ident::new(&format!("__Sumtype_TypeRef_Trait_{}", r), Span::call_site());
264 let constraint_expr_trait_ident = Ident::new(
265 &format!("__Sumtype_ConstraintExprTrait_{}", r),
266 Span::call_site(),
267 );
268 let (found_exprs, type_emitted) = self.collect_inline_macro(
269 &path_of_ident(enum_ident.clone(), is_module),
270 &path_of_ident(typeref_ident.clone(), is_module),
271 &path_of_ident(constraint_expr_trait_ident.clone(), is_module),
272 generics_env,
273 is_module,
274 );
275 let reftypes = found_exprs
276 .iter()
277 .filter_map(|info| info.reftype_ident.clone())
278 .collect::<Vec<_>>();
279 let (impl_generics_env, _, where_clause_env) = split_for_impl(generics_env);
280 if found_exprs.len() == 0 {
281 abort!(Span::call_site(), "Cannot find any sumtype!() in expr");
282 }
283 let expr_generics_list = found_exprs.iter().fold(HashMap::new(), |mut acc, info| {
284 *acc.entry(info.generics.clone()).or_insert(0usize) += 1;
285 acc
286 });
287 if expr_generics_list.len() != 1 {
288 let mut expr_gparams = expr_generics_list.into_iter().collect::<Vec<_>>();
289 expr_gparams.sort_by_key(|item| item.1);
290 abort!(expr_gparams[0].0.span(), "Generic argument mismatch");
291 }
292 let expr_generics = expr_generics_list.into_iter().next().unwrap().0;
293 let mut analyzed = found_exprs.iter().fold(HashMap::new(), |mut acc, info| {
294 for (id, lts) in &info.analyzed_bounds {
295 acc.entry(id.clone())
296 .or_insert(HashSet::new())
297 .extend(lts.iter().map(|lt| TypeParamBound::Lifetime(lt.clone())));
298 }
299 acc
300 });
301 if let Some(where_clause) = &expr_generics.where_clause {
302 for pred in &where_clause.predicates {
303 match pred {
304 WherePredicate::Type(PredicateType {
305 bounded_ty, bounds, ..
306 }) => {
307 if let Type::Path(path) = bounded_ty {
308 if path.qself.is_none() {
309 if let Some(id) = path.path.get_ident() {
310 analyzed
311 .entry(id.clone())
312 .or_insert(HashSet::new())
313 .extend(bounds.clone());
314 }
315 }
316 }
317 }
318 _ => (),
319 }
320 }
321 }
322 let expr_garg = expr_generics
323 .params
324 .iter()
325 .cloned()
326 .map(generic_param_to_arg)
327 .collect::<Vec<_>>();
328 for info in &type_emitted {
329 if info.generic_args.len() != expr_garg.len()
330 || !expr_garg
331 .iter()
332 .zip(&info.generic_args)
333 .all(|two| match two {
334 (GenericArgument::Lifetime(_), GenericArgument::Lifetime(_))
335 | (GenericArgument::Const(_), GenericArgument::Const(_))
336 | (GenericArgument::Type(_), GenericArgument::Type(_)) => true,
337 _ => false,
338 })
339 {
340 abort!(
341 info.generic_args.span(),
342 "The generic arguments are incompatible with generic params in expression."
343 )
344 }
345 }
346 let mut impl_generics =
347 merge_generic_params(impl_generics_env, expr_generics.params).collect::<Vec<_>>();
348 for g in impl_generics.iter_mut() {
349 if let GenericParam::Type(TypeParam { ident, bounds, .. }) = g {
350 if let Some(bs) = analyzed.get(ident) {
351 for b in bs {
352 bounds.push(b.clone());
353 }
354 }
355 }
356 }
357 let ty_generics = impl_generics
358 .iter()
359 .cloned()
360 .map(generic_param_to_arg)
361 .collect::<Vec<_>>();
362 let where_clause = expr_generics
363 .where_clause
364 .clone()
365 .map(|wc| wc.predicates)
366 .into_iter()
367 .flatten()
368 .chain(where_clause_env)
369 .collect::<Vec<_>>();
370 let (unspecified_ty_params, variants) = found_exprs.iter().enumerate().fold(
371 (vec![], vec![]),
372 |(mut ty_params, mut variants), (i, info)| {
373 if let Some(reft) = &info.reftype_ident {
374 variants.push((
375 info.variant_ident.clone(),
376 parse_quote!(<#reft as #typeref_ident<#(#ty_generics),*>>::Type),
377 ));
378 } else {
379 let tp_ident =
380 Ident::new(&format!("__Sumtype_TypeParam_{}", i), Span::call_site());
381 variants.push((info.variant_ident.clone(), parse_quote!(#tp_ident)));
382 ty_params.push(tp_ident);
383 }
384 (ty_params, variants)
385 },
386 );
387 if let (Some(info), true) = (
388 found_exprs
389 .iter()
390 .filter(|info| info.reftype_ident.is_none())
391 .next(),
392 type_emitted.len() > 0,
393 ) {
394 abort!(
395 &info.span,
396 r#"
397To emit full type, you should specify the type.
398Example: sumtype!(std::iter::empty(), std::iter::Empty<T>)
399"#
400 )
401 } else {
402 let replaced_ty_generics: Vec<_> = ty_generics
403 .iter()
404 .map(|ga| match ga {
405 GenericArgument::Lifetime(lt) => quote!(& #lt ()),
406 GenericArgument::Const(_) => quote!(),
407 o => quote!(#o),
408 })
409 .collect();
410 let constraint_traits = (0..args.bounds.len())
411 .map(|n| {
412 Ident::new(
413 &format!("__Sumtype_ConstraintExprTrait_{}_{}", n, random()),
414 Span::call_site(),
415 )
416 })
417 .collect::<Vec<_>>();
418 let out = quote! {
419 #(for reft in &reftypes) {
420 #[doc(hidden)]
421 struct #reft;
422 }
423 #[doc(hidden)]
424 trait #typeref_ident <#(#impl_generics),*> { type Type; }
425 #[doc(hidden)]
426 #vis enum #enum_ident <
427 #(#impl_generics),*
428 #(if impl_generics.len() > 0 && unspecified_ty_params.len() > 0) { , }
429 #(#unspecified_ty_params),*
430 > {
431 #(for (ident, ty) in &variants) {
432 #ident ( #ty ),
433 }
434 __Uninhabited(
435 (
436 ::core::convert::Infallible,
437 #(::core::marker::PhantomData<#replaced_ty_generics>),*
438 )
439 ),
440 }
441 #[doc(hidden)]
442 trait #constraint_expr_trait_ident<#(#impl_generics),*> {}
443 impl<#(#impl_generics,)*__Sumtype_TypeParam> #constraint_expr_trait_ident<#(#ty_generics),*> for __Sumtype_TypeParam
444 where
445 #(for t in &constraint_traits) {
446 __Sumtype_TypeParam: #t<#(#ty_generics),*>,
447 }
448 #(#where_clause,)*
449 {}
450 #(for (trait_, constraint_trait) in args.bounds.iter().zip(&constraint_traits)) {
451 #{ SumTypeImpl::Trait(trait_.clone()).gen(
452 &path_of_ident(enum_ident.clone(), false),
453 unspecified_ty_params.as_slice(),
454 variants.as_slice(),
455 impl_generics.clone(),
456 ty_generics.clone(),
457 where_clause.clone(),
458 constraint_trait,
459 ) }
460 }
461 };
462 (out, self)
463 }
464 }
465}
466
467const _: () = {
468 use syn::visit_mut::VisitMut;
469 struct Visitor<'a> {
470 enum_path: &'a Path,
471 typeref_path: &'a Path,
472 constraint_expr_trait_path: &'a Path,
473 found_exprs: Vec<ExprMacroInfo>,
474 emit_type: Vec<TypeMacroInfo>,
475 generics: Option<&'a Generics>,
476 is_module: bool,
477 }
478
479 impl ProcessTree for Block {
480 fn collect_inline_macro(
481 &mut self,
482 enum_path: &Path,
483 typeref_path: &Path,
484 constraint_expr_trait_path: &Path,
485 generics: Option<&Generics>,
486 is_module: bool,
487 ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
488 let mut visitor = Visitor::new(
489 enum_path,
490 typeref_path,
491 constraint_expr_trait_path,
492 generics,
493 is_module,
494 );
495 visitor.visit_block_mut(self);
496 (visitor.found_exprs, visitor.emit_type)
497 }
498 }
499
500 impl ProcessTree for Item {
501 fn collect_inline_macro(
502 &mut self,
503 enum_path: &Path,
504 typeref_path: &Path,
505 constraint_expr_trait_path: &Path,
506 generics: Option<&Generics>,
507 is_module: bool,
508 ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
509 let mut visitor = Visitor::new(
510 enum_path,
511 typeref_path,
512 constraint_expr_trait_path,
513 generics,
514 is_module,
515 );
516 visitor.visit_item_mut(self);
517 (visitor.found_exprs, visitor.emit_type)
518 }
519 }
520
521 impl ProcessTree for Stmt {
522 fn collect_inline_macro(
523 &mut self,
524 enum_path: &Path,
525 typeref_path: &Path,
526 constraint_expr_trait_path: &Path,
527 generics: Option<&Generics>,
528 is_module: bool,
529 ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
530 let mut visitor = Visitor::new(
531 enum_path,
532 typeref_path,
533 constraint_expr_trait_path,
534 generics,
535 is_module,
536 );
537 visitor.visit_stmt_mut(self);
538 (visitor.found_exprs, visitor.emit_type)
539 }
540 }
541
542 impl<'a> Visitor<'a> {
543 fn new(
544 enum_path: &'a Path,
545 typeref_path: &'a Path,
546 constraint_expr_trait_path: &'a Path,
547 generics: Option<&'a Generics>,
548 is_module: bool,
549 ) -> Self {
550 Self {
551 enum_path,
552 typeref_path,
553 constraint_expr_trait_path,
554 found_exprs: Vec::new(),
555 emit_type: Vec::new(),
556 generics,
557 is_module,
558 }
559 }
560 fn do_type_macro(&mut self, mac: &Macro) -> TokenStream {
561 #[derive(Parse)]
562 struct Arg {
563 #[call(Punctuated::parse_terminated)]
564 generic_args: Punctuated<GenericArgument, Token![,]>,
565 }
566 let arg: Arg = mac
567 .parse_body()
568 .unwrap_or_else(|e| abort!(e.span(), &format!("{}", &e)));
569 let ty_generics = merge_generic_args(
570 self.generics
571 .iter()
572 .map(|g| g.params.iter().cloned().map(generic_param_to_arg))
573 .flatten(),
574 arg.generic_args.clone(),
575 )
576 .collect::<Vec<_>>();
577 self.emit_type.push(TypeMacroInfo {
578 _span: mac.span(),
579 generic_args: arg.generic_args,
580 });
581 quote! {
582 #{&self.enum_path}
583 #(if ty_generics.len() > 0){
584 <#(#ty_generics),*>
585 }
586 }
587 }
588
589 fn analyze_lifetime_bounds(
590 &self,
591 generics: &Generics,
592 ty: &Type,
593 ) -> HashMap<Ident, HashSet<Lifetime>> {
594 struct LifetimeVisitor {
595 generic_lifetimes: HashSet<Lifetime>,
596 generic_params: HashSet<Ident>,
597 lifetime_stack: Vec<Lifetime>,
598 result: HashMap<Ident, HashSet<Lifetime>>,
599 }
600 use syn::visit::Visit;
601 impl<'ast> syn::visit::Visit<'ast> for LifetimeVisitor {
602 fn visit_type_reference(&mut self, i: &TypeReference) {
603 if let Some(lt) = &i.lifetime {
604 if self.generic_lifetimes.contains(lt) {
605 self.lifetime_stack.push(lt.clone());
606 syn::visit::visit_type_reference(self, i);
607 self.lifetime_stack.pop();
608 return;
609 }
610 }
611 syn::visit::visit_type_reference(self, i);
612 }
613 fn visit_type_path(&mut self, i: &TypePath) {
614 if i.qself.is_none() {
615 if let Some(id) = i.path.get_ident() {
616 if self.generic_params.contains(id) {
617 self.result
618 .entry(id.clone())
619 .or_insert(HashSet::new())
620 .extend(self.lifetime_stack.clone());
621 }
622 return;
623 }
624 }
625 syn::visit::visit_type_path(self, i);
626 }
627 }
628 let mut visitor = LifetimeVisitor {
629 generic_lifetimes: generics
630 .params
631 .iter()
632 .filter_map(|p| {
633 if let GenericParam::Lifetime(LifetimeParam { lifetime, .. }) = p {
634 Some(lifetime.clone())
635 } else {
636 None
637 }
638 })
639 .collect(),
640 generic_params: generics
641 .params
642 .iter()
643 .filter_map(|p| {
644 if let GenericParam::Type(TypeParam { ident, .. }) = p {
645 Some(ident.clone())
646 } else {
647 None
648 }
649 })
650 .collect(),
651 lifetime_stack: Vec::new(),
652 result: HashMap::new(),
653 };
654 visitor.visit_type(ty);
655 visitor.result
656 }
657
658 fn do_expr_macro(&mut self, mac: &Macro) -> TokenStream {
659 #[derive(Parse)]
660 struct Arg {
661 expr: Expr,
662 _comma_token: Option<Token![,]>,
663 _for_token: Option<Token![for]>,
664 #[prefix(Option<Token![<]>)]
665 #[postfix(Option<Token![>]>)]
666 #[parse_if(_for_token.is_some())]
667 #[call(Punctuated::parse_separated_nonempty)]
668 for_generics: Option<Punctuated<GenericParam, Token![,]>>,
669 #[parse_if(_comma_token.is_some())]
670 ty: Option<Type>,
671 #[parse_if(_comma_token.is_some())]
672 where_clause: Option<Option<WhereClause>>,
673 }
674 let arg: Arg = mac
675 .parse_body()
676 .unwrap_or_else(|e| abort!(e.span(), &format!("{}", &e)));
677 let n = self.found_exprs.len();
678 let variant_ident = Ident::new(&format!("__SumType_Variant_{}", n), Span::call_site());
679 let reftype_ident = Ident::new(
680 &format!("__SumType_RefType_{}_{}", random(), n),
681 Span::call_site(),
682 );
683 let reftype_path = path_of_ident(reftype_ident.clone(), self.is_module);
684 let id_fn_ident =
685 Ident::new(&format!("__sum_type_id_fn_{}", random()), Span::call_site());
686 let (mut impl_generics, _, where_clause) = split_for_impl(self.generics);
687 let analyzed =
688 if let (Some(generics), Some(ty)) = (self.generics.as_ref(), arg.ty.as_ref()) {
689 self.analyze_lifetime_bounds(*generics, ty)
690 } else {
691 HashMap::new()
692 };
693 let generics = Generics {
694 params: arg.for_generics.clone().unwrap_or(Default::default()),
695 where_clause: arg.where_clause.unwrap_or(Some(WhereClause {
696 predicates: Punctuated::new(),
697 where_token: Default::default(),
698 })),
699 ..Default::default()
700 };
701 for g in impl_generics.iter_mut() {
702 if let GenericParam::Type(TypeParam { ident, bounds, .. }) = g {
703 if let Some(lts) = analyzed.get(ident) {
704 for lt in lts {
705 bounds.push(TypeParamBound::Lifetime(lt.clone().clone()));
706 }
707 }
708 }
709 }
710 let impl_generics =
711 merge_generic_params(impl_generics, generics.params.clone()).collect::<Vec<_>>();
712 let ty_generics = impl_generics
713 .iter()
714 .cloned()
715 .map(generic_param_to_arg)
716 .collect::<Vec<_>>();
717 let where_clause = generics
718 .where_clause
719 .clone()
720 .map(|wc| wc.predicates)
721 .into_iter()
722 .flatten()
723 .chain(where_clause)
724 .collect::<Vec<_>>();
725 self.found_exprs.push(ExprMacroInfo {
726 span: mac.span(),
727 variant_ident: variant_ident.clone(),
728 reftype_ident: arg.ty.as_ref().map(|_| reftype_ident.clone()),
729 analyzed_bounds: analyzed.clone(),
730 generics,
731 });
732 quote! {
733 {
734 #(if let Some(ty) = &arg.ty){
735 impl<#(#impl_generics,)*> #{&self.typeref_path} <#(#ty_generics),*> for #reftype_path
736 #(if where_clause.len() > 0) {
737 where #(#where_clause,)*
738 }
739 {
740 type Type = #ty;
741 }
742 }
743 fn #id_fn_ident<
744 #(#impl_generics,)* __SumType_T: #{&self.constraint_expr_trait_path}<#(#ty_generics),*>
745 >(t: __SumType_T) -> __SumType_T
746 #(if where_clause.len() > 0) {
747 where #(#where_clause,)*
748 }
749 { t }
750 #id_fn_ident::<#(#ty_generics,)*_>(#{&self.enum_path}::#variant_ident(#{&arg.expr}))
751 }
752 }
753 }
754 }
755
756 impl<'a> VisitMut for Visitor<'a> {
757 fn visit_type_mut(&mut self, ty: &mut Type) {
758 if let Type::Macro(tm) = &*ty {
759 if tm.mac.path.is_ident("sumtype") {
760 let out = self.do_type_macro(&tm.mac);
761 *ty = parse2(out).unwrap();
762 return;
763 }
764 }
765 syn::visit_mut::visit_type_mut(self, ty);
766 }
767
768 fn visit_expr_mut(&mut self, expr: &mut Expr) {
769 if let Expr::Macro(em) = &*expr {
770 if em.mac.path.is_ident("sumtype") {
771 let out = self.do_expr_macro(&em.mac);
772 *expr = parse2(out).unwrap();
773 return;
774 }
775 }
776 syn::visit_mut::visit_expr_mut(self, expr);
777 }
778
779 fn visit_stmt_mut(&mut self, stmt: &mut Stmt) {
780 if let Stmt::Macro(sm) = &*stmt {
781 if sm.mac.path.is_ident("sumtype") {
782 let out = self.do_expr_macro(&sm.mac);
783 *stmt = parse2(out).unwrap();
784 return;
785 }
786 }
787 syn::visit_mut::visit_stmt_mut(self, stmt);
788 }
789 }
790};
791
792fn inner(args: &Arguments, input: TokenStream) -> TokenStream {
793 let public = Visibility::Public(Default::default());
794 if let Ok(block) = parse2::<Block>(input.clone()) {
795 let (out, block) = block.emit_items(args, None, false, public);
796 quote! { #out #[allow(non_local_definitions)] #block }
797 } else if let Ok(item_trait) = parse2::<ItemTrait>(input.clone()) {
798 let generics = item_trait.generics.clone();
799 let vis = item_trait.vis.clone();
800 let (out, item) = Item::Trait(item_trait).emit_items(args, Some(&generics), false, vis);
801 quote! { #out #[allow(non_local_definitions)] #item }
802 } else if let Ok(item_impl) = parse2::<ItemImpl>(input.clone()) {
803 let generics = item_impl.generics.clone();
804 let (out, item) = Item::Impl(item_impl).emit_items(args, Some(&generics), false, public);
805 quote! { #out #[allow(non_local_definitions)] #item }
806 } else if let Ok(item_fn) = parse2::<ItemFn>(input.clone()) {
807 let generics = item_fn.sig.generics.clone();
808 let vis = item_fn.vis.clone();
809 let (out, item) = Item::Fn(item_fn).emit_items(args, Some(&generics), false, vis);
810 quote! { #out #[allow(non_local_definitions)] #item }
811 } else if let Ok(item_mod) = parse2::<ItemMod>(input.clone()) {
812 let (out, item) = Item::Mod(item_mod).emit_items(args, None, true, public);
813 quote! { #out #[allow(non_local_definitions)] #item }
814 } else if let Ok(item) = parse2::<Item>(input.clone()) {
815 let (out, item) = item.emit_items(args, None, false, public);
816 quote! { #out #[allow(non_local_definitions)] #item }
817 } else if let Ok(stmt) = parse2::<Stmt>(input.clone()) {
818 let (out, stmt) = stmt.emit_items(args, None, false, public);
819 quote! { #out #[allow(non_local_definitions)] #stmt }
820 } else {
821 abort!(input.span(), "This element is not supported")
822 }
823}
824
825fn process_supported_supertraits<'a>(
826 traits: impl IntoIterator<Item = &'a TypeParamBound>,
827 krate: &Path,
828) -> (Vec<Path>, Vec<Path>) {
829 let mut supertraits = Vec::new();
830 let mut derive_traits = Vec::new();
831 for tpb in traits.into_iter() {
832 if let TypeParamBound::Trait(tb) = tpb {
833 if let Some(ident) = tb.path.get_ident() {
834 match ident.to_string().as_str() {
835 "Copy" | "Clone" | "Hash" | "Eq" => {
836 supertraits.push(parse_quote!(#krate::traits::#ident))
837 }
838 "PartialEq" => derive_traits.push(parse_quote!(PartialEq)),
839 o if o.starts_with("__SumTrait_Sealed") => (),
840 _ => (),
841 }
842 } else {
843 supertraits.push(tb.path.clone())
844 }
845 } else {
846 abort!(tpb.span(), "Only path is supported");
847 }
848 }
849 (supertraits, derive_traits)
850}
851
852fn collect_typeref_types(input: &ItemTrait) -> Vec<Type> {
853 fn can_make_typeref_type(ty: &Type, generics: &Generics) -> bool {
854 use syn::visit::Visit;
855 struct Visitor(bool, Vec<Ident>, Vec<Ident>);
856 let generics_param_tys = generics
857 .params
858 .iter()
859 .filter_map(|p| {
860 if let GenericParam::Type(TypeParam { ident, .. }) = p {
861 Some(ident.clone())
862 } else {
863 None
864 }
865 })
866 .collect::<Vec<_>>();
867 let generics_param_vals = generics
868 .params
869 .iter()
870 .filter_map(|p| {
871 if let GenericParam::Const(ConstParam { ident, .. }) = p {
872 Some(ident.clone())
873 } else {
874 None
875 }
876 })
877 .collect::<Vec<_>>();
878 impl<'a> syn::visit::Visit<'a> for Visitor {
879 fn visit_type(&mut self, i: &Type) {
880 match i {
881 Type::ImplTrait(_) | Type::Verbatim(_) | Type::Infer(_) | Type::Macro(_) => {
882 self.0 = false;
883 }
884 Type::Reference(TypeReference { lifetime, .. }) => {
885 if let Some(lifetime) = lifetime {
886 if &lifetime.ident != "static" {
887 self.0 = false;
888 }
889 } else {
890 self.0 = false;
891 }
892 }
893 Type::Path(tp) => {
894 if tp.qself.is_none() && self.2.iter().any(|ident| tp.path.is_ident(ident))
895 || (tp.path.segments.len() >= 1 && &tp.path.segments[0].ident == "Self")
896 {
897 self.0 = false;
898 }
899 }
900 _ => (),
901 }
902 syn::visit::visit_type(self, i)
903 }
904 fn visit_expr(&mut self, i: &Expr) {
905 if let Expr::Path(ExprPath { qself, path, .. }) = i {
906 if qself.is_none() && self.1.iter().any(|ident| path.is_ident(ident)) {
907 self.0 = false;
908 }
909 }
910 }
911 }
912 let mut visitor = Visitor(true, generics_param_tys, generics_param_vals);
913 visitor.visit_type(ty);
914 visitor.0
915 }
916 use syn::visit::Visit;
917 struct Visitor(Vec<Type>, Generics);
918 impl<'a> syn::visit::Visit<'a> for Visitor {
919 fn visit_type(&mut self, i: &Type) {
920 if can_make_typeref_type(i, &self.1) {
921 self.0.push(i.clone());
922 } else {
923 syn::visit::visit_type(self, i)
924 }
925 }
926 }
927 let mut visitor = Visitor(Vec::new(), input.generics.clone());
928 visitor.visit_item_trait(input);
929 visitor.0
930}
931
932fn sumtrait_impl(
933 args: Option<Path>,
934 marker_path: &Path,
935 krate: &Path,
936 input: ItemTrait,
937) -> TokenStream {
938 let (supertraits, derive_traits) = process_supported_supertraits(&input.supertraits, krate);
939 for item in &input.items {
940 match item {
941 TraitItem::Const(_) => abort!(item.span(), "associated const is not supported"),
942 TraitItem::Fn(tfn) => {
943 if tfn.sig.inputs.len() == 0 || !matches!(&tfn.sig.inputs[0], FnArg::Receiver(_)) {
944 abort!(tfn.sig.span(), "requires receiver")
945 }
946 }
947 TraitItem::Type(tty) => {
948 if tty.default.is_some() {
949 abort!(tty.span(), "associated type defaults is not supported")
950 }
951 if tty.generics.params.len() > 0 || tty.generics.where_clause.is_some() {
952 abort!(
953 tty.generics.span(),
954 "generalized associated types is not supported"
955 )
956 }
957 }
958 o => abort!(o.span(), "Not supported"),
959 }
960 }
961 let temporary_mac_name =
962 Ident::new(&format!("__sumtype_macro_{}", random()), Span::call_site());
963 let typeref_types = collect_typeref_types(&input);
964 let (_, _, where_clause) = input.generics.split_for_impl();
965 let typeref_id = random() as usize;
966 quote! {
967 #input
968 #(for (i, ty) in typeref_types.iter().enumerate()) {
969 impl<#(for p in &input.generics.params),{#p}> #krate::TypeRef<#typeref_id, #i> for #marker_path #where_clause {
970 type Type = #ty;
971 }
972 }
973
974 #[doc(hidden)]
975 #[macro_export]
976 macro_rules! #temporary_mac_name {
977 ($($t:tt)*) => {
978 #krate::_sumtrait_internal!(
979 { $($t)* }
980 [#(#typeref_types),*],
981 {#input},
982 #typeref_id,
983 #krate,
984 #marker_path,
985 [#{args.map(|m| quote!(#m)).unwrap_or(quote!(_))}],
986 [#(#supertraits),*],
987 [#(#derive_traits),*],
988 );
989 };
990 }
991 #[doc(hidden)]
992 #{&input.vis} use #temporary_mac_name as #{&input.ident};
993 }
994}
995
996#[doc(hidden)]
997#[proc_macro_error]
998#[proc_macro]
999#[proc_debug::proc_debug]
1000pub fn _sumtrait_internal(input: TokenStream1) -> TokenStream1 {
1001 sumtrait_internal::sumtrait_internal(input.into()).into()
1002}
1003
1004#[proc_macro_error]
1005#[proc_macro_attribute]
1006#[proc_debug::proc_debug]
1007pub fn sumtrait(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
1008 #[derive(FromMeta, Debug)]
1009 struct SumtraitArgs {
1010 implement: Option<Path>,
1011 krate: Option<Path>,
1012 marker: Path,
1013 }
1014 let args = SumtraitArgs::from_list(&NestedMeta::parse_meta_list(attr.into()).unwrap()).unwrap();
1015
1016 let krate = args.krate.unwrap_or(parse_quote!(::sumtype));
1017 sumtrait_impl(
1018 args.implement,
1019 &args.marker,
1020 &krate,
1021 parse(input).unwrap_or_else(|_| abort!(Span::call_site(), "Requires trait definition")),
1022 )
1023 .into()
1024}
1025
1026#[proc_macro_error]
1089#[proc_macro_attribute]
1090#[proc_debug::proc_debug]
1091pub fn sumtype(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
1092 inner(&parse_macro_input!(attr as Arguments), input.into()).into()
1093}