1#![allow(non_snake_case)]
2#![warn(missing_docs)]
3
4use macro_magic::import_tokens_attr;
7use proc_macro::TokenStream;
8use proc_macro2::{TokenStream as TokenStream2, TokenTree};
9use quote::{ToTokens, format_ident, quote};
10use rand::{distr::StandardUniform, prelude::Distribution};
11use std::{
12 cell::RefCell,
13 collections::{HashMap, HashSet},
14 fmt::Debug,
15 sync::atomic::AtomicU64,
16};
17use syn::{
18 Error, GenericParam, Generics, Ident, ImplItem, ImplItemFn, Item, ItemFn, ItemImpl, ItemMod,
19 ItemTrait, Path, Result, Signature, TraitItem, TraitItemFn, TraitItemType, TypePath,
20 Visibility, WherePredicate,
21 parse::{Nothing, Parse, ParseStream},
22 parse_macro_input, parse_quote, parse_str, parse2,
23 punctuated::Punctuated,
24 spanned::Spanned,
25 visit::Visit,
26 visit_mut::VisitMut,
27};
28
29#[cfg(feature = "debug")]
30use proc_utils::PrettyPrint;
31
32mod generic_visitor;
33use generic_visitor::*;
34
35static IMPL_COUNT: AtomicU64 = AtomicU64::new(0);
36thread_local! {
37 static SUPERTRAIT_PATH: RefCell<String> = RefCell::new(String::from("::supertrait"));
38}
39
40fn get_supertrait_path() -> Path {
41 SUPERTRAIT_PATH.with(|p| parse_str(p.borrow().clone().as_str()).unwrap())
42}
43
44fn random<T>() -> T
45where
46 StandardUniform: Distribution<T>,
47{
48 rand::random()
49}
50
51#[proc_macro]
59pub fn set_supertrait_path(tokens: TokenStream) -> TokenStream {
60 let path = parse_macro_input!(tokens as Path);
61 SUPERTRAIT_PATH.with(|p| p.replace(path.to_token_stream().to_string()));
62 quote!().into()
63}
64
65struct SuperTraitDef {
66 pub orig_trait: ItemTrait,
67 pub const_fns: Vec<TraitItemFn>,
68 pub types_with_defaults: Vec<TraitItemType>,
69 pub other_items: Vec<TraitItem>,
70}
71
72impl Parse for SuperTraitDef {
73 fn parse(input: ParseStream) -> Result<Self> {
74 let orig_trait = input.parse::<ItemTrait>()?;
75 let mut const_fns: Vec<TraitItemFn> = Vec::new();
76 let mut types_with_defaults: Vec<TraitItemType> = Vec::new();
77 let mut other_items: Vec<TraitItem> = Vec::new();
78 for trait_item in &orig_trait.items {
79 match trait_item {
80 TraitItem::Fn(trait_item_fn) => match trait_item_fn.sig.constness {
81 Some(_) => const_fns.push(trait_item_fn.clone()),
82 None => other_items.push(trait_item.clone()),
83 },
84 TraitItem::Type(typ) => match typ.default {
85 Some(_) => types_with_defaults.push(typ.clone()),
86 None => other_items.push(trait_item.clone()),
87 },
88 other_item => other_items.push(other_item.clone()),
89 }
90 }
91 Ok(SuperTraitDef {
92 orig_trait,
93 const_fns,
94 types_with_defaults,
95 other_items,
96 })
97 }
98}
99
100struct FilteredGenerics {
101 use_generics: Generics,
104 impl_generics: Generics,
107 has_defaults: HashSet<Ident>,
108}
109
110impl FilteredGenerics {
111 fn strip_default_generics(&mut self) {
112 let has_defaults = self
114 .impl_generics
115 .params
116 .iter()
117 .filter_map(|g| match g {
118 GenericParam::Lifetime(_) => None,
119 GenericParam::Type(typ) => match typ.default {
120 Some(_) => Some(typ.force_get_ident()),
121 None => None,
122 },
123 GenericParam::Const(constant) => match constant.default {
124 Some(_) => Some(constant.force_get_ident()), None => None,
126 },
127 })
128 .collect::<HashSet<Ident>>();
129 self.impl_generics.params = self
131 .impl_generics
132 .params
133 .iter()
134 .filter(|g| !has_defaults.contains(&g.force_get_ident()))
135 .cloned()
136 .collect();
137 self.use_generics.params = self
139 .use_generics
140 .params
141 .iter()
142 .filter(|g| !has_defaults.contains(&g.force_get_ident()))
143 .cloned()
144 .collect();
145 self.has_defaults = has_defaults;
146 }
147}
148
149fn filter_generics(generics: &Generics, whitelist: &HashSet<GenericUsage>) -> FilteredGenerics {
150 let filtered_generic_params = generics
151 .params
152 .iter()
153 .cloned()
154 .filter(|g| whitelist.contains(&g.into()));
155
156 let filtered_where_clause = match &generics.where_clause {
158 Some(where_clause) => {
159 let mut where_clause = where_clause.clone();
160 let predicates_filtered = where_clause.predicates.iter().filter(|p| match *p {
161 WherePredicate::Lifetime(lifetime) => {
162 whitelist.contains(&GenericUsage::from_lifetime(&lifetime.lifetime))
163 }
164 WherePredicate::Type(typ) => {
165 whitelist.contains(&GenericUsage::from_type(&typ.bounded_ty))
166 }
167 _ => unimplemented!(),
168 });
169 where_clause.predicates = parse_quote!(#(#predicates_filtered),*);
170 Some(where_clause)
171 }
172 None => None,
173 };
174
175 let use_generic_params = filtered_generic_params.clone().map(|g| match g {
176 GenericParam::Lifetime(lifetime) => lifetime.lifetime.to_token_stream(),
177 GenericParam::Type(typ) => typ.ident.to_token_stream(),
178 GenericParam::Const(constant) => constant.ident.to_token_stream(),
179 });
180
181 let use_generics = Generics {
182 lt_token: parse_quote!(<),
183 params: parse_quote!(#(#use_generic_params),*),
184 gt_token: parse_quote!(>),
185 where_clause: filtered_where_clause.clone(),
186 };
187
188 let impl_generics = Generics {
189 lt_token: parse_quote!(<),
190 params: parse_quote!(#(#filtered_generic_params),*),
191 gt_token: parse_quote!(>),
192 where_clause: filtered_where_clause,
193 };
194
195 FilteredGenerics {
196 use_generics,
197 impl_generics,
198 has_defaults: HashSet::new(),
199 }
200}
201
202struct DefaultRemover;
203
204impl VisitMut for DefaultRemover {
205 fn visit_generics_mut(&mut self, generics: &mut Generics) {
206 for param in &mut generics.params {
207 if let syn::GenericParam::Type(type_param) = param {
208 type_param.default = None;
209 }
210 }
211 }
212}
213
214trait PunctuatedExtension<T: ToTokens + Clone, P: ToTokens + std::default::Default>: Sized {
215 fn push_front(&mut self, value: T);
216}
217
218impl<T: ToTokens + Clone, P: ToTokens + std::default::Default> PunctuatedExtension<T, P>
219 for Punctuated<T, P>
220{
221 fn push_front(&mut self, value: T) {
222 let mut new_punctuated = Punctuated::new();
223 new_punctuated.push(value);
224 new_punctuated.extend(self.iter().cloned());
225 *self = new_punctuated;
226 }
227}
228
229#[proc_macro_attribute]
411pub fn supertrait(attr: TokenStream, tokens: TokenStream) -> TokenStream {
412 let mut attr = attr;
413 let debug = debug_feature(&mut attr);
414 match supertrait_internal(attr, tokens, debug) {
415 Ok(tokens) => tokens.into(),
416 Err(err) => err.into_compile_error().into(),
417 }
418}
419
420fn debug_feature(attr: &mut TokenStream) -> bool {
421 if let Ok(ident) = syn::parse::<Ident>(attr.clone()) {
422 if ident == "debug" {
423 *attr = TokenStream::new();
424 if cfg!(feature = "debug") {
425 true
426 } else {
427 println!(
428 "warning: the 'debug' feature must be enabled for debug to work on supertrait attributes."
429 );
430 false
431 }
432 } else {
433 false
434 }
435 } else {
436 false
437 }
438}
439
440fn supertrait_internal(
441 attr: impl Into<TokenStream2>,
442 tokens: impl Into<TokenStream2>,
443 #[allow(unused)] debug: bool,
444) -> Result<TokenStream2> {
445 parse2::<Nothing>(attr.into())?;
446 let def = parse2::<SuperTraitDef>(tokens.into())?;
447 let export_tokens_ident = format_ident!("{}_exported_tokens", def.orig_trait.ident);
448 let mut modified_trait = def.orig_trait;
449 modified_trait.items = def.other_items;
450 let ident = modified_trait.ident.clone();
451 let attrs = modified_trait.attrs.clone();
452 let mut defaults = def.types_with_defaults;
453 let unfilled_defaults = defaults
454 .iter()
455 .cloned()
456 .map(|mut typ| {
457 typ.default = None;
458 typ
459 })
460 .collect::<Vec<_>>();
461 let mut visitor = FindGenericParam::new(&modified_trait.generics);
462 let mut replace_self = ReplaceSelfType {
463 replace_type: parse_quote!(__Self),
464 };
465 for trait_item_type in &mut defaults {
466 visitor.visit_trait_item_type(trait_item_type);
467 replace_self.visit_trait_item_type_mut(trait_item_type);
468 }
469
470 let mut default_generics = filter_generics(&modified_trait.generics, &visitor.usages);
471 default_generics
472 .impl_generics
473 .params
474 .push_front(parse_quote!(__Self));
475 default_generics
476 .use_generics
477 .params
478 .push_front(parse_quote!(__Self));
479
480 let default_impl_generics = default_generics.impl_generics;
481 let default_use_generics = default_generics.use_generics;
482
483 modified_trait.ident = parse_quote!(Trait);
484
485 let trait_use_generic_params = modified_trait.generics.params.iter().map(|g| match g {
486 GenericParam::Lifetime(lifetime) => lifetime.lifetime.to_token_stream(),
487 GenericParam::Type(typ) => typ.ident.to_token_stream(),
488 GenericParam::Const(constant) => constant.ident.to_token_stream(),
489 });
490
491 let trait_impl_generics = modified_trait.generics.clone();
492 let trait_use_generics = Generics {
493 lt_token: parse_quote!(<),
494 params: parse_quote!(#(#trait_use_generic_params),*),
495 gt_token: parse_quote!(>),
496 where_clause: modified_trait.generics.where_clause.clone(),
497 };
498
499 let const_fns = def.const_fns;
500 let mut trait_impl_generics_fn: ItemFn = parse_quote! { fn trait_impl_generics() {} };
501 let mut trait_use_generics_fn: ItemFn = parse_quote! { fn trait_use_generics() {} };
502 let mut default_impl_generics_fn: ItemFn = parse_quote! { fn default_impl_generics() {} };
503 let mut default_use_generics_fn: ItemFn = parse_quote! { fn default_use_generics() {} };
504 trait_impl_generics_fn.sig.generics = trait_impl_generics.clone();
505 trait_use_generics_fn.sig.generics = trait_use_generics;
506 default_impl_generics_fn.sig.generics = default_impl_generics.clone();
507 default_use_generics_fn.sig.generics = default_use_generics.clone();
508
509 modified_trait
510 .items
511 .extend(unfilled_defaults.iter().map(|item| parse_quote!(#item)));
512
513 let converted_const_fns = const_fns.iter().map(|const_fn| {
514 let mut item = const_fn.clone();
515 item.sig.constness = None;
516 let item: TraitItem = parse_quote!(#item);
517 item
518 });
519 modified_trait.items.extend(converted_const_fns);
520
521 let supertrait_path = get_supertrait_path();
522
523 let random_value: u32 = random();
525 let sealed_ident = format_ident!("SupertraitSealed{random_value}");
526 let sealed_trait: ItemTrait = parse_quote!(pub trait #sealed_ident {});
527 modified_trait.supertraits.push(parse_quote!(#sealed_ident));
528
529 let mut default_remover = DefaultRemover {};
530 let mut default_impl_generics_no_defaults = default_impl_generics.clone();
531 default_remover.visit_generics_mut(&mut default_impl_generics_no_defaults);
532
533 modified_trait.vis = parse_quote!(pub);
534
535 for def in defaults.iter_mut() {
536 def.bounds.clear()
537 }
538
539 let output = quote! {
540 #(#attrs)*
541 #[allow(non_snake_case)]
542 pub mod #ident {
543 use super::*;
544
545 pub struct Defaults;
547
548 pub trait DefaultTypes #default_impl_generics {
552 #[doc(hidden)]
554 type __Self;
555 #(#unfilled_defaults)*
556 }
557
558 impl #default_impl_generics_no_defaults DefaultTypes #default_use_generics for Defaults {
559 #(#defaults)*
560 #[doc(hidden)]
561 type __Self = ();
562 }
563
564 #[doc(hidden)]
567 #sealed_trait
568
569 #(#attrs)*
572 #modified_trait
573
574 #[#supertrait_path::__private::macro_magic::export_tokens_no_emit(#export_tokens_ident)]
575 mod exported_tokens {
576 trait ConstFns {
577 #(#const_fns)*
578 }
579
580 #trait_impl_generics_fn
581 #trait_use_generics_fn
582 #default_impl_generics_fn
583 #default_use_generics_fn
584
585 mod default_items {
586 #(#defaults)*
587 }
588
589 const #sealed_ident: () = ();
590 }
591 }
592 };
593 #[cfg(feature = "debug")]
594 if debug {
595 output.pretty_print();
596 }
597 Ok(output)
598}
599
600#[doc(hidden)]
601#[import_tokens_attr(format!(
602 "{}::__private::macro_magic",
603 get_supertrait_path().to_token_stream().to_string()
604))]
605#[proc_macro_attribute]
606pub fn __impl_supertrait(attr: TokenStream, tokens: TokenStream) -> TokenStream {
607 match impl_supertrait_internal(attr, tokens) {
608 Ok(tokens) => tokens.into(),
609 Err(err) => err.into_compile_error().into(),
610 }
611}
612
613#[proc_macro_attribute]
759pub fn impl_supertrait(attr: TokenStream, tokens: TokenStream) -> TokenStream {
760 let mut attr = attr;
761 let debug = debug_feature(&mut attr);
762 parse_macro_input!(attr as Nothing);
763 let item_impl = parse_macro_input!(tokens as ItemImpl);
764 let trait_being_impled = match item_impl.trait_.clone() {
765 Some((_, path, _)) => path,
766 None => return Error::new(
767 item_impl.span(),
768 "Supertrait impls must have a trait being implemented. Inherent impls are not supported."
769 ).into_compile_error().into(),
770 }.strip_trailing_generics();
771 let supertrait_path = get_supertrait_path();
772 let export_tokens_ident = format_ident!(
773 "{}_exported_tokens",
774 trait_being_impled.segments.last().unwrap().ident
775 );
776 let debug_tokens = if debug {
777 quote!(#[debug_mode])
778 } else {
779 quote!()
780 };
781 let output = quote! {
782 #[#supertrait_path::__impl_supertrait(#trait_being_impled::#export_tokens_ident)]
783 #debug_tokens
784 #item_impl
785 };
786 output.into()
787}
788
789struct ImportedTokens {
790 const_fns: Vec<TraitItemFn>,
791 trait_impl_generics: Generics,
792 #[allow(unused)]
793 trait_use_generics: Generics,
794 #[allow(unused)]
795 default_impl_generics: Generics,
796 default_use_generics: Generics,
797 default_items: Vec<TraitItem>,
798 sealed_ident: Ident,
799}
800
801impl TryFrom<ItemMod> for ImportedTokens {
802 type Error = Error;
803
804 fn try_from(item_mod: ItemMod) -> std::result::Result<Self, Self::Error> {
805 if item_mod.ident != "exported_tokens" {
806 return Err(Error::new(
807 item_mod.ident.span(),
808 "expected `exported_tokens`.",
809 ));
810 }
811 let item_mod_span = item_mod.span();
812 let Some((_, main_body)) = item_mod.content else {
813 return Err(Error::new(
814 item_mod_span,
815 "`exported_tokens` module must have a defined body.",
816 ));
817 };
818 let Some(Item::Trait(ItemTrait {
819 ident: const_fns_ident,
820 items: const_fns,
821 ..
822 })) = main_body.get(0)
823 else {
824 return Err(Error::new(
825 item_mod_span,
826 "the first item in `exported_tokens` should be a trait called `ConstFns`.",
827 ));
828 };
829 if const_fns_ident != "ConstFns" {
830 return Err(Error::new(const_fns_ident.span(), "expected `ConstFns`."));
831 }
832
833 let const_fns: Vec<TraitItemFn> = const_fns
834 .into_iter()
835 .map(|item| match item {
836 TraitItem::Fn(trait_item_fn) => Ok(trait_item_fn.clone()),
837 _ => return Err(Error::new(item.span(), "expected `fn`")), })
839 .collect::<std::result::Result<_, Self::Error>>()?;
840
841 let Some(Item::Fn(ItemFn {
842 sig:
843 Signature {
844 ident: trait_impl_generics_ident,
845 generics: trait_impl_generics,
846 ..
847 },
848 ..
849 })) = main_body.get(1)
850 else {
851 return Err(Error::new(
852 item_mod_span,
853 "the second item in `exported_tokens` should be an fn called `trait_impl_generics`.",
854 ));
855 };
856 if trait_impl_generics_ident != "trait_impl_generics" {
857 return Err(Error::new(
858 trait_impl_generics_ident.span(),
859 "expected `trait_impl_generics`.",
860 ));
861 }
862
863 let Some(Item::Fn(ItemFn {
864 sig:
865 Signature {
866 ident: trait_use_generics_ident,
867 generics: trait_use_generics,
868 ..
869 },
870 ..
871 })) = main_body.get(2)
872 else {
873 return Err(Error::new(
874 item_mod_span,
875 "the third item in `exported_tokens` should be an fn called `trait_use_generics`.",
876 ));
877 };
878 if trait_use_generics_ident != "trait_use_generics" {
879 return Err(Error::new(
880 trait_use_generics_ident.span(),
881 "expected `trait_use_generics`.",
882 ));
883 }
884
885 let Some(Item::Fn(ItemFn {
886 sig:
887 Signature {
888 ident: default_impl_generics_ident,
889 generics: default_impl_generics,
890 ..
891 },
892 ..
893 })) = main_body.get(3)
894 else {
895 return Err(Error::new(
896 item_mod_span,
897 "the fourth item in `exported_tokens` should be an fn called `default_impl_generics`.",
898 ));
899 };
900 if default_impl_generics_ident != "default_impl_generics" {
901 return Err(Error::new(
902 default_impl_generics_ident.span(),
903 "expected `default_impl_generics`.",
904 ));
905 }
906
907 let Some(Item::Fn(ItemFn {
908 sig:
909 Signature {
910 ident: default_use_generics_ident,
911 generics: default_use_generics,
912 ..
913 },
914 ..
915 })) = main_body.get(4)
916 else {
917 return Err(Error::new(
918 item_mod_span,
919 "the fifth item in `exported_tokens` should be an fn called `default_use_generics`.",
920 ));
921 };
922 if default_use_generics_ident != "default_use_generics" {
923 return Err(Error::new(
924 default_use_generics_ident.span(),
925 "expected `default_use_generics`.",
926 ));
927 }
928
929 let Some(Item::Mod(default_items_mod)) = main_body.get(5) else {
930 return Err(Error::new(
931 item_mod_span,
932 "the sixth item in `exported_tokens` should be a module called `default_items_mod`.",
933 ));
934 };
935 if default_items_mod.ident != "default_items" {
936 return Err(Error::new(
937 default_items_mod.ident.span(),
938 "expected `default_items`.",
939 ));
940 }
941 let Some((_, default_items)) = default_items_mod.content.clone() else {
942 return Err(Error::new(
943 default_items_mod.ident.span(),
944 "`default_items` item must be an inline module.",
945 ));
946 };
947 let default_items: Vec<TraitItem> = default_items
948 .iter()
949 .map(|item| parse_quote!(#item))
950 .collect();
951
952 let Some(Item::Const(sealed_const)) = main_body.get(6) else {
953 return Err(Error::new(
954 item_mod_span,
955 "the seventh item in `exported_tokens` should be a const specifying the sealed ident.",
956 ));
957 };
958 let sealed_ident = sealed_const.ident.clone();
959
960 Ok(ImportedTokens {
961 const_fns,
962 trait_impl_generics: trait_impl_generics.clone(),
963 trait_use_generics: trait_use_generics.clone(),
964 default_impl_generics: default_impl_generics.clone(),
965 default_use_generics: default_use_generics.clone(),
966 default_items: default_items,
967 sealed_ident,
968 })
969 }
970}
971
972trait GetIdent {
973 fn get_ident(&self) -> Option<Ident>;
974}
975
976impl GetIdent for TraitItem {
977 fn get_ident(&self) -> Option<Ident> {
978 use TraitItem::*;
979 match self {
980 Const(item_const) => Some(item_const.ident.clone()),
981 Fn(item_fn) => Some(item_fn.sig.ident.clone()),
982 Type(item_type) => Some(item_type.ident.clone()),
983 _ => None,
984 }
985 }
986}
987
988impl GetIdent for ImplItem {
989 fn get_ident(&self) -> Option<Ident> {
990 use ImplItem::*;
991 match self {
992 Const(item_const) => Some(item_const.ident.clone()),
993 Fn(item_fn) => Some(item_fn.sig.ident.clone()),
994 Type(item_type) => Some(item_type.ident.clone()),
995 _ => None,
996 }
997 }
998}
999
1000trait FlattenGroups {
1001 fn flatten_groups(&self) -> TokenStream2;
1002}
1003
1004impl FlattenGroups for TokenStream2 {
1005 fn flatten_groups(&self) -> TokenStream2 {
1006 let mut iter = self.clone().into_iter();
1007 let mut final_tokens = TokenStream2::new();
1008 while let Some(token) = iter.next() {
1009 if let TokenTree::Group(group) = &token {
1010 let flattened = group.stream().flatten_groups();
1011 final_tokens.extend(quote!(<#flattened>));
1012 continue;
1013 }
1014 final_tokens.extend([token]);
1015 }
1016 final_tokens
1017 }
1018}
1019
1020trait ForceGetIdent: ToTokens {
1021 fn force_get_ident(&self) -> Ident {
1022 let mut iter = self.to_token_stream().flatten_groups().into_iter();
1023 let mut final_tokens = TokenStream2::new();
1024 while let Some(token) = iter.next() {
1025 let mut tmp = final_tokens.clone();
1026 tmp.extend([token.clone()]);
1027 if parse2::<Ident>(tmp).is_ok() {
1028 final_tokens.extend([token]);
1029 }
1030 }
1031 parse_quote!(#final_tokens)
1032 }
1033}
1034
1035trait StripTrailingGenerics {
1036 fn strip_trailing_generics(&self) -> Self;
1037}
1038
1039impl StripTrailingGenerics for Path {
1040 fn strip_trailing_generics(&self) -> Self {
1041 let mut tmp = self.clone();
1042 let Some(last) = tmp.segments.last_mut() else {
1043 unreachable!()
1044 };
1045 let ident = last.ident.clone();
1046 *last = parse_quote!(#ident);
1047 tmp
1048 }
1049}
1050
1051impl<T: ToTokens> ForceGetIdent for T {}
1052
1053fn merge_generics(a: &Generics, b: &Generics) -> Generics {
1054 let mut params = b.params.clone();
1055 params.extend(a.params.clone());
1056 let where_clause = match &a.where_clause {
1057 Some(a_where) => match &b.where_clause {
1058 Some(b_where) => {
1059 let mut combined_where = b_where.clone();
1060 combined_where.predicates.extend(a_where.predicates.clone());
1061 Some(combined_where)
1062 }
1063 None => a.where_clause.clone(),
1064 },
1065 None => b.where_clause.clone(),
1066 };
1067 Generics {
1068 lt_token: b.lt_token.clone(),
1069 params,
1070 gt_token: b.gt_token.clone(),
1071 where_clause: where_clause,
1072 }
1073}
1074
1075struct ReplaceSelfAssociatedType {
1076 replace_prefix: TokenStream2,
1077}
1078
1079impl VisitMut for ReplaceSelfAssociatedType {
1080 fn visit_type_path_mut(&mut self, type_path: &mut TypePath) {
1081 if type_path.path.segments.len() < 2 {
1082 return;
1083 }
1084 let first_seg = type_path.path.segments.first().unwrap();
1085 if first_seg.ident != "Self" {
1086 return;
1087 }
1088 let segments = type_path.path.segments.iter().skip(1);
1089 let replace = self.replace_prefix.clone();
1090 *type_path = parse_quote!(#replace::#(#segments)::*)
1091 }
1092}
1093
1094struct ReplaceType {
1095 search_type: RemappedGeneric,
1096 replace_type: GenericParam,
1097}
1098
1099impl VisitMut for ReplaceType {
1100 fn visit_ident_mut(&mut self, ident: &mut Ident) {
1101 let search_ident = self.search_type.ident();
1102 if ident != search_ident {
1103 return;
1104 }
1105 *ident = self.replace_type.force_get_ident();
1106 }
1107}
1108
1109struct ReplaceSelfType {
1110 replace_type: Ident,
1111}
1112
1113impl VisitMut for ReplaceSelfType {
1114 fn visit_ident_mut(&mut self, ident: &mut Ident) {
1115 if ident != "Self" {
1116 return;
1117 }
1118 *ident = self.replace_type.clone();
1119 }
1120}
1121
1122#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
1123enum RemappedGeneric {
1124 Lifetime(Ident),
1125 Type(Ident),
1126 Const(Ident),
1127}
1128
1129impl RemappedGeneric {
1130 fn ident(&self) -> &Ident {
1131 match self {
1132 RemappedGeneric::Lifetime(ident) => ident,
1133 RemappedGeneric::Type(ident) => ident,
1134 RemappedGeneric::Const(ident) => ident,
1135 }
1136 }
1137}
1138
1139fn impl_supertrait_internal(
1140 foreign_tokens: impl Into<TokenStream2>,
1141 item_tokens: impl Into<TokenStream2>,
1142) -> Result<TokenStream2> {
1143 let mut item_impl = parse2::<ItemImpl>(item_tokens.into())?;
1144 #[cfg(feature = "debug")]
1145 let mut debug = false;
1146 #[cfg(feature = "debug")]
1147 for (i, attr) in item_impl.attrs.iter().enumerate() {
1148 let Some(ident) = attr.path().get_ident() else {
1149 continue;
1150 };
1151 if ident == "debug_mode" {
1152 debug = true;
1153 item_impl.attrs.remove(i);
1154 break;
1155 }
1156 }
1157 let Some((_, trait_path, _)) = &mut item_impl.trait_ else {
1158 return Err(Error::new(
1159 item_impl.span(),
1160 "#[impl_supertrait] can only be attached to non-inherent impls involving a trait \
1161 that has `#[supertrait]` attached to it.",
1162 ));
1163 };
1164 let impl_target = item_impl.self_ty.clone();
1165
1166 let ImportedTokens {
1167 const_fns,
1168 trait_impl_generics,
1169 trait_use_generics: _,
1170 default_impl_generics,
1171 default_use_generics,
1172 default_items,
1173 sealed_ident,
1174 } = ImportedTokens::try_from(parse2::<ItemMod>(foreign_tokens.into())?)?;
1175
1176 let mut remapped_type_params: HashMap<RemappedGeneric, GenericParam> = HashMap::new();
1177 for (i, param) in trait_impl_generics.params.iter().enumerate() {
1178 let remapped = match param {
1179 GenericParam::Lifetime(lifetime) => {
1180 RemappedGeneric::Lifetime(lifetime.lifetime.ident.clone())
1181 }
1182 GenericParam::Type(typ) => RemappedGeneric::Type(typ.ident.clone()),
1183 GenericParam::Const(constant) => RemappedGeneric::Const(constant.ident.clone()),
1184 };
1185 let last_seg = trait_path.segments.last().unwrap();
1186 let args: Vec<TokenStream2> = match last_seg.arguments.clone() {
1187 syn::PathArguments::None => continue,
1188 syn::PathArguments::AngleBracketed(args) => {
1189 args.args.into_iter().map(|a| a.to_token_stream()).collect()
1190 }
1191 syn::PathArguments::Parenthesized(args) => args
1192 .inputs
1193 .into_iter()
1194 .map(|a| a.to_token_stream())
1195 .collect(),
1196 };
1197
1198 if i >= args.len() {
1199 continue;
1200 }
1201 let target = args[i].clone();
1202 let target: GenericParam = parse_quote!(#target);
1203 remapped_type_params.insert(remapped, target);
1204 }
1205
1206 remapped_type_params.insert(
1208 RemappedGeneric::Type(parse_quote!(__Self)),
1209 parse_quote!(#impl_target),
1210 );
1211
1212 let trait_mod = trait_path.clone().strip_trailing_generics();
1213 let trait_mod_ident = trait_mod.segments.last().unwrap().ident.clone();
1214 trait_path.segments.insert(
1215 trait_path.segments.len() - 1,
1216 parse_quote!(#trait_mod_ident),
1217 );
1218 trait_path.segments.last_mut().unwrap().ident = parse_quote!(Trait);
1219
1220 let mut filtered_tmp = FilteredGenerics {
1222 impl_generics: default_impl_generics,
1223 use_generics: default_use_generics,
1224 has_defaults: HashSet::new(),
1225 };
1226 filtered_tmp.strip_default_generics();
1227 let default_use_generics = filtered_tmp.use_generics;
1228
1229 let mut final_items: HashMap<Ident, ImplItem> = HashMap::new();
1230 for item in default_items {
1231 let item_ident = item.get_ident().unwrap();
1232 let mut item: ImplItem = parse_quote!(#item);
1233 use ImplItem::*;
1234 match &mut item {
1235 Const(item_const) => {
1236 item_const.expr = parse_quote!(<#trait_mod::Defaults as #trait_mod::DefaultTypes #default_use_generics>::#item_ident)
1237 }
1238 Type(item_type) => {
1239 item_type.ty = parse_quote!(<#trait_mod::Defaults as #trait_mod::DefaultTypes #default_use_generics>::#item_ident)
1240 }
1241 _ => unimplemented!("this item has no notion of defaults"),
1242 }
1243 for search in remapped_type_params.keys() {
1244 let replace = &remapped_type_params[search];
1245 let mut visitor = ReplaceType {
1246 search_type: search.clone(),
1247 replace_type: replace.clone(),
1248 };
1249 visitor.visit_impl_item_mut(&mut item);
1250 }
1251 if item_ident == "__Self" {
1252 item = parse_quote!(#impl_target);
1253 }
1254 final_items.insert(item_ident, item);
1255 }
1256 let mut final_verbatim_items: Vec<ImplItem> = Vec::new();
1257 for item in &item_impl.items {
1258 let Some(item_ident) = item.get_ident() else {
1259 final_verbatim_items.push(item.clone());
1260 continue;
1261 };
1262 final_items.insert(item_ident, item.clone());
1263 }
1264
1265 let mut final_items = final_items.values().cloned().collect::<Vec<_>>();
1266 final_items.extend(final_verbatim_items);
1267 item_impl.items = final_items;
1268
1269 let mut impl_const_fn_idents: HashSet<Ident> = HashSet::new();
1270 let mut impl_const_fns: Vec<ImplItem>;
1271 (impl_const_fns, item_impl.items) = item_impl.items.into_iter().partition(|item| {
1272 let ImplItem::Fn(impl_item_fn) = item else {
1273 return false;
1274 };
1275 if impl_item_fn.sig.constness.is_none() {
1276 return false;
1277 };
1278 impl_const_fn_idents.insert(impl_item_fn.sig.ident.clone());
1279 true
1280 });
1281
1282 for item in &const_fns {
1283 if !impl_const_fn_idents.contains(&item.sig.ident) {
1284 if item.default.is_none() {
1285 return Err(Error::new(
1286 item_impl.span(),
1287 format!("missing impl for `{}`.", item.sig.ident),
1288 ));
1289 }
1290 impl_const_fns.push(parse_quote!(#item));
1291 }
1292 }
1293 for const_fn in impl_const_fns.iter_mut() {
1294 let mut last_seg = trait_path.segments.last().unwrap().clone();
1295 last_seg.ident = parse_quote!(Trait);
1296 let mut visitor = ReplaceSelfAssociatedType {
1297 replace_prefix: quote!(<#impl_target as #trait_mod::#last_seg>),
1298 };
1299 visitor.visit_impl_item_mut(const_fn);
1300 let ImplItem::Fn(const_fn) = const_fn else {
1301 unreachable!()
1302 };
1303 const_fn.vis = parse_quote!(pub);
1304 }
1305
1306 let impl_index = IMPL_COUNT.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1307 let trait_import_name: Ident = format_ident!(
1308 "{}{}TraitImpl_{}",
1309 impl_target.clone().force_get_ident(),
1310 item_impl.trait_.clone().unwrap().1.force_get_ident(),
1311 impl_index,
1312 );
1313
1314 let converted_const_fns = impl_const_fns.iter().map(|const_fn| {
1315 let mut const_fn: ImplItemFn = parse_quote!(#const_fn);
1316 const_fn.sig.constness = None;
1317 const_fn.vis = Visibility::Inherited;
1318 let item: ImplItem = parse_quote!(#const_fn);
1319 item
1320 });
1321
1322 let impl_const_fns = impl_const_fns.iter().map(|const_fn| {
1323 let mut const_fn_visitor = FindGenericParam::new(&trait_impl_generics);
1324 const_fn_visitor.visit_impl_item(const_fn);
1325 let const_fn_generics =
1326 filter_generics(&trait_impl_generics, &const_fn_visitor.usages).impl_generics;
1327 let mut const_fn: ImplItemFn = parse_quote!(#const_fn);
1328 const_fn.sig.generics = merge_generics(&const_fn_generics, &const_fn.sig.generics);
1329 const_fn.sig.generics.params = const_fn
1330 .sig
1331 .generics
1332 .params
1333 .iter()
1334 .cloned()
1335 .map(|g| match g {
1336 GenericParam::Lifetime(lifetime) => GenericParam::Lifetime(lifetime),
1337 GenericParam::Type(typ) => {
1338 let mut typ = typ.clone();
1339 typ.default = None;
1340 GenericParam::Type(typ)
1341 }
1342 GenericParam::Const(constant) => {
1343 let mut constant = constant.clone();
1344 constant.default = None;
1345 GenericParam::Const(constant)
1346 }
1347 })
1348 .collect();
1349 const_fn
1350 });
1351
1352 item_impl.items.extend(converted_const_fns);
1353 let mut impl_visitor = FindGenericParam::new(&item_impl.generics);
1354 impl_visitor.visit_item_impl(&item_impl);
1355 let mut filtered_generics = filter_generics(&item_impl.generics, &impl_visitor.usages);
1356 filtered_generics.strip_default_generics();
1357 item_impl.generics = filtered_generics.impl_generics;
1358
1359 let inherent_impl = if impl_const_fns.len() > 0 {
1360 Some(quote! {
1361 impl #impl_target {
1363 #(#impl_const_fns)*
1364 }
1365 })
1366 } else {
1367 None
1368 };
1369
1370 let output = quote! {
1371 #item_impl
1372
1373 impl #trait_mod::#sealed_ident for #impl_target {}
1374
1375 #inherent_impl
1376
1377 #[doc(hidden)]
1378 #[allow(unused)]
1379 use #trait_mod::Trait as #trait_import_name;
1380 };
1381 #[cfg(feature = "debug")]
1382 if debug {
1383 output.pretty_print();
1384 }
1385 Ok(output)
1386}