sea_orm/entity/
link.rs

1use crate::{EntityTrait, QuerySelect, RelationDef, Select, join_tbl_on_condition};
2use sea_query::{
3    Alias, CommonTableExpression, Condition, IntoIden, IntoTableRef, JoinType, UnionType,
4};
5
6/// Same as [RelationDef]
7pub type LinkDef = RelationDef;
8
9/// A Trait for links between Entities
10pub trait Linked {
11    #[allow(missing_docs)]
12    type FromEntity: EntityTrait;
13
14    #[allow(missing_docs)]
15    type ToEntity: EntityTrait;
16
17    /// Link for an Entity
18    fn link(&self) -> Vec<LinkDef>;
19
20    /// Find all the Entities that are linked to the Entity
21    fn find_linked(&self) -> Select<Self::ToEntity> {
22        find_linked(self.link().into_iter().rev(), JoinType::InnerJoin)
23    }
24}
25
26pub(crate) fn find_linked<I, E>(links: I, join: JoinType) -> Select<E>
27where
28    I: Iterator<Item = LinkDef>,
29    E: EntityTrait,
30{
31    let mut select = Select::new();
32    for (i, mut rel) in links.enumerate() {
33        let from_tbl = format!("r{i}").into_iden();
34        let to_tbl = if i > 0 {
35            format!("r{}", i - 1).into_iden()
36        } else {
37            rel.to_tbl.sea_orm_table().clone()
38        };
39        let table_ref = rel.from_tbl;
40
41        let mut condition = Condition::all().add(join_tbl_on_condition(
42            from_tbl.clone(),
43            to_tbl.clone(),
44            rel.from_col,
45            rel.to_col,
46        ));
47        if let Some(f) = rel.on_condition.take() {
48            condition = condition.add(f(from_tbl.clone(), to_tbl.clone()));
49        }
50
51        select.query().join_as(join, table_ref, from_tbl, condition);
52    }
53    select
54}
55
56pub(crate) fn find_linked_recursive<E>(
57    mut initial_query: Select<E>,
58    mut link: Vec<LinkDef>,
59) -> Select<E>
60where
61    E: EntityTrait,
62{
63    let cte_name = Alias::new("cte");
64
65    let Some(first) = link.first_mut() else {
66        return initial_query;
67    };
68    first.from_tbl = cte_name.clone().into_table_ref();
69    let mut recursive_query: Select<E> =
70        find_linked(link.into_iter().rev(), JoinType::InnerJoin).select_only();
71    initial_query.query.exprs_mut_for_each(|expr| {
72        recursive_query.query.expr(expr.clone());
73    });
74
75    let mut cte_query = initial_query.query.clone();
76    cte_query.union(UnionType::All, recursive_query.query);
77
78    let cte = CommonTableExpression::new()
79        .table_name(cte_name.clone())
80        .query(cte_query)
81        .to_owned();
82
83    let mut select = E::find().select_only();
84    initial_query.query.exprs_mut_for_each(|expr| {
85        select.query.expr(expr.clone());
86    });
87    select
88        .query
89        .from_clear()
90        .from_as(cte_name, E::default())
91        .with_cte(cte);
92    select
93}