1#![doc = include_str!("./README.md")]
2
3use petgraph::graph::{DefaultIx, NodeIndex};
4use petgraph::visit::EdgeRef;
5use petgraph::Graph;
6use proc_macro2::Span;
7use std::collections::{HashMap, HashSet};
8use syn::parse::Parse;
9use syn::spanned::Spanned;
10use syn::visit::Visit;
11use syn::*;
12use template_quote::{quote, ToTokens};
13
14pub use syn;
15
16pub struct Leaker {
18 generics: Generics,
20 graph: Graph<Type, ()>,
21 map: HashMap<Type, NodeIndex<DefaultIx>>,
22 must_intern_nodes: HashSet<NodeIndex<DefaultIx>>,
23 root_nodes: HashSet<NodeIndex<DefaultIx>>,
24 pub implementor_type_fn: Box<dyn Fn(PathArguments) -> Type>,
29}
30
31#[derive(Debug, Clone)]
33pub struct NotInternableError(pub Span);
34
35impl std::fmt::Display for NotInternableError {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.write_str("Not internable")
38 }
39}
40
41impl std::error::Error for NotInternableError {}
42
43pub fn encode_generics_to_ty<'a>(iter: impl IntoIterator<Item = &'a GenericArgument>) -> Type {
45 Type::Tuple(TypeTuple {
46 paren_token: Default::default(),
47 elems: iter
48 .into_iter()
49 .map(|param| -> Type {
50 match param {
51 GenericArgument::Lifetime(lifetime) => {
52 parse_quote!(& #lifetime ())
53 }
54 GenericArgument::Const(expr) => {
55 parse_quote!([(); #expr as usize])
56 }
57 GenericArgument::Type(ty) => ty.clone(),
58 _ => panic!(),
59 }
60 })
61 .collect(),
62 })
63}
64
65#[derive(Clone, Debug)]
67pub enum CheckResult {
68 MustIntern(Span),
70 MustInternOrInherit(Span),
73 MustNotIntern(Span),
76 Neutral,
78}
79
80struct AnalyzeVisitor<'a> {
81 leaker: &'a mut Leaker,
82 generics: Generics,
83 parent: Option<NodeIndex<DefaultIx>>,
84 error: Option<Span>,
85}
86
87impl<'a, 'ast> Visit<'ast> for AnalyzeVisitor<'a> {
88 fn visit_trait_item_fn(&mut self, i: &TraitItemFn) {
89 let mut visitor = AnalyzeVisitor {
90 leaker: &mut self.leaker,
91 generics: self.generics.clone(),
92 parent: self.parent.clone(),
93 error: self.error.clone(),
94 };
95 for g in &i.sig.generics.params {
96 visitor.generics.params.push(g.clone());
97 }
98 let mut i = i.clone();
99 i.default = None;
100 i.semi_token = Some(Default::default());
101 syn::visit::visit_trait_item_fn(&mut visitor, &mut i);
102 self.error = visitor.error;
103 }
104
105 fn visit_trait_item_type(&mut self, i: &TraitItemType) {
106 let mut visitor = AnalyzeVisitor {
107 leaker: &mut self.leaker,
108 generics: self.generics.clone(),
109 parent: self.parent.clone(),
110 error: self.error.clone(),
111 };
112 for g in &i.generics.params {
113 visitor.generics.params.push(g.clone());
114 }
115 syn::visit::visit_trait_item_type(&mut visitor, i);
116 self.error = visitor.error;
117 }
118
119 fn visit_type(&mut self, i: &Type) {
120 match self.leaker.check(&self.generics, i) {
121 Err((_, s)) | Ok(CheckResult::MustNotIntern(s)) => {
122 self.error = Some(s);
124 }
125 o => {
126 let child = self
127 .leaker
128 .map
129 .entry(i.clone())
130 .or_insert_with(|| self.leaker.graph.add_node(i.clone()))
131 .clone();
132 if let Some(parent) = &self.parent {
133 self.leaker
134 .graph
135 .add_edge(parent.clone(), child.clone(), ());
136 } else {
137 self.leaker.root_nodes.insert(child.clone());
138 }
139 if let Ok(CheckResult::MustIntern(_)) = o {
140 self.leaker.must_intern_nodes.insert(child.clone());
142 } else {
143 let parent = self.parent;
145 self.parent = Some(child);
146 syn::visit::visit_type(self, i);
147 self.parent = parent;
148 }
149 }
150 }
151 }
152
153 fn visit_trait_bound(&mut self, i: &TraitBound) {
154 if i.path.leading_colon.is_none() {
155 self.error = Some(i.span());
157 } else {
158 syn::visit::visit_trait_bound(self, i)
159 }
160 }
161
162 fn visit_expr_path(&mut self, i: &ExprPath) {
163 if i.path.leading_colon.is_none() {
164 self.error = Some(i.span());
166 } else {
167 syn::visit::visit_expr_path(self, i)
168 }
169 }
170}
171
172impl Leaker {
173 pub fn from_struct(input: &ItemStruct) -> std::result::Result<Self, NotInternableError> {
188 let name = input.ident.clone();
189 let mut leaker = Leaker::with_generics_and_implementor(
190 input.generics.clone(),
191 Box::new(move |args: PathArguments| parse_quote!(#name #args)),
192 );
193 let mut visitor = AnalyzeVisitor {
194 leaker: &mut leaker,
195 parent: None,
196 error: None,
197 generics: input.generics.clone(),
198 };
199 visitor.visit_item_struct(input);
200 if let Some(e) = visitor.error {
201 return Err(NotInternableError(e));
202 }
203 Ok(leaker)
204 }
205
206 pub fn from_enum(input: &ItemEnum) -> std::result::Result<Self, NotInternableError> {
208 let name = input.ident.clone();
209 let mut leaker = Self::with_generics_and_implementor(
210 input.generics.clone(),
211 Box::new(move |args: PathArguments| parse_quote!(#name #args)),
212 );
213 let mut visitor = AnalyzeVisitor {
214 leaker: &mut leaker,
215 parent: None,
216 error: None,
217 generics: input.generics.clone(),
218 };
219 visitor.visit_item_enum(input);
220 if let Some(e) = visitor.error {
221 return Err(NotInternableError(e));
222 }
223 Ok(leaker)
224 }
225
226 pub fn from_trait(
247 input: &ItemTrait,
248 implementor_type_fn: Box<dyn Fn(PathArguments) -> Type>,
249 ) -> std::result::Result<Self, NotInternableError> {
250 let mut leaker =
251 Self::with_generics_and_implementor(input.generics.clone(), implementor_type_fn);
252 let mut visitor = AnalyzeVisitor {
253 leaker: &mut leaker,
254 parent: None,
255 error: None,
256 generics: input.generics.clone(),
257 };
258 visitor.visit_item_trait(input);
259 if let Some(e) = visitor.error {
260 return Err(NotInternableError(e));
261 }
262 Ok(leaker)
263 }
264
265 pub fn with_generics_and_implementor(
270 generics: Generics,
271 implementor_type_fn: Box<dyn Fn(PathArguments) -> Type>,
272 ) -> Self {
273 Self {
274 generics,
275 graph: Graph::new(),
276 map: HashMap::new(),
277 must_intern_nodes: HashSet::new(),
278 root_nodes: HashSet::new(),
279 implementor_type_fn,
280 }
281 }
282
283 pub fn intern(
285 &mut self,
286 generics: Generics,
287 ty: &Type,
288 ) -> std::result::Result<&mut Self, NotInternableError> {
289 let mut visitor = AnalyzeVisitor {
290 leaker: self,
291 parent: None,
292 error: None,
293 generics,
294 };
295 visitor.visit_type(ty);
296 if let Some(e) = visitor.error {
297 Err(NotInternableError(e))
298 } else {
299 Ok(self)
300 }
301 }
302
303 pub fn check(
309 &self,
310 generics: &Generics,
311 ty: &Type,
312 ) -> std::result::Result<CheckResult, (Span, Span)> {
313 use syn::visit::Visit;
314 #[derive(Clone)]
315 struct Visitor {
316 generic_lifetimes: Vec<Lifetime>,
317 generic_idents: Vec<Ident>,
318 impossible: Option<Span>,
319 must: Option<(isize, Span)>,
320 }
321
322 const _: () = {
323 use syn::visit::Visit;
324 impl<'a> Visit<'a> for Visitor {
325 fn visit_type(&mut self, i: &Type) {
326 match i {
327 Type::BareFn(TypeBareFn {
328 lifetimes,
329 inputs,
330 output,
331 ..
332 }) => {
333 let mut visitor = self.clone();
334 visitor.must = None;
335 visitor.generic_lifetimes.extend(
336 lifetimes
337 .as_ref()
338 .map(|ls| {
339 ls.lifetimes.iter().map(|gp| {
340 if let GenericParam::Lifetime(lt) = gp {
341 lt.lifetime.clone()
342 } else {
343 panic!()
344 }
345 })
346 })
347 .into_iter()
348 .flatten(),
349 );
350 for input in inputs {
351 visitor.visit_type(&input.ty);
352 }
353 if let ReturnType::Type(_, output) = output {
354 visitor.visit_type(output.as_ref());
355 }
356 if self.impossible.is_none() {
357 self.impossible = visitor.impossible;
358 }
359 match (self.must.clone(), visitor.must.clone()) {
360 (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
361 self.must = Some((v_l + 1, v_m))
362 }
363 (None, Some((v_l, v_m))) => self.must = Some((v_l + 1, v_m)),
364 _ => (),
365 }
366 return;
367 }
368 Type::TraitObject(TypeTraitObject { bounds, .. }) => {
369 for bound in bounds {
370 match bound {
371 TypeParamBound::Trait(TraitBound {
372 lifetimes, path, ..
373 }) => {
374 let mut visitor = self.clone();
375 visitor.must = None;
376 visitor.generic_lifetimes.extend(
377 lifetimes
378 .as_ref()
379 .map(|ls| {
380 ls.lifetimes.iter().map(|gp| {
381 if let GenericParam::Lifetime(lt) = gp {
382 lt.lifetime.clone()
383 } else {
384 panic!()
385 }
386 })
387 })
388 .into_iter()
389 .flatten(),
390 );
391 visitor.visit_path(path);
392 if self.impossible.is_none() {
393 self.impossible = visitor.impossible;
394 }
395 match (self.must.clone(), visitor.must.clone()) {
396 (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
397 self.must = Some((v_l + 1, v_m))
398 }
399 (None, Some((v_l, v_m))) => {
400 self.must = Some((v_l + 1, v_m))
401 }
402 _ => (),
403 }
404 return;
405 }
406 TypeParamBound::Verbatim(_) => {
407 self.impossible = Some(bound.span());
408 return;
409 }
410 _ => (),
411 }
412 }
413 }
414 Type::ImplTrait(_)
415 | Type::Infer(_)
416 | Type::Macro(_)
417 | Type::Verbatim(_) => {
418 self.impossible = Some(i.span());
419 }
420 _ => (),
421 }
422 let mut visitor = self.clone();
423 visitor.must = None;
424 syn::visit::visit_type(&mut visitor, i);
425 if visitor.impossible.is_some() {
426 self.impossible = visitor.impossible;
427 }
428 match (self.must.clone(), visitor.must.clone()) {
429 (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
430 self.must = Some((v_l + 1, v_m))
431 }
432 (None, Some((v_l, v_m))) => self.must = Some((v_l + 1, v_m)),
433 _ => (),
434 }
435 }
436 fn visit_qself(&mut self, i: &QSelf) {
437 if i.as_token.is_none() {
438 self.impossible = Some(i.span());
439 }
440 syn::visit::visit_qself(self, i)
441 }
442 fn visit_lifetime(&mut self, i: &Lifetime) {
443 if i.to_string() != "'static"
444 && !self.generic_lifetimes.iter().any(|lt| lt == i)
445 {
446 self.must = Some((-1, i.span()));
447 }
448 syn::visit::visit_lifetime(self, i)
449 }
450 fn visit_expr(&mut self, i: &Expr) {
451 match i {
452 Expr::Closure(_)
453 | Expr::Assign(_)
454 | Expr::Verbatim(_)
455 | Expr::Macro(_)
456 | Expr::Infer(_) => {
457 self.impossible = Some(i.span());
458 }
459 _ => (),
460 }
461 syn::visit::visit_expr(self, i)
462 }
463 fn visit_path(&mut self, i: &Path) {
464 if matches!(i.segments.iter().next(), Some(PathSegment { ident, arguments }) if ident == "Self" && arguments.is_none())
465 {
466 } else {
468 match (i.leading_colon, i.get_ident()) {
469 (None, Some(ident))
471 if self.generic_idents.contains(&ident) || ident == "Self" => {}
472 (None, _) => {
474 self.must = Some((-1, i.span()));
475 }
476 (Some(_), _) => (),
478 }
479 }
480
481 syn::visit::visit_path(self, i)
482 }
483 }
484 };
485 let mut visitor = Visitor {
486 generic_lifetimes: generics
487 .params
488 .iter()
489 .filter_map(|gp| {
490 if let GenericParam::Lifetime(lt) = gp {
491 Some(lt.lifetime.clone())
492 } else {
493 None
494 }
495 })
496 .collect(),
497 generic_idents: generics
498 .params
499 .iter()
500 .filter_map(|gp| match gp {
501 GenericParam::Type(TypeParam { ident, .. })
502 | GenericParam::Const(ConstParam { ident, .. }) => Some(ident.clone()),
503 _ => None,
504 })
505 .collect(),
506 impossible: None,
507 must: None,
508 };
509 visitor.visit_type(ty);
510 match (visitor.must, visitor.impossible) {
511 (None, None) => Ok(CheckResult::Neutral),
512 (Some((0, span)), None) => Ok(CheckResult::MustIntern(span)),
513 (Some((_, span)), None) => Ok(CheckResult::MustInternOrInherit(span)),
514 (None, Some(span)) => Ok(CheckResult::MustNotIntern(span)),
515 (Some((_, span0)), Some(span1)) => Err((span0, span1)),
516 }
517 }
518
519 pub fn finish(
521 self,
522 mut repeater_path_fn: impl FnMut(usize) -> Path,
523 ) -> (Vec<ItemImpl>, Referrer) {
524 let (impl_generics, _, where_clause) = self.generics.split_for_impl();
525 let id_map: HashMap<_, _> = self
526 .root_nodes
527 .iter()
528 .enumerate()
529 .map(|(n, idx)| {
530 (
531 self.graph.node_weight(idx.clone()).expect("hello3").clone(),
532 n,
533 )
534 })
535 .collect();
536 let path_args = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
537 colon2_token: None,
538 lt_token: Token),
539 args: self
540 .generics
541 .params
542 .iter()
543 .map(|param| match param {
544 GenericParam::Lifetime(lifetime_param) => {
545 GenericArgument::Lifetime(lifetime_param.lifetime.clone())
546 }
547 GenericParam::Type(TypeParam { ident, .. }) => {
548 GenericArgument::Type(parse_quote!(#ident))
549 }
550 GenericParam::Const(ConstParam { ident, .. }) => {
551 GenericArgument::Const(parse_quote!(#ident))
552 }
553 })
554 .collect(),
555 gt_token: Token),
556 });
557 (
558 id_map
559 .iter()
560 .map(|(ty, n)| {
561 parse2(quote! {
562 impl #impl_generics #{repeater_path_fn(*n)} for #{(*self.implementor_type_fn)(path_args.clone())} #where_clause {
563 type Type = #ty;
564 }
565 }).expect("hello4")
566 })
567 .collect(),
568 Referrer { map: id_map },
569 )
570 }
571
572 #[cfg_attr(doc, aquamarine::aquamarine)]
573 pub fn reduce_roots(&mut self) {
642 self.reduce_unreachable_nodes();
643 self.reduce_obvious_nodes();
644 }
646
647 fn reduce_obvious_nodes(&mut self) {
648 let mut must_intern_nodes: Vec<_> = self.must_intern_nodes.iter().cloned().collect();
649 let mut root_nodes: Vec<_> = self.root_nodes.iter().cloned().collect();
650 let nodes: Vec<_> = self
651 .graph
652 .node_indices()
653 .filter_map(|n| {
654 let mut it = self.graph.edges(n.clone());
655 if let Some(edge) = it.next() {
656 if let None = it.next() {
657 return Some((edge.source(), edge.target(), edge.id()));
658 }
659 }
660 None
661 })
662 .collect();
663 for (node1, node2, edge) in nodes {
664 self.graph.remove_edge(edge);
665 for (edge_in_source, edge_in) in self
666 .graph
667 .edges_directed(node1, petgraph::Direction::Incoming)
668 .map(|er| (er.source(), er.id()))
669 .collect::<Vec<_>>()
670 {
671 self.graph.update_edge(edge_in_source, node2, ());
672 self.graph.remove_edge(edge_in);
673 must_intern_nodes.iter_mut().for_each(|n| {
674 if n == &node1 {
675 *n = node2.clone()
676 }
677 });
678 root_nodes.iter_mut().for_each(|n| {
679 if n == &node1 {
680 *n = node2.clone()
681 }
682 });
683 }
684 self.graph.remove_node(node1);
685 }
686 self.must_intern_nodes = must_intern_nodes.into_iter().collect();
687 self.root_nodes = root_nodes.into_iter().collect();
688 }
689
690 fn reduce_unreachable_nodes(&mut self) {
691 let reachable_forward = get_reachable_nodes(&self.graph, self.root_nodes.iter().cloned());
692 self.graph.reverse();
693 let reachable_backward =
694 get_reachable_nodes(&self.graph, self.must_intern_nodes.iter().cloned());
695 self.graph.reverse();
696 let reachable = reachable_forward
697 .intersection(&reachable_backward)
698 .cloned()
699 .collect::<HashSet<_>>();
700 self.graph = self.graph.filter_map(
701 |ix, node| reachable.contains(&ix).then_some(node.clone()),
702 |ix, _| {
703 let (e1, e2) = self.graph.edge_endpoints(ix).unwrap();
704 (reachable.contains(&e1) && reachable.contains(&e2)).then_some(())
705 },
706 );
707 let mut new_graph = Graph::new();
708 let mut node_map = HashMap::new();
709 for node in self.graph.node_indices() {
710 if reachable.contains(&node) {
711 let new_node = new_graph.add_node(self.graph[node].clone());
712 node_map.insert(node, new_node);
713 }
714 }
715 for edge in self.graph.edge_indices() {
716 let (n1, n2) = self.graph.edge_endpoints(edge).unwrap();
717 if let (Some(nn1), Some(nn2)) = (node_map.get(&n1), node_map.get(&n2)) {
718 new_graph.add_edge(*nn1, *nn2, ());
719 }
720 }
721 self.root_nodes = self
722 .root_nodes
723 .iter()
724 .filter_map(|n| node_map.get(n).cloned())
725 .collect();
726 self.must_intern_nodes = self
727 .must_intern_nodes
728 .iter()
729 .filter_map(|n| node_map.get(n).cloned())
730 .collect();
731 }
732}
733
734fn get_reachable_nodes<N, E>(
735 graph: &Graph<N, E>,
736 roots: impl IntoIterator<Item = NodeIndex<DefaultIx>>,
737) -> HashSet<NodeIndex<DefaultIx>> {
738 let mut ret: HashSet<_> = roots.into_iter().collect();
739 let mut frontier = ret.clone();
740 loop {
741 let mut buf: HashSet<_> = frontier
742 .iter()
743 .map(|node| graph.neighbors(node.clone()))
744 .flatten()
745 .collect();
746 if !buf.iter().any(|node| ret.insert(node.clone())) {
747 break;
748 }
749 std::mem::swap(&mut frontier, &mut buf);
750 }
751 ret
752}
753
754#[derive(Clone, PartialEq, Eq, Debug)]
755pub struct Referrer {
756 map: HashMap<Type, usize>,
757}
758
759impl Referrer {
760 pub fn is_empty(&self) -> bool {
761 self.map.is_empty()
762 }
763
764 pub fn expand(
765 &self,
766 ty: Type,
767 leaker_ty: &Type,
768 repeater_path_fn: impl FnMut(usize) -> Path,
769 ) -> Type {
770 use syn::fold::Fold;
771
772 struct Folder<'a, F>(&'a Referrer, &'a Type, F);
773 impl<'a, F: FnMut(usize) -> Path> Fold for Folder<'a, F> {
774 fn fold_type(&mut self, ty: Type) -> Type {
775 if let Some(idx) = self.0.map.get(&ty) {
776 parse2(quote! {
777 <#{&self.1} as #{&self.2(*idx)}>::Type
778 })
779 .expect("hello2")
780 } else {
781 syn::fold::fold_type(self, ty)
782 }
783 }
784 }
785 let mut folder = Folder(self, leaker_ty, repeater_path_fn);
786 folder.fold_type(ty)
787 }
788}
789
790impl Parse for Referrer {
791 fn parse(input: parse::ParseStream) -> Result<Self> {
792 let mut map = HashMap::new();
793 let map_content: syn::parse::ParseBuffer<'_>;
794 parenthesized!(map_content in input);
795 while !map_content.is_empty() {
796 let key: Type = map_content.parse()?;
797 map_content.parse::<Token![:]>()?;
798 let value: LitInt = map_content.parse()?;
799 map.insert(key, value.base10_parse()?);
800 if map_content.is_empty() {
801 break;
802 }
803 map_content.parse::<Token![,]>()?;
804 }
805 Ok(Self { map })
806 }
807}
808
809impl ToTokens for Referrer {
810 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
811 tokens.extend(quote! {
812 (
813 #(for (key, val) in &self.map), {
814 #key: #val
815 }
816 )
817 })
818 }
819}