serwus/
pagination.rs

1use diesel::{
2    prelude::*, query_builder::*, query_dsl::methods::LoadQuery, sql_types::BigInt,
3};
4
5use crate::containers::ListResponse;
6use crate::db_pool::{Db, DbConnection};
7
8
9
10pub trait Paginate: Sized {
11    fn paginate(self, page: i64) -> Paginated<Self>;
12}
13
14impl<T> Paginate for T {
15    fn paginate(self, page: i64) -> Paginated<Self> {
16        Paginated {
17            query: self,
18            page,
19            per_page: DEFAULT_PER_PAGE,
20            offset: (page - 1) * DEFAULT_PER_PAGE,
21        }
22    }
23}
24
25const DEFAULT_PER_PAGE: i64 = 10;
26
27#[derive(Debug, Clone, Copy, QueryId)]
28pub struct Paginated<T> {
29    query: T,
30    page: i64,
31    per_page: i64,
32    offset: i64,
33}
34
35impl<T> Paginated<T> {
36    #[must_use]
37    pub fn per_page(self, per_page: i64) -> Self {
38        Paginated {
39            per_page,
40            offset: (self.page - 1) * per_page,
41            ..self
42        }
43    }
44
45    pub fn load_and_count_pages<'a, U>(
46        self,
47        conn: &mut DbConnection,
48    ) -> QueryResult<ListResponse<U>>
49    where
50        Self: LoadQuery<'a, DbConnection, (U, i64)>,
51    {
52        let per_page = self.per_page;
53        let page = self.page;
54
55        let results = self.load::<(U, i64)>(conn)?;
56        let total = results.first().map(|x| x.1).unwrap_or(0);
57        let records = results.into_iter().map(|x| x.0).collect();
58        let total_pages = (total as f64 / per_page as f64).ceil() as i64;
59
60        let next_page = (page < total_pages).then_some(page + 1);
61
62        Ok(ListResponse {
63            total,
64            total_pages,
65            next_page,
66            data: records,
67        })
68    }
69}
70
71impl<T: Query> Query for Paginated<T> {
72    type SqlType = (T::SqlType, BigInt);
73}
74
75impl<T> RunQueryDsl<DbConnection> for Paginated<T> {}
76
77impl<T> QueryFragment<Db> for Paginated<T>
78where
79    T: QueryFragment<Db>,
80{
81    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Db>) -> QueryResult<()> {
82        out.push_sql("SELECT *, COUNT(*) OVER () FROM (");
83        self.query.walk_ast(out.reborrow())?;
84        out.push_sql(") t LIMIT ");
85        out.push_bind_param::<BigInt, _>(&self.per_page)?;
86        out.push_sql(" OFFSET ");
87        out.push_bind_param::<BigInt, _>(&self.offset)?;
88        Ok(())
89    }
90}