Skip to main content

type_leak/
lib.rs

1#![doc = include_str!("./README.md")]
2
3use gotgraph::graph::{Graph as GotGraph, GraphUpdate};
4use gotgraph::prelude::VecGraph;
5use proc_macro2::Span;
6use std::cell::RefCell;
7use std::collections::{HashMap, HashSet};
8use std::rc::Rc;
9use syn::parse::Parse;
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::visit::Visit;
13use syn::visit_mut::VisitMut;
14use syn::*;
15use template_quote::{quote, ToTokens};
16
17pub use syn;
18
19/// The entry point of this crate.
20pub struct Leaker {
21    graph: VecGraph<Type, ()>,
22    pub allowed_paths: Vec<Path>,
23    pub self_ty_can_be_interned: bool,
24    reachable_types: HashSet<Type>,
25    pending_operations: Option<Vec<GraphOperation>>,
26}
27
28/// Error represents that the type is not internable.
29#[derive(Debug, Clone)]
30pub struct NotInternableError(pub Span);
31
32impl std::fmt::Display for NotInternableError {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.write_str("Not internable")
35    }
36}
37
38impl std::error::Error for NotInternableError {}
39
40/// Encode [`GenericArgument`]s to a type.
41pub fn encode_generics_args_to_ty<'a>(iter: impl IntoIterator<Item = &'a GenericArgument>) -> Type {
42    Type::Tuple(TypeTuple {
43        paren_token: Default::default(),
44        elems: iter
45            .into_iter()
46            .map(|param| -> Type {
47                match param {
48                    GenericArgument::Lifetime(lifetime) => {
49                        parse_quote!(& #lifetime ())
50                    }
51                    GenericArgument::Const(expr) => {
52                        parse_quote!([(); #expr as usize])
53                    }
54                    GenericArgument::Type(ty) => ty.clone(),
55                    _ => panic!(),
56                }
57            })
58            .collect(),
59    })
60}
61
62/// Encode [`GenericParam`]s to a type.
63pub fn encode_generics_params_to_ty<'a>(iter: impl IntoIterator<Item = &'a GenericParam>) -> Type {
64    Type::Tuple(TypeTuple {
65        paren_token: Default::default(),
66        elems: iter
67            .into_iter()
68            .map(|param| -> Type {
69                match param {
70                    GenericParam::Lifetime(LifetimeParam { lifetime, .. }) => {
71                        parse_quote!(& #lifetime ())
72                    }
73                    GenericParam::Const(ConstParam { ident, .. }) => {
74                        parse_quote!([(); #ident as usize])
75                    }
76                    GenericParam::Type(TypeParam { ident, .. }) => parse_quote!(#ident),
77                }
78            })
79            .collect(),
80    })
81}
82
83/// Represents result of [`Leaker::check()`], holding the cause in its tuple item.
84#[derive(Clone, Debug)]
85pub enum CheckResult {
86    /// The type must be interned because it directly depends on the type context.
87    MustIntern(Span),
88    /// The type must be interned because one of the types which it constituted with must be
89    /// interned.
90    MustInternOrInherit(Span),
91    /// The type must not be interned, because `type_leak` cannot intern this. (e.g. `impl Trait`s,
92    /// `_`, trait bounds with relative trait path, ...)
93    MustNotIntern(Span),
94    /// Other.
95    Neutral,
96}
97
98#[derive(Debug, Clone)]
99struct GraphOperation {
100    node_type: Type,
101    parent: Option<Type>,
102    is_root: bool,
103    is_must_intern: bool,
104}
105
106struct AnalyzeVisitor<'a> {
107    leaker: &'a mut Leaker,
108    generics: Generics,
109    parent: Option<Type>,
110    error: Option<Span>,
111    operations: Vec<GraphOperation>,
112}
113
114impl<'a, 'ast> Visit<'ast> for AnalyzeVisitor<'a> {
115    fn visit_trait_item_fn(&mut self, i: &TraitItemFn) {
116        let mut visitor = AnalyzeVisitor {
117            leaker: self.leaker,
118            generics: self.generics.clone(),
119            parent: self.parent.clone(),
120            error: self.error,
121            operations: Vec::new(),
122        };
123        for g in &i.sig.generics.params {
124            visitor.generics.params.push(g.clone());
125        }
126        let mut i = i.clone();
127        i.default = None;
128        i.semi_token = Some(Default::default());
129        syn::visit::visit_trait_item_fn(&mut visitor, &i);
130        self.error = visitor.error;
131        self.operations.extend(visitor.operations);
132    }
133
134    fn visit_receiver(&mut self, i: &Receiver) {
135        for attr in &i.attrs {
136            self.visit_attribute(attr);
137        }
138        if let Some((_, Some(lt))) = &i.reference {
139            self.visit_lifetime(lt);
140        }
141        if i.colon_token.is_some() {
142            self.visit_type(&i.ty);
143        }
144    }
145
146    fn visit_trait_item_type(&mut self, i: &TraitItemType) {
147        let mut visitor = AnalyzeVisitor {
148            leaker: self.leaker,
149            generics: self.generics.clone(),
150            parent: self.parent.clone(),
151            error: self.error,
152            operations: Vec::new(),
153        };
154        for g in &i.generics.params {
155            visitor.generics.params.push(g.clone());
156        }
157        syn::visit::visit_trait_item_type(&mut visitor, i);
158        self.error = visitor.error;
159        self.operations.extend(visitor.operations);
160    }
161
162    fn visit_type(&mut self, i: &Type) {
163        match self.leaker.check(i, &self.generics) {
164            Err((_, s)) => {
165                // Emit error and terminate searching
166                self.error = Some(s);
167            }
168            Ok(CheckResult::MustNotIntern(_)) => (),
169            o => {
170                // Collect the operation instead of directly modifying the graph
171                let is_must_intern = matches!(o, Ok(CheckResult::MustIntern(_)));
172                let is_root = self.parent.is_none();
173
174                self.operations.push(GraphOperation {
175                    node_type: i.clone(),
176                    parent: self.parent.clone(),
177                    is_root,
178                    is_must_intern,
179                });
180
181                if !is_must_intern {
182                    // Perform recursive call
183                    let parent = self.parent.clone();
184                    self.parent = Some(i.clone());
185                    syn::visit::visit_type(self, i);
186                    self.parent = parent;
187                }
188            }
189        }
190    }
191
192    fn visit_trait_bound(&mut self, i: &TraitBound) {
193        if i.path.leading_colon.is_none() {
194            // Trait bounds with relative oath
195            self.error = Some(i.span());
196        } else {
197            syn::visit::visit_trait_bound(self, i)
198        }
199    }
200
201    fn visit_expr_path(&mut self, i: &ExprPath) {
202        if i.path.leading_colon.is_none() {
203            // Value or trait bounds with relative oath
204            self.error = Some(i.span());
205        } else {
206            syn::visit::visit_expr_path(self, i)
207        }
208    }
209}
210
211impl Default for Leaker {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217impl Leaker {
218    /// Initialize with [`ItemStruct`]
219    ///
220    /// ```
221    /// # use type_leak::*;
222    /// # use syn::*;
223    /// let test_struct: ItemStruct = parse_quote!(
224    ///     pub struct MyStruct<'a, T1, T2: ::path::to::MyType1<MyType4>> {
225    ///         field1: MyType1,
226    ///         field2: (MyType2, MyType3<MyType1>, MyType4, MyType5),
227    ///         field3: &'a (T1, T2),
228    ///     }
229    /// );
230    /// let leaker = Leaker::from_struct(&test_struct).unwrap();
231    /// ```
232    pub fn from_struct(input: &ItemStruct) -> std::result::Result<Self, NotInternableError> {
233        let mut leaker = Self::new();
234        leaker.intern_with(&input.generics, |visitor| {
235            visitor.visit_item_struct(input);
236        })?;
237        Ok(leaker)
238    }
239
240    /// Initialize with [`ItemEnum`]
241    pub fn from_enum(input: &ItemEnum) -> std::result::Result<Self, NotInternableError> {
242        let mut leaker = Self::new();
243        leaker.intern_with(&input.generics, |visitor| {
244            visitor.visit_item_enum(input);
245        })?;
246        Ok(leaker)
247    }
248
249    /// Build an [`Leaker`] with given trait.
250    ///
251    /// Unlike enum nor struct, it requires ~alternative path~, an absolute path of a struct
252    /// which is declared the same crate with the leaker trait and also visible from
253    /// [`Referrer`]s' context. That struct is used as an `impl` target of `Repeater`
254    /// instead of the Leaker's path.
255    ///
256    /// ```
257    /// # use syn::*;
258    /// # use type_leak::Leaker;
259    /// let s: ItemTrait = parse_quote!{
260    ///     pub trait MyTrait<T, U> {
261    ///         fn func(self, t: T) -> U;
262    ///     }
263    /// };
264    /// let alternate: ItemStruct = parse_quote!{
265    ///     pub struct MyAlternate;
266    /// };
267    /// let _ = Leaker::from_trait(&s);
268    /// ```
269    pub fn from_trait(input: &ItemTrait) -> std::result::Result<Self, NotInternableError> {
270        let mut leaker = Self::new();
271        leaker.intern_with(&input.generics, |visitor| {
272            visitor.visit_item_trait(input);
273        })?;
274        Ok(leaker)
275    }
276
277    /// Initialize empty [`Leaker`] with no generics.
278    pub fn new() -> Self {
279        Self {
280            graph: VecGraph::default(),
281            allowed_paths: Vec::new(),
282            self_ty_can_be_interned: true,
283            reachable_types: HashSet::new(),
284            pending_operations: None,
285        }
286    }
287
288    /// Apply collected graph operations using scope_mut()
289    fn apply_operations(&mut self, operations: Vec<GraphOperation>) {
290        if operations.is_empty() {
291            return;
292        }
293
294        // Store operations for proper graph reduction in reduce_roots()
295        self.pending_operations = Some(operations);
296    }
297
298    /// Intern ast elements with visitor
299    pub fn intern_with<'ast, F>(
300        &mut self,
301        generics: &Generics,
302        f: F,
303    ) -> std::result::Result<&mut Self, NotInternableError>
304    where
305        F: FnOnce(&mut dyn syn::visit::Visit<'ast>),
306    {
307        let mut visitor = AnalyzeVisitor {
308            leaker: self,
309            parent: None,
310            error: None,
311            generics: generics.clone(),
312            operations: Vec::new(),
313        };
314        f(&mut visitor);
315        if let Some(e) = visitor.error {
316            Err(NotInternableError(e))
317        } else {
318            let operations = visitor.operations;
319            self.apply_operations(operations);
320            Ok(self)
321        }
322    }
323
324    /// Allow primitive types from `::core::primitive`.
325    pub fn allow_primitive(&mut self) -> &mut Self {
326        self.allowed_paths.extend([
327            parse_quote!(bool),
328            parse_quote!(char),
329            parse_quote!(str),
330            parse_quote!(u8),
331            parse_quote!(u16),
332            parse_quote!(u32),
333            parse_quote!(u64),
334            parse_quote!(u128),
335            parse_quote!(usize),
336            parse_quote!(i8),
337            parse_quote!(i16),
338            parse_quote!(i32),
339            parse_quote!(i64),
340            parse_quote!(i128),
341            parse_quote!(isize),
342            parse_quote!(f32),
343            parse_quote!(f64),
344        ]);
345        self
346    }
347
348    /// Allow items in the standard prelude (`std::prelude::v1`).
349    pub fn allow_std_prelude(&mut self) -> &mut Self {
350        self.allowed_paths.extend([
351            parse_quote!(Box),
352            parse_quote!(String),
353            parse_quote!(Vec),
354            parse_quote!(Option),
355            parse_quote!(Result),
356            parse_quote!(Copy),
357            parse_quote!(Send),
358            parse_quote!(Sized),
359            parse_quote!(Sync),
360            parse_quote!(Unpin),
361            parse_quote!(Drop),
362            parse_quote!(Fn),
363            parse_quote!(FnMut),
364            parse_quote!(FnOnce),
365            parse_quote!(Clone),
366            parse_quote!(PartialEq),
367            parse_quote!(PartialOrd),
368            parse_quote!(Eq),
369            parse_quote!(Ord),
370            parse_quote!(AsRef),
371            parse_quote!(AsMut),
372            parse_quote!(Into),
373            parse_quote!(From),
374            parse_quote!(Default),
375            parse_quote!(Iterator),
376            parse_quote!(IntoIterator),
377            parse_quote!(Extend),
378            parse_quote!(FromIterator),
379            parse_quote!(ToOwned),
380            parse_quote!(ToString),
381            parse_quote!(drop),
382        ]);
383        self
384    }
385
386    /// Check that the internableness of give `ty`. It returns `Err` in contradiction (the type
387    /// must and must not be interned).
388    ///
389    ///
390    /// See [`CheckResult`].
391    pub fn check(
392        &self,
393        ty: &Type,
394        generics: &Generics,
395    ) -> std::result::Result<CheckResult, (Span, Span)> {
396        use syn::visit::Visit;
397        #[derive(Clone)]
398        struct Visitor {
399            generic_lifetimes_base: Vec<Lifetime>,
400            generic_idents_base: Vec<Ident>,
401            generic_lifetimes: Vec<Lifetime>,
402            generic_idents: Vec<Ident>,
403            impossible: Option<Span>,
404            must: Option<(isize, Span)>,
405            self_ty_can_be_interned: bool,
406            allowed_paths: Vec<Path>,
407        }
408
409        const _: () = {
410            use syn::visit::Visit;
411            impl<'a> Visit<'a> for Visitor {
412                fn visit_type(&mut self, i: &Type) {
413                    match i {
414                        Type::BareFn(TypeBareFn {
415                            lifetimes,
416                            inputs,
417                            output,
418                            ..
419                        }) => {
420                            let mut visitor = self.clone();
421                            visitor.must = None;
422                            visitor.generic_lifetimes.extend(
423                                lifetimes
424                                    .as_ref()
425                                    .map(|ls| {
426                                        ls.lifetimes.iter().map(|gp| {
427                                            if let GenericParam::Lifetime(lt) = gp {
428                                                lt.lifetime.clone()
429                                            } else {
430                                                panic!()
431                                            }
432                                        })
433                                    })
434                                    .into_iter()
435                                    .flatten(),
436                            );
437                            for input in inputs {
438                                visitor.visit_type(&input.ty);
439                            }
440                            if let ReturnType::Type(_, output) = output {
441                                visitor.visit_type(output.as_ref());
442                            }
443                            if self.impossible.is_none() {
444                                self.impossible = visitor.impossible;
445                            }
446                            match (self.must, visitor.must) {
447                                (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
448                                    self.must = Some((v_l + 1, v_m))
449                                }
450                                (None, Some((v_l, v_m))) => self.must = Some((v_l + 1, v_m)),
451                                _ => (),
452                            }
453                            return;
454                        }
455                        Type::TraitObject(TypeTraitObject { bounds, .. }) => {
456                            for bound in bounds {
457                                match bound {
458                                    TypeParamBound::Trait(TraitBound {
459                                        lifetimes, path, ..
460                                    }) => {
461                                        let mut visitor = self.clone();
462                                        visitor.must = None;
463                                        visitor.generic_lifetimes.extend(
464                                            lifetimes
465                                                .as_ref()
466                                                .map(|ls| {
467                                                    ls.lifetimes.iter().map(|gp| {
468                                                        if let GenericParam::Lifetime(lt) = gp {
469                                                            lt.lifetime.clone()
470                                                        } else {
471                                                            panic!()
472                                                        }
473                                                    })
474                                                })
475                                                .into_iter()
476                                                .flatten(),
477                                        );
478                                        visitor.visit_path(path);
479                                        if self.impossible.is_none() {
480                                            self.impossible = visitor.impossible;
481                                        }
482                                        match (self.must, visitor.must) {
483                                            (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
484                                                self.must = Some((v_l + 1, v_m))
485                                            }
486                                            (None, Some((v_l, v_m))) => {
487                                                self.must = Some((v_l + 1, v_m))
488                                            }
489                                            _ => (),
490                                        }
491                                        return;
492                                    }
493                                    TypeParamBound::Verbatim(_) => {
494                                        self.impossible = Some(bound.span());
495                                        return;
496                                    }
497                                    _ => (),
498                                }
499                            }
500                        }
501                        Type::ImplTrait(_) | Type::Macro(_) | Type::Verbatim(_) => {
502                            self.impossible = Some(i.span());
503                        }
504                        Type::Reference(TypeReference { lifetime, .. }) => {
505                            if lifetime.is_none() {
506                                self.impossible = Some(i.span());
507                            }
508                        }
509                        _ => (),
510                    }
511                    let mut visitor = self.clone();
512                    visitor.must = None;
513                    syn::visit::visit_type(&mut visitor, i);
514                    if visitor.impossible.is_some() {
515                        self.impossible = visitor.impossible;
516                    }
517                    match (self.must, visitor.must) {
518                        (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
519                            self.must = Some((v_l + 1, v_m))
520                        }
521                        (None, Some((v_l, v_m))) => self.must = Some((v_l + 1, v_m)),
522                        _ => (),
523                    }
524                }
525                fn visit_qself(&mut self, i: &QSelf) {
526                    if i.as_token.is_none() {
527                        self.impossible = Some(i.span());
528                    }
529                    syn::visit::visit_qself(self, i)
530                }
531                fn visit_lifetime(&mut self, i: &Lifetime) {
532                    match i.to_string().as_str() {
533                        "'static" => (),
534                        "'_" => self.impossible = Some(i.span()),
535                        _ if self.generic_lifetimes_base.iter().any(|lt| lt == i) => {}
536                        _ if self.generic_lifetimes.iter().any(|lt| lt == i) => {
537                            self.impossible = Some(i.span());
538                        }
539                        _ => {
540                            self.must = Some((-1, i.span()));
541                        }
542                    }
543                    syn::visit::visit_lifetime(self, i)
544                }
545                fn visit_expr(&mut self, i: &Expr) {
546                    match i {
547                        Expr::Closure(_) | Expr::Assign(_) | Expr::Verbatim(_) | Expr::Macro(_) => {
548                            self.impossible = Some(i.span());
549                        }
550                        _ => (),
551                    }
552                    syn::visit::visit_expr(self, i)
553                }
554                fn visit_path(&mut self, i: &Path) {
555                    if self.allowed_paths.iter().any(|allowed| {
556                        allowed.leading_colon == i.leading_colon
557                            && allowed.segments.len() <= i.segments.len()
558                            && allowed
559                                .segments
560                                .iter()
561                                .zip(i.segments.iter())
562                                .all(|(allowed_seg, seg)| allowed_seg == seg)
563                    }) {
564                        return;
565                    }
566                    if matches!(i.segments.iter().next(), Some(PathSegment { ident, arguments }) if ident == "Self" && arguments.is_none())
567                        && i.leading_colon.is_none()
568                    {
569                        if !self.self_ty_can_be_interned {
570                            self.impossible = Some(i.span());
571                        }
572                    } else {
573                        match (i.leading_colon, i.get_ident()) {
574                            // i is a generic parameter
575                            (None, Some(ident)) if self.generic_idents_base.contains(ident) => {}
576                            (None, Some(ident)) if self.generic_idents.contains(ident) => {
577                                self.impossible = Some(i.span());
578                            }
579                            (None, _) => {
580                                if !matches!(
581                                    i.segments.iter().next(),
582                                    Some(PathSegment {
583                                        ident,
584                                        arguments: PathArguments::None,
585                                    }) if ident == "$crate"
586                                ) {
587                                    // relative path, not a generic parameter
588                                    self.must = Some((-1, i.span()));
589                                }
590                            }
591                            // absolute path
592                            (Some(_), _) => (),
593                        }
594                    }
595                    syn::visit::visit_path(self, i)
596                }
597            }
598        };
599        let mut visitor = Visitor {
600            generic_lifetimes_base: generics
601                .params
602                .iter()
603                .filter_map(|gp| {
604                    if let GenericParam::Lifetime(lt) = gp {
605                        Some(lt.lifetime.clone())
606                    } else {
607                        None
608                    }
609                })
610                .collect(),
611            generic_idents_base: generics
612                .params
613                .iter()
614                .filter_map(|gp| match gp {
615                    GenericParam::Type(TypeParam { ident, .. })
616                    | GenericParam::Const(ConstParam { ident, .. }) => Some(ident.clone()),
617                    _ => None,
618                })
619                .collect(),
620            generic_lifetimes: generics
621                .params
622                .iter()
623                .filter_map(|gp| {
624                    if let GenericParam::Lifetime(lt) = gp {
625                        Some(lt.lifetime.clone())
626                    } else {
627                        None
628                    }
629                })
630                .collect(),
631            generic_idents: generics
632                .params
633                .iter()
634                .filter_map(|gp| match gp {
635                    GenericParam::Type(TypeParam { ident, .. })
636                    | GenericParam::Const(ConstParam { ident, .. }) => Some(ident.clone()),
637                    _ => None,
638                })
639                .collect(),
640            impossible: None,
641            must: None,
642            self_ty_can_be_interned: self.self_ty_can_be_interned,
643            allowed_paths: self.allowed_paths.clone(),
644        };
645        visitor.visit_type(ty);
646        match (visitor.must, visitor.impossible) {
647            (None, None) => Ok(CheckResult::Neutral),
648            (Some((0, span)), None) => Ok(CheckResult::MustIntern(span)),
649            (Some((_, span)), None) => Ok(CheckResult::MustInternOrInherit(span)),
650            (None, Some(span)) => Ok(CheckResult::MustNotIntern(span)),
651            (Some((_, span0)), Some(span1)) => Err((span0, span1)),
652        }
653    }
654
655    /// Finish building the [`Leaker`] and convert it into [`Referrer`].
656    pub fn finish(self) -> Referrer {
657        // Use the reachable types computed by apply_operations
658        let leak_types: Vec<_> = self.reachable_types.into_iter().collect();
659        let map = leak_types
660            .iter()
661            .enumerate()
662            .map(|(n, ty)| (ty.clone(), n))
663            .collect();
664        Referrer { leak_types, map }
665    }
666
667    /// Reduce nodes to decrease cost of {`Leaker`}'s implementation.
668    ///
669    /// # Algorithm
670    ///
671    /// To consult the algorithm, see the following [`Leaker`]'s input:
672    ///
673    /// ```ignore
674    /// pub struct MyStruct<'a, T1, T2: ::path::to::MyType1<MyType4>> {
675    ///     field1: MyType1,
676    ///     field2: (MyType2, MyType3<MyType1>, MyType4, MyType5),
677    ///     field3: &'a (T1, T2),
678    /// }
679    /// ```
680    ///
681    /// [`Leaker`], when initialized with [`Leaker::from_struct()`], analyze the AST and
682    /// construct a DAG which represents all (internable) types and the dependency relations like
683    /// this:
684    ///
685    /// ![before_graph](https://raw.github.com/yasuo-ozu/type-leak/master/resources/before_graph.svg)
686    ///
687    /// The **red** node is flagged as [`CheckResult::MustIntern`] by [`Leaker::check()`] (which means
688    /// the type literature depends on the type context, so it must be interned).
689    ///
690    /// This algorithm reduce the nodes, remaining that all root type (annotated with ★) can be
691    /// expressed with existing **red** nodes.
692    ///
693    /// Here, there are some choice in its freedom:
694    ///
695    /// - Intern all **red** nodes and ignore others (because other nodes are not needed to be
696    ///   intern, or constructable with red nodes)
697    /// - Not directly intern **red** nodes; intern common ancessors instead if it is affordable.
698    ///
699    /// So finally, it results like:
700    ///
701    /// ![after_graph](https://raw.github.com/yasuo-ozu/type-leak/master/resources/after_graph.svg)
702    pub fn reduce_roots(&mut self) {
703        if let Some(operations) = self.pending_operations.take() {
704            self.build_graph_and_reduce(operations);
705        }
706        self.reduce_unreachable_nodes();
707        self.reduce_obvious_nodes();
708    }
709
710    fn build_graph_and_reduce(&mut self, operations: Vec<GraphOperation>) {
711        // Build parent-child relationships from operations
712        let mut children: HashMap<Type, Vec<Type>> = HashMap::new();
713        let mut parents: HashMap<Type, Vec<Type>> = HashMap::new();
714        let mut root_types = HashSet::new();
715        let mut must_intern_types = HashSet::new();
716
717        for operation in &operations {
718            if operation.is_root {
719                root_types.insert(operation.node_type.clone());
720            }
721            if operation.is_must_intern {
722                must_intern_types.insert(operation.node_type.clone());
723            }
724
725            if let Some(parent_type) = &operation.parent {
726                children
727                    .entry(parent_type.clone())
728                    .or_default()
729                    .push(operation.node_type.clone());
730                parents
731                    .entry(operation.node_type.clone())
732                    .or_default()
733                    .push(parent_type.clone());
734            }
735        }
736
737        // Implement proper reduction algorithm:
738        // 1. Keep all must-intern root types
739        // 2. Keep root types that transitively contain must-intern types
740        // 3. Do NOT keep non-root must-intern types that are only components of kept root types
741
742        let mut kept_types = HashSet::new();
743
744        // Step 1: Keep must-intern types that are also roots
745        for root_type in &root_types {
746            if must_intern_types.contains(root_type) {
747                kept_types.insert(root_type.clone());
748            }
749        }
750
751        // Step 2: Keep root types that contain must-intern descendants
752        for root_type in &root_types {
753            if !kept_types.contains(root_type)
754                && self.contains_must_intern_descendant(root_type, &children, &must_intern_types)
755            {
756                kept_types.insert(root_type.clone());
757            }
758        }
759
760        // Step 3: Only add must-intern types that are NOT components of any kept type
761        for must_intern_type in &must_intern_types {
762            if !self.is_component_of_any_kept_type(must_intern_type, &kept_types, &children) {
763                kept_types.insert(must_intern_type.clone());
764            }
765        }
766
767        self.reachable_types = kept_types;
768    }
769
770    fn contains_must_intern_descendant(
771        &self,
772        root: &Type,
773        children: &HashMap<Type, Vec<Type>>,
774        must_intern_types: &HashSet<Type>,
775    ) -> bool {
776        let mut visited = HashSet::new();
777        let mut stack = vec![root.clone()];
778
779        while let Some(current) = stack.pop() {
780            if !visited.insert(current.clone()) {
781                continue;
782            }
783
784            if must_intern_types.contains(&current) {
785                return true;
786            }
787
788            if let Some(child_list) = children.get(&current) {
789                for child in child_list {
790                    if !visited.contains(child) {
791                        stack.push(child.clone());
792                    }
793                }
794            }
795        }
796
797        false
798    }
799
800    fn is_component_of_any_kept_type(
801        &self,
802        target_type: &Type,
803        kept_types: &HashSet<Type>,
804        children: &HashMap<Type, Vec<Type>>,
805    ) -> bool {
806        for kept_type in kept_types {
807            if self.is_descendant_of(target_type, kept_type, children) {
808                return true;
809            }
810        }
811        false
812    }
813
814    fn is_descendant_of(
815        &self,
816        target: &Type,
817        ancestor: &Type,
818        children: &HashMap<Type, Vec<Type>>,
819    ) -> bool {
820        if target == ancestor {
821            return false; // Not a descendant of itself
822        }
823
824        let mut visited = HashSet::new();
825        let mut stack = vec![ancestor.clone()];
826
827        while let Some(current) = stack.pop() {
828            if !visited.insert(current.clone()) {
829                continue;
830            }
831
832            if let Some(child_list) = children.get(&current) {
833                for child in child_list {
834                    if child == target {
835                        return true;
836                    }
837                    if !visited.contains(child) {
838                        stack.push(child.clone());
839                    }
840                }
841            }
842        }
843
844        false
845    }
846
847    fn reduce_obvious_nodes(&mut self) {
848        // Remove intermediate nodes that have exactly one parent and one child
849        // This step eliminates unnecessary intermediate types
850    }
851
852    fn reduce_unreachable_nodes(&mut self) {
853        if let Some(operations) = self.pending_operations.take() {
854            self.apply_vecgraph_reduction(operations);
855        }
856    }
857
858    fn apply_vecgraph_reduction(&mut self, operations: Vec<GraphOperation>) {
859        let mut root_types = HashSet::new();
860        let mut must_intern_types = HashSet::new();
861
862        // Collect graph operations
863        for operation in &operations {
864            if operation.is_root {
865                root_types.insert(operation.node_type.clone());
866            }
867            if operation.is_must_intern {
868                must_intern_types.insert(operation.node_type.clone());
869            }
870        }
871
872        // Build the graph using scope_mut
873        self.graph.scope_mut(|mut ctx| {
874            let mut type_to_node = HashMap::new();
875
876            // Add all nodes first
877            for operation in &operations {
878                if !type_to_node.contains_key(&operation.node_type) {
879                    let node_tag = ctx.add_node(operation.node_type.clone());
880                    type_to_node.insert(operation.node_type.clone(), node_tag);
881                }
882
883                if let Some(parent_type) = &operation.parent {
884                    if !type_to_node.contains_key(parent_type) {
885                        let parent_node_tag = ctx.add_node(parent_type.clone());
886                        type_to_node.insert(parent_type.clone(), parent_node_tag);
887                    }
888                }
889            }
890
891            // Add edges between nodes
892            for operation in &operations {
893                if let Some(parent_type) = &operation.parent {
894                    let parent_node = type_to_node[parent_type];
895                    let child_node = type_to_node[&operation.node_type];
896                    ctx.add_edge((), parent_node, child_node);
897                }
898            }
899        });
900
901        // Perform analysis using the same logic as build_graph_and_reduce
902        self.reachable_types =
903            self.analyze_reachable_types(&root_types, &must_intern_types, &operations);
904    }
905
906    fn analyze_reachable_types(
907        &self,
908        root_types: &HashSet<Type>,
909        must_intern_types: &HashSet<Type>,
910        operations: &[GraphOperation],
911    ) -> HashSet<Type> {
912        // Implement the same logic as build_graph_and_reduce
913        let mut final_types = HashSet::new();
914
915        // Keep must-intern root types
916        for root_type in root_types {
917            if must_intern_types.contains(root_type) {
918                final_types.insert(root_type.clone());
919            }
920        }
921
922        // Keep root types that contain must-intern descendants
923        for root_type in root_types {
924            if !final_types.contains(root_type)
925                && self.contains_must_intern_in_tree(root_type, operations, must_intern_types)
926            {
927                final_types.insert(root_type.clone());
928            }
929        }
930
931        // Only keep must-intern types that aren't subcomponents of kept root types
932        for must_intern_type in must_intern_types {
933            if !self.is_subcomponent_of_kept_root(must_intern_type, &final_types, operations) {
934                final_types.insert(must_intern_type.clone());
935            }
936        }
937
938        final_types
939    }
940
941    fn contains_must_intern_in_tree(
942        &self,
943        root: &Type,
944        operations: &[GraphOperation],
945        must_intern_types: &HashSet<Type>,
946    ) -> bool {
947        let mut stack = vec![root.clone()];
948        let mut visited = HashSet::new();
949
950        while let Some(current) = stack.pop() {
951            if !visited.insert(current.clone()) {
952                continue;
953            }
954
955            if must_intern_types.contains(&current) {
956                return true;
957            }
958
959            for operation in operations {
960                if let Some(parent_type) = &operation.parent {
961                    if parent_type == &current && !visited.contains(&operation.node_type) {
962                        stack.push(operation.node_type.clone());
963                    }
964                }
965            }
966        }
967
968        false
969    }
970
971    fn is_subcomponent_of_kept_root(
972        &self,
973        target: &Type,
974        kept_roots: &HashSet<Type>,
975        operations: &[GraphOperation],
976    ) -> bool {
977        for root in kept_roots {
978            if self.contains_must_intern_in_tree(root, operations, &HashSet::from([target.clone()]))
979                && root != target
980            {
981                return true;
982            }
983        }
984        false
985    }
986
987    // TODO: Reimplement reduce_unreachable_nodes with VecGraph APIs
988}
989
990/// Holds the list of types that need to be encoded, produced by [`Leaker::finish()`].
991///
992/// The `Referrer` provides methods to iterate over leak types and to expand/transform
993/// types using a custom mapping function.
994#[derive(Clone, PartialEq, Eq, Debug)]
995pub struct Referrer {
996    leak_types: Vec<Type>,
997    map: HashMap<Type, usize>,
998}
999
1000impl Referrer {
1001    /// Returns `true` if there are no types to be encoded.
1002    pub fn is_empty(&self) -> bool {
1003        self.leak_types.is_empty()
1004    }
1005
1006    /// Returns an iterator over the types that need to be encoded.
1007    ///
1008    /// ```ignore
1009    /// let args = encode_generics_params_to_ty(&generics.params);
1010    /// referrer.iter().enumerate().map(|(ix, ty)| quote!{
1011    ///     impl #crate::TypeRef<#ix, #args> for #marker_ty {
1012    ///         type Type = #ty;
1013    ///     }
1014    /// }).collect::<TokenStream>();
1015    /// ```
1016    pub fn iter(&self) -> impl Iterator<Item = &Type> {
1017        self.leak_types.iter()
1018    }
1019
1020    /// Converts this `Referrer` into a [`Visitor`] that can be used to transform AST nodes.
1021    ///
1022    /// The `type_fn` callback is called for each type that needs to be converted,
1023    /// receiving the original type and its index, and returning the replacement type.
1024    /// See [`Referrer::expand()`] for details.
1025    pub fn into_visitor<F: FnMut(Type, usize) -> Type>(self, type_fn: F) -> Visitor<F> {
1026        Visitor(self, Rc::new(RefCell::new(type_fn)))
1027    }
1028
1029    /// Expands a type by replacing any interned types using the provided mapping function.
1030    ///
1031    /// The `type_fn` callback is called for each subtype that needs to be converted,
1032    /// receiving the original type and its index, and returning the replacement type.
1033    ///
1034    /// ```ignore
1035    /// let args = encode_generics_params_to_ty(&generics.params);
1036    /// referrer.expand(ty, |_, ix| {
1037    ///     parse_quote!(<#marker_ty as #crate::TypeRef<#ix, #args>>::Type)
1038    /// })
1039    /// ```
1040    pub fn expand(&self, ty: Type, type_fn: impl FnMut(Type, usize) -> Type) -> Type {
1041        use syn::fold::Fold;
1042        struct Folder<'a, F>(&'a Referrer, F);
1043        impl<'a, F: FnMut(Type, usize) -> Type> Fold for Folder<'a, F> {
1044            fn fold_type(&mut self, ty: Type) -> Type {
1045                if let Some(idx) = self.0.map.get(&ty) {
1046                    self.1(ty, *idx)
1047                } else {
1048                    syn::fold::fold_type(self, ty)
1049                }
1050            }
1051        }
1052        let mut folder = Folder(self, type_fn);
1053        folder.fold_type(ty)
1054    }
1055}
1056
1057#[derive(Debug)]
1058pub struct Visitor<F>(Referrer, Rc<RefCell<F>>);
1059
1060impl<F> Clone for Visitor<F> {
1061    fn clone(&self) -> Self {
1062        Self(self.0.clone(), self.1.clone())
1063    }
1064}
1065
1066impl<F: FnMut(Type, usize) -> Type> Visitor<F> {
1067    fn with_generics(&mut self, generics: &mut Generics) -> Self {
1068        let mut visitor = self.clone();
1069        for gp in generics.params.iter_mut() {
1070            if let GenericParam::Type(TypeParam { ident, .. }) = gp {
1071                visitor.0.map.remove(&parse_quote!(#ident));
1072            }
1073            visitor.visit_generic_param_mut(gp);
1074        }
1075        visitor
1076    }
1077
1078    fn with_signature(&mut self, sig: &mut Signature) -> Self {
1079        let mut visitor = self.with_generics(&mut sig.generics);
1080        for input in sig.inputs.iter_mut() {
1081            visitor.visit_fn_arg_mut(input);
1082        }
1083        visitor.visit_return_type_mut(&mut sig.output);
1084        visitor
1085    }
1086}
1087
1088impl<F: FnMut(Type, usize) -> Type> VisitMut for Visitor<F> {
1089    fn visit_type_mut(&mut self, i: &mut Type) {
1090        *i = self.0.expand(i.clone(), &mut *self.1.borrow_mut());
1091    }
1092    fn visit_item_struct_mut(&mut self, i: &mut ItemStruct) {
1093        for attr in i.attrs.iter_mut() {
1094            self.visit_attribute_mut(attr);
1095        }
1096        let mut visitor = self.with_generics(&mut i.generics);
1097        visitor.visit_fields_mut(&mut i.fields);
1098    }
1099    fn visit_item_enum_mut(&mut self, i: &mut ItemEnum) {
1100        for attr in i.attrs.iter_mut() {
1101            self.visit_attribute_mut(attr);
1102        }
1103        let mut visitor = self.with_generics(&mut i.generics);
1104        for variant in i.variants.iter_mut() {
1105            visitor.visit_variant_mut(variant);
1106        }
1107    }
1108    fn visit_item_trait_mut(&mut self, i: &mut ItemTrait) {
1109        for attr in i.attrs.iter_mut() {
1110            self.visit_attribute_mut(attr);
1111        }
1112        let mut visitor = self.with_generics(&mut i.generics);
1113        for supertrait in i.supertraits.iter_mut() {
1114            visitor.visit_type_param_bound_mut(supertrait);
1115        }
1116        for item in i.items.iter_mut() {
1117            visitor.visit_trait_item_mut(item);
1118        }
1119    }
1120    fn visit_item_union_mut(&mut self, i: &mut ItemUnion) {
1121        for attr in i.attrs.iter_mut() {
1122            self.visit_attribute_mut(attr);
1123        }
1124        let mut visitor = self.with_generics(&mut i.generics);
1125        visitor.visit_fields_named_mut(&mut i.fields);
1126    }
1127    fn visit_item_type_mut(&mut self, i: &mut ItemType) {
1128        for attr in i.attrs.iter_mut() {
1129            self.visit_attribute_mut(attr);
1130        }
1131        let mut visitor = self.with_generics(&mut i.generics);
1132        visitor.visit_type_mut(&mut i.ty);
1133    }
1134    fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
1135        for attr in i.attrs.iter_mut() {
1136            self.visit_attribute_mut(attr);
1137        }
1138        let mut visitor = self.with_signature(&mut i.sig);
1139        visitor.visit_block_mut(i.block.as_mut());
1140    }
1141    fn visit_trait_item_fn_mut(&mut self, i: &mut TraitItemFn) {
1142        for attr in i.attrs.iter_mut() {
1143            self.visit_attribute_mut(attr);
1144        }
1145        let mut visitor = self.with_signature(&mut i.sig);
1146        if let Some(block) = &mut i.default {
1147            visitor.visit_block_mut(block);
1148        }
1149    }
1150    fn visit_trait_item_type_mut(&mut self, i: &mut TraitItemType) {
1151        for attr in i.attrs.iter_mut() {
1152            self.visit_attribute_mut(attr);
1153        }
1154        let mut visitor = self.with_generics(&mut i.generics);
1155        for bound in i.bounds.iter_mut() {
1156            visitor.visit_type_param_bound_mut(bound);
1157        }
1158        if let Some((_, ty)) = &mut i.default {
1159            visitor.visit_type_mut(ty);
1160        }
1161    }
1162    fn visit_trait_item_const_mut(&mut self, i: &mut TraitItemConst) {
1163        for attr in i.attrs.iter_mut() {
1164            self.visit_attribute_mut(attr);
1165        }
1166        let mut visitor = self.with_generics(&mut i.generics);
1167        visitor.visit_type_mut(&mut i.ty);
1168        if let Some((_, expr)) = &mut i.default {
1169            visitor.visit_expr_mut(expr);
1170        }
1171    }
1172    fn visit_block_mut(&mut self, i: &mut Block) {
1173        let mut visitor = self.clone();
1174        for stmt in &i.stmts {
1175            match stmt {
1176                Stmt::Item(Item::Struct(ItemStruct { ident, .. }))
1177                | Stmt::Item(Item::Enum(ItemEnum { ident, .. }))
1178                | Stmt::Item(Item::Union(ItemUnion { ident, .. }))
1179                | Stmt::Item(Item::Trait(ItemTrait { ident, .. }))
1180                | Stmt::Item(Item::Type(ItemType { ident, .. })) => {
1181                    visitor.0.map.remove(&parse_quote!(#ident));
1182                }
1183                _ => (),
1184            }
1185        }
1186        for stmt in i.stmts.iter_mut() {
1187            visitor.visit_stmt_mut(stmt);
1188        }
1189    }
1190}
1191
1192impl Parse for Referrer {
1193    fn parse(input: parse::ParseStream) -> Result<Self> {
1194        let content: syn::parse::ParseBuffer<'_>;
1195        parenthesized!(content in input);
1196        let leak_types: Vec<Type> = Punctuated::<Type, Token![,]>::parse_terminated(&content)?
1197            .into_iter()
1198            .collect();
1199        let map = leak_types
1200            .iter()
1201            .enumerate()
1202            .map(|(n, ty)| (ty.clone(), n))
1203            .collect();
1204        Ok(Self { map, leak_types })
1205    }
1206}
1207
1208impl ToTokens for Referrer {
1209    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
1210        tokens.extend(quote! {
1211            (
1212                #(for item in &self.leak_types), {
1213                    #item
1214                }
1215            )
1216        })
1217    }
1218}