1use darling::ast::NestedMeta;
2use darling::FromMeta;
3use derive_syn_parse::Parse;
4use proc_macro::TokenStream as TokenStream1;
5use proc_macro2::Span;
6use proc_macro2::TokenStream;
7use proc_macro_error::{abort, proc_macro_error};
8use std::collections::{HashMap, HashSet};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::*;
12use template_quote::{quote, ToTokens};
13use type_leak::*;
14
15mod sumtrait_internal;
16
17fn random() -> u64 {
18 use std::hash::{BuildHasher, Hasher};
19 std::collections::hash_map::RandomState::new()
20 .build_hasher()
21 .finish()
22}
23
24fn generic_param_to_arg(i: GenericParam) -> GenericArgument {
25 match i {
26 GenericParam::Lifetime(LifetimeParam { lifetime, .. }) => {
27 GenericArgument::Lifetime(lifetime)
28 }
29 GenericParam::Type(TypeParam { ident, .. }) => GenericArgument::Type(parse_quote!(#ident)),
30 GenericParam::Const(ConstParam { ident, .. }) => {
31 GenericArgument::Const(parse_quote!(#ident))
32 }
33 }
34}
35
36fn merge_generic_params(
37 args1: impl IntoIterator<Item = GenericParam, IntoIter: Clone>,
38 args2: impl IntoIterator<Item = GenericParam, IntoIter: Clone>,
39) -> impl Iterator<Item = GenericParam> {
40 let it1 = args1.into_iter();
41 let it2 = args2.into_iter();
42 it1.clone()
43 .filter(|arg| matches!(arg, GenericParam::Lifetime(_)))
44 .chain(
45 it2.clone()
46 .filter(|arg| matches!(arg, GenericParam::Lifetime(_))),
47 )
48 .chain(
49 it1.clone()
50 .filter(|arg| matches!(arg, GenericParam::Const(_))),
51 )
52 .chain(
53 it2.clone()
54 .filter(|arg| matches!(arg, GenericParam::Const(_))),
55 )
56 .chain(
57 it1.clone()
58 .filter(|arg| matches!(arg, GenericParam::Type(_))),
59 )
60 .chain(
61 it2.clone()
62 .filter(|arg| matches!(arg, GenericParam::Type(_))),
63 )
64}
65
66fn merge_generic_args(
67 args1: impl IntoIterator<Item = GenericArgument, IntoIter: Clone>,
68 args2: impl IntoIterator<Item = GenericArgument, IntoIter: Clone>,
69) -> impl Iterator<Item = GenericArgument> {
70 let it1 = args1.into_iter();
71 let it2 = args2.into_iter();
72 it1.clone()
73 .filter(|arg| matches!(arg, GenericArgument::Lifetime(_)))
74 .chain(
75 it2.clone()
76 .filter(|arg| matches!(arg, GenericArgument::Lifetime(_))),
77 )
78 .chain(
79 it1.clone()
80 .filter(|arg| matches!(arg, GenericArgument::Const(_))),
81 )
82 .chain(
83 it2.clone()
84 .filter(|arg| matches!(arg, GenericArgument::Const(_))),
85 )
86 .chain(
87 it1.clone()
88 .filter(|arg| matches!(arg, GenericArgument::Type(_))),
89 )
90 .chain(
91 it2.clone()
92 .filter(|arg| matches!(arg, GenericArgument::Type(_))),
93 )
94 .chain(it1.filter(|arg| {
95 matches!(
96 arg,
97 GenericArgument::AssocType(_)
98 | GenericArgument::AssocConst(_)
99 | GenericArgument::Constraint(_)
100 )
101 }))
102 .chain(it2.filter(|arg| {
103 matches!(
104 arg,
105 GenericArgument::AssocType(_)
106 | GenericArgument::AssocConst(_)
107 | GenericArgument::Constraint(_)
108 )
109 }))
110}
111
112fn path_of_ident(ident: Ident, is_super: bool) -> Path {
113 let mut segments = vec![];
114 if is_super {
115 segments.push(PathSegment {
116 ident: Ident::new("super", Span::call_site()),
117 arguments: PathArguments::None,
118 });
119 }
120 segments.push(PathSegment {
121 ident,
122 arguments: PathArguments::None,
123 });
124 Path {
125 leading_colon: None,
126 segments: segments.into_iter().collect(),
127 }
128}
129
130fn split_for_impl(
131 generics: Option<&Generics>,
132) -> (Vec<GenericParam>, Vec<GenericArgument>, Vec<WherePredicate>) {
133 if let Some(generics) = generics {
134 let (_, ty_generics, where_clause) = generics.split_for_impl();
135 let ty_generics: std::result::Result<AngleBracketedGenericArguments, _> =
136 parse2(ty_generics.into_token_stream());
137 (
138 generics.params.iter().cloned().collect(),
139 ty_generics
140 .map(|g| g.args.into_iter().collect())
141 .unwrap_or(vec![]),
142 where_clause
143 .map(|w| w.predicates.iter().cloned().collect())
144 .unwrap_or(vec![]),
145 )
146 } else {
147 (vec![], vec![], vec![])
148 }
149}
150
151#[derive(Parse)]
152struct Arguments {
153 #[call(Punctuated::parse_terminated)]
154 bounds: Punctuated<Path, Token![+]>,
155}
156
157enum SumTypeImpl {
158 Trait(Path),
159}
160
161impl SumTypeImpl {
162 #[allow(clippy::too_many_arguments)]
163 fn gen(
164 &self,
165 enum_path: &Path,
166 unspecified_ty_params: &[Ident],
167 variants: &[(Ident, Type)],
168 impl_generics: Vec<GenericParam>,
169 ty_generics: Vec<GenericArgument>,
170 where_clause: Vec<WherePredicate>,
171 constraint_expr_trait_ident: &Ident,
172 ) -> TokenStream {
173 match self {
174 SumTypeImpl::Trait(trait_path) => {
175 quote! {
176 #trait_path!(
177 #constraint_expr_trait_ident,
178 #trait_path,
179 #enum_path,
180 [#(#unspecified_ty_params),*],
181 [#(for (id, ty) in variants),{#id:#ty}],
182 [ #(#impl_generics),* ],
183 [#(#ty_generics),*],
184 { #(#where_clause),* },
185 );
186 }
187 }
188 }
189 }
190}
191
192struct ExprMacroInfo {
193 span: Span,
194 variant_ident: Ident,
195 reftype_ident: Option<Ident>,
196 analyzed_bounds: HashMap<Ident, HashSet<Lifetime>>,
197 generics: Generics,
198}
199
200struct TypeMacroInfo {
201 _span: Span,
202 generic_args: Punctuated<GenericArgument, Token![,]>,
203}
204
205trait ProcessTree: Sized {
207 fn collect_inline_macro(
209 &mut self,
210 enum_path: &Path,
211 typeref_path: &Path,
212 constraint_expr_trait_path: &Path,
213 generics: Option<&Generics>,
214 is_module: bool,
215 ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>);
216
217 fn emit_items(
218 mut self,
219 args: &Arguments,
220 generics_env: Option<&Generics>,
221 is_module: bool,
222 vis: Visibility,
223 ) -> (TokenStream, Self) {
224 let r = random();
225 let enum_ident = Ident::new(&format!("__Sumtype_Enum_{}", r), Span::call_site());
226 let typeref_ident =
227 Ident::new(&format!("__Sumtype_TypeRef_Trait_{}", r), Span::call_site());
228 let constraint_expr_trait_ident = Ident::new(
229 &format!("__Sumtype_ConstraintExprTrait_{}", r),
230 Span::call_site(),
231 );
232 let (found_exprs, type_emitted) = self.collect_inline_macro(
233 &path_of_ident(enum_ident.clone(), is_module),
234 &path_of_ident(typeref_ident.clone(), is_module),
235 &path_of_ident(constraint_expr_trait_ident.clone(), is_module),
236 generics_env,
237 is_module,
238 );
239 let reftypes = found_exprs
240 .iter()
241 .filter_map(|info| info.reftype_ident.clone())
242 .collect::<Vec<_>>();
243 let (impl_generics_env, _, where_clause_env) = split_for_impl(generics_env);
244 if found_exprs.is_empty() {
245 abort!(Span::call_site(), "Cannot find any sumtype!() in expr");
246 }
247 let expr_generics_list = found_exprs.iter().fold(HashMap::new(), |mut acc, info| {
248 *acc.entry(info.generics.clone()).or_insert(0usize) += 1;
249 acc
250 });
251 if expr_generics_list.len() != 1 {
252 let mut expr_gparams = expr_generics_list.into_iter().collect::<Vec<_>>();
253 expr_gparams.sort_by_key(|item| item.1);
254 abort!(expr_gparams[0].0.span(), "Generic argument mismatch");
255 }
256 let expr_generics = expr_generics_list.into_iter().next().unwrap().0;
257 let mut analyzed = found_exprs.iter().fold(
258 HashMap::new(),
259 |mut acc: HashMap<Ident, HashSet<TypeParamBound>>, info| {
260 for (id, lts) in &info.analyzed_bounds {
261 acc.entry(id.clone())
262 .or_default()
263 .extend(lts.iter().map(|lt| TypeParamBound::Lifetime(lt.clone())));
264 }
265 acc
266 },
267 );
268 if let Some(where_clause) = &expr_generics.where_clause {
269 for pred in &where_clause.predicates {
270 if let WherePredicate::Type(PredicateType {
271 bounded_ty: Type::Path(path),
272 bounds,
273 ..
274 }) = pred
275 {
276 if path.qself.is_none() {
277 if let Some(id) = path.path.get_ident() {
278 analyzed
279 .entry(id.clone())
280 .or_insert(HashSet::new())
281 .extend(bounds.clone());
282 }
283 }
284 }
285 }
286 }
287 let expr_garg = expr_generics
288 .params
289 .iter()
290 .cloned()
291 .map(generic_param_to_arg)
292 .collect::<Vec<_>>();
293 for info in &type_emitted {
294 if info.generic_args.len() != expr_garg.len()
295 || !expr_garg.iter().zip(&info.generic_args).all(|two| {
296 matches!(
297 two,
298 (GenericArgument::Lifetime(_), GenericArgument::Lifetime(_))
299 | (GenericArgument::Const(_), GenericArgument::Const(_))
300 | (GenericArgument::Type(_), GenericArgument::Type(_))
301 )
302 })
303 {
304 abort!(
305 info.generic_args.span(),
306 "The generic arguments are incompatible with generic params in expression."
307 )
308 }
309 }
310 let mut impl_generics =
311 merge_generic_params(impl_generics_env, expr_generics.params).collect::<Vec<_>>();
312 for g in impl_generics.iter_mut() {
313 if let GenericParam::Type(TypeParam { ident, bounds, .. }) = g {
314 if let Some(bs) = analyzed.get(ident) {
315 for b in bs {
316 bounds.push(b.clone());
317 }
318 }
319 }
320 }
321 let ty_generics = impl_generics
322 .iter()
323 .cloned()
324 .map(generic_param_to_arg)
325 .collect::<Vec<_>>();
326 let where_clause = expr_generics
327 .where_clause
328 .clone()
329 .map(|wc| wc.predicates)
330 .into_iter()
331 .flatten()
332 .chain(where_clause_env)
333 .collect::<Vec<_>>();
334 let (unspecified_ty_params, variants) = found_exprs.iter().enumerate().fold(
335 (vec![], vec![]),
336 |(mut ty_params, mut variants), (i, info)| {
337 if let Some(reft) = &info.reftype_ident {
338 variants.push((
339 info.variant_ident.clone(),
340 parse_quote!(<#reft as #typeref_ident<#(#ty_generics),*>>::Type),
341 ));
342 } else {
343 let tp_ident =
344 Ident::new(&format!("__Sumtype_TypeParam_{}", i), Span::call_site());
345 variants.push((info.variant_ident.clone(), parse_quote!(#tp_ident)));
346 ty_params.push(tp_ident);
347 }
348 (ty_params, variants)
349 },
350 );
351 if let (Some(info), true) = (
352 found_exprs.iter().find(|info| info.reftype_ident.is_none()),
353 !type_emitted.is_empty(),
354 ) {
355 abort!(
356 &info.span,
357 r#"
358To emit full type, you should specify the type.
359Example: sumtype!(std::iter::empty(), std::iter::Empty<T>)
360"#
361 )
362 } else {
363 let replaced_ty_generics: Vec<_> = ty_generics
364 .iter()
365 .map(|ga| match ga {
366 GenericArgument::Lifetime(lt) => quote!(& #lt ()),
367 GenericArgument::Const(_) => quote!(),
368 o => quote!(#o),
369 })
370 .collect();
371 let constraint_traits = (0..args.bounds.len())
372 .map(|n| {
373 Ident::new(
374 &format!("__Sumtype_ConstraintExprTrait_{}_{}", n, random()),
375 Span::call_site(),
376 )
377 })
378 .collect::<Vec<_>>();
379 let out = quote! {
380 #(for reft in &reftypes) {
381 #[doc(hidden)]
382 #[allow(non_camel_case_types)]
383 #[allow(non_camel_case_types)]
384 struct #reft;
385 }
386 #[doc(hidden)]
387 #[allow(non_camel_case_types)]
388 trait #typeref_ident <#(#impl_generics),*> { type Type; }
389 #[doc(hidden)]
390 #[allow(non_camel_case_types)]
391 #vis enum #enum_ident <
392 #(#impl_generics),*
393 #(if !impl_generics.is_empty() && !unspecified_ty_params.is_empty()) { , }
394 #(#unspecified_ty_params),*
395 > {
396 #(for (ident, ty) in &variants) {
397 #ident ( #ty ),
398 }
399 __Uninhabited(
400 (
401 ::core::convert::Infallible,
402 #(::core::marker::PhantomData<#replaced_ty_generics>),*
403 )
404 ),
405 }
406 #[doc(hidden)]
407 #[allow(non_camel_case_types)]
408 trait #constraint_expr_trait_ident<#(#impl_generics),*> {}
409 impl<#(#impl_generics,)*__Sumtype_TypeParam> #constraint_expr_trait_ident<#(#ty_generics),*> for __Sumtype_TypeParam
410 where
411 #(for t in &constraint_traits) {
412 __Sumtype_TypeParam: #t<#(#ty_generics),*>,
413 }
414 #(#where_clause,)*
415 {}
416 #(for (trait_, constraint_trait) in args.bounds.iter().zip(&constraint_traits)) {
417 #{ SumTypeImpl::Trait(trait_.clone()).gen(
418 &path_of_ident(enum_ident.clone(), false),
419 unspecified_ty_params.as_slice(),
420 variants.as_slice(),
421 impl_generics.clone(),
422 ty_generics.clone(),
423 where_clause.clone(),
424 constraint_trait,
425 ) }
426 }
427 };
428 (out, self)
429 }
430 }
431}
432
433const _: () = {
434 use syn::visit_mut::VisitMut;
435 struct Visitor<'a> {
436 enum_path: &'a Path,
437 typeref_path: &'a Path,
438 constraint_expr_trait_path: &'a Path,
439 found_exprs: Vec<ExprMacroInfo>,
440 emit_type: Vec<TypeMacroInfo>,
441 generics: Option<&'a Generics>,
442 is_module: bool,
443 }
444
445 impl ProcessTree for Block {
446 fn collect_inline_macro(
447 &mut self,
448 enum_path: &Path,
449 typeref_path: &Path,
450 constraint_expr_trait_path: &Path,
451 generics: Option<&Generics>,
452 is_module: bool,
453 ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
454 let mut visitor = Visitor::new(
455 enum_path,
456 typeref_path,
457 constraint_expr_trait_path,
458 generics,
459 is_module,
460 );
461 visitor.visit_block_mut(self);
462 (visitor.found_exprs, visitor.emit_type)
463 }
464 }
465
466 impl ProcessTree for Item {
467 fn collect_inline_macro(
468 &mut self,
469 enum_path: &Path,
470 typeref_path: &Path,
471 constraint_expr_trait_path: &Path,
472 generics: Option<&Generics>,
473 is_module: bool,
474 ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
475 let mut visitor = Visitor::new(
476 enum_path,
477 typeref_path,
478 constraint_expr_trait_path,
479 generics,
480 is_module,
481 );
482 visitor.visit_item_mut(self);
483 (visitor.found_exprs, visitor.emit_type)
484 }
485 }
486
487 impl ProcessTree for Stmt {
488 fn collect_inline_macro(
489 &mut self,
490 enum_path: &Path,
491 typeref_path: &Path,
492 constraint_expr_trait_path: &Path,
493 generics: Option<&Generics>,
494 is_module: bool,
495 ) -> (Vec<ExprMacroInfo>, Vec<TypeMacroInfo>) {
496 let mut visitor = Visitor::new(
497 enum_path,
498 typeref_path,
499 constraint_expr_trait_path,
500 generics,
501 is_module,
502 );
503 visitor.visit_stmt_mut(self);
504 (visitor.found_exprs, visitor.emit_type)
505 }
506 }
507
508 impl<'a> Visitor<'a> {
509 fn new(
510 enum_path: &'a Path,
511 typeref_path: &'a Path,
512 constraint_expr_trait_path: &'a Path,
513 generics: Option<&'a Generics>,
514 is_module: bool,
515 ) -> Self {
516 Self {
517 enum_path,
518 typeref_path,
519 constraint_expr_trait_path,
520 found_exprs: Vec::new(),
521 emit_type: Vec::new(),
522 generics,
523 is_module,
524 }
525 }
526 fn do_type_macro(&mut self, mac: &Macro) -> TokenStream {
527 #[derive(Parse)]
528 struct Arg {
529 #[call(Punctuated::parse_terminated)]
530 generic_args: Punctuated<GenericArgument, Token![,]>,
531 }
532 let arg: Arg = mac
533 .parse_body()
534 .unwrap_or_else(|e| abort!(e.span(), &format!("{}", &e)));
535 let ty_generics = merge_generic_args(
536 self.generics
537 .iter()
538 .flat_map(|g| g.params.iter().cloned().map(generic_param_to_arg)),
539 arg.generic_args.clone(),
540 )
541 .collect::<Vec<_>>();
542 self.emit_type.push(TypeMacroInfo {
543 _span: mac.span(),
544 generic_args: arg.generic_args,
545 });
546 quote! {
547 #{&self.enum_path}
548 #(if !ty_generics.is_empty()){
549 <#(#ty_generics),*>
550 }
551 }
552 }
553
554 fn analyze_lifetime_bounds(
555 &self,
556 generics: &Generics,
557 ty: &Type,
558 ) -> HashMap<Ident, HashSet<Lifetime>> {
559 struct LifetimeVisitor {
560 generic_lifetimes: HashSet<Lifetime>,
561 generic_params: HashSet<Ident>,
562 lifetime_stack: Vec<Lifetime>,
563 result: HashMap<Ident, HashSet<Lifetime>>,
564 }
565 use syn::visit::Visit;
566 impl syn::visit::Visit<'_> for LifetimeVisitor {
567 fn visit_type_reference(&mut self, i: &TypeReference) {
568 if let Some(lt) = &i.lifetime {
569 if self.generic_lifetimes.contains(lt) {
570 self.lifetime_stack.push(lt.clone());
571 syn::visit::visit_type_reference(self, i);
572 self.lifetime_stack.pop();
573 return;
574 }
575 }
576 syn::visit::visit_type_reference(self, i);
577 }
578 fn visit_type_path(&mut self, i: &TypePath) {
579 if i.qself.is_none() {
580 if let Some(id) = i.path.get_ident() {
581 if self.generic_params.contains(id) {
582 self.result
583 .entry(id.clone())
584 .or_default()
585 .extend(self.lifetime_stack.clone());
586 }
587 return;
588 }
589 }
590 syn::visit::visit_type_path(self, i);
591 }
592 }
593 let mut visitor = LifetimeVisitor {
594 generic_lifetimes: generics
595 .params
596 .iter()
597 .filter_map(|p| {
598 if let GenericParam::Lifetime(LifetimeParam { lifetime, .. }) = p {
599 Some(lifetime.clone())
600 } else {
601 None
602 }
603 })
604 .collect(),
605 generic_params: generics
606 .params
607 .iter()
608 .filter_map(|p| {
609 if let GenericParam::Type(TypeParam { ident, .. }) = p {
610 Some(ident.clone())
611 } else {
612 None
613 }
614 })
615 .collect(),
616 lifetime_stack: Vec::new(),
617 result: HashMap::new(),
618 };
619 visitor.visit_type(ty);
620 visitor.result
621 }
622
623 fn do_expr_macro(&mut self, mac: &Macro) -> TokenStream {
624 #[derive(Parse)]
625 struct Arg {
626 expr: Expr,
627 _comma_token: Option<Token![,]>,
628 _for_token: Option<Token![for]>,
629 #[prefix(Option<Token![<]>)]
630 #[postfix(Option<Token![>]>)]
631 #[parse_if(_for_token.is_some())]
632 #[call(Punctuated::parse_separated_nonempty)]
633 for_generics: Option<Punctuated<GenericParam, Token![,]>>,
634 #[parse_if(_comma_token.is_some())]
635 ty: Option<Type>,
636 #[parse_if(_comma_token.is_some())]
637 where_clause: Option<Option<WhereClause>>,
638 }
639 let arg: Arg = mac
640 .parse_body()
641 .unwrap_or_else(|e| abort!(e.span(), &format!("{}", &e)));
642 let n = self.found_exprs.len();
643 let variant_ident = Ident::new(&format!("__SumType_Variant_{}", n), Span::call_site());
644 let reftype_ident = Ident::new(
645 &format!("__SumType_RefType_{}_{}", random(), n),
646 Span::call_site(),
647 );
648 let reftype_path = path_of_ident(reftype_ident.clone(), self.is_module);
649 let id_fn_ident =
650 Ident::new(&format!("__sum_type_id_fn_{}", random()), Span::call_site());
651 let (mut impl_generics, _, where_clause) = split_for_impl(self.generics);
652 let analyzed =
653 if let (Some(generics), Some(ty)) = (self.generics.as_ref(), arg.ty.as_ref()) {
654 self.analyze_lifetime_bounds(generics, ty)
655 } else {
656 HashMap::new()
657 };
658 let generics = Generics {
659 params: arg.for_generics.clone().unwrap_or_default(),
660 where_clause: arg.where_clause.unwrap_or(Some(WhereClause {
661 predicates: Punctuated::new(),
662 where_token: Default::default(),
663 })),
664 ..Default::default()
665 };
666 for g in impl_generics.iter_mut() {
667 if let GenericParam::Type(TypeParam { ident, bounds, .. }) = g {
668 if let Some(lts) = analyzed.get(ident) {
669 for lt in lts {
670 bounds.push(TypeParamBound::Lifetime(lt.clone().clone()));
671 }
672 }
673 }
674 }
675 let impl_generics =
676 merge_generic_params(impl_generics, generics.params.clone()).collect::<Vec<_>>();
677 let ty_generics = impl_generics
678 .iter()
679 .cloned()
680 .map(generic_param_to_arg)
681 .collect::<Vec<_>>();
682 let where_clause = generics
683 .where_clause
684 .clone()
685 .map(|wc| wc.predicates)
686 .into_iter()
687 .flatten()
688 .chain(where_clause)
689 .collect::<Vec<_>>();
690 self.found_exprs.push(ExprMacroInfo {
691 span: mac.span(),
692 variant_ident: variant_ident.clone(),
693 reftype_ident: arg.ty.as_ref().map(|_| reftype_ident.clone()),
694 analyzed_bounds: analyzed.clone(),
695 generics,
696 });
697 quote! {
698 {
699 #(if let Some(ty) = &arg.ty){
700 impl<#(#impl_generics,)*> #{&self.typeref_path} <#(#ty_generics),*> for #reftype_path
701 #(if !where_clause.is_empty()) {
702 where #(#where_clause,)*
703 }
704 {
705 type Type = #ty;
706 }
707 }
708 fn #id_fn_ident<
709 #(#impl_generics,)* __SumType_T: #{&self.constraint_expr_trait_path}<#(#ty_generics),*>
710 >(t: __SumType_T) -> __SumType_T
711 #(if !where_clause.is_empty()) {
712 where #(#where_clause,)*
713 }
714 { t }
715 #id_fn_ident::<#(#ty_generics,)*_>(#{&self.enum_path}::#variant_ident(#{&arg.expr}))
716 }
717 }
718 }
719 }
720
721 impl VisitMut for Visitor<'_> {
722 fn visit_type_mut(&mut self, ty: &mut Type) {
723 if let Type::Macro(tm) = &*ty {
724 if tm.mac.path.is_ident("sumtype") {
725 let out = self.do_type_macro(&tm.mac);
726 *ty = parse2(out).unwrap();
727 return;
728 }
729 }
730 syn::visit_mut::visit_type_mut(self, ty);
731 }
732
733 fn visit_expr_mut(&mut self, expr: &mut Expr) {
734 if let Expr::Macro(em) = &*expr {
735 if em.mac.path.is_ident("sumtype") {
736 let out = self.do_expr_macro(&em.mac);
737 *expr = parse2(out).unwrap();
738 return;
739 }
740 }
741 syn::visit_mut::visit_expr_mut(self, expr);
742 }
743
744 fn visit_stmt_mut(&mut self, stmt: &mut Stmt) {
745 if let Stmt::Macro(sm) = &*stmt {
746 if sm.mac.path.is_ident("sumtype") {
747 let out = self.do_expr_macro(&sm.mac);
748 *stmt = parse2(out).unwrap();
749 return;
750 }
751 }
752 syn::visit_mut::visit_stmt_mut(self, stmt);
753 }
754 }
755};
756
757fn inner(args: &Arguments, input: TokenStream) -> TokenStream {
758 let public = Visibility::Public(Default::default());
759 if let Ok(block) = parse2::<Block>(input.clone()) {
760 let (out, block) = block.emit_items(args, None, false, public);
761 quote! { #out #[allow(non_local_definitions)] #block }
762 } else if let Ok(item_trait) = parse2::<ItemTrait>(input.clone()) {
763 let generics = item_trait.generics.clone();
764 let vis = item_trait.vis.clone();
765 let (out, item) = Item::Trait(item_trait).emit_items(args, Some(&generics), false, vis);
766 quote! { #out #[allow(non_local_definitions)] #item }
767 } else if let Ok(item_impl) = parse2::<ItemImpl>(input.clone()) {
768 let generics = item_impl.generics.clone();
769 let (out, item) = Item::Impl(item_impl).emit_items(args, Some(&generics), false, public);
770 quote! { #out #[allow(non_local_definitions)] #item }
771 } else if let Ok(item_fn) = parse2::<ItemFn>(input.clone()) {
772 let generics = item_fn.sig.generics.clone();
773 let vis = item_fn.vis.clone();
774 let (out, item) = Item::Fn(item_fn).emit_items(args, Some(&generics), false, vis);
775 quote! { #out #[allow(non_local_definitions)] #item }
776 } else if let Ok(item_mod) = parse2::<ItemMod>(input.clone()) {
777 let (out, item) = Item::Mod(item_mod).emit_items(args, None, true, public);
778 quote! { #out #[allow(non_local_definitions)] #item }
779 } else if let Ok(item) = parse2::<Item>(input.clone()) {
780 let (out, item) = item.emit_items(args, None, false, public);
781 quote! { #out #[allow(non_local_definitions)] #item }
782 } else if let Ok(stmt) = parse2::<Stmt>(input.clone()) {
783 let (out, stmt) = stmt.emit_items(args, None, false, public);
784 quote! { #out #[allow(non_local_definitions)] #stmt }
785 } else {
786 abort!(input.span(), "This element is not supported")
787 }
788}
789
790fn process_supported_supertraits<'a>(
791 traits: impl IntoIterator<Item = &'a TypeParamBound>,
792 krate: &Path,
793) -> (Vec<Path>, Vec<Path>) {
794 let mut supertraits = Vec::new();
795 let mut derive_traits = Vec::new();
796 for tpb in traits.into_iter() {
797 if let TypeParamBound::Trait(tb) = tpb {
798 if let Some(ident) = tb.path.get_ident() {
799 match ident.to_string().as_str() {
800 "Copy" | "Clone" | "Hash" | "Eq" => {
801 supertraits.push(parse_quote!(#krate::traits::#ident))
802 }
803 "PartialEq" => derive_traits.push(parse_quote!(PartialEq)),
804 o if o.starts_with("__SumTrait_Sealed") => (),
805 _ => (),
806 }
807 } else {
808 supertraits.push(tb.path.clone())
809 }
810 } else {
811 abort!(tpb.span(), "Only path is supported");
812 }
813 }
814 (supertraits, derive_traits)
815}
816
817fn collect_typeref_types(input: &ItemTrait) -> Vec<Type> {
818 let mut leaker = Leaker::from_trait(input)
819 .unwrap_or_else(|_| Leaker::with_generics(input.generics.clone()));
820 leaker.self_ty_can_be_interned = false;
821 leaker
822 .finish()
823 .iter()
824 .cloned()
825 .collect()
826}
827
828fn sumtrait_impl(
829 args: Option<Path>,
830 marker_path: &Path,
831 krate: &Path,
832 input: ItemTrait,
833) -> TokenStream {
834 let (supertraits, derive_traits) = process_supported_supertraits(&input.supertraits, krate);
835 for item in &input.items {
836 match item {
837 TraitItem::Const(_) => abort!(item.span(), "associated const is not supported"),
838 TraitItem::Fn(tfn) => {
839 if tfn.sig.inputs.is_empty() || !matches!(&tfn.sig.inputs[0], FnArg::Receiver(_)) {
840 abort!(tfn.sig.span(), "requires receiver")
841 }
842 }
843 TraitItem::Type(tty) => {
844 if tty.default.is_some() {
845 abort!(tty.span(), "associated type defaults is not supported")
846 }
847 if !tty.generics.params.is_empty() || tty.generics.where_clause.is_some() {
848 abort!(
849 tty.generics.span(),
850 "generalized associated types is not supported"
851 )
852 }
853 }
854 o => abort!(o.span(), "Not supported"),
855 }
856 }
857 let temporary_mac_name =
858 Ident::new(&format!("__sumtype_macro_{}", random()), Span::call_site());
859 let typeref_types = collect_typeref_types(&input);
860 let (_, _, where_clause) = input.generics.split_for_impl();
861 let typeref_id = random() as usize;
862 quote! {
863 #input
864 #(for (i, ty) in typeref_types.iter().enumerate()) {
865 impl<#(for p in &input.generics.params),{#p}> #krate::TypeRef<#typeref_id, #i> for #marker_path #where_clause {
866 type Type = #ty;
867 }
868 }
869
870 #[doc(hidden)]
871 #[macro_export]
872 macro_rules! #temporary_mac_name {
873 ($($t:tt)*) => {
874 #krate::_sumtrait_internal!(
875 { $($t)* }
876 [#(#typeref_types),*],
877 {#input},
878 #typeref_id,
879 #krate,
880 #marker_path,
881 [#{args.map(|m| quote!(#m)).unwrap_or(quote!(_))}],
882 [#(#supertraits),*],
883 [#(#derive_traits),*],
884 );
885 };
886 }
887 #[doc(hidden)]
888 #{&input.vis} use #temporary_mac_name as #{&input.ident};
889 }
890}
891
892#[doc(hidden)]
893#[proc_macro_error]
894#[proc_macro]
895pub fn _sumtrait_internal(input: TokenStream1) -> TokenStream1 {
896 sumtrait_internal::sumtrait_internal(input.into()).into()
897}
898
899#[proc_macro_error]
900#[proc_macro_attribute]
901pub fn sumtrait(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
902 #[derive(FromMeta, Debug)]
903 struct SumtraitArgs {
904 implement: Option<Path>,
905 krate: Option<Path>,
906 marker: Path,
907 }
908 let args = SumtraitArgs::from_list(&NestedMeta::parse_meta_list(attr.into()).unwrap()).unwrap();
909
910 let krate = args.krate.unwrap_or(parse_quote!(::sumtype));
911 sumtrait_impl(
912 args.implement,
913 &args.marker,
914 &krate,
915 parse(input).unwrap_or_else(|_| abort!(Span::call_site(), "Requires trait definition")),
916 )
917 .into()
918}
919
920#[proc_macro_error]
983#[proc_macro_attribute]
984pub fn sumtype(attr: TokenStream1, input: TokenStream1) -> TokenStream1 {
985 inner(&parse_macro_input!(attr as Arguments), input.into()).into()
986}