1#![allow(non_snake_case)]
2#![warn(missing_docs)]
3
4use macro_magic::import_tokens_attr;
7use proc_macro::TokenStream;
8use proc_macro2::{TokenStream as TokenStream2, TokenTree};
9use quote::{format_ident, quote, ToTokens};
10use rand::{distr::StandardUniform, prelude::Distribution};
11use std::{
12 cell::RefCell,
13 collections::{HashMap, HashSet},
14 fmt::Debug,
15 sync::atomic::AtomicU64,
16};
17use syn::{
18 parse::{Nothing, Parse, ParseStream},
19 parse2, parse_macro_input, parse_quote, parse_str,
20 punctuated::Punctuated,
21 spanned::Spanned,
22 visit::Visit,
23 visit_mut::VisitMut,
24 Error, GenericParam, Generics, Ident, ImplItem, ImplItemFn, Item, ItemFn, ItemImpl, ItemMod,
25 ItemTrait, Path, Result, Signature, TraitItem, TraitItemFn, TraitItemType, TypePath,
26 Visibility, WherePredicate,
27};
28
29#[cfg(feature = "debug")]
30use proc_utils::PrettyPrint;
31
32mod generic_visitor;
33use generic_visitor::*;
34
35static IMPL_COUNT: AtomicU64 = AtomicU64::new(0);
36thread_local! {
37 static SUPERTRAIT_PATH: RefCell<String> = RefCell::new(String::from("::supertrait"));
38}
39
40fn get_supertrait_path() -> Path {
41 SUPERTRAIT_PATH.with(|p| parse_str(p.borrow().clone().as_str()).unwrap())
42}
43
44fn random<T>() -> T
45where
46 StandardUniform: Distribution<T>,
47{
48 rand::random()
49}
50
51#[proc_macro]
59pub fn set_supertrait_path(tokens: TokenStream) -> TokenStream {
60 let path = parse_macro_input!(tokens as Path);
61 SUPERTRAIT_PATH.with(|p| p.replace(path.to_token_stream().to_string()));
62 quote!().into()
63}
64
65struct SuperTraitDef {
66 pub orig_trait: ItemTrait,
67 pub const_fns: Vec<TraitItemFn>,
68 pub types_with_defaults: Vec<TraitItemType>,
69 pub other_items: Vec<TraitItem>,
70}
71
72impl Parse for SuperTraitDef {
73 fn parse(input: ParseStream) -> Result<Self> {
74 let orig_trait = input.parse::<ItemTrait>()?;
75 let mut const_fns: Vec<TraitItemFn> = Vec::new();
76 let mut types_with_defaults: Vec<TraitItemType> = Vec::new();
77 let mut other_items: Vec<TraitItem> = Vec::new();
78 for trait_item in &orig_trait.items {
79 match trait_item {
80 TraitItem::Fn(trait_item_fn) => match trait_item_fn.sig.constness {
81 Some(_) => const_fns.push(trait_item_fn.clone()),
82 None => other_items.push(trait_item.clone()),
83 },
84 TraitItem::Type(typ) => match typ.default {
85 Some(_) => types_with_defaults.push(typ.clone()),
86 None => other_items.push(trait_item.clone()),
87 },
88 other_item => other_items.push(other_item.clone()),
89 }
90 }
91 Ok(SuperTraitDef {
92 orig_trait,
93 const_fns,
94 types_with_defaults,
95 other_items,
96 })
97 }
98}
99
100struct FilteredGenerics {
101 use_generics: Generics,
104 impl_generics: Generics,
107 has_defaults: HashSet<Ident>,
108}
109
110impl FilteredGenerics {
111 fn strip_default_generics(&mut self) {
112 let has_defaults = self
114 .impl_generics
115 .params
116 .iter()
117 .filter_map(|g| match g {
118 GenericParam::Lifetime(_) => None,
119 GenericParam::Type(typ) => match typ.default {
120 Some(_) => Some(typ.force_get_ident()),
121 None => None,
122 },
123 GenericParam::Const(constant) => match constant.default {
124 Some(_) => Some(constant.force_get_ident()), None => None,
126 },
127 })
128 .collect::<HashSet<Ident>>();
129 self.impl_generics.params = self
131 .impl_generics
132 .params
133 .iter()
134 .filter(|g| !has_defaults.contains(&g.force_get_ident()))
135 .cloned()
136 .collect();
137 self.use_generics.params = self
139 .use_generics
140 .params
141 .iter()
142 .filter(|g| !has_defaults.contains(&g.force_get_ident()))
143 .cloned()
144 .collect();
145 self.has_defaults = has_defaults;
146 }
147}
148
149fn filter_generics(generics: &Generics, whitelist: &HashSet<GenericUsage>) -> FilteredGenerics {
150 let filtered_generic_params = generics
151 .params
152 .iter()
153 .cloned()
154 .filter(|g| whitelist.contains(&g.into()));
155
156 let filtered_where_clause = match &generics.where_clause {
158 Some(where_clause) => {
159 let mut where_clause = where_clause.clone();
160 let predicates_filtered = where_clause.predicates.iter().filter(|p| match *p {
161 WherePredicate::Lifetime(lifetime) => {
162 whitelist.contains(&GenericUsage::from_lifetime(&lifetime.lifetime))
163 }
164 WherePredicate::Type(typ) => {
165 whitelist.contains(&GenericUsage::from_type(&typ.bounded_ty))
166 }
167 _ => unimplemented!(),
168 });
169 where_clause.predicates = parse_quote!(#(#predicates_filtered),*);
170 Some(where_clause)
171 }
172 None => None,
173 };
174
175 let use_generic_params = filtered_generic_params.clone().map(|g| match g {
176 GenericParam::Lifetime(lifetime) => lifetime.lifetime.to_token_stream(),
177 GenericParam::Type(typ) => typ.ident.to_token_stream(),
178 GenericParam::Const(constant) => constant.ident.to_token_stream(),
179 });
180
181 let use_generics = Generics {
182 lt_token: parse_quote!(<),
183 params: parse_quote!(#(#use_generic_params),*),
184 gt_token: parse_quote!(>),
185 where_clause: filtered_where_clause.clone(),
186 };
187
188 let impl_generics = Generics {
189 lt_token: parse_quote!(<),
190 params: parse_quote!(#(#filtered_generic_params),*),
191 gt_token: parse_quote!(>),
192 where_clause: filtered_where_clause,
193 };
194
195 FilteredGenerics {
196 use_generics,
197 impl_generics,
198 has_defaults: HashSet::new(),
199 }
200}
201
202struct DefaultRemover;
203
204impl VisitMut for DefaultRemover {
205 fn visit_generics_mut(&mut self, generics: &mut Generics) {
206 for param in &mut generics.params {
207 if let syn::GenericParam::Type(ref mut type_param) = param {
208 type_param.default = None;
209 }
210 }
211 }
212}
213
214trait PunctuatedExtension<T: ToTokens + Clone, P: ToTokens + std::default::Default>: Sized {
215 fn push_front(&mut self, value: T);
216}
217
218impl<T: ToTokens + Clone, P: ToTokens + std::default::Default> PunctuatedExtension<T, P>
219 for Punctuated<T, P>
220{
221 fn push_front(&mut self, value: T) {
222 let mut new_punctuated = Punctuated::new();
223 new_punctuated.push(value);
224 new_punctuated.extend(self.iter().cloned());
225 *self = new_punctuated;
226 }
227}
228
229#[proc_macro_attribute]
411pub fn supertrait(attr: TokenStream, tokens: TokenStream) -> TokenStream {
412 let mut attr = attr;
413 let debug = debug_feature(&mut attr);
414 match supertrait_internal(attr, tokens, debug) {
415 Ok(tokens) => tokens.into(),
416 Err(err) => err.into_compile_error().into(),
417 }
418}
419
420fn debug_feature(attr: &mut TokenStream) -> bool {
421 if let Ok(ident) = syn::parse::<Ident>(attr.clone()) {
422 if ident == "debug" {
423 *attr = TokenStream::new();
424 if cfg!(feature = "debug") {
425 true
426 } else {
427 println!("warning: the 'debug' feature must be enabled for debug to work on supertrait attributes.");
428 false
429 }
430 } else {
431 false
432 }
433 } else {
434 false
435 }
436}
437
438fn supertrait_internal(
439 attr: impl Into<TokenStream2>,
440 tokens: impl Into<TokenStream2>,
441 #[allow(unused)] debug: bool,
442) -> Result<TokenStream2> {
443 parse2::<Nothing>(attr.into())?;
444 let def = parse2::<SuperTraitDef>(tokens.into())?;
445 let export_tokens_ident = format_ident!("{}_exported_tokens", def.orig_trait.ident);
446 let mut modified_trait = def.orig_trait;
447 modified_trait.items = def.other_items;
448 let ident = modified_trait.ident.clone();
449 let attrs = modified_trait.attrs.clone();
450 let mut defaults = def.types_with_defaults;
451 let unfilled_defaults = defaults
452 .iter()
453 .cloned()
454 .map(|mut typ| {
455 typ.default = None;
456 typ
457 })
458 .collect::<Vec<_>>();
459 let mut visitor = FindGenericParam::new(&modified_trait.generics);
460 let mut replace_self = ReplaceSelfType {
461 replace_type: parse_quote!(__Self),
462 };
463 for trait_item_type in &mut defaults {
464 visitor.visit_trait_item_type(trait_item_type);
465 replace_self.visit_trait_item_type_mut(trait_item_type);
466 }
467
468 let mut default_generics = filter_generics(&modified_trait.generics, &visitor.usages);
469 default_generics
470 .impl_generics
471 .params
472 .push_front(parse_quote!(__Self));
473 default_generics
474 .use_generics
475 .params
476 .push_front(parse_quote!(__Self));
477
478 let default_impl_generics = default_generics.impl_generics;
479 let default_use_generics = default_generics.use_generics;
480
481 modified_trait.ident = parse_quote!(Trait);
482
483 let trait_use_generic_params = modified_trait.generics.params.iter().map(|g| match g {
484 GenericParam::Lifetime(lifetime) => lifetime.lifetime.to_token_stream(),
485 GenericParam::Type(typ) => typ.ident.to_token_stream(),
486 GenericParam::Const(constant) => constant.ident.to_token_stream(),
487 });
488
489 let trait_impl_generics = modified_trait.generics.clone();
490 let trait_use_generics = Generics {
491 lt_token: parse_quote!(<),
492 params: parse_quote!(#(#trait_use_generic_params),*),
493 gt_token: parse_quote!(>),
494 where_clause: modified_trait.generics.where_clause.clone(),
495 };
496
497 let const_fns = def.const_fns;
498 let mut trait_impl_generics_fn: ItemFn = parse_quote! { fn trait_impl_generics() {} };
499 let mut trait_use_generics_fn: ItemFn = parse_quote! { fn trait_use_generics() {} };
500 let mut default_impl_generics_fn: ItemFn = parse_quote! { fn default_impl_generics() {} };
501 let mut default_use_generics_fn: ItemFn = parse_quote! { fn default_use_generics() {} };
502 trait_impl_generics_fn.sig.generics = trait_impl_generics.clone();
503 trait_use_generics_fn.sig.generics = trait_use_generics;
504 default_impl_generics_fn.sig.generics = default_impl_generics.clone();
505 default_use_generics_fn.sig.generics = default_use_generics.clone();
506
507 modified_trait
508 .items
509 .extend(unfilled_defaults.iter().map(|item| parse_quote!(#item)));
510
511 let converted_const_fns = const_fns.iter().map(|const_fn| {
512 let mut item = const_fn.clone();
513 item.sig.constness = None;
514 let item: TraitItem = parse_quote!(#item);
515 item
516 });
517 modified_trait.items.extend(converted_const_fns);
518
519 let supertrait_path = get_supertrait_path();
520
521 let random_value: u32 = random();
523 let sealed_ident = format_ident!("SupertraitSealed{random_value}");
524 let sealed_trait: ItemTrait = parse_quote!(pub trait #sealed_ident {});
525 modified_trait.supertraits.push(parse_quote!(#sealed_ident));
526
527 let mut default_remover = DefaultRemover {};
528 let mut default_impl_generics_no_defaults = default_impl_generics.clone();
529 default_remover.visit_generics_mut(&mut default_impl_generics_no_defaults);
530
531 modified_trait.vis = parse_quote!(pub);
532
533 for def in defaults.iter_mut() {
534 def.bounds.clear()
535 }
536
537 let output = quote! {
538 #(#attrs)*
539 #[allow(non_snake_case)]
540 pub mod #ident {
541 use super::*;
542
543 pub struct Defaults;
545
546 pub trait DefaultTypes #default_impl_generics {
550 #[doc(hidden)]
552 type __Self;
553 #(#unfilled_defaults)*
554 }
555
556 impl #default_impl_generics_no_defaults DefaultTypes #default_use_generics for Defaults {
557 #(#defaults)*
558 #[doc(hidden)]
559 type __Self = ();
560 }
561
562 #[doc(hidden)]
565 #sealed_trait
566
567 #(#attrs)*
570 #modified_trait
571
572 #[#supertrait_path::__private::macro_magic::export_tokens_no_emit(#export_tokens_ident)]
573 mod exported_tokens {
574 trait ConstFns {
575 #(#const_fns)*
576 }
577
578 #trait_impl_generics_fn
579 #trait_use_generics_fn
580 #default_impl_generics_fn
581 #default_use_generics_fn
582
583 mod default_items {
584 #(#defaults)*
585 }
586
587 const #sealed_ident: () = ();
588 }
589 }
590 };
591 #[cfg(feature = "debug")]
592 if debug {
593 output.pretty_print();
594 }
595 Ok(output)
596}
597
598#[doc(hidden)]
599#[import_tokens_attr(format!(
600 "{}::__private::macro_magic",
601 get_supertrait_path().to_token_stream().to_string()
602))]
603#[proc_macro_attribute]
604pub fn __impl_supertrait(attr: TokenStream, tokens: TokenStream) -> TokenStream {
605 match impl_supertrait_internal(attr, tokens) {
606 Ok(tokens) => tokens.into(),
607 Err(err) => err.into_compile_error().into(),
608 }
609}
610
611#[proc_macro_attribute]
757pub fn impl_supertrait(attr: TokenStream, tokens: TokenStream) -> TokenStream {
758 let mut attr = attr;
759 let debug = debug_feature(&mut attr);
760 parse_macro_input!(attr as Nothing);
761 let item_impl = parse_macro_input!(tokens as ItemImpl);
762 let trait_being_impled = match item_impl.trait_.clone() {
763 Some((_, path, _)) => path,
764 None => return Error::new(
765 item_impl.span(),
766 "Supertrait impls must have a trait being implemented. Inherent impls are not supported."
767 ).into_compile_error().into(),
768 }.strip_trailing_generics();
769 let supertrait_path = get_supertrait_path();
770 let export_tokens_ident = format_ident!(
771 "{}_exported_tokens",
772 trait_being_impled.segments.last().unwrap().ident
773 );
774 let debug_tokens = if debug {
775 quote!(#[debug_mode])
776 } else {
777 quote!()
778 };
779 let output = quote! {
780 #[#supertrait_path::__impl_supertrait(#trait_being_impled::#export_tokens_ident)]
781 #debug_tokens
782 #item_impl
783 };
784 output.into()
785}
786
787struct ImportedTokens {
788 const_fns: Vec<TraitItemFn>,
789 trait_impl_generics: Generics,
790 #[allow(unused)]
791 trait_use_generics: Generics,
792 #[allow(unused)]
793 default_impl_generics: Generics,
794 default_use_generics: Generics,
795 default_items: Vec<TraitItem>,
796 sealed_ident: Ident,
797}
798
799impl TryFrom<ItemMod> for ImportedTokens {
800 type Error = Error;
801
802 fn try_from(item_mod: ItemMod) -> std::result::Result<Self, Self::Error> {
803 if item_mod.ident != "exported_tokens" {
804 return Err(Error::new(
805 item_mod.ident.span(),
806 "expected `exported_tokens`.",
807 ));
808 }
809 let item_mod_span = item_mod.span();
810 let Some((_, main_body)) = item_mod.content else {
811 return Err(Error::new(
812 item_mod_span,
813 "`exported_tokens` module must have a defined body.",
814 ));
815 };
816 let Some(Item::Trait(ItemTrait {
817 ident: const_fns_ident,
818 items: const_fns,
819 ..
820 })) = main_body.get(0)
821 else {
822 return Err(Error::new(
823 item_mod_span,
824 "the first item in `exported_tokens` should be a trait called `ConstFns`.",
825 ));
826 };
827 if const_fns_ident != "ConstFns" {
828 return Err(Error::new(const_fns_ident.span(), "expected `ConstFns`."));
829 }
830
831 let const_fns: Vec<TraitItemFn> = const_fns
832 .into_iter()
833 .map(|item| match item {
834 TraitItem::Fn(trait_item_fn) => Ok(trait_item_fn.clone()),
835 _ => return Err(Error::new(item.span(), "expected `fn`")), })
837 .collect::<std::result::Result<_, Self::Error>>()?;
838
839 let Some(Item::Fn(ItemFn {
840 sig:
841 Signature {
842 ident: trait_impl_generics_ident,
843 generics: trait_impl_generics,
844 ..
845 },
846 ..
847 })) = main_body.get(1)
848 else {
849 return Err(Error::new(
850 item_mod_span,
851 "the second item in `exported_tokens` should be an fn called `trait_impl_generics`.",
852 ));
853 };
854 if trait_impl_generics_ident != "trait_impl_generics" {
855 return Err(Error::new(
856 trait_impl_generics_ident.span(),
857 "expected `trait_impl_generics`.",
858 ));
859 }
860
861 let Some(Item::Fn(ItemFn {
862 sig:
863 Signature {
864 ident: trait_use_generics_ident,
865 generics: trait_use_generics,
866 ..
867 },
868 ..
869 })) = main_body.get(2)
870 else {
871 return Err(Error::new(
872 item_mod_span,
873 "the third item in `exported_tokens` should be an fn called `trait_use_generics`.",
874 ));
875 };
876 if trait_use_generics_ident != "trait_use_generics" {
877 return Err(Error::new(
878 trait_use_generics_ident.span(),
879 "expected `trait_use_generics`.",
880 ));
881 }
882
883 let Some(Item::Fn(ItemFn {
884 sig:
885 Signature {
886 ident: default_impl_generics_ident,
887 generics: default_impl_generics,
888 ..
889 },
890 ..
891 })) = main_body.get(3)
892 else {
893 return Err(Error::new(
894 item_mod_span,
895 "the fourth item in `exported_tokens` should be an fn called `default_impl_generics`.",
896 ));
897 };
898 if default_impl_generics_ident != "default_impl_generics" {
899 return Err(Error::new(
900 default_impl_generics_ident.span(),
901 "expected `default_impl_generics`.",
902 ));
903 }
904
905 let Some(Item::Fn(ItemFn {
906 sig:
907 Signature {
908 ident: default_use_generics_ident,
909 generics: default_use_generics,
910 ..
911 },
912 ..
913 })) = main_body.get(4)
914 else {
915 return Err(Error::new(
916 item_mod_span,
917 "the fifth item in `exported_tokens` should be an fn called `default_use_generics`.",
918 ));
919 };
920 if default_use_generics_ident != "default_use_generics" {
921 return Err(Error::new(
922 default_use_generics_ident.span(),
923 "expected `default_use_generics`.",
924 ));
925 }
926
927 let Some(Item::Mod(default_items_mod)) = main_body.get(5) else {
928 return Err(Error::new(
929 item_mod_span,
930 "the sixth item in `exported_tokens` should be a module called `default_items_mod`.",
931 ));
932 };
933 if default_items_mod.ident != "default_items" {
934 return Err(Error::new(
935 default_items_mod.ident.span(),
936 "expected `default_items`.",
937 ));
938 }
939 let Some((_, default_items)) = default_items_mod.content.clone() else {
940 return Err(Error::new(
941 default_items_mod.ident.span(),
942 "`default_items` item must be an inline module.",
943 ));
944 };
945 let default_items: Vec<TraitItem> = default_items
946 .iter()
947 .map(|item| parse_quote!(#item))
948 .collect();
949
950 let Some(Item::Const(sealed_const)) = main_body.get(6) else {
951 return Err(Error::new(
952 item_mod_span,
953 "the seventh item in `exported_tokens` should be a const specifying the sealed ident.",
954 ));
955 };
956 let sealed_ident = sealed_const.ident.clone();
957
958 Ok(ImportedTokens {
959 const_fns,
960 trait_impl_generics: trait_impl_generics.clone(),
961 trait_use_generics: trait_use_generics.clone(),
962 default_impl_generics: default_impl_generics.clone(),
963 default_use_generics: default_use_generics.clone(),
964 default_items: default_items,
965 sealed_ident,
966 })
967 }
968}
969
970trait GetIdent {
971 fn get_ident(&self) -> Option<Ident>;
972}
973
974impl GetIdent for TraitItem {
975 fn get_ident(&self) -> Option<Ident> {
976 use TraitItem::*;
977 match self {
978 Const(item_const) => Some(item_const.ident.clone()),
979 Fn(item_fn) => Some(item_fn.sig.ident.clone()),
980 Type(item_type) => Some(item_type.ident.clone()),
981 _ => None,
982 }
983 }
984}
985
986impl GetIdent for ImplItem {
987 fn get_ident(&self) -> Option<Ident> {
988 use ImplItem::*;
989 match self {
990 Const(item_const) => Some(item_const.ident.clone()),
991 Fn(item_fn) => Some(item_fn.sig.ident.clone()),
992 Type(item_type) => Some(item_type.ident.clone()),
993 _ => None,
994 }
995 }
996}
997
998trait FlattenGroups {
999 fn flatten_groups(&self) -> TokenStream2;
1000}
1001
1002impl FlattenGroups for TokenStream2 {
1003 fn flatten_groups(&self) -> TokenStream2 {
1004 let mut iter = self.clone().into_iter();
1005 let mut final_tokens = TokenStream2::new();
1006 while let Some(token) = iter.next() {
1007 if let TokenTree::Group(group) = &token {
1008 let flattened = group.stream().flatten_groups();
1009 final_tokens.extend(quote!(<#flattened>));
1010 continue;
1011 }
1012 final_tokens.extend([token]);
1013 }
1014 final_tokens
1015 }
1016}
1017
1018trait ForceGetIdent: ToTokens {
1019 fn force_get_ident(&self) -> Ident {
1020 let mut iter = self.to_token_stream().flatten_groups().into_iter();
1021 let mut final_tokens = TokenStream2::new();
1022 while let Some(token) = iter.next() {
1023 let mut tmp = final_tokens.clone();
1024 tmp.extend([token.clone()]);
1025 if parse2::<Ident>(tmp).is_ok() {
1026 final_tokens.extend([token]);
1027 }
1028 }
1029 parse_quote!(#final_tokens)
1030 }
1031}
1032
1033trait StripTrailingGenerics {
1034 fn strip_trailing_generics(&self) -> Self;
1035}
1036
1037impl StripTrailingGenerics for Path {
1038 fn strip_trailing_generics(&self) -> Self {
1039 let mut tmp = self.clone();
1040 let Some(last) = tmp.segments.last_mut() else {
1041 unreachable!()
1042 };
1043 let ident = last.ident.clone();
1044 *last = parse_quote!(#ident);
1045 tmp
1046 }
1047}
1048
1049impl<T: ToTokens> ForceGetIdent for T {}
1050
1051fn merge_generics(a: &Generics, b: &Generics) -> Generics {
1052 let mut params = b.params.clone();
1053 params.extend(a.params.clone());
1054 let where_clause = match &a.where_clause {
1055 Some(a_where) => match &b.where_clause {
1056 Some(b_where) => {
1057 let mut combined_where = b_where.clone();
1058 combined_where.predicates.extend(a_where.predicates.clone());
1059 Some(combined_where)
1060 }
1061 None => a.where_clause.clone(),
1062 },
1063 None => b.where_clause.clone(),
1064 };
1065 Generics {
1066 lt_token: b.lt_token.clone(),
1067 params,
1068 gt_token: b.gt_token.clone(),
1069 where_clause: where_clause,
1070 }
1071}
1072
1073struct ReplaceSelfAssociatedType {
1074 replace_prefix: TokenStream2,
1075}
1076
1077impl VisitMut for ReplaceSelfAssociatedType {
1078 fn visit_type_path_mut(&mut self, type_path: &mut TypePath) {
1079 if type_path.path.segments.len() < 2 {
1080 return;
1081 }
1082 let first_seg = type_path.path.segments.first().unwrap();
1083 if first_seg.ident != "Self" {
1084 return;
1085 }
1086 let segments = type_path.path.segments.iter().skip(1);
1087 let replace = self.replace_prefix.clone();
1088 *type_path = parse_quote!(#replace::#(#segments)::*)
1089 }
1090}
1091
1092struct ReplaceType {
1093 search_type: RemappedGeneric,
1094 replace_type: GenericParam,
1095}
1096
1097impl VisitMut for ReplaceType {
1098 fn visit_ident_mut(&mut self, ident: &mut Ident) {
1099 let search_ident = self.search_type.ident();
1100 if ident != search_ident {
1101 return;
1102 }
1103 *ident = self.replace_type.force_get_ident();
1104 }
1105}
1106
1107struct ReplaceSelfType {
1108 replace_type: Ident,
1109}
1110
1111impl VisitMut for ReplaceSelfType {
1112 fn visit_ident_mut(&mut self, ident: &mut Ident) {
1113 if ident != "Self" {
1114 return;
1115 }
1116 *ident = self.replace_type.clone();
1117 }
1118}
1119
1120#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
1121enum RemappedGeneric {
1122 Lifetime(Ident),
1123 Type(Ident),
1124 Const(Ident),
1125}
1126
1127impl RemappedGeneric {
1128 fn ident(&self) -> &Ident {
1129 match self {
1130 RemappedGeneric::Lifetime(ident) => ident,
1131 RemappedGeneric::Type(ident) => ident,
1132 RemappedGeneric::Const(ident) => ident,
1133 }
1134 }
1135}
1136
1137fn impl_supertrait_internal(
1138 foreign_tokens: impl Into<TokenStream2>,
1139 item_tokens: impl Into<TokenStream2>,
1140) -> Result<TokenStream2> {
1141 let mut item_impl = parse2::<ItemImpl>(item_tokens.into())?;
1142 #[cfg(feature = "debug")]
1143 let mut debug = false;
1144 #[cfg(feature = "debug")]
1145 for (i, attr) in item_impl.attrs.iter().enumerate() {
1146 let Some(ident) = attr.path().get_ident() else {
1147 continue;
1148 };
1149 if ident == "debug_mode" {
1150 debug = true;
1151 item_impl.attrs.remove(i);
1152 break;
1153 }
1154 }
1155 let Some((_, trait_path, _)) = &mut item_impl.trait_ else {
1156 return Err(Error::new(
1157 item_impl.span(),
1158 "#[impl_supertrait] can only be attached to non-inherent impls involving a trait \
1159 that has `#[supertrait]` attached to it.",
1160 ));
1161 };
1162 let impl_target = item_impl.self_ty.clone();
1163
1164 let ImportedTokens {
1165 const_fns,
1166 trait_impl_generics,
1167 trait_use_generics: _,
1168 default_impl_generics,
1169 default_use_generics,
1170 default_items,
1171 sealed_ident,
1172 } = ImportedTokens::try_from(parse2::<ItemMod>(foreign_tokens.into())?)?;
1173
1174 let mut remapped_type_params: HashMap<RemappedGeneric, GenericParam> = HashMap::new();
1175 for (i, param) in trait_impl_generics.params.iter().enumerate() {
1176 let remapped = match param {
1177 GenericParam::Lifetime(lifetime) => {
1178 RemappedGeneric::Lifetime(lifetime.lifetime.ident.clone())
1179 }
1180 GenericParam::Type(typ) => RemappedGeneric::Type(typ.ident.clone()),
1181 GenericParam::Const(constant) => RemappedGeneric::Const(constant.ident.clone()),
1182 };
1183 let last_seg = trait_path.segments.last().unwrap();
1184 let args: Vec<TokenStream2> = match last_seg.arguments.clone() {
1185 syn::PathArguments::None => continue,
1186 syn::PathArguments::AngleBracketed(args) => {
1187 args.args.into_iter().map(|a| a.to_token_stream()).collect()
1188 }
1189 syn::PathArguments::Parenthesized(args) => args
1190 .inputs
1191 .into_iter()
1192 .map(|a| a.to_token_stream())
1193 .collect(),
1194 };
1195
1196 if i >= args.len() {
1197 continue;
1198 }
1199 let target = args[i].clone();
1200 let target: GenericParam = parse_quote!(#target);
1201 remapped_type_params.insert(remapped, target);
1202 }
1203
1204 remapped_type_params.insert(
1206 RemappedGeneric::Type(parse_quote!(__Self)),
1207 parse_quote!(#impl_target),
1208 );
1209
1210 let trait_mod = trait_path.clone().strip_trailing_generics();
1211 let trait_mod_ident = trait_mod.segments.last().unwrap().ident.clone();
1212 trait_path.segments.insert(
1213 trait_path.segments.len() - 1,
1214 parse_quote!(#trait_mod_ident),
1215 );
1216 trait_path.segments.last_mut().unwrap().ident = parse_quote!(Trait);
1217
1218 let mut filtered_tmp = FilteredGenerics {
1220 impl_generics: default_impl_generics,
1221 use_generics: default_use_generics,
1222 has_defaults: HashSet::new(),
1223 };
1224 filtered_tmp.strip_default_generics();
1225 let default_use_generics = filtered_tmp.use_generics;
1226
1227 let mut final_items: HashMap<Ident, ImplItem> = HashMap::new();
1228 for item in default_items {
1229 let item_ident = item.get_ident().unwrap();
1230 let mut item: ImplItem = parse_quote!(#item);
1231 use ImplItem::*;
1232 match &mut item {
1233 Const(item_const) => {
1234 item_const.expr = parse_quote!(<#trait_mod::Defaults as #trait_mod::DefaultTypes #default_use_generics>::#item_ident)
1235 }
1236 Type(item_type) => {
1237 item_type.ty = parse_quote!(<#trait_mod::Defaults as #trait_mod::DefaultTypes #default_use_generics>::#item_ident)
1238 }
1239 _ => unimplemented!("this item has no notion of defaults"),
1240 }
1241 for search in remapped_type_params.keys() {
1242 let replace = &remapped_type_params[search];
1243 let mut visitor = ReplaceType {
1244 search_type: search.clone(),
1245 replace_type: replace.clone(),
1246 };
1247 visitor.visit_impl_item_mut(&mut item);
1248 }
1249 if item_ident == "__Self" {
1250 item = parse_quote!(#impl_target);
1251 }
1252 final_items.insert(item_ident, item);
1253 }
1254 let mut final_verbatim_items: Vec<ImplItem> = Vec::new();
1255 for item in &item_impl.items {
1256 let Some(item_ident) = item.get_ident() else {
1257 final_verbatim_items.push(item.clone());
1258 continue;
1259 };
1260 final_items.insert(item_ident, item.clone());
1261 }
1262
1263 let mut final_items = final_items.values().cloned().collect::<Vec<_>>();
1264 final_items.extend(final_verbatim_items);
1265 item_impl.items = final_items;
1266
1267 let mut impl_const_fn_idents: HashSet<Ident> = HashSet::new();
1268 let mut impl_const_fns: Vec<ImplItem>;
1269 (impl_const_fns, item_impl.items) = item_impl.items.into_iter().partition(|item| {
1270 let ImplItem::Fn(impl_item_fn) = item else {
1271 return false;
1272 };
1273 if impl_item_fn.sig.constness.is_none() {
1274 return false;
1275 };
1276 impl_const_fn_idents.insert(impl_item_fn.sig.ident.clone());
1277 true
1278 });
1279
1280 for item in &const_fns {
1281 if !impl_const_fn_idents.contains(&item.sig.ident) {
1282 if item.default.is_none() {
1283 return Err(Error::new(
1284 item_impl.span(),
1285 format!("missing impl for `{}`.", item.sig.ident),
1286 ));
1287 }
1288 impl_const_fns.push(parse_quote!(#item));
1289 }
1290 }
1291 for const_fn in impl_const_fns.iter_mut() {
1292 let mut last_seg = trait_path.segments.last().unwrap().clone();
1293 last_seg.ident = parse_quote!(Trait);
1294 let mut visitor = ReplaceSelfAssociatedType {
1295 replace_prefix: quote!(<#impl_target as #trait_mod::#last_seg>),
1296 };
1297 visitor.visit_impl_item_mut(const_fn);
1298 let ImplItem::Fn(const_fn) = const_fn else {
1299 unreachable!()
1300 };
1301 const_fn.vis = parse_quote!(pub);
1302 }
1303
1304 let impl_index = IMPL_COUNT.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1305 let trait_import_name: Ident = format_ident!(
1306 "{}{}TraitImpl_{}",
1307 impl_target.clone().force_get_ident(),
1308 item_impl.trait_.clone().unwrap().1.force_get_ident(),
1309 impl_index,
1310 );
1311
1312 let converted_const_fns = impl_const_fns.iter().map(|const_fn| {
1313 let mut const_fn: ImplItemFn = parse_quote!(#const_fn);
1314 const_fn.sig.constness = None;
1315 const_fn.vis = Visibility::Inherited;
1316 let item: ImplItem = parse_quote!(#const_fn);
1317 item
1318 });
1319
1320 let impl_const_fns = impl_const_fns.iter().map(|const_fn| {
1321 let mut const_fn_visitor = FindGenericParam::new(&trait_impl_generics);
1322 const_fn_visitor.visit_impl_item(const_fn);
1323 let const_fn_generics =
1324 filter_generics(&trait_impl_generics, &const_fn_visitor.usages).impl_generics;
1325 let mut const_fn: ImplItemFn = parse_quote!(#const_fn);
1326 const_fn.sig.generics = merge_generics(&const_fn_generics, &const_fn.sig.generics);
1327 const_fn.sig.generics.params = const_fn
1328 .sig
1329 .generics
1330 .params
1331 .iter()
1332 .cloned()
1333 .map(|g| match g {
1334 GenericParam::Lifetime(lifetime) => GenericParam::Lifetime(lifetime),
1335 GenericParam::Type(typ) => {
1336 let mut typ = typ.clone();
1337 typ.default = None;
1338 GenericParam::Type(typ)
1339 }
1340 GenericParam::Const(constant) => {
1341 let mut constant = constant.clone();
1342 constant.default = None;
1343 GenericParam::Const(constant)
1344 }
1345 })
1346 .collect();
1347 const_fn
1348 });
1349
1350 item_impl.items.extend(converted_const_fns);
1351 let mut impl_visitor = FindGenericParam::new(&item_impl.generics);
1352 impl_visitor.visit_item_impl(&item_impl);
1353 let mut filtered_generics = filter_generics(&item_impl.generics, &impl_visitor.usages);
1354 filtered_generics.strip_default_generics();
1355 item_impl.generics = filtered_generics.impl_generics;
1356
1357 let inherent_impl = if impl_const_fns.len() > 0 {
1358 Some(quote! {
1359 impl #impl_target {
1361 #(#impl_const_fns)*
1362 }
1363 })
1364 } else {
1365 None
1366 };
1367
1368 let output = quote! {
1369 #item_impl
1370
1371 impl #trait_mod::#sealed_ident for #impl_target {}
1372
1373 #inherent_impl
1374
1375 #[doc(hidden)]
1376 #[allow(unused)]
1377 use #trait_mod::Trait as #trait_import_name;
1378 };
1379 #[cfg(feature = "debug")]
1380 if debug {
1381 output.pretty_print();
1382 }
1383 Ok(output)
1384}