sea_orm/entity/
compound.rs

1#![allow(missing_docs)]
2use super::{ColumnTrait, EntityTrait, PrimaryKeyToColumn, PrimaryKeyTrait};
3use crate::{
4    ConnectionTrait, DbErr, IntoSimpleExpr, ItemsAndPagesNumber, Iterable, ModelTrait, QueryFilter,
5    QueryOrder,
6};
7use sea_query::{IntoValueTuple, Order, TableRef};
8use std::marker::PhantomData;
9
10mod has_many;
11mod has_one;
12
13pub use has_many::{HasMany, Iter as HasManyIter};
14pub use has_one::HasOne;
15
16#[async_trait::async_trait]
17pub trait EntityLoaderTrait<E: EntityTrait>: QueryFilter + QueryOrder + Clone {
18    /// The return type of this loader
19    type ModelEx: ModelTrait<Entity = E>;
20
21    /// Find a model by primary key
22    fn filter_by_id<T>(mut self, values: T) -> Self
23    where
24        T: Into<<E::PrimaryKey as PrimaryKeyTrait>::ValueType>,
25    {
26        let mut keys = E::PrimaryKey::iter();
27        for v in values.into().into_value_tuple() {
28            if let Some(key) = keys.next() {
29                let col = key.into_column();
30                self.filter_mut(col.eq(v));
31            } else {
32                unreachable!("primary key arity mismatch");
33            }
34        }
35        self
36    }
37
38    /// Apply order by primary key to the query statement
39    fn order_by_id_asc(self) -> Self {
40        self.order_by_id(Order::Asc)
41    }
42
43    /// Apply order by primary key to the query statement
44    fn order_by_id_desc(self) -> Self {
45        self.order_by_id(Order::Desc)
46    }
47
48    /// Apply order by primary key to the query statement
49    fn order_by_id(mut self, order: Order) -> Self {
50        for key in E::PrimaryKey::iter() {
51            let col = key.into_column();
52            <Self as QueryOrder>::query(&mut self)
53                .order_by_expr(col.into_simple_expr(), order.clone());
54        }
55        self
56    }
57
58    /// Paginate query.
59    fn paginate<'db, C: ConnectionTrait>(
60        self,
61        db: &'db C,
62        page_size: u64,
63    ) -> EntityLoaderPaginator<'db, C, E, Self> {
64        EntityLoaderPaginator {
65            loader: self,
66            page: 0,
67            page_size,
68            db,
69            phantom: PhantomData,
70        }
71    }
72
73    #[doc(hidden)]
74    async fn fetch<C: ConnectionTrait>(
75        self,
76        db: &C,
77        page: u64,
78        page_size: u64,
79    ) -> Result<Vec<Self::ModelEx>, DbErr>;
80
81    #[doc(hidden)]
82    async fn num_items<C: ConnectionTrait>(self, db: &C, page_size: u64) -> Result<u64, DbErr>;
83}
84
85#[derive(Debug)]
86pub struct EntityLoaderPaginator<'db, C, E, L>
87where
88    C: ConnectionTrait,
89    E: EntityTrait,
90    L: EntityLoaderTrait<E>,
91{
92    pub(crate) loader: L,
93    pub(crate) page: u64,
94    pub(crate) page_size: u64,
95    pub(crate) db: &'db C,
96    pub(crate) phantom: PhantomData<E>,
97}
98
99/// Just a marker trait on EntityReverse
100pub trait EntityReverse {
101    type Entity: EntityTrait;
102}
103
104/// Subject to change, not yet stable
105#[derive(Debug, Copy, Clone, PartialEq)]
106pub struct EntityLoaderWithSelf<R: EntityTrait, S: EntityTrait>(pub R, pub S);
107
108/// Subject to change, not yet stable
109#[derive(Debug, Copy, Clone, PartialEq)]
110pub struct EntityLoaderWithSelfRev<R: EntityTrait, S: EntityReverse>(pub R, pub S);
111
112#[derive(Debug, Clone, PartialEq)]
113pub enum LoadTarget {
114    TableRef(TableRef),
115    TableRefRev(TableRef),
116    Relation(String),
117}
118
119impl<'db, C, E, L> EntityLoaderPaginator<'db, C, E, L>
120where
121    C: ConnectionTrait,
122    E: EntityTrait,
123    L: EntityLoaderTrait<E>,
124{
125    /// Fetch a specific page; page index starts from zero
126    pub async fn fetch_page(&self, page: u64) -> Result<Vec<L::ModelEx>, DbErr> {
127        self.loader
128            .clone()
129            .fetch(self.db, page, self.page_size)
130            .await
131    }
132
133    /// Fetch the current page
134    pub async fn fetch(&self) -> Result<Vec<L::ModelEx>, DbErr> {
135        self.fetch_page(self.page).await
136    }
137
138    /// Get the total number of items
139    pub async fn num_items(&self) -> Result<u64, DbErr> {
140        self.loader.clone().num_items(self.db, self.page_size).await
141    }
142
143    /// Get the total number of pages
144    pub async fn num_pages(&self) -> Result<u64, DbErr> {
145        let num_items = self.num_items().await?;
146        let num_pages = self.compute_pages_number(num_items);
147        Ok(num_pages)
148    }
149
150    /// Get the total number of items and pages
151    pub async fn num_items_and_pages(&self) -> Result<ItemsAndPagesNumber, DbErr> {
152        let number_of_items = self.num_items().await?;
153        let number_of_pages = self.compute_pages_number(number_of_items);
154
155        Ok(ItemsAndPagesNumber {
156            number_of_items,
157            number_of_pages,
158        })
159    }
160
161    /// Compute the number of pages for the current page
162    fn compute_pages_number(&self, num_items: u64) -> u64 {
163        (num_items / self.page_size) + (num_items % self.page_size > 0) as u64
164    }
165
166    /// Increment the page counter
167    pub fn next(&mut self) {
168        self.page += 1;
169    }
170
171    /// Get current page number
172    pub fn cur_page(&self) -> u64 {
173        self.page
174    }
175
176    /// Fetch one page and increment the page counter
177    pub async fn fetch_and_next(&mut self) -> Result<Option<Vec<L::ModelEx>>, DbErr> {
178        let vec = self.fetch().await?;
179        self.next();
180        let opt = if !vec.is_empty() { Some(vec) } else { None };
181        Ok(opt)
182    }
183}
184
185#[cfg(test)]
186mod test {
187    use crate::ModelTrait;
188    use crate::tests_cfg::cake;
189
190    #[test]
191    fn test_model_ex_convert() {
192        let cake = cake::Model {
193            id: 12,
194            name: "hello".into(),
195        };
196        let cake_ex: cake::ModelEx = cake.clone().into();
197
198        assert_eq!(cake, cake_ex);
199        assert_eq!(cake_ex, cake);
200        assert_eq!(cake.id, cake_ex.id);
201        assert_eq!(cake.name, cake_ex.name);
202
203        assert_eq!(cake_ex.get(cake::Column::Id), 12i32.into());
204        assert_eq!(cake_ex.get(cake::Column::Name), "hello".into());
205
206        assert_eq!(cake::Model::from(cake_ex), cake);
207    }
208}