sea_orm/entity/
compound.rs

1#![allow(missing_docs)]
2use super::{ColumnTrait, EntityTrait, PrimaryKeyToColumn, PrimaryKeyTrait};
3use crate::{
4    ConnectionTrait, DbErr, IntoSimpleExpr, ItemsAndPagesNumber, Iterable, ModelTrait, QueryFilter,
5    QueryOrder,
6};
7use sea_query::{IntoValueTuple, Order, TableRef};
8use std::marker::PhantomData;
9
10mod has_many;
11mod has_one;
12
13pub use has_many::{HasMany, Iter as HasManyIter};
14pub use has_one::HasOne;
15
16pub trait EntityLoaderTrait<E: EntityTrait>: QueryFilter + QueryOrder + Clone {
17    /// The return type of this loader
18    type ModelEx: ModelTrait<Entity = E>;
19
20    /// Find a model by primary key
21    fn filter_by_id<T>(mut self, values: T) -> Self
22    where
23        T: Into<<E::PrimaryKey as PrimaryKeyTrait>::ValueType>,
24    {
25        let mut keys = E::PrimaryKey::iter();
26        for v in values.into().into_value_tuple() {
27            if let Some(key) = keys.next() {
28                let col = key.into_column();
29                self.filter_mut(col.eq(v));
30            } else {
31                unreachable!("primary key arity mismatch");
32            }
33        }
34        self
35    }
36
37    /// Apply order by primary key to the query statement
38    fn order_by_id_asc(self) -> Self {
39        self.order_by_id(Order::Asc)
40    }
41
42    /// Apply order by primary key to the query statement
43    fn order_by_id_desc(self) -> Self {
44        self.order_by_id(Order::Desc)
45    }
46
47    /// Apply order by primary key to the query statement
48    fn order_by_id(mut self, order: Order) -> Self {
49        for key in E::PrimaryKey::iter() {
50            let col = key.into_column();
51            <Self as QueryOrder>::query(&mut self)
52                .order_by_expr(col.into_simple_expr(), order.clone());
53        }
54        self
55    }
56
57    /// Paginate query.
58    fn paginate<'db, C: ConnectionTrait>(
59        self,
60        db: &'db C,
61        page_size: u64,
62    ) -> EntityLoaderPaginator<'db, C, E, Self> {
63        EntityLoaderPaginator {
64            loader: self,
65            page: 0,
66            page_size,
67            db,
68            phantom: PhantomData,
69        }
70    }
71
72    #[doc(hidden)]
73    fn fetch<C: ConnectionTrait>(
74        self,
75        db: &C,
76        page: u64,
77        page_size: u64,
78    ) -> Result<Vec<Self::ModelEx>, DbErr>;
79
80    #[doc(hidden)]
81    fn num_items<C: ConnectionTrait>(self, db: &C, page_size: u64) -> Result<u64, DbErr>;
82}
83
84#[derive(Debug)]
85pub struct EntityLoaderPaginator<'db, C, E, L>
86where
87    C: ConnectionTrait,
88    E: EntityTrait,
89    L: EntityLoaderTrait<E>,
90{
91    pub(crate) loader: L,
92    pub(crate) page: u64,
93    pub(crate) page_size: u64,
94    pub(crate) db: &'db C,
95    pub(crate) phantom: PhantomData<E>,
96}
97
98/// Just a marker trait on EntityReverse
99pub trait EntityReverse {
100    type Entity: EntityTrait;
101}
102
103/// Subject to change, not yet stable
104#[derive(Debug, Copy, Clone, PartialEq)]
105pub struct EntityLoaderWithSelf<R: EntityTrait, S: EntityTrait>(pub R, pub S);
106
107/// Subject to change, not yet stable
108#[derive(Debug, Copy, Clone, PartialEq)]
109pub struct EntityLoaderWithSelfRev<R: EntityTrait, S: EntityReverse>(pub R, pub S);
110
111#[derive(Debug, Clone, PartialEq)]
112pub enum LoadTarget {
113    TableRef(TableRef),
114    TableRefRev(TableRef),
115    Relation(String),
116}
117
118impl<'db, C, E, L> EntityLoaderPaginator<'db, C, E, L>
119where
120    C: ConnectionTrait,
121    E: EntityTrait,
122    L: EntityLoaderTrait<E>,
123{
124    /// Fetch a specific page; page index starts from zero
125    pub fn fetch_page(&self, page: u64) -> Result<Vec<L::ModelEx>, DbErr> {
126        self.loader.clone().fetch(self.db, page, self.page_size)
127    }
128
129    /// Fetch the current page
130    pub fn fetch(&self) -> Result<Vec<L::ModelEx>, DbErr> {
131        self.fetch_page(self.page)
132    }
133
134    /// Get the total number of items
135    pub fn num_items(&self) -> Result<u64, DbErr> {
136        self.loader.clone().num_items(self.db, self.page_size)
137    }
138
139    /// Get the total number of pages
140    pub fn num_pages(&self) -> Result<u64, DbErr> {
141        let num_items = self.num_items()?;
142        let num_pages = self.compute_pages_number(num_items);
143        Ok(num_pages)
144    }
145
146    /// Get the total number of items and pages
147    pub fn num_items_and_pages(&self) -> Result<ItemsAndPagesNumber, DbErr> {
148        let number_of_items = self.num_items()?;
149        let number_of_pages = self.compute_pages_number(number_of_items);
150
151        Ok(ItemsAndPagesNumber {
152            number_of_items,
153            number_of_pages,
154        })
155    }
156
157    /// Compute the number of pages for the current page
158    fn compute_pages_number(&self, num_items: u64) -> u64 {
159        (num_items / self.page_size) + (num_items % self.page_size > 0) as u64
160    }
161
162    /// Increment the page counter
163    pub fn next(&mut self) {
164        self.page += 1;
165    }
166
167    /// Get current page number
168    pub fn cur_page(&self) -> u64 {
169        self.page
170    }
171
172    /// Fetch one page and increment the page counter
173    pub fn fetch_and_next(&mut self) -> Result<Option<Vec<L::ModelEx>>, DbErr> {
174        let vec = self.fetch()?;
175        self.next();
176        let opt = if !vec.is_empty() { Some(vec) } else { None };
177        Ok(opt)
178    }
179}
180
181#[cfg(test)]
182mod test {
183    use crate::ModelTrait;
184    use crate::tests_cfg::cake;
185
186    #[test]
187    fn test_model_ex_convert() {
188        let cake = cake::Model {
189            id: 12,
190            name: "hello".into(),
191        };
192        let cake_ex: cake::ModelEx = cake.clone().into();
193
194        assert_eq!(cake, cake_ex);
195        assert_eq!(cake_ex, cake);
196        assert_eq!(cake.id, cake_ex.id);
197        assert_eq!(cake.name, cake_ex.name);
198
199        assert_eq!(cake_ex.get(cake::Column::Id), 12i32.into());
200        assert_eq!(cake_ex.get(cake::Column::Name), "hello".into());
201
202        assert_eq!(cake::Model::from(cake_ex), cake);
203    }
204}