1use crate::{
2 ConnectionTrait, DbBackend, EntityTrait, FromQueryResult, Select, SelectModel, SelectTwo,
3 SelectTwoModel, Selector, SelectorRaw, SelectorTrait, error::*,
4};
5use async_stream::stream;
6use futures_util::Stream;
7use sea_query::{Expr, SelectStatement};
8use std::{marker::PhantomData, pin::Pin};
9
10pub type PinBoxStream<'db, Item> = Pin<Box<dyn Stream<Item = Item> + 'db>>;
12
13#[derive(Clone, Debug)]
15pub struct Paginator<'db, C, S>
16where
17 C: ConnectionTrait,
18 S: SelectorTrait + 'db,
19{
20 pub(crate) query: SelectStatement,
21 pub(crate) page: u64,
22 pub(crate) page_size: u64,
23 pub(crate) db: &'db C,
24 pub(crate) selector: PhantomData<S>,
25}
26
27#[derive(Clone, Debug)]
29pub struct ItemsAndPagesNumber {
30 pub number_of_items: u64,
32 pub number_of_pages: u64,
34}
35
36impl<'db, C, S> Paginator<'db, C, S>
39where
40 C: ConnectionTrait,
41 S: SelectorTrait + 'db,
42{
43 pub async fn fetch_page(&self, page: u64) -> Result<Vec<S::Item>, DbErr> {
45 let query = self
46 .query
47 .clone()
48 .limit(self.page_size)
49 .offset(self.page_size * page)
50 .to_owned();
51 let rows = self.db.query_all(&query).await?;
52 let mut buffer = Vec::with_capacity(rows.len());
53 for row in rows.into_iter() {
54 buffer.push(S::from_raw_query_result(row)?);
55 }
56 Ok(buffer)
57 }
58
59 pub async fn fetch(&self) -> Result<Vec<S::Item>, DbErr> {
61 self.fetch_page(self.page).await
62 }
63
64 pub async fn num_items(&self) -> Result<u64, DbErr> {
66 let db_backend = self.db.get_database_backend();
67 let query = SelectStatement::new()
68 .expr(Expr::cust("COUNT(*) AS num_items"))
69 .from_subquery(
70 self.query
71 .clone()
72 .reset_limit()
73 .reset_offset()
74 .clear_order_by()
75 .to_owned(),
76 "sub_query",
77 )
78 .to_owned();
79 let result = match self.db.query_one(&query).await? {
80 Some(res) => res,
81 None => return Ok(0),
82 };
83 let num_items = match db_backend {
84 DbBackend::Postgres => result.try_get::<i64>("", "num_items")? as u64,
85 _ => result.try_get::<i32>("", "num_items")? as u64,
86 };
87 Ok(num_items)
88 }
89
90 pub async fn num_pages(&self) -> Result<u64, DbErr> {
92 let num_items = self.num_items().await?;
93 let num_pages = self.compute_pages_number(num_items);
94 Ok(num_pages)
95 }
96
97 pub async fn num_items_and_pages(&self) -> Result<ItemsAndPagesNumber, DbErr> {
99 let number_of_items = self.num_items().await?;
100 let number_of_pages = self.compute_pages_number(number_of_items);
101
102 Ok(ItemsAndPagesNumber {
103 number_of_items,
104 number_of_pages,
105 })
106 }
107
108 fn compute_pages_number(&self, num_items: u64) -> u64 {
110 (num_items / self.page_size) + (num_items % self.page_size > 0) as u64
111 }
112
113 pub fn next(&mut self) {
115 self.page += 1;
116 }
117
118 pub fn cur_page(&self) -> u64 {
120 self.page
121 }
122
123 pub async fn fetch_and_next(&mut self) -> Result<Option<Vec<S::Item>>, DbErr> {
156 let vec = self.fetch().await?;
157 self.next();
158 let opt = if !vec.is_empty() { Some(vec) } else { None };
159 Ok(opt)
160 }
161
162 pub fn into_stream(mut self) -> PinBoxStream<'db, Result<Vec<S::Item>, DbErr>> {
197 Box::pin(stream! {
198 while let Some(vec) = self.fetch_and_next().await? {
199 yield Ok(vec);
200 }
201 })
202 }
203}
204
205#[async_trait::async_trait]
206pub trait PaginatorTrait<'db, C>
208where
209 C: ConnectionTrait,
210{
211 type Selector: SelectorTrait + Send + Sync + 'db;
213
214 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, Self::Selector>;
216
217 async fn count(self, db: &'db C) -> Result<u64, DbErr>
219 where
220 Self: Send + Sized,
221 {
222 self.paginate(db, 1).num_items().await
223 }
224}
225
226impl<'db, C, S> PaginatorTrait<'db, C> for Selector<S>
227where
228 C: ConnectionTrait,
229 S: SelectorTrait + Send + Sync + 'db,
230{
231 type Selector = S;
232
233 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, S> {
234 assert!(page_size != 0, "page_size should not be zero");
235 Paginator {
236 query: self.query,
237 page: 0,
238 page_size,
239 db,
240 selector: PhantomData,
241 }
242 }
243}
244
245impl<'db, C, S> PaginatorTrait<'db, C> for SelectorRaw<S>
246where
247 C: ConnectionTrait,
248 S: SelectorTrait + Send + Sync + 'db,
249{
250 type Selector = S;
251 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, S> {
252 assert!(page_size != 0, "page_size should not be zero");
253 let sql = self.stmt.sql.trim()[6..].trim().to_owned();
254 let mut query = SelectStatement::new();
255 query.expr(if let Some(values) = self.stmt.values {
256 Expr::cust_with_values(sql, values.0)
257 } else {
258 Expr::cust(sql)
259 });
260
261 Paginator {
262 query,
263 page: 0,
264 page_size,
265 db,
266 selector: PhantomData,
267 }
268 }
269}
270
271impl<'db, C, M, E> PaginatorTrait<'db, C> for Select<E>
272where
273 C: ConnectionTrait,
274 E: EntityTrait<Model = M>,
275 M: FromQueryResult + Sized + Send + Sync + 'db,
276{
277 type Selector = SelectModel<M>;
278
279 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, Self::Selector> {
280 self.into_model().paginate(db, page_size)
281 }
282}
283
284impl<'db, C, M, N, E, F> PaginatorTrait<'db, C> for SelectTwo<E, F>
285where
286 C: ConnectionTrait,
287 E: EntityTrait<Model = M>,
288 F: EntityTrait<Model = N>,
289 M: FromQueryResult + Sized + Send + Sync + 'db,
290 N: FromQueryResult + Sized + Send + Sync + 'db,
291{
292 type Selector = SelectTwoModel<M, N>;
293
294 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, Self::Selector> {
295 self.into_model().paginate(db, page_size)
296 }
297}
298
299#[cfg(test)]
300#[cfg(feature = "mock")]
301mod tests {
302 use super::*;
303 use crate::entity::prelude::*;
304 use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction};
305 use crate::{Statement, tests_cfg::*};
306 use futures_util::TryStreamExt;
307 use pretty_assertions::assert_eq;
308 use sea_query::{Expr, SelectStatement, Value};
309 use std::sync::LazyLock;
310
311 static RAW_STMT: LazyLock<Statement> = LazyLock::new(|| {
312 Statement::from_sql_and_values(
313 DbBackend::Postgres,
314 r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
315 [],
316 )
317 });
318
319 fn setup() -> (DatabaseConnection, Vec<Vec<fruit::Model>>) {
320 let page1 = vec![
321 fruit::Model {
322 id: 1,
323 name: "Blueberry".into(),
324 cake_id: Some(1),
325 },
326 fruit::Model {
327 id: 2,
328 name: "Raspberry".into(),
329 cake_id: Some(1),
330 },
331 ];
332
333 let page2 = vec![fruit::Model {
334 id: 3,
335 name: "Strawberry".into(),
336 cake_id: Some(2),
337 }];
338
339 let page3 = Vec::<fruit::Model>::new();
340
341 let db = MockDatabase::new(DbBackend::Postgres)
342 .append_query_results([page1.clone(), page2.clone(), page3.clone()])
343 .into_connection();
344
345 (db, vec![page1, page2, page3])
346 }
347
348 fn setup_num_items() -> (DatabaseConnection, i64) {
349 let num_items = 3;
350 let db = MockDatabase::new(DbBackend::Postgres)
351 .append_query_results([[maplit::btreemap! {
352 "num_items" => Into::<Value>::into(num_items),
353 }]])
354 .into_connection();
355
356 (db, num_items)
357 }
358
359 #[smol_potat::test]
360 async fn fetch_page() -> Result<(), DbErr> {
361 let (db, pages) = setup();
362
363 let paginator = fruit::Entity::find().paginate(&db, 2);
364
365 assert_eq!(paginator.fetch_page(0).await?, pages[0].clone());
366 assert_eq!(paginator.fetch_page(1).await?, pages[1].clone());
367 assert_eq!(paginator.fetch_page(2).await?, pages[2].clone());
368
369 let mut select = SelectStatement::new()
370 .exprs([
371 Expr::col((fruit::Entity, fruit::Column::Id)),
372 Expr::col((fruit::Entity, fruit::Column::Name)),
373 Expr::col((fruit::Entity, fruit::Column::CakeId)),
374 ])
375 .from(fruit::Entity)
376 .to_owned();
377
378 let query_builder = db.get_database_backend();
379 let stmts = [
380 query_builder.build(select.clone().offset(0).limit(2)),
381 query_builder.build(select.clone().offset(2).limit(2)),
382 query_builder.build(select.offset(4).limit(2)),
383 ];
384
385 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
386 Ok(())
387 }
388
389 #[smol_potat::test]
390 async fn fetch_page_raw() -> Result<(), DbErr> {
391 let (db, pages) = setup();
392
393 let paginator = fruit::Entity::find()
394 .from_raw_sql(RAW_STMT.clone())
395 .paginate(&db, 2);
396
397 assert_eq!(paginator.fetch_page(0).await?, pages[0].clone());
398 assert_eq!(paginator.fetch_page(1).await?, pages[1].clone());
399 assert_eq!(paginator.fetch_page(2).await?, pages[2].clone());
400
401 let mut select = SelectStatement::new()
402 .exprs([
403 Expr::col((fruit::Entity, fruit::Column::Id)),
404 Expr::col((fruit::Entity, fruit::Column::Name)),
405 Expr::col((fruit::Entity, fruit::Column::CakeId)),
406 ])
407 .from(fruit::Entity)
408 .to_owned();
409
410 let query_builder = db.get_database_backend();
411 let stmts = [
412 query_builder.build(select.clone().offset(0).limit(2)),
413 query_builder.build(select.clone().offset(2).limit(2)),
414 query_builder.build(select.offset(4).limit(2)),
415 ];
416
417 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
418 Ok(())
419 }
420
421 #[smol_potat::test]
422 async fn fetch() -> Result<(), DbErr> {
423 let (db, pages) = setup();
424
425 let mut paginator = fruit::Entity::find().paginate(&db, 2);
426
427 assert_eq!(paginator.fetch().await?, pages[0].clone());
428 paginator.next();
429
430 assert_eq!(paginator.fetch().await?, pages[1].clone());
431 paginator.next();
432
433 assert_eq!(paginator.fetch().await?, pages[2].clone());
434
435 let mut select = SelectStatement::new()
436 .exprs([
437 Expr::col((fruit::Entity, fruit::Column::Id)),
438 Expr::col((fruit::Entity, fruit::Column::Name)),
439 Expr::col((fruit::Entity, fruit::Column::CakeId)),
440 ])
441 .from(fruit::Entity)
442 .to_owned();
443
444 let query_builder = db.get_database_backend();
445 let stmts = [
446 query_builder.build(select.clone().offset(0).limit(2)),
447 query_builder.build(select.clone().offset(2).limit(2)),
448 query_builder.build(select.offset(4).limit(2)),
449 ];
450
451 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
452 Ok(())
453 }
454
455 #[smol_potat::test]
456 async fn fetch_raw() -> Result<(), DbErr> {
457 let (db, pages) = setup();
458
459 let mut paginator = fruit::Entity::find()
460 .from_raw_sql(RAW_STMT.clone())
461 .paginate(&db, 2);
462
463 assert_eq!(paginator.fetch().await?, pages[0].clone());
464 paginator.next();
465
466 assert_eq!(paginator.fetch().await?, pages[1].clone());
467 paginator.next();
468
469 assert_eq!(paginator.fetch().await?, pages[2].clone());
470
471 let mut select = SelectStatement::new()
472 .exprs([
473 Expr::col((fruit::Entity, fruit::Column::Id)),
474 Expr::col((fruit::Entity, fruit::Column::Name)),
475 Expr::col((fruit::Entity, fruit::Column::CakeId)),
476 ])
477 .from(fruit::Entity)
478 .to_owned();
479
480 let query_builder = db.get_database_backend();
481 let stmts = [
482 query_builder.build(select.clone().offset(0).limit(2)),
483 query_builder.build(select.clone().offset(2).limit(2)),
484 query_builder.build(select.offset(4).limit(2)),
485 ];
486
487 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
488 Ok(())
489 }
490
491 #[smol_potat::test]
492 async fn num_pages() -> Result<(), DbErr> {
493 let (db, num_items) = setup_num_items();
494
495 let num_items = num_items as u64;
496 let page_size = 2_u64;
497 let num_pages = (num_items / page_size) + (num_items % page_size > 0) as u64;
498 let paginator = fruit::Entity::find().paginate(&db, page_size);
499
500 assert_eq!(paginator.num_pages().await?, num_pages);
501
502 let sub_query = SelectStatement::new()
503 .exprs([
504 Expr::col((fruit::Entity, fruit::Column::Id)),
505 Expr::col((fruit::Entity, fruit::Column::Name)),
506 Expr::col((fruit::Entity, fruit::Column::CakeId)),
507 ])
508 .from(fruit::Entity)
509 .to_owned();
510
511 let select = SelectStatement::new()
512 .expr(Expr::cust("COUNT(*) AS num_items"))
513 .from_subquery(sub_query, "sub_query")
514 .to_owned();
515
516 let query_builder = db.get_database_backend();
517 let stmts = [query_builder.build(&select)];
518
519 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
520 Ok(())
521 }
522
523 #[smol_potat::test]
524 async fn num_pages_raw() -> Result<(), DbErr> {
525 let (db, num_items) = setup_num_items();
526
527 let num_items = num_items as u64;
528 let page_size = 2_u64;
529 let num_pages = (num_items / page_size) + (num_items % page_size > 0) as u64;
530 let paginator = fruit::Entity::find()
531 .from_raw_sql(RAW_STMT.clone())
532 .paginate(&db, page_size);
533
534 assert_eq!(paginator.num_pages().await?, num_pages);
535
536 let sub_query = SelectStatement::new()
537 .exprs([
538 Expr::col((fruit::Entity, fruit::Column::Id)),
539 Expr::col((fruit::Entity, fruit::Column::Name)),
540 Expr::col((fruit::Entity, fruit::Column::CakeId)),
541 ])
542 .from(fruit::Entity)
543 .to_owned();
544
545 let select = SelectStatement::new()
546 .expr(Expr::cust("COUNT(*) AS num_items"))
547 .from_subquery(sub_query, "sub_query")
548 .to_owned();
549
550 let query_builder = db.get_database_backend();
551 let stmts = [query_builder.build(&select)];
552
553 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
554 Ok(())
555 }
556
557 #[smol_potat::test]
558 async fn next_and_cur_page() -> Result<(), DbErr> {
559 let (db, _) = setup();
560
561 let mut paginator = fruit::Entity::find().paginate(&db, 2);
562
563 assert_eq!(paginator.cur_page(), 0);
564 paginator.next();
565
566 assert_eq!(paginator.cur_page(), 1);
567 paginator.next();
568
569 assert_eq!(paginator.cur_page(), 2);
570 Ok(())
571 }
572
573 #[smol_potat::test]
574 async fn next_and_cur_page_raw() -> Result<(), DbErr> {
575 let (db, _) = setup();
576
577 let mut paginator = fruit::Entity::find()
578 .from_raw_sql(RAW_STMT.clone())
579 .paginate(&db, 2);
580
581 assert_eq!(paginator.cur_page(), 0);
582 paginator.next();
583
584 assert_eq!(paginator.cur_page(), 1);
585 paginator.next();
586
587 assert_eq!(paginator.cur_page(), 2);
588 Ok(())
589 }
590
591 #[smol_potat::test]
592 async fn fetch_and_next() -> Result<(), DbErr> {
593 let (db, pages) = setup();
594
595 let mut paginator = fruit::Entity::find().paginate(&db, 2);
596
597 assert_eq!(paginator.cur_page(), 0);
598 assert_eq!(paginator.fetch_and_next().await?, Some(pages[0].clone()));
599
600 assert_eq!(paginator.cur_page(), 1);
601 assert_eq!(paginator.fetch_and_next().await?, Some(pages[1].clone()));
602
603 assert_eq!(paginator.cur_page(), 2);
604 assert_eq!(paginator.fetch_and_next().await?, None);
605
606 let mut select = SelectStatement::new()
607 .exprs([
608 Expr::col((fruit::Entity, fruit::Column::Id)),
609 Expr::col((fruit::Entity, fruit::Column::Name)),
610 Expr::col((fruit::Entity, fruit::Column::CakeId)),
611 ])
612 .from(fruit::Entity)
613 .to_owned();
614
615 let query_builder = db.get_database_backend();
616 let stmts = [
617 query_builder.build(select.clone().offset(0).limit(2)),
618 query_builder.build(select.clone().offset(2).limit(2)),
619 query_builder.build(select.offset(4).limit(2)),
620 ];
621
622 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
623 Ok(())
624 }
625
626 #[smol_potat::test]
627 async fn fetch_and_next_raw() -> Result<(), DbErr> {
628 let (db, pages) = setup();
629
630 let mut paginator = fruit::Entity::find()
631 .from_raw_sql(RAW_STMT.clone())
632 .paginate(&db, 2);
633
634 assert_eq!(paginator.cur_page(), 0);
635 assert_eq!(paginator.fetch_and_next().await?, Some(pages[0].clone()));
636
637 assert_eq!(paginator.cur_page(), 1);
638 assert_eq!(paginator.fetch_and_next().await?, Some(pages[1].clone()));
639
640 assert_eq!(paginator.cur_page(), 2);
641 assert_eq!(paginator.fetch_and_next().await?, None);
642
643 let mut select = SelectStatement::new()
644 .exprs([
645 Expr::col((fruit::Entity, fruit::Column::Id)),
646 Expr::col((fruit::Entity, fruit::Column::Name)),
647 Expr::col((fruit::Entity, fruit::Column::CakeId)),
648 ])
649 .from(fruit::Entity)
650 .to_owned();
651
652 let query_builder = db.get_database_backend();
653 let stmts = [
654 query_builder.build(select.clone().offset(0).limit(2)),
655 query_builder.build(select.clone().offset(2).limit(2)),
656 query_builder.build(select.offset(4).limit(2)),
657 ];
658
659 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
660 Ok(())
661 }
662
663 #[smol_potat::test]
664 async fn into_stream() -> Result<(), DbErr> {
665 let (db, pages) = setup();
666
667 let mut fruit_stream = fruit::Entity::find().paginate(&db, 2).into_stream();
668
669 assert_eq!(fruit_stream.try_next().await?, Some(pages[0].clone()));
670 assert_eq!(fruit_stream.try_next().await?, Some(pages[1].clone()));
671 assert_eq!(fruit_stream.try_next().await?, None);
672
673 drop(fruit_stream);
674
675 let mut select = SelectStatement::new()
676 .exprs([
677 Expr::col((fruit::Entity, fruit::Column::Id)),
678 Expr::col((fruit::Entity, fruit::Column::Name)),
679 Expr::col((fruit::Entity, fruit::Column::CakeId)),
680 ])
681 .from(fruit::Entity)
682 .to_owned();
683
684 let query_builder = db.get_database_backend();
685 let stmts = [
686 query_builder.build(select.clone().offset(0).limit(2)),
687 query_builder.build(select.clone().offset(2).limit(2)),
688 query_builder.build(select.offset(4).limit(2)),
689 ];
690
691 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
692 Ok(())
693 }
694
695 #[smol_potat::test]
696 async fn into_stream_raw() -> Result<(), DbErr> {
697 let (db, pages) = setup();
698
699 let mut fruit_stream = fruit::Entity::find()
700 .from_raw_sql(RAW_STMT.clone())
701 .paginate(&db, 2)
702 .into_stream();
703
704 assert_eq!(fruit_stream.try_next().await?, Some(pages[0].clone()));
705 assert_eq!(fruit_stream.try_next().await?, Some(pages[1].clone()));
706 assert_eq!(fruit_stream.try_next().await?, None);
707
708 drop(fruit_stream);
709
710 let mut select = SelectStatement::new()
711 .exprs([
712 Expr::col((fruit::Entity, fruit::Column::Id)),
713 Expr::col((fruit::Entity, fruit::Column::Name)),
714 Expr::col((fruit::Entity, fruit::Column::CakeId)),
715 ])
716 .from(fruit::Entity)
717 .to_owned();
718
719 let query_builder = db.get_database_backend();
720 let stmts = [
721 query_builder.build(select.clone().offset(0).limit(2)),
722 query_builder.build(select.clone().offset(2).limit(2)),
723 query_builder.build(select.offset(4).limit(2)),
724 ];
725
726 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
727 Ok(())
728 }
729
730 #[smol_potat::test]
731 async fn into_stream_raw_leading_spaces() -> Result<(), DbErr> {
732 let (db, pages) = setup();
733
734 let raw_stmt = Statement::from_sql_and_values(
735 DbBackend::Postgres,
736 r#" SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit" "#,
737 [],
738 );
739
740 let mut fruit_stream = fruit::Entity::find()
741 .from_raw_sql(raw_stmt.clone())
742 .paginate(&db, 2)
743 .into_stream();
744
745 assert_eq!(fruit_stream.try_next().await?, Some(pages[0].clone()));
746 assert_eq!(fruit_stream.try_next().await?, Some(pages[1].clone()));
747 assert_eq!(fruit_stream.try_next().await?, None);
748
749 drop(fruit_stream);
750
751 let mut select = SelectStatement::new()
752 .exprs([
753 Expr::col((fruit::Entity, fruit::Column::Id)),
754 Expr::col((fruit::Entity, fruit::Column::Name)),
755 Expr::col((fruit::Entity, fruit::Column::CakeId)),
756 ])
757 .from(fruit::Entity)
758 .to_owned();
759
760 let query_builder = db.get_database_backend();
761 let stmts = [
762 query_builder.build(select.clone().offset(0).limit(2)),
763 query_builder.build(select.clone().offset(2).limit(2)),
764 query_builder.build(select.offset(4).limit(2)),
765 ];
766
767 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
768 Ok(())
769 }
770
771 #[smol_potat::test]
772 #[should_panic]
773 async fn error() {
774 let (db, _pages) = setup();
775
776 fruit::Entity::find().paginate(&db, 0);
777 }
778}