1use crate::{
2 ConnectionTrait, DbBackend, EntityTrait, FromQueryResult, Select, SelectModel, SelectTwo,
3 SelectTwoModel, Selector, SelectorRaw, SelectorTrait, error::*,
4};
5use async_stream::stream;
6use futures_util::Stream;
7use sea_query::{Expr, SelectStatement};
8use std::{marker::PhantomData, pin::Pin};
9
10pub type PinBoxStream<'db, Item> = Pin<Box<dyn Stream<Item = Item> + 'db>>;
12
13#[derive(Clone, Debug)]
15pub struct Paginator<'db, C, S>
16where
17 C: ConnectionTrait,
18 S: SelectorTrait + 'db,
19{
20 pub(crate) query: SelectStatement,
21 pub(crate) page: u64,
22 pub(crate) page_size: u64,
23 pub(crate) db: &'db C,
24 pub(crate) selector: PhantomData<S>,
25}
26
27#[derive(Clone, Debug)]
29pub struct ItemsAndPagesNumber {
30 pub number_of_items: u64,
32 pub number_of_pages: u64,
34}
35
36impl<'db, C, S> Paginator<'db, C, S>
39where
40 C: ConnectionTrait,
41 S: SelectorTrait + 'db,
42{
43 pub async fn fetch_page(&self, page: u64) -> Result<Vec<S::Item>, DbErr> {
45 let query = self
46 .query
47 .clone()
48 .limit(self.page_size)
49 .offset(self.page_size * page)
50 .to_owned();
51 let rows = self.db.query_all(&query).await?;
52 let mut buffer = Vec::with_capacity(rows.len());
53 for row in rows.into_iter() {
54 buffer.push(S::from_raw_query_result(row)?);
56 }
57 Ok(buffer)
58 }
59
60 pub async fn fetch(&self) -> Result<Vec<S::Item>, DbErr> {
62 self.fetch_page(self.page).await
63 }
64
65 pub async fn num_items(&self) -> Result<u64, DbErr> {
67 let db_backend = self.db.get_database_backend();
68 let query = SelectStatement::new()
69 .expr(Expr::cust("COUNT(*) AS num_items"))
70 .from_subquery(
71 self.query
72 .clone()
73 .reset_limit()
74 .reset_offset()
75 .clear_order_by()
76 .to_owned(),
77 "sub_query",
78 )
79 .to_owned();
80 let result = match self.db.query_one(&query).await? {
81 Some(res) => res,
82 None => return Ok(0),
83 };
84 let num_items = match db_backend {
85 DbBackend::Postgres => result.try_get::<i64>("", "num_items")? as u64,
86 _ => result.try_get::<i32>("", "num_items")? as u64,
87 };
88 Ok(num_items)
89 }
90
91 pub async fn num_pages(&self) -> Result<u64, DbErr> {
93 let num_items = self.num_items().await?;
94 let num_pages = self.compute_pages_number(num_items);
95 Ok(num_pages)
96 }
97
98 pub async fn num_items_and_pages(&self) -> Result<ItemsAndPagesNumber, DbErr> {
100 let number_of_items = self.num_items().await?;
101 let number_of_pages = self.compute_pages_number(number_of_items);
102
103 Ok(ItemsAndPagesNumber {
104 number_of_items,
105 number_of_pages,
106 })
107 }
108
109 fn compute_pages_number(&self, num_items: u64) -> u64 {
111 (num_items / self.page_size) + (num_items % self.page_size > 0) as u64
112 }
113
114 pub fn next(&mut self) {
116 self.page += 1;
117 }
118
119 pub fn cur_page(&self) -> u64 {
121 self.page
122 }
123
124 pub async fn fetch_and_next(&mut self) -> Result<Option<Vec<S::Item>>, DbErr> {
157 let vec = self.fetch().await?;
158 self.next();
159 let opt = if !vec.is_empty() { Some(vec) } else { None };
160 Ok(opt)
161 }
162
163 pub fn into_stream(mut self) -> PinBoxStream<'db, Result<Vec<S::Item>, DbErr>> {
198 Box::pin(stream! {
199 while let Some(vec) = self.fetch_and_next().await? {
200 yield Ok(vec);
201 }
202 })
203 }
204}
205
206#[async_trait::async_trait]
207pub trait PaginatorTrait<'db, C>
209where
210 C: ConnectionTrait,
211{
212 type Selector: SelectorTrait + Send + Sync + 'db;
214
215 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, Self::Selector>;
217
218 async fn count(self, db: &'db C) -> Result<u64, DbErr>
220 where
221 Self: Send + Sized,
222 {
223 self.paginate(db, 1).num_items().await
224 }
225}
226
227impl<'db, C, S> PaginatorTrait<'db, C> for Selector<S>
228where
229 C: ConnectionTrait,
230 S: SelectorTrait + Send + Sync + 'db,
231{
232 type Selector = S;
233
234 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, S> {
235 assert!(page_size != 0, "page_size should not be zero");
236 Paginator {
237 query: self.query,
238 page: 0,
239 page_size,
240 db,
241 selector: PhantomData,
242 }
243 }
244}
245
246impl<'db, C, S> PaginatorTrait<'db, C> for SelectorRaw<S>
247where
248 C: ConnectionTrait,
249 S: SelectorTrait + Send + Sync + 'db,
250{
251 type Selector = S;
252 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, S> {
253 assert!(page_size != 0, "page_size should not be zero");
254 let sql = self.stmt.sql.trim()[6..].trim().to_owned();
255 let mut query = SelectStatement::new();
256 query.expr(if let Some(values) = self.stmt.values {
257 Expr::cust_with_values(sql, values.0)
258 } else {
259 Expr::cust(sql)
260 });
261
262 Paginator {
263 query,
264 page: 0,
265 page_size,
266 db,
267 selector: PhantomData,
268 }
269 }
270}
271
272impl<'db, C, M, E> PaginatorTrait<'db, C> for Select<E>
273where
274 C: ConnectionTrait,
275 E: EntityTrait<Model = M>,
276 M: FromQueryResult + Sized + Send + Sync + 'db,
277{
278 type Selector = SelectModel<M>;
279
280 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, Self::Selector> {
281 self.into_model().paginate(db, page_size)
282 }
283}
284
285impl<'db, C, M, N, E, F> PaginatorTrait<'db, C> for SelectTwo<E, F>
286where
287 C: ConnectionTrait,
288 E: EntityTrait<Model = M>,
289 F: EntityTrait<Model = N>,
290 M: FromQueryResult + Sized + Send + Sync + 'db,
291 N: FromQueryResult + Sized + Send + Sync + 'db,
292{
293 type Selector = SelectTwoModel<M, N>;
294
295 fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, Self::Selector> {
296 self.into_model().paginate(db, page_size)
297 }
298}
299
300#[cfg(test)]
301#[cfg(feature = "mock")]
302mod tests {
303 use super::*;
304 use crate::entity::prelude::*;
305 use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction};
306 use crate::{Statement, tests_cfg::*};
307 use futures_util::TryStreamExt;
308 use pretty_assertions::assert_eq;
309 use sea_query::{Expr, SelectStatement, Value};
310 use std::sync::LazyLock;
311
312 static RAW_STMT: LazyLock<Statement> = LazyLock::new(|| {
313 Statement::from_sql_and_values(
314 DbBackend::Postgres,
315 r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
316 [],
317 )
318 });
319
320 fn setup() -> (DatabaseConnection, Vec<Vec<fruit::Model>>) {
321 let page1 = vec![
322 fruit::Model {
323 id: 1,
324 name: "Blueberry".into(),
325 cake_id: Some(1),
326 },
327 fruit::Model {
328 id: 2,
329 name: "Raspberry".into(),
330 cake_id: Some(1),
331 },
332 ];
333
334 let page2 = vec![fruit::Model {
335 id: 3,
336 name: "Strawberry".into(),
337 cake_id: Some(2),
338 }];
339
340 let page3 = Vec::<fruit::Model>::new();
341
342 let db = MockDatabase::new(DbBackend::Postgres)
343 .append_query_results([page1.clone(), page2.clone(), page3.clone()])
344 .into_connection();
345
346 (db, vec![page1, page2, page3])
347 }
348
349 fn setup_num_items() -> (DatabaseConnection, i64) {
350 let num_items = 3;
351 let db = MockDatabase::new(DbBackend::Postgres)
352 .append_query_results([[maplit::btreemap! {
353 "num_items" => Into::<Value>::into(num_items),
354 }]])
355 .into_connection();
356
357 (db, num_items)
358 }
359
360 #[smol_potat::test]
361 async fn fetch_page() -> Result<(), DbErr> {
362 let (db, pages) = setup();
363
364 let paginator = fruit::Entity::find().paginate(&db, 2);
365
366 assert_eq!(paginator.fetch_page(0).await?, pages[0].clone());
367 assert_eq!(paginator.fetch_page(1).await?, pages[1].clone());
368 assert_eq!(paginator.fetch_page(2).await?, pages[2].clone());
369
370 let mut select = SelectStatement::new()
371 .exprs([
372 Expr::col((fruit::Entity, fruit::Column::Id)),
373 Expr::col((fruit::Entity, fruit::Column::Name)),
374 Expr::col((fruit::Entity, fruit::Column::CakeId)),
375 ])
376 .from(fruit::Entity)
377 .to_owned();
378
379 let query_builder = db.get_database_backend();
380 let stmts = [
381 query_builder.build(select.clone().offset(0).limit(2)),
382 query_builder.build(select.clone().offset(2).limit(2)),
383 query_builder.build(select.offset(4).limit(2)),
384 ];
385
386 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
387 Ok(())
388 }
389
390 #[smol_potat::test]
391 async fn fetch_page_raw() -> Result<(), DbErr> {
392 let (db, pages) = setup();
393
394 let paginator = fruit::Entity::find()
395 .from_raw_sql(RAW_STMT.clone())
396 .paginate(&db, 2);
397
398 assert_eq!(paginator.fetch_page(0).await?, pages[0].clone());
399 assert_eq!(paginator.fetch_page(1).await?, pages[1].clone());
400 assert_eq!(paginator.fetch_page(2).await?, pages[2].clone());
401
402 let mut select = SelectStatement::new()
403 .exprs([
404 Expr::col((fruit::Entity, fruit::Column::Id)),
405 Expr::col((fruit::Entity, fruit::Column::Name)),
406 Expr::col((fruit::Entity, fruit::Column::CakeId)),
407 ])
408 .from(fruit::Entity)
409 .to_owned();
410
411 let query_builder = db.get_database_backend();
412 let stmts = [
413 query_builder.build(select.clone().offset(0).limit(2)),
414 query_builder.build(select.clone().offset(2).limit(2)),
415 query_builder.build(select.offset(4).limit(2)),
416 ];
417
418 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
419 Ok(())
420 }
421
422 #[smol_potat::test]
423 async fn fetch() -> Result<(), DbErr> {
424 let (db, pages) = setup();
425
426 let mut paginator = fruit::Entity::find().paginate(&db, 2);
427
428 assert_eq!(paginator.fetch().await?, pages[0].clone());
429 paginator.next();
430
431 assert_eq!(paginator.fetch().await?, pages[1].clone());
432 paginator.next();
433
434 assert_eq!(paginator.fetch().await?, pages[2].clone());
435
436 let mut select = SelectStatement::new()
437 .exprs([
438 Expr::col((fruit::Entity, fruit::Column::Id)),
439 Expr::col((fruit::Entity, fruit::Column::Name)),
440 Expr::col((fruit::Entity, fruit::Column::CakeId)),
441 ])
442 .from(fruit::Entity)
443 .to_owned();
444
445 let query_builder = db.get_database_backend();
446 let stmts = [
447 query_builder.build(select.clone().offset(0).limit(2)),
448 query_builder.build(select.clone().offset(2).limit(2)),
449 query_builder.build(select.offset(4).limit(2)),
450 ];
451
452 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
453 Ok(())
454 }
455
456 #[smol_potat::test]
457 async fn fetch_raw() -> Result<(), DbErr> {
458 let (db, pages) = setup();
459
460 let mut paginator = fruit::Entity::find()
461 .from_raw_sql(RAW_STMT.clone())
462 .paginate(&db, 2);
463
464 assert_eq!(paginator.fetch().await?, pages[0].clone());
465 paginator.next();
466
467 assert_eq!(paginator.fetch().await?, pages[1].clone());
468 paginator.next();
469
470 assert_eq!(paginator.fetch().await?, pages[2].clone());
471
472 let mut select = SelectStatement::new()
473 .exprs([
474 Expr::col((fruit::Entity, fruit::Column::Id)),
475 Expr::col((fruit::Entity, fruit::Column::Name)),
476 Expr::col((fruit::Entity, fruit::Column::CakeId)),
477 ])
478 .from(fruit::Entity)
479 .to_owned();
480
481 let query_builder = db.get_database_backend();
482 let stmts = [
483 query_builder.build(select.clone().offset(0).limit(2)),
484 query_builder.build(select.clone().offset(2).limit(2)),
485 query_builder.build(select.offset(4).limit(2)),
486 ];
487
488 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
489 Ok(())
490 }
491
492 #[smol_potat::test]
493 async fn num_pages() -> Result<(), DbErr> {
494 let (db, num_items) = setup_num_items();
495
496 let num_items = num_items as u64;
497 let page_size = 2_u64;
498 let num_pages = (num_items / page_size) + (num_items % page_size > 0) as u64;
499 let paginator = fruit::Entity::find().paginate(&db, page_size);
500
501 assert_eq!(paginator.num_pages().await?, num_pages);
502
503 let sub_query = SelectStatement::new()
504 .exprs([
505 Expr::col((fruit::Entity, fruit::Column::Id)),
506 Expr::col((fruit::Entity, fruit::Column::Name)),
507 Expr::col((fruit::Entity, fruit::Column::CakeId)),
508 ])
509 .from(fruit::Entity)
510 .to_owned();
511
512 let select = SelectStatement::new()
513 .expr(Expr::cust("COUNT(*) AS num_items"))
514 .from_subquery(sub_query, "sub_query")
515 .to_owned();
516
517 let query_builder = db.get_database_backend();
518 let stmts = [query_builder.build(&select)];
519
520 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
521 Ok(())
522 }
523
524 #[smol_potat::test]
525 async fn num_pages_raw() -> Result<(), DbErr> {
526 let (db, num_items) = setup_num_items();
527
528 let num_items = num_items as u64;
529 let page_size = 2_u64;
530 let num_pages = (num_items / page_size) + (num_items % page_size > 0) as u64;
531 let paginator = fruit::Entity::find()
532 .from_raw_sql(RAW_STMT.clone())
533 .paginate(&db, page_size);
534
535 assert_eq!(paginator.num_pages().await?, num_pages);
536
537 let sub_query = SelectStatement::new()
538 .exprs([
539 Expr::col((fruit::Entity, fruit::Column::Id)),
540 Expr::col((fruit::Entity, fruit::Column::Name)),
541 Expr::col((fruit::Entity, fruit::Column::CakeId)),
542 ])
543 .from(fruit::Entity)
544 .to_owned();
545
546 let select = SelectStatement::new()
547 .expr(Expr::cust("COUNT(*) AS num_items"))
548 .from_subquery(sub_query, "sub_query")
549 .to_owned();
550
551 let query_builder = db.get_database_backend();
552 let stmts = [query_builder.build(&select)];
553
554 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
555 Ok(())
556 }
557
558 #[smol_potat::test]
559 async fn next_and_cur_page() -> Result<(), DbErr> {
560 let (db, _) = setup();
561
562 let mut paginator = fruit::Entity::find().paginate(&db, 2);
563
564 assert_eq!(paginator.cur_page(), 0);
565 paginator.next();
566
567 assert_eq!(paginator.cur_page(), 1);
568 paginator.next();
569
570 assert_eq!(paginator.cur_page(), 2);
571 Ok(())
572 }
573
574 #[smol_potat::test]
575 async fn next_and_cur_page_raw() -> Result<(), DbErr> {
576 let (db, _) = setup();
577
578 let mut paginator = fruit::Entity::find()
579 .from_raw_sql(RAW_STMT.clone())
580 .paginate(&db, 2);
581
582 assert_eq!(paginator.cur_page(), 0);
583 paginator.next();
584
585 assert_eq!(paginator.cur_page(), 1);
586 paginator.next();
587
588 assert_eq!(paginator.cur_page(), 2);
589 Ok(())
590 }
591
592 #[smol_potat::test]
593 async fn fetch_and_next() -> Result<(), DbErr> {
594 let (db, pages) = setup();
595
596 let mut paginator = fruit::Entity::find().paginate(&db, 2);
597
598 assert_eq!(paginator.cur_page(), 0);
599 assert_eq!(paginator.fetch_and_next().await?, Some(pages[0].clone()));
600
601 assert_eq!(paginator.cur_page(), 1);
602 assert_eq!(paginator.fetch_and_next().await?, Some(pages[1].clone()));
603
604 assert_eq!(paginator.cur_page(), 2);
605 assert_eq!(paginator.fetch_and_next().await?, None);
606
607 let mut select = SelectStatement::new()
608 .exprs([
609 Expr::col((fruit::Entity, fruit::Column::Id)),
610 Expr::col((fruit::Entity, fruit::Column::Name)),
611 Expr::col((fruit::Entity, fruit::Column::CakeId)),
612 ])
613 .from(fruit::Entity)
614 .to_owned();
615
616 let query_builder = db.get_database_backend();
617 let stmts = [
618 query_builder.build(select.clone().offset(0).limit(2)),
619 query_builder.build(select.clone().offset(2).limit(2)),
620 query_builder.build(select.offset(4).limit(2)),
621 ];
622
623 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
624 Ok(())
625 }
626
627 #[smol_potat::test]
628 async fn fetch_and_next_raw() -> Result<(), DbErr> {
629 let (db, pages) = setup();
630
631 let mut paginator = fruit::Entity::find()
632 .from_raw_sql(RAW_STMT.clone())
633 .paginate(&db, 2);
634
635 assert_eq!(paginator.cur_page(), 0);
636 assert_eq!(paginator.fetch_and_next().await?, Some(pages[0].clone()));
637
638 assert_eq!(paginator.cur_page(), 1);
639 assert_eq!(paginator.fetch_and_next().await?, Some(pages[1].clone()));
640
641 assert_eq!(paginator.cur_page(), 2);
642 assert_eq!(paginator.fetch_and_next().await?, None);
643
644 let mut select = SelectStatement::new()
645 .exprs([
646 Expr::col((fruit::Entity, fruit::Column::Id)),
647 Expr::col((fruit::Entity, fruit::Column::Name)),
648 Expr::col((fruit::Entity, fruit::Column::CakeId)),
649 ])
650 .from(fruit::Entity)
651 .to_owned();
652
653 let query_builder = db.get_database_backend();
654 let stmts = [
655 query_builder.build(select.clone().offset(0).limit(2)),
656 query_builder.build(select.clone().offset(2).limit(2)),
657 query_builder.build(select.offset(4).limit(2)),
658 ];
659
660 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
661 Ok(())
662 }
663
664 #[smol_potat::test]
665 async fn into_stream() -> Result<(), DbErr> {
666 let (db, pages) = setup();
667
668 let mut fruit_stream = fruit::Entity::find().paginate(&db, 2).into_stream();
669
670 assert_eq!(fruit_stream.try_next().await?, Some(pages[0].clone()));
671 assert_eq!(fruit_stream.try_next().await?, Some(pages[1].clone()));
672 assert_eq!(fruit_stream.try_next().await?, None);
673
674 drop(fruit_stream);
675
676 let mut select = SelectStatement::new()
677 .exprs([
678 Expr::col((fruit::Entity, fruit::Column::Id)),
679 Expr::col((fruit::Entity, fruit::Column::Name)),
680 Expr::col((fruit::Entity, fruit::Column::CakeId)),
681 ])
682 .from(fruit::Entity)
683 .to_owned();
684
685 let query_builder = db.get_database_backend();
686 let stmts = [
687 query_builder.build(select.clone().offset(0).limit(2)),
688 query_builder.build(select.clone().offset(2).limit(2)),
689 query_builder.build(select.offset(4).limit(2)),
690 ];
691
692 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
693 Ok(())
694 }
695
696 #[smol_potat::test]
697 async fn into_stream_raw() -> Result<(), DbErr> {
698 let (db, pages) = setup();
699
700 let mut fruit_stream = fruit::Entity::find()
701 .from_raw_sql(RAW_STMT.clone())
702 .paginate(&db, 2)
703 .into_stream();
704
705 assert_eq!(fruit_stream.try_next().await?, Some(pages[0].clone()));
706 assert_eq!(fruit_stream.try_next().await?, Some(pages[1].clone()));
707 assert_eq!(fruit_stream.try_next().await?, None);
708
709 drop(fruit_stream);
710
711 let mut select = SelectStatement::new()
712 .exprs([
713 Expr::col((fruit::Entity, fruit::Column::Id)),
714 Expr::col((fruit::Entity, fruit::Column::Name)),
715 Expr::col((fruit::Entity, fruit::Column::CakeId)),
716 ])
717 .from(fruit::Entity)
718 .to_owned();
719
720 let query_builder = db.get_database_backend();
721 let stmts = [
722 query_builder.build(select.clone().offset(0).limit(2)),
723 query_builder.build(select.clone().offset(2).limit(2)),
724 query_builder.build(select.offset(4).limit(2)),
725 ];
726
727 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
728 Ok(())
729 }
730
731 #[smol_potat::test]
732 async fn into_stream_raw_leading_spaces() -> Result<(), DbErr> {
733 let (db, pages) = setup();
734
735 let raw_stmt = Statement::from_sql_and_values(
736 DbBackend::Postgres,
737 r#" SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit" "#,
738 [],
739 );
740
741 let mut fruit_stream = fruit::Entity::find()
742 .from_raw_sql(raw_stmt.clone())
743 .paginate(&db, 2)
744 .into_stream();
745
746 assert_eq!(fruit_stream.try_next().await?, Some(pages[0].clone()));
747 assert_eq!(fruit_stream.try_next().await?, Some(pages[1].clone()));
748 assert_eq!(fruit_stream.try_next().await?, None);
749
750 drop(fruit_stream);
751
752 let mut select = SelectStatement::new()
753 .exprs([
754 Expr::col((fruit::Entity, fruit::Column::Id)),
755 Expr::col((fruit::Entity, fruit::Column::Name)),
756 Expr::col((fruit::Entity, fruit::Column::CakeId)),
757 ])
758 .from(fruit::Entity)
759 .to_owned();
760
761 let query_builder = db.get_database_backend();
762 let stmts = [
763 query_builder.build(select.clone().offset(0).limit(2)),
764 query_builder.build(select.clone().offset(2).limit(2)),
765 query_builder.build(select.offset(4).limit(2)),
766 ];
767
768 assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
769 Ok(())
770 }
771
772 #[smol_potat::test]
773 #[should_panic]
774 async fn error() {
775 let (db, _pages) = setup();
776
777 fruit::Entity::find().paginate(&db, 0);
778 }
779}