type_leak/
lib.rs

1#![doc = include_str!("./README.md")]
2
3use petgraph::graph::{DefaultIx, NodeIndex};
4use petgraph::visit::EdgeRef;
5use petgraph::Graph;
6use proc_macro2::Span;
7use std::cell::RefCell;
8use std::collections::{HashMap, HashSet};
9use std::rc::Rc;
10use syn::parse::Parse;
11use syn::punctuated::Punctuated;
12use syn::spanned::Spanned;
13use syn::visit::Visit;
14use syn::visit_mut::VisitMut;
15use syn::*;
16use template_quote::{quote, ToTokens};
17
18pub use syn;
19
20/// The entry point of this crate.
21pub struct Leaker {
22    generics: Generics,
23    graph: Graph<Type, ()>,
24    map: HashMap<Type, NodeIndex<DefaultIx>>,
25    must_intern_nodes: HashSet<NodeIndex<DefaultIx>>,
26    root_nodes: HashSet<NodeIndex<DefaultIx>>,
27}
28
29/// Error represents that the type is not internable.
30#[derive(Debug, Clone)]
31pub struct NotInternableError(pub Span);
32
33impl std::fmt::Display for NotInternableError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.write_str("Not internable")
36    }
37}
38
39impl std::error::Error for NotInternableError {}
40
41/// Encode [`GenericParam`]s to a type.
42pub fn encode_generics_to_ty<'a>(iter: impl IntoIterator<Item = &'a GenericArgument>) -> Type {
43    Type::Tuple(TypeTuple {
44        paren_token: Default::default(),
45        elems: iter
46            .into_iter()
47            .map(|param| -> Type {
48                match param {
49                    GenericArgument::Lifetime(lifetime) => {
50                        parse_quote!(& #lifetime ())
51                    }
52                    GenericArgument::Const(expr) => {
53                        parse_quote!([(); #expr as usize])
54                    }
55                    GenericArgument::Type(ty) => ty.clone(),
56                    _ => panic!(),
57                }
58            })
59            .collect(),
60    })
61}
62
63/// Represents result of [`Leaker::check()`], holding the cause in its tuple item.
64#[derive(Clone, Debug)]
65pub enum CheckResult {
66    /// The type must be interned because it directly depends on the type context.
67    MustIntern(Span),
68    /// The type must be interned because one of the types which it constituted with must be
69    /// interned.
70    MustInternOrInherit(Span),
71    /// The type must not be interned, because `type_leak` cannot intern this. (e.g. `impl Trait`s,
72    /// `_`, trait bounds with relative trait path, ...)
73    MustNotIntern(Span),
74    /// Other.
75    Neutral,
76}
77
78struct AnalyzeVisitor<'a> {
79    leaker: &'a mut Leaker,
80    generics: Generics,
81    generics_base: Generics,
82    parent: Option<NodeIndex<DefaultIx>>,
83    error: Option<Span>,
84}
85
86impl<'a, 'ast> Visit<'ast> for AnalyzeVisitor<'a> {
87    fn visit_trait_item_fn(&mut self, i: &TraitItemFn) {
88        let mut visitor = AnalyzeVisitor {
89            leaker: &mut self.leaker,
90            generics: self.generics.clone(),
91            generics_base: self.generics_base.clone(),
92            parent: self.parent.clone(),
93            error: self.error.clone(),
94        };
95        for g in &i.sig.generics.params {
96            visitor.generics.params.push(g.clone());
97        }
98        let mut i = i.clone();
99        i.default = None;
100        i.semi_token = Some(Default::default());
101        syn::visit::visit_trait_item_fn(&mut visitor, &mut i);
102        self.error = visitor.error;
103    }
104
105    fn visit_receiver(&mut self, i: &Receiver) {
106        for attr in &i.attrs {
107            self.visit_attribute(attr);
108        }
109        if let Some((_, Some(lt))) = &i.reference {
110            self.visit_lifetime(lt);
111        }
112        if i.colon_token.is_some() {
113            self.visit_type(&i.ty);
114        }
115    }
116
117    fn visit_trait_item_type(&mut self, i: &TraitItemType) {
118        let mut visitor = AnalyzeVisitor {
119            leaker: &mut self.leaker,
120            generics: self.generics.clone(),
121            generics_base: self.generics_base.clone(),
122            parent: self.parent.clone(),
123            error: self.error.clone(),
124        };
125        for g in &i.generics.params {
126            visitor.generics.params.push(g.clone());
127        }
128        syn::visit::visit_trait_item_type(&mut visitor, i);
129        self.error = visitor.error;
130    }
131
132    fn visit_type(&mut self, i: &Type) {
133        match self.leaker.check(&self.generics_base, &self.generics, i) {
134            Err((_, s)) => {
135                // Emit error and terminate searching
136                self.error = Some(s);
137            }
138            Ok(CheckResult::MustNotIntern(_)) => (),
139            o => {
140                let child = self
141                    .leaker
142                    .map
143                    .entry(i.clone())
144                    .or_insert_with(|| self.leaker.graph.add_node(i.clone()))
145                    .clone();
146                if let Some(parent) = &self.parent {
147                    self.leaker
148                        .graph
149                        .add_edge(parent.clone(), child.clone(), ());
150                } else {
151                    self.leaker.root_nodes.insert(child.clone());
152                }
153                if let Ok(CheckResult::MustIntern(_)) = o {
154                    // Terminate searching
155                    self.leaker.must_intern_nodes.insert(child.clone());
156                } else {
157                    // Perform recuesive call
158                    let parent = self.parent;
159                    self.parent = Some(child);
160                    syn::visit::visit_type(self, i);
161                    self.parent = parent;
162                }
163            }
164        }
165    }
166
167    fn visit_trait_bound(&mut self, i: &TraitBound) {
168        if i.path.leading_colon.is_none() {
169            // Trait bounds with relative oath
170            self.error = Some(i.span());
171        } else {
172            syn::visit::visit_trait_bound(self, i)
173        }
174    }
175
176    fn visit_expr_path(&mut self, i: &ExprPath) {
177        if i.path.leading_colon.is_none() {
178            // Value or trait bounds with relative oath
179            self.error = Some(i.span());
180        } else {
181            syn::visit::visit_expr_path(self, i)
182        }
183    }
184}
185
186impl Leaker {
187    /// Initialize with [`ItemStruct`]
188    ///
189    /// ```
190    /// # use type_leak::*;
191    /// # use syn::*;
192    /// let test_struct: ItemStruct = parse_quote!(
193    ///     pub struct MyStruct<'a, T1, T2: ::path::to::MyType1<MyType4>> {
194    ///         field1: MyType1,
195    ///         field2: (MyType2, MyType3<MyType1>, MyType4, MyType5),
196    ///         field3: &'a (T1, T2),
197    ///     }
198    /// );
199    /// let leaker = Leaker::from_struct(&test_struct).unwrap();
200    /// ```
201    pub fn from_struct(input: &ItemStruct) -> std::result::Result<Self, NotInternableError> {
202        let mut leaker = Leaker::with_generics(input.generics.clone());
203        let mut visitor = AnalyzeVisitor {
204            leaker: &mut leaker,
205            parent: None,
206            error: None,
207            generics: input.generics.clone(),
208            generics_base: input.generics.clone(),
209        };
210        visitor.visit_item_struct(input);
211        if let Some(e) = visitor.error {
212            return Err(NotInternableError(e));
213        }
214        Ok(leaker)
215    }
216
217    /// Initialize with [`ItemEnum`]
218    pub fn from_enum(input: &ItemEnum) -> std::result::Result<Self, NotInternableError> {
219        let mut leaker = Self::with_generics(input.generics.clone());
220        let mut visitor = AnalyzeVisitor {
221            leaker: &mut leaker,
222            parent: None,
223            error: None,
224            generics: input.generics.clone(),
225            generics_base: input.generics.clone(),
226        };
227        visitor.visit_item_enum(input);
228        if let Some(e) = visitor.error {
229            return Err(NotInternableError(e));
230        }
231        Ok(leaker)
232    }
233
234    /// Build an [`Leaker`] with given trait.
235    ///
236    /// Unlike enum nor struct, it requires ~alternative path~, an absolute path of a struct
237    /// which is declared the same crate with the leaker trait and also visible from
238    /// [`Referrer`]s' context. That struct is used as an `impl` target of `Repeater`
239    /// instead of the Leaker's path.
240    ///
241    /// ```
242    /// # use syn::*;
243    /// # use type_leak::Leaker;
244    /// let s: ItemTrait = parse_quote!{
245    ///     pub trait MyTrait<T, U> {
246    ///         fn func(self, t: T) -> U;
247    ///     }
248    /// };
249    /// let alternate: ItemStruct = parse_quote!{
250    ///     pub struct MyAlternate;
251    /// };
252    /// let _ = Leaker::from_trait(&s);
253    /// ```
254    pub fn from_trait(input: &ItemTrait) -> std::result::Result<Self, NotInternableError> {
255        let mut leaker = Self::with_generics(input.generics.clone());
256        let mut visitor = AnalyzeVisitor {
257            leaker: &mut leaker,
258            parent: None,
259            error: None,
260            generics: input.generics.clone(),
261            generics_base: input.generics.clone(),
262        };
263        visitor.visit_item_trait(input);
264        if let Some(e) = visitor.error {
265            return Err(NotInternableError(e));
266        }
267        Ok(leaker)
268    }
269
270    /// Initialize empty [`Leaker`] with given generics.
271    ///
272    /// Types, consts, lifetimes defined in the [`Generics`] is treated as "no needs to be interned" although
273    /// they looks like relative path names.
274    pub fn with_generics(generics: Generics) -> Self {
275        Self {
276            generics,
277            graph: Graph::new(),
278            map: HashMap::new(),
279            must_intern_nodes: HashSet::new(),
280            root_nodes: HashSet::new(),
281        }
282    }
283
284    /// Intern the given type as a root node.
285    pub fn intern(
286        &mut self,
287        generics: Generics,
288        ty: &Type,
289    ) -> std::result::Result<&mut Self, NotInternableError> {
290        let mut visitor = AnalyzeVisitor {
291            leaker: self,
292            parent: None,
293            error: None,
294            generics_base: generics.clone(),
295            generics,
296        };
297        visitor.visit_type(ty);
298        if let Some(e) = visitor.error {
299            Err(NotInternableError(e))
300        } else {
301            Ok(self)
302        }
303    }
304
305    /// Check that the internableness of give `ty`. It returns `Err` in contradiction (the type
306    /// must and must not be interned).
307    ///
308    ///
309    /// See [`CheckResult`].
310    pub fn check(
311        &self,
312        generics_base: &Generics,
313        generics: &Generics,
314        ty: &Type,
315    ) -> std::result::Result<CheckResult, (Span, Span)> {
316        use syn::visit::Visit;
317        #[derive(Clone)]
318        struct Visitor {
319            generic_lifetimes_base: Vec<Lifetime>,
320            generic_idents_base: Vec<Ident>,
321            generic_lifetimes: Vec<Lifetime>,
322            generic_idents: Vec<Ident>,
323            impossible: Option<Span>,
324            must: Option<(isize, Span)>,
325        }
326
327        const _: () = {
328            use syn::visit::Visit;
329            impl<'a> Visit<'a> for Visitor {
330                fn visit_type(&mut self, i: &Type) {
331                    match i {
332                        Type::BareFn(TypeBareFn {
333                            lifetimes,
334                            inputs,
335                            output,
336                            ..
337                        }) => {
338                            let mut visitor = self.clone();
339                            visitor.must = None;
340                            visitor.generic_lifetimes.extend(
341                                lifetimes
342                                    .as_ref()
343                                    .map(|ls| {
344                                        ls.lifetimes.iter().map(|gp| {
345                                            if let GenericParam::Lifetime(lt) = gp {
346                                                lt.lifetime.clone()
347                                            } else {
348                                                panic!()
349                                            }
350                                        })
351                                    })
352                                    .into_iter()
353                                    .flatten(),
354                            );
355                            for input in inputs {
356                                visitor.visit_type(&input.ty);
357                            }
358                            if let ReturnType::Type(_, output) = output {
359                                visitor.visit_type(output.as_ref());
360                            }
361                            if self.impossible.is_none() {
362                                self.impossible = visitor.impossible;
363                            }
364                            match (self.must.clone(), visitor.must.clone()) {
365                                (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
366                                    self.must = Some((v_l + 1, v_m))
367                                }
368                                (None, Some((v_l, v_m))) => self.must = Some((v_l + 1, v_m)),
369                                _ => (),
370                            }
371                            return;
372                        }
373                        Type::TraitObject(TypeTraitObject { bounds, .. }) => {
374                            for bound in bounds {
375                                match bound {
376                                    TypeParamBound::Trait(TraitBound {
377                                        lifetimes, path, ..
378                                    }) => {
379                                        let mut visitor = self.clone();
380                                        visitor.must = None;
381                                        visitor.generic_lifetimes.extend(
382                                            lifetimes
383                                                .as_ref()
384                                                .map(|ls| {
385                                                    ls.lifetimes.iter().map(|gp| {
386                                                        if let GenericParam::Lifetime(lt) = gp {
387                                                            lt.lifetime.clone()
388                                                        } else {
389                                                            panic!()
390                                                        }
391                                                    })
392                                                })
393                                                .into_iter()
394                                                .flatten(),
395                                        );
396                                        visitor.visit_path(path);
397                                        if self.impossible.is_none() {
398                                            self.impossible = visitor.impossible;
399                                        }
400                                        match (self.must.clone(), visitor.must.clone()) {
401                                            (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
402                                                self.must = Some((v_l + 1, v_m))
403                                            }
404                                            (None, Some((v_l, v_m))) => {
405                                                self.must = Some((v_l + 1, v_m))
406                                            }
407                                            _ => (),
408                                        }
409                                        return;
410                                    }
411                                    TypeParamBound::Verbatim(_) => {
412                                        self.impossible = Some(bound.span());
413                                        return;
414                                    }
415                                    _ => (),
416                                }
417                            }
418                        }
419                        Type::ImplTrait(_) | Type::Macro(_) | Type::Verbatim(_) => {
420                            self.impossible = Some(i.span());
421                        }
422                        Type::Reference(TypeReference { lifetime, .. }) => {
423                            if lifetime.is_none() {
424                                self.impossible = Some(i.span());
425                            }
426                        }
427                        _ => (),
428                    }
429                    let mut visitor = self.clone();
430                    visitor.must = None;
431                    syn::visit::visit_type(&mut visitor, i);
432                    if visitor.impossible.is_some() {
433                        self.impossible = visitor.impossible;
434                    }
435                    match (self.must.clone(), visitor.must.clone()) {
436                        (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
437                            self.must = Some((v_l + 1, v_m))
438                        }
439                        (None, Some((v_l, v_m))) => self.must = Some((v_l + 1, v_m)),
440                        _ => (),
441                    }
442                }
443                fn visit_qself(&mut self, i: &QSelf) {
444                    if i.as_token.is_none() {
445                        self.impossible = Some(i.span());
446                    }
447                    syn::visit::visit_qself(self, i)
448                }
449                fn visit_lifetime(&mut self, i: &Lifetime) {
450                    match i.to_string().as_str() {
451                        "'static" => (),
452                        "'_" => self.impossible = Some(i.span()),
453                        _ if self.generic_lifetimes_base.iter().any(|lt| lt == i) => {}
454                        _ if self.generic_lifetimes.iter().any(|lt| lt == i) => {
455                            self.impossible = Some(i.span());
456                        }
457                        _ => {
458                            self.must = Some((-1, i.span()));
459                        }
460                    }
461                    syn::visit::visit_lifetime(self, i)
462                }
463                fn visit_expr(&mut self, i: &Expr) {
464                    match i {
465                        Expr::Closure(_) | Expr::Assign(_) | Expr::Verbatim(_) | Expr::Macro(_) => {
466                            self.impossible = Some(i.span());
467                        }
468                        _ => (),
469                    }
470                    syn::visit::visit_expr(self, i)
471                }
472                fn visit_path(&mut self, i: &Path) {
473                    if matches!(i.segments.iter().next(), Some(PathSegment { ident, arguments }) if ident == "Self" && arguments.is_none())
474                        && i.leading_colon.is_none()
475                    {
476                        // do nothing
477                    } else {
478                        match (i.leading_colon, i.get_ident()) {
479                            // i is a generic parameter
480                            (None, Some(ident)) if self.generic_idents_base.contains(&ident) => {}
481                            (None, Some(ident)) if self.generic_idents.contains(&ident) => {
482                                self.impossible = Some(i.span());
483                            }
484                            // relative path, not a generic parameter
485                            (None, _) => {
486                                self.must = Some((-1, i.span()));
487                            }
488                            // absolute path
489                            (Some(_), _) => (),
490                        }
491                    }
492                    syn::visit::visit_path(self, i)
493                }
494            }
495        };
496        let mut visitor = Visitor {
497            generic_lifetimes_base: generics_base
498                .params
499                .iter()
500                .filter_map(|gp| {
501                    if let GenericParam::Lifetime(lt) = gp {
502                        Some(lt.lifetime.clone())
503                    } else {
504                        None
505                    }
506                })
507                .collect(),
508            generic_idents_base: generics_base
509                .params
510                .iter()
511                .filter_map(|gp| match gp {
512                    GenericParam::Type(TypeParam { ident, .. })
513                    | GenericParam::Const(ConstParam { ident, .. }) => Some(ident.clone()),
514                    _ => None,
515                })
516                .collect(),
517            generic_lifetimes: generics
518                .params
519                .iter()
520                .filter_map(|gp| {
521                    if let GenericParam::Lifetime(lt) = gp {
522                        Some(lt.lifetime.clone())
523                    } else {
524                        None
525                    }
526                })
527                .collect(),
528            generic_idents: generics
529                .params
530                .iter()
531                .filter_map(|gp| match gp {
532                    GenericParam::Type(TypeParam { ident, .. })
533                    | GenericParam::Const(ConstParam { ident, .. }) => Some(ident.clone()),
534                    _ => None,
535                })
536                .collect(),
537            impossible: None,
538            must: None,
539        };
540        visitor.visit_type(ty);
541        match (visitor.must, visitor.impossible) {
542            (None, None) => Ok(CheckResult::Neutral),
543            (Some((0, span)), None) => Ok(CheckResult::MustIntern(span)),
544            (Some((_, span)), None) => Ok(CheckResult::MustInternOrInherit(span)),
545            (None, Some(span)) => Ok(CheckResult::MustNotIntern(span)),
546            (Some((_, span0)), Some(span1)) => Err((span0, span1)),
547        }
548    }
549
550    pub fn generics(&self) -> &Generics {
551        &self.generics
552    }
553
554    /// Finish building the [`Leaker`] and convert it into [`Referrer`].
555    pub fn finish(self) -> Referrer {
556        let leak_types: Vec<_> = self
557            .root_nodes
558            .iter()
559            .map(|node| self.graph.node_weight(node.clone()).unwrap().clone())
560            .collect();
561        let map = leak_types
562            .iter()
563            .enumerate()
564            .map(|(n, ty)| (ty.clone(), n))
565            .collect();
566        Referrer { leak_types, map }
567    }
568
569    #[cfg_attr(doc, aquamarine::aquamarine)]
570    /// Reduce nodes to decrease cost of {`Leaker`}'s implementation.
571    ///
572    /// # Algorithm
573    ///
574    /// To consult the algorithm, see the following [`Leaker`]'s input:
575    ///
576    /// ```ignore
577    /// pub struct MyStruct<'a, T1, T2: ::path::to::MyType1<MyType4>> {
578    ///     field1: MyType1,
579    ///     field2: (MyType2, MyType3<MyType1>, MyType4, MyType5),
580    ///     field3: &'a (T1, T2),
581    /// }
582    /// ```
583    ///
584    /// [`Leaker`], when initialized with [`Leaker::from_struct()`], analyze the AST and
585    /// construct a DAG which represents all (internable) types and the dependency relations like
586    /// this:
587    ///
588    /// ```mermaid
589    /// graph TD
590    ///   1["★MyType1
591    ///   (Node 1)"]
592    ///   2["★(MyType2, MyType3#lt;MyType1#gt;, MyType4, MyType5)
593    ///   (Node 2)"] -->3["MyType2
594    ///   (Node 3)"]
595    ///   2 -->4["MyType3#lt;MyType1#gt;
596    ///   (Node 4)"]
597    ///   2 -->0["★MyType4
598    ///   (Node 0)"]
599    ///   2 -->5["MyType5
600    ///   (Node 5)"]
601    ///   6["★&a (T1, T2)
602    ///   (Node 6)"] -->7["(T1, T2)
603    ///   (Node 7)"]
604    ///   7 -->8["T1
605    ///   (Node 8)"]
606    ///   7 -->9["T2
607    ///   (Node 9)"]
608    ///   classDef redNode stroke:#ff0000;
609    ///   class 0,1,3,4,5 redNode;
610    /// ```
611    ///
612    /// The **red** node is flagged as [`CheckResult::MustIntern`] by [`Leaker::check()`] (which means
613    /// the type literature depends on the type context, so it must be interned).
614    ///
615    /// This algorithm reduce the nodes, remaining that all root type (annotated with ★) can be
616    /// expressed with existing **red** nodes.
617    ///
618    /// Here, there are some choice in its freedom:
619    ///
620    /// - Intern all **red** nodes and ignore others (because other nodes are not needed to be
621    /// intern, or constructable with red nodes)
622    /// - Not directly intern **red** nodes; intern common ancessors instead if it is affordable.
623    ///
624    /// So finally, it results like:
625    ///
626    /// ```mermaid
627    /// graph TD
628    ///   0["MyType4
629    ///   (Node 0)"]
630    ///   1["MyType1
631    ///   (Node 1)"]
632    ///   2["(MyType2, MyType3 #lt; MyType1 #gt;, MyType4, MyType5)
633    ///   (Node 2)"]
634    ///   classDef redNode stroke:#ff0000;
635    ///   class 0,1,2 redNode;
636    /// ```
637    ///
638    pub fn reduce_roots(&mut self) {
639        self.reduce_unreachable_nodes();
640        self.reduce_obvious_nodes();
641        // TODO: unobvious root reduction with heaulistics
642    }
643
644    fn reduce_obvious_nodes(&mut self) {
645        let mut must_intern_nodes: Vec<_> = self.must_intern_nodes.iter().cloned().collect();
646        let mut root_nodes: Vec<_> = self.root_nodes.iter().cloned().collect();
647        let nodes: Vec<_> = self
648            .graph
649            .node_indices()
650            .filter_map(|n| {
651                let mut it = self.graph.edges(n.clone());
652                if let Some(edge) = it.next() {
653                    if let None = it.next() {
654                        return Some((edge.source(), edge.target(), edge.id()));
655                    }
656                }
657                None
658            })
659            .collect();
660        let mut removing_nodes = HashSet::new();
661        for (node1, node2, edge) in nodes {
662            self.graph.remove_edge(edge);
663            for (edge_in_source, edge_in) in self
664                .graph
665                .edges_directed(node1, petgraph::Direction::Incoming)
666                .map(|er| (er.source(), er.id()))
667                .collect::<Vec<_>>()
668            {
669                self.graph.update_edge(edge_in_source, node2, ());
670                self.graph.remove_edge(edge_in);
671            }
672            must_intern_nodes.iter_mut().for_each(|n| {
673                if n == &node1 {
674                    *n = node2.clone()
675                }
676            });
677            root_nodes.iter_mut().for_each(|n| {
678                if n == &node1 {
679                    *n = node2.clone()
680                }
681            });
682            removing_nodes.insert(node1);
683        }
684        let mut new_graph = Graph::new();
685        let mut node_map = HashMap::new();
686        for node in self.graph.node_indices() {
687            if !removing_nodes.contains(&node) {
688                let new_node = new_graph.add_node(self.graph[node].clone());
689                node_map.insert(node, new_node);
690            }
691        }
692        for edge in self.graph.edge_indices() {
693            let (n1, n2) = self.graph.edge_endpoints(edge).unwrap();
694            if let (Some(nn1), Some(nn2)) = (node_map.get(&n1), node_map.get(&n2)) {
695                new_graph.add_edge(*nn1, *nn2, ());
696            }
697        }
698        self.root_nodes = root_nodes
699            .iter()
700            .filter_map(|n| node_map.get(n).cloned())
701            .collect();
702        self.must_intern_nodes = must_intern_nodes
703            .iter()
704            .filter_map(|n| node_map.get(n).cloned())
705            .collect();
706        let _ = std::mem::replace(&mut self.graph, new_graph);
707    }
708
709    fn reduce_unreachable_nodes(&mut self) {
710        let reachable_forward = get_reachable_nodes(&self.graph, self.root_nodes.iter().cloned());
711        self.graph.reverse();
712        let reachable_backward =
713            get_reachable_nodes(&self.graph, self.must_intern_nodes.iter().cloned());
714        self.graph.reverse();
715        let reachable = reachable_forward
716            .intersection(&reachable_backward)
717            .cloned()
718            .collect::<HashSet<_>>();
719        // self.graph = self.graph.filter_map(
720        //     |ix, node| reachable.contains(&ix).then_some(node.clone()),
721        //     |ix, _| {
722        //         let (e1, e2) = self.graph.edge_endpoints(ix).unwrap();
723        //         (reachable.contains(&e1) && reachable.contains(&e2)).then_some(())
724        //     },
725        // );
726        let mut new_graph = Graph::new();
727        let mut node_map = HashMap::new();
728        for node in self.graph.node_indices() {
729            if reachable.contains(&node) {
730                let new_node = new_graph.add_node(self.graph[node].clone());
731                node_map.insert(node, new_node);
732            }
733        }
734        for edge in self.graph.edge_indices() {
735            let (n1, n2) = self.graph.edge_endpoints(edge).unwrap();
736            if let (Some(nn1), Some(nn2)) = (node_map.get(&n1), node_map.get(&n2)) {
737                new_graph.add_edge(*nn1, *nn2, ());
738            }
739        }
740        self.root_nodes = self
741            .root_nodes
742            .iter()
743            .filter_map(|n| node_map.get(n).cloned())
744            .collect();
745        self.must_intern_nodes = self
746            .must_intern_nodes
747            .iter()
748            .filter_map(|n| node_map.get(n).cloned())
749            .collect();
750        let _ = std::mem::replace(&mut self.graph, new_graph);
751    }
752}
753
754fn get_reachable_nodes<N, E>(
755    graph: &Graph<N, E>,
756    roots: impl IntoIterator<Item = NodeIndex<DefaultIx>>,
757) -> HashSet<NodeIndex<DefaultIx>> {
758    let mut ret: HashSet<_> = roots.into_iter().collect();
759    let mut frontier = ret.clone();
760    loop {
761        let mut buf: HashSet<_> = frontier
762            .iter()
763            .map(|node| graph.neighbors(node.clone()))
764            .flatten()
765            .collect();
766        let mut count = 0;
767        for node in &buf {
768            if ret.insert(node.clone()) {
769                count += 1;
770            }
771        }
772        if count == 0 {
773            break;
774        }
775        std::mem::swap(&mut frontier, &mut buf);
776    }
777    ret
778}
779
780#[derive(Clone, PartialEq, Eq, Debug)]
781pub struct Referrer {
782    leak_types: Vec<Type>,
783    map: HashMap<Type, usize>,
784}
785
786impl Referrer {
787    pub fn is_empty(&self) -> bool {
788        self.leak_types.is_empty()
789    }
790
791    pub fn iter(&self) -> impl Iterator<Item = &Type> {
792        self.leak_types.iter()
793    }
794
795    pub fn into_visitor<F: FnMut(Type, usize) -> Type>(self, type_fn: F) -> Visitor<F> {
796        Visitor(self, Rc::new(RefCell::new(type_fn)))
797    }
798
799    pub fn expand(&self, ty: Type, type_fn: impl FnMut(Type, usize) -> Type) -> Type {
800        use syn::fold::Fold;
801        struct Folder<'a, F>(&'a Referrer, F);
802        impl<'a, F: FnMut(Type, usize) -> Type> Fold for Folder<'a, F> {
803            fn fold_type(&mut self, ty: Type) -> Type {
804                if let Some(idx) = self.0.map.get(&ty) {
805                    self.1(ty, *idx)
806                } else {
807                    syn::fold::fold_type(self, ty)
808                }
809            }
810        }
811        let mut folder = Folder(self, type_fn);
812        folder.fold_type(ty)
813    }
814}
815
816#[derive(Debug)]
817pub struct Visitor<F>(Referrer, Rc<RefCell<F>>);
818
819impl<F> Clone for Visitor<F> {
820    fn clone(&self) -> Self {
821        Self(self.0.clone(), self.1.clone())
822    }
823}
824
825impl<F: FnMut(Type, usize) -> Type> Visitor<F> {
826    fn with_generics(&mut self, generics: &mut Generics) -> Self {
827        let mut visitor = self.clone();
828        for gp in generics.params.iter_mut() {
829            if let GenericParam::Type(TypeParam { ident, .. }) = gp {
830                visitor.0.map.remove(&parse_quote!(#ident));
831            }
832            visitor.visit_generic_param_mut(gp);
833        }
834        visitor
835    }
836
837    fn with_signature(&mut self, sig: &mut Signature) -> Self {
838        let mut visitor = self.with_generics(&mut sig.generics);
839        for input in sig.inputs.iter_mut() {
840            visitor.visit_fn_arg_mut(input);
841        }
842        visitor.visit_return_type_mut(&mut sig.output);
843        visitor
844    }
845}
846
847impl<F: FnMut(Type, usize) -> Type> VisitMut for Visitor<F> {
848    fn visit_type_mut(&mut self, i: &mut Type) {
849        *i = self.0.expand(i.clone(), &mut *self.1.borrow_mut());
850    }
851    fn visit_item_struct_mut(&mut self, i: &mut ItemStruct) {
852        for attr in i.attrs.iter_mut() {
853            self.visit_attribute_mut(attr);
854        }
855        let mut visitor = self.with_generics(&mut i.generics);
856        visitor.visit_fields_mut(&mut i.fields);
857    }
858    fn visit_item_enum_mut(&mut self, i: &mut ItemEnum) {
859        for attr in i.attrs.iter_mut() {
860            self.visit_attribute_mut(attr);
861        }
862        let mut visitor = self.with_generics(&mut i.generics);
863        for variant in i.variants.iter_mut() {
864            visitor.visit_variant_mut(variant);
865        }
866    }
867    fn visit_item_trait_mut(&mut self, i: &mut ItemTrait) {
868        for attr in i.attrs.iter_mut() {
869            self.visit_attribute_mut(attr);
870        }
871        let mut visitor = self.with_generics(&mut i.generics);
872        for supertrait in i.supertraits.iter_mut() {
873            visitor.visit_type_param_bound_mut(supertrait);
874        }
875        for item in i.items.iter_mut() {
876            visitor.visit_trait_item_mut(item);
877        }
878    }
879    fn visit_item_union_mut(&mut self, i: &mut ItemUnion) {
880        for attr in i.attrs.iter_mut() {
881            self.visit_attribute_mut(attr);
882        }
883        let mut visitor = self.with_generics(&mut i.generics);
884        visitor.visit_fields_named_mut(&mut i.fields);
885    }
886    fn visit_item_type_mut(&mut self, i: &mut ItemType) {
887        for attr in i.attrs.iter_mut() {
888            self.visit_attribute_mut(attr);
889        }
890        let mut visitor = self.with_generics(&mut i.generics);
891        visitor.visit_type_mut(&mut i.ty);
892    }
893    fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
894        for attr in i.attrs.iter_mut() {
895            self.visit_attribute_mut(attr);
896        }
897        let mut visitor = self.with_signature(&mut i.sig);
898        visitor.visit_block_mut(i.block.as_mut());
899    }
900    fn visit_trait_item_fn_mut(&mut self, i: &mut TraitItemFn) {
901        for attr in i.attrs.iter_mut() {
902            self.visit_attribute_mut(attr);
903        }
904        let mut visitor = self.with_signature(&mut i.sig);
905        if let Some(block) = &mut i.default {
906            visitor.visit_block_mut(block);
907        }
908    }
909    fn visit_trait_item_type_mut(&mut self, i: &mut TraitItemType) {
910        for attr in i.attrs.iter_mut() {
911            self.visit_attribute_mut(attr);
912        }
913        let mut visitor = self.with_generics(&mut i.generics);
914        for bound in i.bounds.iter_mut() {
915            visitor.visit_type_param_bound_mut(bound);
916        }
917        if let Some((_, ty)) = &mut i.default {
918            visitor.visit_type_mut(ty);
919        }
920    }
921    fn visit_trait_item_const_mut(&mut self, i: &mut TraitItemConst) {
922        for attr in i.attrs.iter_mut() {
923            self.visit_attribute_mut(attr);
924        }
925        let mut visitor = self.with_generics(&mut i.generics);
926        visitor.visit_type_mut(&mut i.ty);
927        if let Some((_, expr)) = &mut i.default {
928            visitor.visit_expr_mut(expr);
929        }
930    }
931    fn visit_block_mut(&mut self, i: &mut Block) {
932        let mut visitor = self.clone();
933        for stmt in &i.stmts {
934            match stmt {
935                Stmt::Item(Item::Struct(ItemStruct { ident, .. }))
936                | Stmt::Item(Item::Enum(ItemEnum { ident, .. }))
937                | Stmt::Item(Item::Union(ItemUnion { ident, .. }))
938                | Stmt::Item(Item::Trait(ItemTrait { ident, .. }))
939                | Stmt::Item(Item::Type(ItemType { ident, .. })) => {
940                    visitor.0.map.remove(&parse_quote!(#ident));
941                }
942                _ => (),
943            }
944        }
945        for stmt in i.stmts.iter_mut() {
946            visitor.visit_stmt_mut(stmt);
947        }
948    }
949}
950
951impl Parse for Referrer {
952    fn parse(input: parse::ParseStream) -> Result<Self> {
953        let content: syn::parse::ParseBuffer<'_>;
954        parenthesized!(content in input);
955        let leak_types: Vec<Type> = Punctuated::<Type, Token![,]>::parse_terminated(&content)?
956            .into_iter()
957            .collect();
958        let map = leak_types
959            .iter()
960            .enumerate()
961            .map(|(n, ty)| (ty.clone(), n))
962            .collect();
963        Ok(Self { map, leak_types })
964    }
965}
966
967impl ToTokens for Referrer {
968    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
969        tokens.extend(quote! {
970            (
971                #(for item in &self.leak_types), {
972                    #item
973                }
974            )
975        })
976    }
977}