1use crate::{
2 error::*, ConnectionTrait, DbBackend, EntityTrait, FromQueryResult, Select, SelectModel,
3 SelectTwo, SelectTwoModel, Selector, SelectorRaw, SelectorTrait,
4};
5use async_stream::stream;
6use futures_util::Stream;
7use sea_query::{Alias, Expr, SelectStatement};
8use std::{marker::PhantomData, pin::Pin};
9
10pub type PinBoxStream<'db, Item> = Pin<Box<dyn Stream<Item = Item> + 'db>>;
12
13#[derive(Clone, Debug)]
15pub struct Paginator<'db, C, S>
16where
17 C: ConnectionTrait,
18 S: SelectorTrait + 'db,
19{
20 pub(crate) query: SelectStatement,
21 pub(crate) page: u64,
22 pub(crate) page_size: u64,
23 pub(crate) db: &'db C,
24 pub(crate) selector: PhantomData<S>,
25}
26
27#[derive(Clone, Debug)]
29pub struct ItemsAndPagesNumber {
30 pub number_of_items: u64,
32 pub number_of_pages: u64,
34}
35
36impl<'db, C, S> Paginator<'db, C, S>
39where
40 C: ConnectionTrait,
41 S: SelectorTrait + 'db,
42{
43 pub async fn fetch_page(&self, page: u64) -> Result<Vec<S::Item>, DbErr> {
45 let query = self
46 .query
47 .clone()
48 .limit(self.page_size)
49 .offset(self.page_size * page)
50 .to_owned();
51 let builder = self.db.get_database_backend();
52 let stmt = builder.build(&query);
53 let rows = self.db.query_all(stmt).await?;
54 let mut buffer = Vec::with_capacity(rows.len());
55 for row in rows.into_iter() {
56 buffer.push(S::from_raw_query_result(row)?);
58 }
59 Ok(buffer)
60 }
61
62 pub async fn fetch(&self) -> Result<Vec<S::Item>, DbErr> {
64 self.fetch_page(self.page).await
65 }
66
67 pub async fn num_items(&self) -> Result<u64, DbErr> {
69 let builder = self.db.get_database_backend();
70 let stmt = SelectStatement::new()
71 .expr(Expr::cust("COUNT(*) AS num_items"))
72 .from_subquery(
73 self.query
74 .clone()
75 .reset_limit()
76 .reset_offset()
77 .clear_order_by()
78 .to_owned(),
79 Alias::new("sub_query"),
80 )
81 .to_owned();
82 let stmt = builder.build(&stmt);
83 let result = match self.db.query_one(stmt).await? {
84 Some(res) => res,
85 None => return Ok(0),
86 };
87 let num_items = match builder {
88 DbBackend::Postgres => result.try_get::<i64>("", "num_items")? as u64,
89 _ => result.try_get::<i32>("", "num_items")? as u64,
90 };
91 Ok(num_items)
92 }
93
94 pub async fn num_pages(&self) -> Result<u64, DbErr> {
96 let num_items = self.num_items().await?;
97 let num_pages = self.compute_pages_number(num_items);
98 Ok(num_pages)
99 }
100
101 pub async fn num_items_and_pages(&self) -> Result<ItemsAndPagesNumber, DbErr> {
103 let number_of_items = self.num_items().await?;
104 let number_of_pages = self.compute_pages_number(number_of_items);
105
106 Ok(ItemsAndPagesNumber {
107 number_of_items,
108 number_of_pages,
109 })
110 }
111
112 fn compute_pages_number(&self, num_items: u64) -> u64 {
114 (num_items / self.page_size) + (num_items % self.page_size > 0) as u64
115 }
116
117 pub fn next(&mut self) {
119 self.page += 1;
120 }
121
122 pub fn cur_page(&self) -> u64 {
124 self.page
125 }
126
127 pub async fn fetch_and_next(&mut self) -> Result<Option<Vec<S::Item>>, DbErr> {
160 let vec = self.fetch().await?;
161 self.next();
162 let opt = if !vec.is_empty() { Some(vec) } else { None };
163 Ok(opt)
164 }
165
166 pub fn into_stream(mut self) -> PinBoxStream<'db, Result<Vec<S::Item>, DbErr>> {
201 Box::pin(stream! {
202 while let Some(vec) = self.fetch_and_next().await? {
203 yield Ok(vec);
204 }
205 })
206 }
207}
208
209#[async_trait::async_trait]
210pub trait PaginatorTrait<'db, C>
212where
213 C: ConnectionTrait,
214{
215 type Selector: SelectorTrait + Send + Sync + 'db;
217
218 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, Self::Selector>;
220
221 async fn count(self, db: &'db C) -> Result<u64, DbErr>
223 where
224 Self: Send + Sized,
225 {
226 self.paginate(db, 1).num_items().await
227 }
228}
229
230impl<'db, C, S> PaginatorTrait<'db, C> for Selector<S>
231where
232 C: ConnectionTrait,
233 S: SelectorTrait + Send + Sync + 'db,
234{
235 type Selector = S;
236
237 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, S> {
238 assert!(page_size != 0, "page_size should not be zero");
239 Paginator {
240 query: self.query,
241 page: 0,
242 page_size,
243 db,
244 selector: PhantomData,
245 }
246 }
247}
248
249impl<'db, C, S> PaginatorTrait<'db, C> for SelectorRaw<S>
250where
251 C: ConnectionTrait,
252 S: SelectorTrait + Send + Sync + 'db,
253{
254 type Selector = S;
255 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, S> {
256 assert!(page_size != 0, "page_size should not be zero");
257 let sql = self.stmt.sql.trim()[6..].trim();
258 let mut query = SelectStatement::new();
259 query.expr(if let Some(values) = self.stmt.values {
260 Expr::cust_with_values(sql, values.0)
261 } else {
262 Expr::cust(sql)
263 });
264
265 Paginator {
266 query,
267 page: 0,
268 page_size,
269 db,
270 selector: PhantomData,
271 }
272 }
273}
274
275impl<'db, C, M, E> PaginatorTrait<'db, C> for Select<E>
276where
277 C: ConnectionTrait,
278 E: EntityTrait<Model = M>,
279 M: FromQueryResult + Sized + Send + Sync + 'db,
280{
281 type Selector = SelectModel<M>;
282
283 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, Self::Selector> {
284 self.into_model().paginate(db, page_size)
285 }
286}
287
288impl<'db, C, M, N, E, F> PaginatorTrait<'db, C> for SelectTwo<E, F>
289where
290 C: ConnectionTrait,
291 E: EntityTrait<Model = M>,
292 F: EntityTrait<Model = N>,
293 M: FromQueryResult + Sized + Send + Sync + 'db,
294 N: FromQueryResult + Sized + Send + Sync + 'db,
295{
296 type Selector = SelectTwoModel<M, N>;
297
298 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, Self::Selector> {
299 self.into_model().paginate(db, page_size)
300 }
301}
302
303#[cfg(test)]
304#[cfg(feature = "mock")]
305mod tests {
306 use super::*;
307 use crate::entity::prelude::*;
308 use crate::{tests_cfg::*, ConnectionTrait, Statement};
309 use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction};
310 use futures_util::TryStreamExt;
311 use once_cell::sync::Lazy;
312 use pretty_assertions::assert_eq;
313 use sea_query::{Alias, Expr, SelectStatement, Value};
314
315 static RAW_STMT: Lazy<Statement> = Lazy::new(|| {
316 Statement::from_sql_and_values(
317 DbBackend::Postgres,
318 r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
319 [],
320 )
321 });
322
323 fn setup() -> (DatabaseConnection, Vec<Vec<fruit::Model>>) {
324 let page1 = vec![
325 fruit::Model {
326 id: 1,
327 name: "Blueberry".into(),
328 cake_id: Some(1),
329 },
330 fruit::Model {
331 id: 2,
332 name: "Raspberry".into(),
333 cake_id: Some(1),
334 },
335 ];
336
337 let page2 = vec![fruit::Model {
338 id: 3,
339 name: "Strawberry".into(),
340 cake_id: Some(2),
341 }];
342
343 let page3 = Vec::<fruit::Model>::new();
344
345 let db = MockDatabase::new(DbBackend::Postgres)
346 .append_query_results([page1.clone(), page2.clone(), page3.clone()])
347 .into_connection();
348
349 (db, vec![page1, page2, page3])
350 }
351
352 fn setup_num_items() -> (DatabaseConnection, i64) {
353 let num_items = 3;
354 let db = MockDatabase::new(DbBackend::Postgres)
355 .append_query_results([[maplit::btreemap! {
356 "num_items" => Into::<Value>::into(num_items),
357 }]])
358 .into_connection();
359
360 (db, num_items)
361 }
362
363 #[smol_potat::test]
364 async fn fetch_page() -> Result<(), DbErr> {
365 let (db, pages) = setup();
366
367 let paginator = fruit::Entity::find().paginate(&db, 2);
368
369 assert_eq!(paginator.fetch_page(0).await?, pages[0].clone());
370 assert_eq!(paginator.fetch_page(1).await?, pages[1].clone());
371 assert_eq!(paginator.fetch_page(2).await?, pages[2].clone());
372
373 let mut select = SelectStatement::new()
374 .exprs([
375 Expr::col((fruit::Entity, fruit::Column::Id)),
376 Expr::col((fruit::Entity, fruit::Column::Name)),
377 Expr::col((fruit::Entity, fruit::Column::CakeId)),
378 ])
379 .from(fruit::Entity)
380 .to_owned();
381
382 let query_builder = db.get_database_backend();
383 let stmts = [
384 query_builder.build(select.clone().offset(0).limit(2)),
385 query_builder.build(select.clone().offset(2).limit(2)),
386 query_builder.build(select.offset(4).limit(2)),
387 ];
388
389 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
390 Ok(())
391 }
392
393 #[smol_potat::test]
394 async fn fetch_page_raw() -> Result<(), DbErr> {
395 let (db, pages) = setup();
396
397 let paginator = fruit::Entity::find()
398 .from_raw_sql(RAW_STMT.clone())
399 .paginate(&db, 2);
400
401 assert_eq!(paginator.fetch_page(0).await?, pages[0].clone());
402 assert_eq!(paginator.fetch_page(1).await?, pages[1].clone());
403 assert_eq!(paginator.fetch_page(2).await?, pages[2].clone());
404
405 let mut select = SelectStatement::new()
406 .exprs([
407 Expr::col((fruit::Entity, fruit::Column::Id)),
408 Expr::col((fruit::Entity, fruit::Column::Name)),
409 Expr::col((fruit::Entity, fruit::Column::CakeId)),
410 ])
411 .from(fruit::Entity)
412 .to_owned();
413
414 let query_builder = db.get_database_backend();
415 let stmts = [
416 query_builder.build(select.clone().offset(0).limit(2)),
417 query_builder.build(select.clone().offset(2).limit(2)),
418 query_builder.build(select.offset(4).limit(2)),
419 ];
420
421 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
422 Ok(())
423 }
424
425 #[smol_potat::test]
426 async fn fetch() -> Result<(), DbErr> {
427 let (db, pages) = setup();
428
429 let mut paginator = fruit::Entity::find().paginate(&db, 2);
430
431 assert_eq!(paginator.fetch().await?, pages[0].clone());
432 paginator.next();
433
434 assert_eq!(paginator.fetch().await?, pages[1].clone());
435 paginator.next();
436
437 assert_eq!(paginator.fetch().await?, pages[2].clone());
438
439 let mut select = SelectStatement::new()
440 .exprs([
441 Expr::col((fruit::Entity, fruit::Column::Id)),
442 Expr::col((fruit::Entity, fruit::Column::Name)),
443 Expr::col((fruit::Entity, fruit::Column::CakeId)),
444 ])
445 .from(fruit::Entity)
446 .to_owned();
447
448 let query_builder = db.get_database_backend();
449 let stmts = [
450 query_builder.build(select.clone().offset(0).limit(2)),
451 query_builder.build(select.clone().offset(2).limit(2)),
452 query_builder.build(select.offset(4).limit(2)),
453 ];
454
455 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
456 Ok(())
457 }
458
459 #[smol_potat::test]
460 async fn fetch_raw() -> Result<(), DbErr> {
461 let (db, pages) = setup();
462
463 let mut paginator = fruit::Entity::find()
464 .from_raw_sql(RAW_STMT.clone())
465 .paginate(&db, 2);
466
467 assert_eq!(paginator.fetch().await?, pages[0].clone());
468 paginator.next();
469
470 assert_eq!(paginator.fetch().await?, pages[1].clone());
471 paginator.next();
472
473 assert_eq!(paginator.fetch().await?, pages[2].clone());
474
475 let mut select = SelectStatement::new()
476 .exprs([
477 Expr::col((fruit::Entity, fruit::Column::Id)),
478 Expr::col((fruit::Entity, fruit::Column::Name)),
479 Expr::col((fruit::Entity, fruit::Column::CakeId)),
480 ])
481 .from(fruit::Entity)
482 .to_owned();
483
484 let query_builder = db.get_database_backend();
485 let stmts = [
486 query_builder.build(select.clone().offset(0).limit(2)),
487 query_builder.build(select.clone().offset(2).limit(2)),
488 query_builder.build(select.offset(4).limit(2)),
489 ];
490
491 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
492 Ok(())
493 }
494
495 #[smol_potat::test]
496 async fn num_pages() -> Result<(), DbErr> {
497 let (db, num_items) = setup_num_items();
498
499 let num_items = num_items as u64;
500 let page_size = 2_u64;
501 let num_pages = (num_items / page_size) + (num_items % page_size > 0) as u64;
502 let paginator = fruit::Entity::find().paginate(&db, page_size);
503
504 assert_eq!(paginator.num_pages().await?, num_pages);
505
506 let sub_query = SelectStatement::new()
507 .exprs([
508 Expr::col((fruit::Entity, fruit::Column::Id)),
509 Expr::col((fruit::Entity, fruit::Column::Name)),
510 Expr::col((fruit::Entity, fruit::Column::CakeId)),
511 ])
512 .from(fruit::Entity)
513 .to_owned();
514
515 let select = SelectStatement::new()
516 .expr(Expr::cust("COUNT(*) AS num_items"))
517 .from_subquery(sub_query, Alias::new("sub_query"))
518 .to_owned();
519
520 let query_builder = db.get_database_backend();
521 let stmts = [query_builder.build(&select)];
522
523 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
524 Ok(())
525 }
526
527 #[smol_potat::test]
528 async fn num_pages_raw() -> Result<(), DbErr> {
529 let (db, num_items) = setup_num_items();
530
531 let num_items = num_items as u64;
532 let page_size = 2_u64;
533 let num_pages = (num_items / page_size) + (num_items % page_size > 0) as u64;
534 let paginator = fruit::Entity::find()
535 .from_raw_sql(RAW_STMT.clone())
536 .paginate(&db, page_size);
537
538 assert_eq!(paginator.num_pages().await?, num_pages);
539
540 let sub_query = SelectStatement::new()
541 .exprs([
542 Expr::col((fruit::Entity, fruit::Column::Id)),
543 Expr::col((fruit::Entity, fruit::Column::Name)),
544 Expr::col((fruit::Entity, fruit::Column::CakeId)),
545 ])
546 .from(fruit::Entity)
547 .to_owned();
548
549 let select = SelectStatement::new()
550 .expr(Expr::cust("COUNT(*) AS num_items"))
551 .from_subquery(sub_query, Alias::new("sub_query"))
552 .to_owned();
553
554 let query_builder = db.get_database_backend();
555 let stmts = [query_builder.build(&select)];
556
557 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
558 Ok(())
559 }
560
561 #[smol_potat::test]
562 async fn next_and_cur_page() -> Result<(), DbErr> {
563 let (db, _) = setup();
564
565 let mut paginator = fruit::Entity::find().paginate(&db, 2);
566
567 assert_eq!(paginator.cur_page(), 0);
568 paginator.next();
569
570 assert_eq!(paginator.cur_page(), 1);
571 paginator.next();
572
573 assert_eq!(paginator.cur_page(), 2);
574 Ok(())
575 }
576
577 #[smol_potat::test]
578 async fn next_and_cur_page_raw() -> Result<(), DbErr> {
579 let (db, _) = setup();
580
581 let mut paginator = fruit::Entity::find()
582 .from_raw_sql(RAW_STMT.clone())
583 .paginate(&db, 2);
584
585 assert_eq!(paginator.cur_page(), 0);
586 paginator.next();
587
588 assert_eq!(paginator.cur_page(), 1);
589 paginator.next();
590
591 assert_eq!(paginator.cur_page(), 2);
592 Ok(())
593 }
594
595 #[smol_potat::test]
596 async fn fetch_and_next() -> Result<(), DbErr> {
597 let (db, pages) = setup();
598
599 let mut paginator = fruit::Entity::find().paginate(&db, 2);
600
601 assert_eq!(paginator.cur_page(), 0);
602 assert_eq!(paginator.fetch_and_next().await?, Some(pages[0].clone()));
603
604 assert_eq!(paginator.cur_page(), 1);
605 assert_eq!(paginator.fetch_and_next().await?, Some(pages[1].clone()));
606
607 assert_eq!(paginator.cur_page(), 2);
608 assert_eq!(paginator.fetch_and_next().await?, None);
609
610 let mut select = SelectStatement::new()
611 .exprs([
612 Expr::col((fruit::Entity, fruit::Column::Id)),
613 Expr::col((fruit::Entity, fruit::Column::Name)),
614 Expr::col((fruit::Entity, fruit::Column::CakeId)),
615 ])
616 .from(fruit::Entity)
617 .to_owned();
618
619 let query_builder = db.get_database_backend();
620 let stmts = [
621 query_builder.build(select.clone().offset(0).limit(2)),
622 query_builder.build(select.clone().offset(2).limit(2)),
623 query_builder.build(select.offset(4).limit(2)),
624 ];
625
626 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
627 Ok(())
628 }
629
630 #[smol_potat::test]
631 async fn fetch_and_next_raw() -> Result<(), DbErr> {
632 let (db, pages) = setup();
633
634 let mut paginator = fruit::Entity::find()
635 .from_raw_sql(RAW_STMT.clone())
636 .paginate(&db, 2);
637
638 assert_eq!(paginator.cur_page(), 0);
639 assert_eq!(paginator.fetch_and_next().await?, Some(pages[0].clone()));
640
641 assert_eq!(paginator.cur_page(), 1);
642 assert_eq!(paginator.fetch_and_next().await?, Some(pages[1].clone()));
643
644 assert_eq!(paginator.cur_page(), 2);
645 assert_eq!(paginator.fetch_and_next().await?, None);
646
647 let mut select = SelectStatement::new()
648 .exprs([
649 Expr::col((fruit::Entity, fruit::Column::Id)),
650 Expr::col((fruit::Entity, fruit::Column::Name)),
651 Expr::col((fruit::Entity, fruit::Column::CakeId)),
652 ])
653 .from(fruit::Entity)
654 .to_owned();
655
656 let query_builder = db.get_database_backend();
657 let stmts = [
658 query_builder.build(select.clone().offset(0).limit(2)),
659 query_builder.build(select.clone().offset(2).limit(2)),
660 query_builder.build(select.offset(4).limit(2)),
661 ];
662
663 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
664 Ok(())
665 }
666
667 #[smol_potat::test]
668 async fn into_stream() -> Result<(), DbErr> {
669 let (db, pages) = setup();
670
671 let mut fruit_stream = fruit::Entity::find().paginate(&db, 2).into_stream();
672
673 assert_eq!(fruit_stream.try_next().await?, Some(pages[0].clone()));
674 assert_eq!(fruit_stream.try_next().await?, Some(pages[1].clone()));
675 assert_eq!(fruit_stream.try_next().await?, None);
676
677 drop(fruit_stream);
678
679 let mut select = SelectStatement::new()
680 .exprs([
681 Expr::col((fruit::Entity, fruit::Column::Id)),
682 Expr::col((fruit::Entity, fruit::Column::Name)),
683 Expr::col((fruit::Entity, fruit::Column::CakeId)),
684 ])
685 .from(fruit::Entity)
686 .to_owned();
687
688 let query_builder = db.get_database_backend();
689 let stmts = [
690 query_builder.build(select.clone().offset(0).limit(2)),
691 query_builder.build(select.clone().offset(2).limit(2)),
692 query_builder.build(select.offset(4).limit(2)),
693 ];
694
695 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
696 Ok(())
697 }
698
699 #[smol_potat::test]
700 async fn into_stream_raw() -> Result<(), DbErr> {
701 let (db, pages) = setup();
702
703 let mut fruit_stream = fruit::Entity::find()
704 .from_raw_sql(RAW_STMT.clone())
705 .paginate(&db, 2)
706 .into_stream();
707
708 assert_eq!(fruit_stream.try_next().await?, Some(pages[0].clone()));
709 assert_eq!(fruit_stream.try_next().await?, Some(pages[1].clone()));
710 assert_eq!(fruit_stream.try_next().await?, None);
711
712 drop(fruit_stream);
713
714 let mut select = SelectStatement::new()
715 .exprs([
716 Expr::col((fruit::Entity, fruit::Column::Id)),
717 Expr::col((fruit::Entity, fruit::Column::Name)),
718 Expr::col((fruit::Entity, fruit::Column::CakeId)),
719 ])
720 .from(fruit::Entity)
721 .to_owned();
722
723 let query_builder = db.get_database_backend();
724 let stmts = [
725 query_builder.build(select.clone().offset(0).limit(2)),
726 query_builder.build(select.clone().offset(2).limit(2)),
727 query_builder.build(select.offset(4).limit(2)),
728 ];
729
730 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
731 Ok(())
732 }
733
734 #[smol_potat::test]
735 async fn into_stream_raw_leading_spaces() -> Result<(), DbErr> {
736 let (db, pages) = setup();
737
738 let raw_stmt = Statement::from_sql_and_values(
739 DbBackend::Postgres,
740 r#" SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit" "#,
741 [],
742 );
743
744 let mut fruit_stream = fruit::Entity::find()
745 .from_raw_sql(raw_stmt.clone())
746 .paginate(&db, 2)
747 .into_stream();
748
749 assert_eq!(fruit_stream.try_next().await?, Some(pages[0].clone()));
750 assert_eq!(fruit_stream.try_next().await?, Some(pages[1].clone()));
751 assert_eq!(fruit_stream.try_next().await?, None);
752
753 drop(fruit_stream);
754
755 let mut select = SelectStatement::new()
756 .exprs([
757 Expr::col((fruit::Entity, fruit::Column::Id)),
758 Expr::col((fruit::Entity, fruit::Column::Name)),
759 Expr::col((fruit::Entity, fruit::Column::CakeId)),
760 ])
761 .from(fruit::Entity)
762 .to_owned();
763
764 let query_builder = db.get_database_backend();
765 let stmts = [
766 query_builder.build(select.clone().offset(0).limit(2)),
767 query_builder.build(select.clone().offset(2).limit(2)),
768 query_builder.build(select.offset(4).limit(2)),
769 ];
770
771 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
772 Ok(())
773 }
774
775 #[smol_potat::test]
776 #[should_panic]
777 async fn error() {
778 let (db, _pages) = setup();
779
780 fruit::Entity::find().paginate(&db, 0);
781 }
782}