1#![allow(missing_docs)]
2use super::{ColumnTrait, EntityTrait, PrimaryKeyToColumn, PrimaryKeyTrait};
3use crate::{
4 ConnectionTrait, DbErr, IntoSimpleExpr, ItemsAndPagesNumber, Iterable, ModelTrait,
5 PinBoxStream, QueryFilter, QueryOrder, Related,
6};
7use async_stream::stream;
8use sea_query::{IntoValueTuple, Order, TableRef};
9use std::marker::PhantomData;
10
11mod has_many;
12mod has_one;
13
14pub use has_many::{HasMany, Iter as HasManyIter};
15pub use has_one::HasOne;
16
17#[async_trait::async_trait]
18pub trait EntityLoaderTrait<E: EntityTrait>: QueryFilter + QueryOrder + Clone {
19 type ModelEx: ModelTrait<Entity = E>;
21
22 fn filter_by_id<T>(mut self, values: T) -> Self
24 where
25 T: Into<<E::PrimaryKey as PrimaryKeyTrait>::ValueType>,
26 {
27 let mut keys = E::PrimaryKey::iter();
28 for v in values.into().into_value_tuple() {
29 if let Some(key) = keys.next() {
30 let col = key.into_column();
31 self.filter_mut(col.eq(v));
32 } else {
33 unreachable!("primary key arity mismatch");
34 }
35 }
36 self
37 }
38
39 fn order_by_id_asc(self) -> Self {
41 self.order_by_id(Order::Asc)
42 }
43
44 fn order_by_id_desc(self) -> Self {
46 self.order_by_id(Order::Desc)
47 }
48
49 fn order_by_id(mut self, order: Order) -> Self {
51 for key in E::PrimaryKey::iter() {
52 let col = key.into_column();
53 <Self as QueryOrder>::query(&mut self)
54 .order_by_expr(col.into_simple_expr(), order.clone());
55 }
56 self
57 }
58
59 fn paginate<'db, C: ConnectionTrait>(
61 self,
62 db: &'db C,
63 page_size: u64,
64 ) -> EntityLoaderPaginator<'db, C, E, Self> {
65 EntityLoaderPaginator {
66 loader: self,
67 page: 0,
68 page_size,
69 db,
70 phantom: PhantomData,
71 }
72 }
73
74 #[doc(hidden)]
75 async fn fetch<C: ConnectionTrait>(
76 self,
77 db: &C,
78 page: u64,
79 page_size: u64,
80 ) -> Result<Vec<Self::ModelEx>, DbErr>;
81
82 #[doc(hidden)]
83 async fn num_items<C: ConnectionTrait>(self, db: &C, page_size: u64) -> Result<u64, DbErr>;
84}
85
86#[derive(Debug)]
87pub struct EntityLoaderPaginator<'db, C, E, L>
88where
89 C: ConnectionTrait,
90 E: EntityTrait,
91 L: EntityLoaderTrait<E>,
92{
93 pub(crate) loader: L,
94 pub(crate) page: u64,
95 pub(crate) page_size: u64,
96 pub(crate) db: &'db C,
97 pub(crate) phantom: PhantomData<E>,
98}
99
100#[derive(Debug, Clone, PartialEq)]
101pub enum LoadTarget {
102 TableRef(TableRef),
103 Relation(String),
104}
105
106pub trait EntityLoaderWithParam<E: EntityTrait> {
107 fn into_with_param(self) -> (LoadTarget, Option<LoadTarget>);
108}
109
110impl<E, R> EntityLoaderWithParam<E> for R
111where
112 E: EntityTrait,
113 R: EntityTrait,
114 E: Related<R>,
115{
116 fn into_with_param(self) -> (LoadTarget, Option<LoadTarget>) {
117 (LoadTarget::TableRef(self.table_ref()), None)
118 }
119}
120
121impl<E, R, S> EntityLoaderWithParam<E> for (R, S)
122where
123 E: EntityTrait,
124 R: EntityTrait,
125 E: Related<R>,
126 S: EntityTrait,
127 R: Related<S>,
128{
129 fn into_with_param(self) -> (LoadTarget, Option<LoadTarget>) {
130 (
131 LoadTarget::TableRef(self.0.table_ref()),
132 Some(LoadTarget::TableRef(self.1.table_ref())),
133 )
134 }
135}
136
137impl<'db, C, E, L> EntityLoaderPaginator<'db, C, E, L>
138where
139 C: ConnectionTrait,
140 E: EntityTrait,
141 L: EntityLoaderTrait<E>,
142{
143 pub async fn fetch_page(&self, page: u64) -> Result<Vec<L::ModelEx>, DbErr> {
145 self.loader
146 .clone()
147 .fetch(self.db, page, self.page_size)
148 .await
149 }
150
151 pub async fn fetch(&self) -> Result<Vec<L::ModelEx>, DbErr> {
153 self.fetch_page(self.page).await
154 }
155
156 pub async fn num_items(&self) -> Result<u64, DbErr> {
158 self.loader.clone().num_items(self.db, self.page_size).await
159 }
160
161 pub async fn num_pages(&self) -> Result<u64, DbErr> {
163 let num_items = self.num_items().await?;
164 let num_pages = self.compute_pages_number(num_items);
165 Ok(num_pages)
166 }
167
168 pub async fn num_items_and_pages(&self) -> Result<ItemsAndPagesNumber, DbErr> {
170 let number_of_items = self.num_items().await?;
171 let number_of_pages = self.compute_pages_number(number_of_items);
172
173 Ok(ItemsAndPagesNumber {
174 number_of_items,
175 number_of_pages,
176 })
177 }
178
179 fn compute_pages_number(&self, num_items: u64) -> u64 {
181 (num_items / self.page_size) + (num_items % self.page_size > 0) as u64
182 }
183
184 pub fn next(&mut self) {
186 self.page += 1;
187 }
188
189 pub fn cur_page(&self) -> u64 {
191 self.page
192 }
193
194 pub async fn fetch_and_next(&mut self) -> Result<Option<Vec<L::ModelEx>>, DbErr> {
196 let vec = self.fetch().await?;
197 self.next();
198 let opt = if !vec.is_empty() { Some(vec) } else { None };
199 Ok(opt)
200 }
201
202 pub fn into_stream(mut self) -> PinBoxStream<'db, Result<Vec<L::ModelEx>, DbErr>>
204 where
205 L: 'db,
206 {
207 Box::pin(stream! {
208 while let Some(vec) = self.fetch_and_next().await? {
209 yield Ok(vec);
210 }
211 })
212 }
213}
214
215#[cfg(test)]
216mod test {
217 use crate::ModelTrait;
218 use crate::tests_cfg::cake;
219
220 #[test]
221 fn test_model_ex_convert() {
222 let cake = cake::Model {
223 id: 12,
224 name: "hello".into(),
225 };
226 let cake_ex: cake::ModelEx = cake.clone().into();
227
228 assert_eq!(cake, cake_ex);
229 assert_eq!(cake_ex, cake);
230 assert_eq!(cake.id, cake_ex.id);
231 assert_eq!(cake.name, cake_ex.name);
232
233 assert_eq!(cake_ex.get(cake::Column::Id), 12i32.into());
234 assert_eq!(cake_ex.get(cake::Column::Name), "hello".into());
235
236 assert_eq!(cake::Model::from(cake_ex), cake);
237 }
238}