Skip to main content

type_leak/
lib.rs

1#![doc = include_str!("./README.md")]
2
3use gotgraph::graph::{Graph as GotGraph, GraphUpdate};
4use gotgraph::prelude::VecGraph;
5use proc_macro2::{Span, TokenStream};
6use std::cell::RefCell;
7use std::collections::{HashMap, HashSet};
8use std::rc::Rc;
9use syn::parse::Parse;
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::visit::Visit;
13use syn::visit_mut::VisitMut;
14use syn::*;
15use template_quote::{quote, ToTokens};
16
17pub use syn;
18
19/// Configuration for building a [`Leaker`].
20#[derive(Clone, Debug, PartialEq, Eq, Hash)]
21pub struct LeakerConfig {
22    pub allowed_paths: Vec<Path>,
23    pub self_ty_can_be_interned: bool,
24}
25
26impl Default for LeakerConfig {
27    fn default() -> Self {
28        Self {
29            allowed_paths: Vec::new(),
30            self_ty_can_be_interned: true,
31        }
32    }
33}
34
35impl LeakerConfig {
36    pub fn new() -> Self {
37        Default::default()
38    }
39
40    /// Allow primitive types from `::core::primitive`.
41    pub fn allow_primitive(&mut self) -> &mut Self {
42        self.allowed_paths.extend([
43            parse_quote!(bool),
44            parse_quote!(char),
45            parse_quote!(str),
46            parse_quote!(u8),
47            parse_quote!(u16),
48            parse_quote!(u32),
49            parse_quote!(u64),
50            parse_quote!(u128),
51            parse_quote!(usize),
52            parse_quote!(i8),
53            parse_quote!(i16),
54            parse_quote!(i32),
55            parse_quote!(i64),
56            parse_quote!(i128),
57            parse_quote!(isize),
58            parse_quote!(f32),
59            parse_quote!(f64),
60        ]);
61        self
62    }
63
64    /// Allow items in the standard prelude (`std::prelude::v1`).
65    pub fn allow_std_prelude(&mut self) -> &mut Self {
66        self.allowed_paths.extend([
67            parse_quote!(Box),
68            parse_quote!(String),
69            parse_quote!(Vec),
70            parse_quote!(Option),
71            parse_quote!(Result),
72            parse_quote!(Copy),
73            parse_quote!(Send),
74            parse_quote!(Sized),
75            parse_quote!(Sync),
76            parse_quote!(Unpin),
77            parse_quote!(Drop),
78            parse_quote!(Fn),
79            parse_quote!(FnMut),
80            parse_quote!(FnOnce),
81            parse_quote!(Clone),
82            parse_quote!(PartialEq),
83            parse_quote!(PartialOrd),
84            parse_quote!(Eq),
85            parse_quote!(Ord),
86            parse_quote!(AsRef),
87            parse_quote!(AsMut),
88            parse_quote!(Into),
89            parse_quote!(From),
90            parse_quote!(Default),
91            parse_quote!(Iterator),
92            parse_quote!(IntoIterator),
93            parse_quote!(Extend),
94            parse_quote!(FromIterator),
95            parse_quote!(ToOwned),
96            parse_quote!(ToString),
97            parse_quote!(drop),
98        ]);
99        self
100    }
101
102    /// Allow paths starting with `crate`.
103    pub fn allow_crate(&mut self) -> &mut Self {
104        self.allowed_paths.push(parse_quote!(crate));
105        self
106    }
107}
108
109/// The entry point of this crate.
110pub struct Leaker {
111    graph: VecGraph<Type, ()>,
112    pub allowed_paths: Vec<Path>,
113    pub self_ty_can_be_interned: bool,
114    reachable_types: HashSet<Type>,
115    pending_operations: Option<Vec<GraphOperation>>,
116}
117
118/// Error represents that the type is not internable.
119#[derive(Debug, Clone)]
120pub struct NotInternableError(pub TokenStream);
121
122impl std::fmt::Display for NotInternableError {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        f.write_str(&format!("use absolute path: '{}'", &self.0))
125    }
126}
127
128impl std::error::Error for NotInternableError {}
129
130impl NotInternableError {
131    pub fn span(&self) -> Span {
132        self.0.span()
133    }
134}
135
136/// Encode [`GenericArgument`]s to a type.
137pub fn encode_generics_args_to_ty<'a>(iter: impl IntoIterator<Item = &'a GenericArgument>) -> Type {
138    Type::Tuple(TypeTuple {
139        paren_token: Default::default(),
140        elems: iter
141            .into_iter()
142            .map(|param| -> Type {
143                match param {
144                    GenericArgument::Lifetime(lifetime) => {
145                        parse_quote!(& #lifetime ())
146                    }
147                    GenericArgument::Const(expr) => {
148                        parse_quote!([(); #expr as usize])
149                    }
150                    GenericArgument::Type(ty) => ty.clone(),
151                    _ => panic!(),
152                }
153            })
154            .collect(),
155    })
156}
157
158/// Encode [`GenericParam`]s to a type.
159pub fn encode_generics_params_to_ty<'a>(iter: impl IntoIterator<Item = &'a GenericParam>) -> Type {
160    Type::Tuple(TypeTuple {
161        paren_token: Default::default(),
162        elems: iter
163            .into_iter()
164            .map(|param| -> Type {
165                match param {
166                    GenericParam::Lifetime(LifetimeParam { lifetime, .. }) => {
167                        parse_quote!(& #lifetime ())
168                    }
169                    GenericParam::Const(ConstParam { ident, .. }) => {
170                        parse_quote!([(); #ident as usize])
171                    }
172                    GenericParam::Type(TypeParam { ident, .. }) => parse_quote!(#ident),
173                }
174            })
175            .collect(),
176    })
177}
178
179/// Represents result of [`Leaker::check()`], holding the cause in its tuple item.
180#[derive(Clone, Debug)]
181pub enum CheckResult {
182    /// The type must be interned because it directly depends on the type context.
183    MustIntern(TokenStream),
184    /// The type must be interned because one of the types which it constituted with must be
185    /// interned.
186    MustInternOrInherit(TokenStream),
187    /// The type must not be interned, because `type_leak` cannot intern this. (e.g. `impl Trait`s,
188    /// `_`, trait bounds with relative trait path, ...)
189    MustNotIntern(TokenStream),
190    /// Other.
191    Neutral,
192}
193
194#[derive(Debug, Clone)]
195struct GraphOperation {
196    node_type: Type,
197    parent: Option<Type>,
198    is_root: bool,
199    is_must_intern: bool,
200}
201
202struct AnalyzeVisitor<'a> {
203    leaker: &'a mut Leaker,
204    generics: Generics,
205    parent: Option<Type>,
206    error: Option<TokenStream>,
207    operations: Vec<GraphOperation>,
208}
209
210impl<'a, 'ast> Visit<'ast> for AnalyzeVisitor<'a> {
211    fn visit_trait_item_fn(&mut self, i: &TraitItemFn) {
212        let mut visitor = AnalyzeVisitor {
213            leaker: self.leaker,
214            generics: self.generics.clone(),
215            parent: self.parent.clone(),
216            error: self.error.clone(),
217            operations: Vec::new(),
218        };
219        for g in &i.sig.generics.params {
220            visitor.generics.params.push(g.clone());
221        }
222        let mut i = i.clone();
223        i.default = None;
224        i.semi_token = Some(Default::default());
225        syn::visit::visit_trait_item_fn(&mut visitor, &i);
226        if visitor.error.is_some() {
227            self.error = visitor.error;
228        }
229        self.operations.extend(visitor.operations);
230    }
231
232    fn visit_receiver(&mut self, i: &Receiver) {
233        for attr in &i.attrs {
234            self.visit_attribute(attr);
235        }
236        if let Some((_, Some(lt))) = &i.reference {
237            self.visit_lifetime(lt);
238        }
239        if i.colon_token.is_some() {
240            self.visit_type(&i.ty);
241        }
242    }
243
244    fn visit_trait_item_type(&mut self, i: &TraitItemType) {
245        let mut visitor = AnalyzeVisitor {
246            leaker: self.leaker,
247            generics: self.generics.clone(),
248            parent: self.parent.clone(),
249            error: self.error.clone(),
250            operations: Vec::new(),
251        };
252        for g in &i.generics.params {
253            visitor.generics.params.push(g.clone());
254        }
255        syn::visit::visit_trait_item_type(&mut visitor, i);
256        self.error = visitor.error;
257        self.operations.extend(visitor.operations);
258    }
259
260    fn visit_type(&mut self, i: &Type) {
261        match self.leaker.check(i, &self.generics) {
262            Err((_, s)) => {
263                // Emit error and terminate searching
264                self.error = Some(s);
265            }
266            Ok(CheckResult::MustNotIntern(_)) => (),
267            o => {
268                // Collect the operation instead of directly modifying the graph
269                let is_must_intern = matches!(o, Ok(CheckResult::MustIntern(_)));
270                let is_root = self.parent.is_none();
271
272                self.operations.push(GraphOperation {
273                    node_type: i.clone(),
274                    parent: self.parent.clone(),
275                    is_root,
276                    is_must_intern,
277                });
278
279                if !is_must_intern {
280                    // Perform recursive call
281                    let parent = self.parent.clone();
282                    self.parent = Some(i.clone());
283                    syn::visit::visit_type(self, i);
284                    self.parent = parent;
285                }
286            }
287        }
288    }
289
290    fn visit_trait_bound(&mut self, i: &TraitBound) {
291        // Trait bounds with non-absolute oath
292        if i.path.leading_colon.is_none()
293            && !matches!(
294                i.path.segments.iter().next(),
295                Some(PathSegment {
296                    ident,
297                    arguments: PathArguments::None,
298                }) if ident.to_string() == "$crate"
299            )
300            && !self.leaker.allowed_paths.iter().any(|allowed| {
301                allowed.leading_colon == i.path.leading_colon
302                    && allowed.segments.len() <= i.path.segments.len()
303                    && allowed.segments.iter().zip(i.path.segments.iter()).all(
304                        |(allowed_seg, seg)| {
305                            allowed_seg.arguments == seg.arguments
306                                && allowed_seg.ident.to_string() == seg.ident.to_string()
307                        },
308                    )
309            })
310        {
311            self.error = Some(quote!(#i));
312            return;
313        }
314        syn::visit::visit_trait_bound(self, i)
315    }
316
317    fn visit_expr_path(&mut self, i: &ExprPath) {
318        if i.path.leading_colon.is_none() {
319            // Value or trait bounds with relative oath
320            self.error = Some(quote!(#i));
321        } else {
322            syn::visit::visit_expr_path(self, i)
323        }
324    }
325}
326
327impl Default for Leaker {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333impl Leaker {
334    /// Initialize with [`ItemStruct`]
335    ///
336    /// ```
337    /// # use type_leak::*;
338    /// # use syn::*;
339    /// let test_struct: ItemStruct = parse_quote!(
340    ///     pub struct MyStruct<'a, T1, T2: ::path::to::MyType1<MyType4>> {
341    ///         field1: MyType1,
342    ///         field2: (MyType2, MyType3<MyType1>, MyType4, MyType5),
343    ///         field3: &'a (T1, T2),
344    ///     }
345    /// );
346    /// let leaker = Leaker::from_struct(&test_struct).unwrap();
347    /// ```
348    pub fn from_struct(input: &ItemStruct) -> std::result::Result<Self, NotInternableError> {
349        let mut leaker = Self::new();
350        leaker.intern_with(&input.generics, |visitor| {
351            visitor.visit_item_struct(input);
352        })?;
353        Ok(leaker)
354    }
355
356    /// Initialize with [`ItemEnum`]
357    pub fn from_enum(input: &ItemEnum) -> std::result::Result<Self, NotInternableError> {
358        let mut leaker = Self::new();
359        leaker.intern_with(&input.generics, |visitor| {
360            visitor.visit_item_enum(input);
361        })?;
362        Ok(leaker)
363    }
364
365    /// Build an [`Leaker`] with given trait.
366    ///
367    /// Unlike enum nor struct, it requires ~alternative path~, an absolute path of a struct
368    /// which is declared the same crate with the leaker trait and also visible from
369    /// [`Referrer`]s' context. That struct is used as an `impl` target of `Repeater`
370    /// instead of the Leaker's path.
371    ///
372    /// ```
373    /// # use syn::*;
374    /// # use type_leak::Leaker;
375    /// let s: ItemTrait = parse_quote!{
376    ///     pub trait MyTrait<T, U> {
377    ///         fn func(self, t: T) -> U;
378    ///     }
379    /// };
380    /// let alternate: ItemStruct = parse_quote!{
381    ///     pub struct MyAlternate;
382    /// };
383    /// let _ = Leaker::from_trait(&s);
384    /// ```
385    pub fn from_trait(input: &ItemTrait) -> std::result::Result<Self, NotInternableError> {
386        let mut leaker = Self::new();
387        leaker.intern_with(&input.generics, |visitor| {
388            visitor.visit_item_trait(input);
389        })?;
390        Ok(leaker)
391    }
392
393    /// Initialize empty [`Leaker`] with no generics.
394    pub fn new() -> Self {
395        Self {
396            graph: VecGraph::default(),
397            allowed_paths: Vec::new(),
398            self_ty_can_be_interned: true,
399            reachable_types: HashSet::new(),
400            pending_operations: None,
401        }
402    }
403
404    /// Initialize [`Leaker`] from configuration.
405    pub fn from_config(config: LeakerConfig) -> Self {
406        Self {
407            graph: VecGraph::default(),
408            allowed_paths: config.allowed_paths,
409            self_ty_can_be_interned: config.self_ty_can_be_interned,
410            reachable_types: HashSet::new(),
411            pending_operations: None,
412        }
413    }
414
415    /// Apply collected graph operations using scope_mut()
416    fn apply_operations(&mut self, operations: Vec<GraphOperation>) {
417        if operations.is_empty() {
418            return;
419        }
420
421        // Store operations for proper graph reduction in reduce_roots()
422        self.pending_operations = Some(operations);
423    }
424
425    /// Intern ast elements with visitor
426    pub fn intern_with<'ast, F>(
427        &mut self,
428        generics: &Generics,
429        f: F,
430    ) -> std::result::Result<&mut Self, NotInternableError>
431    where
432        F: FnOnce(&mut dyn syn::visit::Visit<'ast>),
433    {
434        let mut visitor = AnalyzeVisitor {
435            leaker: self,
436            parent: None,
437            error: None,
438            generics: generics.clone(),
439            operations: Vec::new(),
440        };
441        f(&mut visitor);
442        if let Some(e) = visitor.error {
443            Err(NotInternableError(e))
444        } else {
445            let operations = visitor.operations;
446            self.apply_operations(operations);
447            Ok(self)
448        }
449    }
450
451    /// Check that the internableness of give `ty`. It returns `Err` in contradiction (the type
452    /// must and must not be interned).
453    ///
454    ///
455    /// See [`CheckResult`].
456    pub fn check(
457        &self,
458        ty: &Type,
459        generics: &Generics,
460    ) -> std::result::Result<CheckResult, (TokenStream, TokenStream)> {
461        use syn::visit::Visit;
462        #[derive(Clone)]
463        struct Visitor {
464            generic_lifetimes: Vec<Lifetime>,
465            generic_idents: Vec<Ident>,
466            impossible: Option<TokenStream>,
467            must: Option<(isize, TokenStream)>,
468            self_ty_can_be_interned: bool,
469            allowed_paths: Vec<Path>,
470        }
471
472        const _: () = {
473            use syn::visit::Visit;
474            impl<'a> Visit<'a> for Visitor {
475                fn visit_type(&mut self, i: &Type) {
476                    match i {
477                        Type::BareFn(TypeBareFn {
478                            lifetimes,
479                            inputs,
480                            output,
481                            ..
482                        }) => {
483                            let mut visitor = self.clone();
484                            visitor.must = None;
485                            visitor.generic_lifetimes.extend(
486                                lifetimes
487                                    .as_ref()
488                                    .map(|ls| {
489                                        ls.lifetimes.iter().map(|gp| {
490                                            if let GenericParam::Lifetime(lt) = gp {
491                                                lt.lifetime.clone()
492                                            } else {
493                                                panic!()
494                                            }
495                                        })
496                                    })
497                                    .into_iter()
498                                    .flatten(),
499                            );
500                            for input in inputs {
501                                visitor.visit_type(&input.ty);
502                            }
503                            if let ReturnType::Type(_, output) = output {
504                                visitor.visit_type(output.as_ref());
505                            }
506                            if self.impossible.is_none() {
507                                self.impossible = visitor.impossible;
508                            }
509                            match (self.must.clone(), visitor.must) {
510                                (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
511                                    self.must = Some((v_l + 1, v_m))
512                                }
513                                (None, Some((v_l, v_m))) => self.must = Some((v_l + 1, v_m)),
514                                _ => (),
515                            }
516                            return;
517                        }
518                        Type::TraitObject(TypeTraitObject { bounds, .. }) => {
519                            for bound in bounds {
520                                match bound {
521                                    TypeParamBound::Trait(TraitBound {
522                                        lifetimes, path, ..
523                                    }) => {
524                                        let mut visitor = self.clone();
525                                        visitor.must = None;
526                                        visitor.generic_lifetimes.extend(
527                                            lifetimes
528                                                .as_ref()
529                                                .map(|ls| {
530                                                    ls.lifetimes.iter().map(|gp| {
531                                                        if let GenericParam::Lifetime(lt) = gp {
532                                                            lt.lifetime.clone()
533                                                        } else {
534                                                            panic!()
535                                                        }
536                                                    })
537                                                })
538                                                .into_iter()
539                                                .flatten(),
540                                        );
541                                        visitor.visit_path(path);
542                                        if self.impossible.is_none() {
543                                            self.impossible = visitor.impossible;
544                                        }
545                                        match (self.must.clone(), visitor.must) {
546                                            (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
547                                                self.must = Some((v_l + 1, v_m))
548                                            }
549                                            (None, Some((v_l, v_m))) => {
550                                                self.must = Some((v_l + 1, v_m))
551                                            }
552                                            _ => (),
553                                        }
554                                        return;
555                                    }
556                                    TypeParamBound::Verbatim(_) => {
557                                        self.impossible = Some(quote!(#bound));
558                                        return;
559                                    }
560                                    _ => (),
561                                }
562                            }
563                        }
564                        Type::ImplTrait(_) | Type::Macro(_) | Type::Verbatim(_) => {
565                            self.impossible = Some(quote!(#i));
566                        }
567                        Type::Reference(TypeReference { lifetime, .. }) => {
568                            if lifetime.is_none() {
569                                self.impossible = Some(quote!(#i));
570                            }
571                        }
572                        _ => (),
573                    }
574                    let mut visitor = self.clone();
575                    visitor.must = None;
576                    syn::visit::visit_type(&mut visitor, i);
577                    if visitor.impossible.is_some() {
578                        self.impossible = visitor.impossible;
579                    }
580                    match (self.must.clone(), visitor.must) {
581                        (Some((s_l, _)), Some((v_l, v_m))) if v_l + 1 < s_l => {
582                            self.must = Some((v_l + 1, v_m))
583                        }
584                        (None, Some((v_l, v_m))) => self.must = Some((v_l + 1, v_m)),
585                        _ => (),
586                    }
587                }
588                fn visit_qself(&mut self, i: &QSelf) {
589                    if i.as_token.is_none() {
590                        let ty = &i.ty;
591                        self.impossible = Some(quote!(#ty));
592                    }
593                    syn::visit::visit_qself(self, i)
594                }
595                fn visit_lifetime(&mut self, i: &Lifetime) {
596                    match i.to_string().as_str() {
597                        "'static" => (),
598                        "'_" => self.impossible = Some(quote!(#i)),
599                        _ if self.generic_lifetimes.iter().any(|lt| lt == i) => {}
600                        _ => {
601                            self.must = Some((-1, quote!(#i)));
602                        }
603                    }
604                    syn::visit::visit_lifetime(self, i)
605                }
606                fn visit_expr(&mut self, i: &Expr) {
607                    match i {
608                        Expr::Closure(_) | Expr::Assign(_) | Expr::Verbatim(_) | Expr::Macro(_) => {
609                            self.impossible = Some(quote!(#i));
610                        }
611                        _ => (),
612                    }
613                    syn::visit::visit_expr(self, i)
614                }
615                fn visit_path(&mut self, i: &Path) {
616                    if self.allowed_paths.iter().any(|allowed| {
617                        allowed.leading_colon == i.leading_colon
618                            && allowed.segments.len() <= i.segments.len()
619                            && allowed.segments.iter().zip(i.segments.iter()).all(
620                                |(allowed_seg, seg)| {
621                                    allowed_seg.arguments == seg.arguments
622                                        && allowed_seg.ident.to_string() == seg.ident.to_string()
623                                },
624                            )
625                    }) {
626                        return;
627                    }
628                    if matches!(i.segments.iter().next(), Some(PathSegment { ident, arguments }) if ident == "Self" && arguments.is_none())
629                        && i.leading_colon.is_none()
630                    {
631                        if !self.self_ty_can_be_interned {
632                            self.impossible = Some(quote!(#i));
633                        }
634                    } else {
635                        if i.leading_colon.is_none()
636                            && matches!(
637                                i.segments.iter().next(),
638                                Some(PathSegment {
639                                    ident,
640                                    arguments: PathArguments::None,
641                                }) if ident.to_string() == "$crate"
642                            )
643                        {
644                            return;
645                        }
646                        match (i.leading_colon, i.get_ident()) {
647                            // i is a generic parameter
648                            (None, Some(ident)) if self.generic_idents.contains(ident) => {}
649                            // relative path, not a generic parameter
650                            (None, _) => {
651                                self.must = Some((-1, quote!(#i)));
652                            }
653                            // absolute path
654                            (Some(_), _) => (),
655                        }
656                    }
657                    syn::visit::visit_path(self, i)
658                }
659            }
660        };
661        let mut visitor = Visitor {
662            generic_lifetimes: generics
663                .params
664                .iter()
665                .filter_map(|gp| {
666                    if let GenericParam::Lifetime(lt) = gp {
667                        Some(lt.lifetime.clone())
668                    } else {
669                        None
670                    }
671                })
672                .collect(),
673            generic_idents: generics
674                .params
675                .iter()
676                .filter_map(|gp| match gp {
677                    GenericParam::Type(TypeParam { ident, .. })
678                    | GenericParam::Const(ConstParam { ident, .. }) => Some(ident.clone()),
679                    _ => None,
680                })
681                .collect(),
682            impossible: None,
683            must: None,
684            self_ty_can_be_interned: self.self_ty_can_be_interned,
685            allowed_paths: self.allowed_paths.clone(),
686        };
687        visitor.visit_type(ty);
688        match (visitor.must, visitor.impossible) {
689            (None, None) => Ok(CheckResult::Neutral),
690            (Some((0, span)), None) => Ok(CheckResult::MustIntern(span)),
691            (Some((_, span)), None) => Ok(CheckResult::MustInternOrInherit(span)),
692            (None, Some(span)) => Ok(CheckResult::MustNotIntern(span)),
693            (Some((_, span0)), Some(span1)) => Err((span0, span1)),
694        }
695    }
696
697    /// Finish building the [`Leaker`] and convert it into [`Referrer`].
698    pub fn finish(self) -> Referrer {
699        // Use the reachable types computed by apply_operations
700        let leak_types: Vec<_> = self.reachable_types.into_iter().collect();
701        let map = leak_types
702            .iter()
703            .enumerate()
704            .map(|(n, ty)| (ty.clone(), n))
705            .collect();
706        Referrer { leak_types, map }
707    }
708
709    /// Reduce nodes to decrease cost of {`Leaker`}'s implementation.
710    ///
711    /// # Algorithm
712    ///
713    /// To consult the algorithm, see the following [`Leaker`]'s input:
714    ///
715    /// ```ignore
716    /// pub struct MyStruct<'a, T1, T2: ::path::to::MyType1<MyType4>> {
717    ///     field1: MyType1,
718    ///     field2: (MyType2, MyType3<MyType1>, MyType4, MyType5),
719    ///     field3: &'a (T1, T2),
720    /// }
721    /// ```
722    ///
723    /// [`Leaker`], when initialized with [`Leaker::from_struct()`], analyze the AST and
724    /// construct a DAG which represents all (internable) types and the dependency relations like
725    /// this:
726    ///
727    /// ![before_graph](https://raw.github.com/yasuo-ozu/type-leak/master/resources/before_graph.svg)
728    ///
729    /// The **red** node is flagged as [`CheckResult::MustIntern`] by [`Leaker::check()`] (which means
730    /// the type literature depends on the type context, so it must be interned).
731    ///
732    /// This algorithm reduce the nodes, remaining that all root type (annotated with ★) can be
733    /// expressed with existing **red** nodes.
734    ///
735    /// Here, there are some choice in its freedom:
736    ///
737    /// - Intern all **red** nodes and ignore others (because other nodes are not needed to be
738    ///   intern, or constructable with red nodes)
739    /// - Not directly intern **red** nodes; intern common ancessors instead if it is affordable.
740    ///
741    /// So finally, it results like:
742    ///
743    /// ![after_graph](https://raw.github.com/yasuo-ozu/type-leak/master/resources/after_graph.svg)
744    pub fn reduce_roots(&mut self) {
745        if let Some(operations) = self.pending_operations.take() {
746            self.build_graph_and_reduce(operations);
747        }
748        self.reduce_unreachable_nodes();
749        self.reduce_obvious_nodes();
750    }
751
752    fn build_graph_and_reduce(&mut self, operations: Vec<GraphOperation>) {
753        // Build parent-child relationships from operations
754        let mut children: HashMap<Type, Vec<Type>> = HashMap::new();
755        let mut parents: HashMap<Type, Vec<Type>> = HashMap::new();
756        let mut root_types = HashSet::new();
757        let mut must_intern_types = HashSet::new();
758
759        for operation in &operations {
760            if operation.is_root {
761                root_types.insert(operation.node_type.clone());
762            }
763            if operation.is_must_intern {
764                must_intern_types.insert(operation.node_type.clone());
765            }
766
767            if let Some(parent_type) = &operation.parent {
768                children
769                    .entry(parent_type.clone())
770                    .or_default()
771                    .push(operation.node_type.clone());
772                parents
773                    .entry(operation.node_type.clone())
774                    .or_default()
775                    .push(parent_type.clone());
776            }
777        }
778
779        // Implement proper reduction algorithm:
780        // 1. Keep all must-intern root types
781        // 2. Keep root types that transitively contain must-intern types
782        // 3. Do NOT keep non-root must-intern types that are only components of kept root types
783
784        let mut kept_types = HashSet::new();
785
786        // Step 1: Keep must-intern types that are also roots
787        for root_type in &root_types {
788            if must_intern_types.contains(root_type) {
789                kept_types.insert(root_type.clone());
790            }
791        }
792
793        // Step 2: Keep root types that contain must-intern descendants
794        for root_type in &root_types {
795            if !kept_types.contains(root_type)
796                && self.contains_must_intern_descendant(root_type, &children, &must_intern_types)
797            {
798                kept_types.insert(root_type.clone());
799            }
800        }
801
802        // Step 3: Only add must-intern types that are NOT components of any kept type
803        for must_intern_type in &must_intern_types {
804            if !self.is_component_of_any_kept_type(must_intern_type, &kept_types, &children) {
805                kept_types.insert(must_intern_type.clone());
806            }
807        }
808
809        self.reachable_types = kept_types;
810    }
811
812    fn contains_must_intern_descendant(
813        &self,
814        root: &Type,
815        children: &HashMap<Type, Vec<Type>>,
816        must_intern_types: &HashSet<Type>,
817    ) -> bool {
818        let mut visited = HashSet::new();
819        let mut stack = vec![root.clone()];
820
821        while let Some(current) = stack.pop() {
822            if !visited.insert(current.clone()) {
823                continue;
824            }
825
826            if must_intern_types.contains(&current) {
827                return true;
828            }
829
830            if let Some(child_list) = children.get(&current) {
831                for child in child_list {
832                    if !visited.contains(child) {
833                        stack.push(child.clone());
834                    }
835                }
836            }
837        }
838
839        false
840    }
841
842    fn is_component_of_any_kept_type(
843        &self,
844        target_type: &Type,
845        kept_types: &HashSet<Type>,
846        children: &HashMap<Type, Vec<Type>>,
847    ) -> bool {
848        for kept_type in kept_types {
849            if self.is_descendant_of(target_type, kept_type, children) {
850                return true;
851            }
852        }
853        false
854    }
855
856    fn is_descendant_of(
857        &self,
858        target: &Type,
859        ancestor: &Type,
860        children: &HashMap<Type, Vec<Type>>,
861    ) -> bool {
862        if target == ancestor {
863            return false; // Not a descendant of itself
864        }
865
866        let mut visited = HashSet::new();
867        let mut stack = vec![ancestor.clone()];
868
869        while let Some(current) = stack.pop() {
870            if !visited.insert(current.clone()) {
871                continue;
872            }
873
874            if let Some(child_list) = children.get(&current) {
875                for child in child_list {
876                    if child == target {
877                        return true;
878                    }
879                    if !visited.contains(child) {
880                        stack.push(child.clone());
881                    }
882                }
883            }
884        }
885
886        false
887    }
888
889    fn reduce_obvious_nodes(&mut self) {
890        // Remove intermediate nodes that have exactly one parent and one child
891        // This step eliminates unnecessary intermediate types
892    }
893
894    fn reduce_unreachable_nodes(&mut self) {
895        if let Some(operations) = self.pending_operations.take() {
896            self.apply_vecgraph_reduction(operations);
897        }
898    }
899
900    fn apply_vecgraph_reduction(&mut self, operations: Vec<GraphOperation>) {
901        let mut root_types = HashSet::new();
902        let mut must_intern_types = HashSet::new();
903
904        // Collect graph operations
905        for operation in &operations {
906            if operation.is_root {
907                root_types.insert(operation.node_type.clone());
908            }
909            if operation.is_must_intern {
910                must_intern_types.insert(operation.node_type.clone());
911            }
912        }
913
914        // Build the graph using scope_mut
915        self.graph.scope_mut(|mut ctx| {
916            let mut type_to_node = HashMap::new();
917
918            // Add all nodes first
919            for operation in &operations {
920                if !type_to_node.contains_key(&operation.node_type) {
921                    let node_tag = ctx.add_node(operation.node_type.clone());
922                    type_to_node.insert(operation.node_type.clone(), node_tag);
923                }
924
925                if let Some(parent_type) = &operation.parent {
926                    if !type_to_node.contains_key(parent_type) {
927                        let parent_node_tag = ctx.add_node(parent_type.clone());
928                        type_to_node.insert(parent_type.clone(), parent_node_tag);
929                    }
930                }
931            }
932
933            // Add edges between nodes
934            for operation in &operations {
935                if let Some(parent_type) = &operation.parent {
936                    let parent_node = type_to_node[parent_type];
937                    let child_node = type_to_node[&operation.node_type];
938                    ctx.add_edge((), parent_node, child_node);
939                }
940            }
941        });
942
943        // Perform analysis using the same logic as build_graph_and_reduce
944        self.reachable_types =
945            self.analyze_reachable_types(&root_types, &must_intern_types, &operations);
946    }
947
948    fn analyze_reachable_types(
949        &self,
950        root_types: &HashSet<Type>,
951        must_intern_types: &HashSet<Type>,
952        operations: &[GraphOperation],
953    ) -> HashSet<Type> {
954        // Implement the same logic as build_graph_and_reduce
955        let mut final_types = HashSet::new();
956
957        // Keep must-intern root types
958        for root_type in root_types {
959            if must_intern_types.contains(root_type) {
960                final_types.insert(root_type.clone());
961            }
962        }
963
964        // Keep root types that contain must-intern descendants
965        for root_type in root_types {
966            if !final_types.contains(root_type)
967                && self.contains_must_intern_in_tree(root_type, operations, must_intern_types)
968            {
969                final_types.insert(root_type.clone());
970            }
971        }
972
973        // Only keep must-intern types that aren't subcomponents of kept root types
974        for must_intern_type in must_intern_types {
975            if !self.is_subcomponent_of_kept_root(must_intern_type, &final_types, operations) {
976                final_types.insert(must_intern_type.clone());
977            }
978        }
979
980        final_types
981    }
982
983    fn contains_must_intern_in_tree(
984        &self,
985        root: &Type,
986        operations: &[GraphOperation],
987        must_intern_types: &HashSet<Type>,
988    ) -> bool {
989        let mut stack = vec![root.clone()];
990        let mut visited = HashSet::new();
991
992        while let Some(current) = stack.pop() {
993            if !visited.insert(current.clone()) {
994                continue;
995            }
996
997            if must_intern_types.contains(&current) {
998                return true;
999            }
1000
1001            for operation in operations {
1002                if let Some(parent_type) = &operation.parent {
1003                    if parent_type == &current && !visited.contains(&operation.node_type) {
1004                        stack.push(operation.node_type.clone());
1005                    }
1006                }
1007            }
1008        }
1009
1010        false
1011    }
1012
1013    fn is_subcomponent_of_kept_root(
1014        &self,
1015        target: &Type,
1016        kept_roots: &HashSet<Type>,
1017        operations: &[GraphOperation],
1018    ) -> bool {
1019        for root in kept_roots {
1020            if self.contains_must_intern_in_tree(root, operations, &HashSet::from([target.clone()]))
1021                && root != target
1022            {
1023                return true;
1024            }
1025        }
1026        false
1027    }
1028
1029    // TODO: Reimplement reduce_unreachable_nodes with VecGraph APIs
1030}
1031
1032/// Holds the list of types that need to be encoded, produced by [`Leaker::finish()`].
1033///
1034/// The `Referrer` provides methods to iterate over leak types and to expand/transform
1035/// types using a custom mapping function.
1036#[derive(Clone, PartialEq, Eq, Debug)]
1037pub struct Referrer {
1038    leak_types: Vec<Type>,
1039    map: HashMap<Type, usize>,
1040}
1041
1042impl Referrer {
1043    /// Returns `true` if there are no types to be encoded.
1044    pub fn is_empty(&self) -> bool {
1045        self.leak_types.is_empty()
1046    }
1047
1048    /// Returns an iterator over the types that need to be encoded.
1049    ///
1050    /// ```ignore
1051    /// let args = encode_generics_params_to_ty(&generics.params);
1052    /// referrer.iter().enumerate().map(|(ix, ty)| quote!{
1053    ///     impl #crate::TypeRef<#ix, #args> for #marker_ty {
1054    ///         type Type = #ty;
1055    ///     }
1056    /// }).collect::<TokenStream>();
1057    /// ```
1058    pub fn iter(&self) -> impl Iterator<Item = &Type> {
1059        self.leak_types.iter()
1060    }
1061
1062    /// Converts this `Referrer` into a [`Visitor`] that can be used to transform AST nodes.
1063    ///
1064    /// The `type_fn` callback is called for each type that needs to be converted,
1065    /// receiving the original type and its index, and returning the replacement type.
1066    /// See [`Referrer::expand()`] for details.
1067    pub fn into_visitor<F: FnMut(Type, usize) -> Type>(self, type_fn: F) -> Visitor<F> {
1068        Visitor(self, Rc::new(RefCell::new(type_fn)))
1069    }
1070
1071    /// Expands a type by replacing any interned types using the provided mapping function.
1072    ///
1073    /// The `type_fn` callback is called for each subtype that needs to be converted,
1074    /// receiving the original type and its index, and returning the replacement type.
1075    ///
1076    /// ```ignore
1077    /// let args = encode_generics_params_to_ty(&generics.params);
1078    /// referrer.expand(ty, |_, ix| {
1079    ///     parse_quote!(<#marker_ty as #crate::TypeRef<#ix, #args>>::Type)
1080    /// })
1081    /// ```
1082    pub fn expand(&self, ty: Type, type_fn: impl FnMut(Type, usize) -> Type) -> Type {
1083        use syn::fold::Fold;
1084        struct Folder<'a, F>(&'a Referrer, F);
1085        impl<'a, F: FnMut(Type, usize) -> Type> Fold for Folder<'a, F> {
1086            fn fold_type(&mut self, ty: Type) -> Type {
1087                if let Some(idx) = self.0.map.get(&ty) {
1088                    self.1(ty, *idx)
1089                } else {
1090                    syn::fold::fold_type(self, ty)
1091                }
1092            }
1093        }
1094        let mut folder = Folder(self, type_fn);
1095        folder.fold_type(ty)
1096    }
1097}
1098
1099#[derive(Debug)]
1100pub struct Visitor<F>(Referrer, Rc<RefCell<F>>);
1101
1102impl<F> Clone for Visitor<F> {
1103    fn clone(&self) -> Self {
1104        Self(self.0.clone(), self.1.clone())
1105    }
1106}
1107
1108impl<F: FnMut(Type, usize) -> Type> Visitor<F> {
1109    fn with_generics(&mut self, generics: &mut Generics) -> Self {
1110        let mut visitor = self.clone();
1111        for gp in generics.params.iter_mut() {
1112            if let GenericParam::Type(TypeParam { ident, .. }) = gp {
1113                visitor.0.map.remove(&parse_quote!(#ident));
1114            }
1115            visitor.visit_generic_param_mut(gp);
1116        }
1117        visitor
1118    }
1119
1120    fn with_signature(&mut self, sig: &mut Signature) -> Self {
1121        let mut visitor = self.with_generics(&mut sig.generics);
1122        for input in sig.inputs.iter_mut() {
1123            visitor.visit_fn_arg_mut(input);
1124        }
1125        visitor.visit_return_type_mut(&mut sig.output);
1126        visitor
1127    }
1128}
1129
1130impl<F: FnMut(Type, usize) -> Type> VisitMut for Visitor<F> {
1131    fn visit_type_mut(&mut self, i: &mut Type) {
1132        *i = self.0.expand(i.clone(), &mut *self.1.borrow_mut());
1133    }
1134    fn visit_item_struct_mut(&mut self, i: &mut ItemStruct) {
1135        for attr in i.attrs.iter_mut() {
1136            self.visit_attribute_mut(attr);
1137        }
1138        let mut visitor = self.with_generics(&mut i.generics);
1139        visitor.visit_fields_mut(&mut i.fields);
1140    }
1141    fn visit_item_enum_mut(&mut self, i: &mut ItemEnum) {
1142        for attr in i.attrs.iter_mut() {
1143            self.visit_attribute_mut(attr);
1144        }
1145        let mut visitor = self.with_generics(&mut i.generics);
1146        for variant in i.variants.iter_mut() {
1147            visitor.visit_variant_mut(variant);
1148        }
1149    }
1150    fn visit_item_trait_mut(&mut self, i: &mut ItemTrait) {
1151        for attr in i.attrs.iter_mut() {
1152            self.visit_attribute_mut(attr);
1153        }
1154        let mut visitor = self.with_generics(&mut i.generics);
1155        for supertrait in i.supertraits.iter_mut() {
1156            visitor.visit_type_param_bound_mut(supertrait);
1157        }
1158        for item in i.items.iter_mut() {
1159            visitor.visit_trait_item_mut(item);
1160        }
1161    }
1162    fn visit_item_union_mut(&mut self, i: &mut ItemUnion) {
1163        for attr in i.attrs.iter_mut() {
1164            self.visit_attribute_mut(attr);
1165        }
1166        let mut visitor = self.with_generics(&mut i.generics);
1167        visitor.visit_fields_named_mut(&mut i.fields);
1168    }
1169    fn visit_item_type_mut(&mut self, i: &mut ItemType) {
1170        for attr in i.attrs.iter_mut() {
1171            self.visit_attribute_mut(attr);
1172        }
1173        let mut visitor = self.with_generics(&mut i.generics);
1174        visitor.visit_type_mut(&mut i.ty);
1175    }
1176    fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
1177        for attr in i.attrs.iter_mut() {
1178            self.visit_attribute_mut(attr);
1179        }
1180        let mut visitor = self.with_signature(&mut i.sig);
1181        visitor.visit_block_mut(i.block.as_mut());
1182    }
1183    fn visit_trait_item_fn_mut(&mut self, i: &mut TraitItemFn) {
1184        for attr in i.attrs.iter_mut() {
1185            self.visit_attribute_mut(attr);
1186        }
1187        let mut visitor = self.with_signature(&mut i.sig);
1188        if let Some(block) = &mut i.default {
1189            visitor.visit_block_mut(block);
1190        }
1191    }
1192    fn visit_trait_item_type_mut(&mut self, i: &mut TraitItemType) {
1193        for attr in i.attrs.iter_mut() {
1194            self.visit_attribute_mut(attr);
1195        }
1196        let mut visitor = self.with_generics(&mut i.generics);
1197        for bound in i.bounds.iter_mut() {
1198            visitor.visit_type_param_bound_mut(bound);
1199        }
1200        if let Some((_, ty)) = &mut i.default {
1201            visitor.visit_type_mut(ty);
1202        }
1203    }
1204    fn visit_trait_item_const_mut(&mut self, i: &mut TraitItemConst) {
1205        for attr in i.attrs.iter_mut() {
1206            self.visit_attribute_mut(attr);
1207        }
1208        let mut visitor = self.with_generics(&mut i.generics);
1209        visitor.visit_type_mut(&mut i.ty);
1210        if let Some((_, expr)) = &mut i.default {
1211            visitor.visit_expr_mut(expr);
1212        }
1213    }
1214    fn visit_block_mut(&mut self, i: &mut Block) {
1215        let mut visitor = self.clone();
1216        for stmt in &i.stmts {
1217            match stmt {
1218                Stmt::Item(Item::Struct(ItemStruct { ident, .. }))
1219                | Stmt::Item(Item::Enum(ItemEnum { ident, .. }))
1220                | Stmt::Item(Item::Union(ItemUnion { ident, .. }))
1221                | Stmt::Item(Item::Trait(ItemTrait { ident, .. }))
1222                | Stmt::Item(Item::Type(ItemType { ident, .. })) => {
1223                    visitor.0.map.remove(&parse_quote!(#ident));
1224                }
1225                _ => (),
1226            }
1227        }
1228        for stmt in i.stmts.iter_mut() {
1229            visitor.visit_stmt_mut(stmt);
1230        }
1231    }
1232}
1233
1234impl Parse for Referrer {
1235    fn parse(input: parse::ParseStream) -> Result<Self> {
1236        let content: syn::parse::ParseBuffer<'_>;
1237        parenthesized!(content in input);
1238        let leak_types: Vec<Type> = Punctuated::<Type, Token![,]>::parse_terminated(&content)?
1239            .into_iter()
1240            .collect();
1241        let map = leak_types
1242            .iter()
1243            .enumerate()
1244            .map(|(n, ty)| (ty.clone(), n))
1245            .collect();
1246        Ok(Self { map, leak_types })
1247    }
1248}
1249
1250impl ToTokens for Referrer {
1251    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
1252        tokens.extend(quote! {
1253            (
1254                #(for item in &self.leak_types), {
1255                    #item
1256                }
1257            )
1258        })
1259    }
1260}