spring_sea_orm/
pagination.rs

1use schemars::JsonSchema;
2use sea_orm::{
3    ConnectionTrait, EntityTrait, FromQueryResult, PaginatorTrait, Select, Selector, SelectorTrait,
4};
5use serde::{Deserialize, Serialize};
6use spring::async_trait;
7use thiserror::Error;
8
9/// pagination information.
10#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
11pub struct Pagination {
12    #[serde(default = "default_page")]
13    pub page: u64,
14    #[serde(default = "default_size")]
15    pub size: u64,
16}
17
18fn default_page() -> u64 {
19    0
20}
21fn default_size() -> u64 {
22    20
23}
24
25impl Pagination {
26    pub fn empty_page<T>(&self) -> Page<T> {
27        Page::new(vec![], self, 0)
28    }
29}
30
31#[cfg(feature = "with-web")]
32mod web {
33    use super::Pagination;
34    use crate::config::SeaOrmWebConfig;
35    use schemars::JsonSchema;
36    use serde::Deserialize;
37    use spring_web::axum::extract::rejection::QueryRejection;
38    use spring_web::axum::extract::{FromRequestParts, Query};
39    use spring_web::axum::http::request::Parts;
40    use spring_web::axum::response::IntoResponse;
41    use spring_web::extractor::RequestPartsExt;
42    use std::result::Result as StdResult;
43    use thiserror::Error;
44
45    #[derive(Debug, Error)]
46    pub enum SeaOrmWebErr {
47        #[error(transparent)]
48        QueryRejection(#[from] QueryRejection),
49
50        #[error(transparent)]
51        WebError(#[from] spring_web::error::WebError),
52    }
53
54    impl IntoResponse for SeaOrmWebErr {
55        fn into_response(self) -> spring_web::axum::response::Response {
56            match self {
57                Self::QueryRejection(e) => e.into_response(),
58                Self::WebError(e) => e.into_response(),
59            }
60        }
61    }
62
63    #[derive(Debug, Clone, Deserialize, JsonSchema)]
64    struct OptionalPagination {
65        page: Option<u64>,
66        size: Option<u64>,
67    }
68
69    impl<S> FromRequestParts<S> for Pagination
70    where
71        S: Sync,
72    {
73        type Rejection = SeaOrmWebErr;
74
75        async fn from_request_parts(
76            parts: &mut Parts,
77            _state: &S,
78        ) -> StdResult<Self, Self::Rejection> {
79            let Query(pagination) = Query::<OptionalPagination>::try_from_uri(&parts.uri)?;
80
81            let config = parts.get_config::<SeaOrmWebConfig>()?;
82
83            let size = match pagination.size {
84                Some(size) => {
85                    if size > config.max_page_size {
86                        config.max_page_size
87                    } else {
88                        size
89                    }
90                }
91                None => config.default_page_size,
92            };
93
94            let page = if config.one_indexed {
95                pagination
96                    .page
97                    .map(|page| if page == 0 { 0 } else { page - 1 })
98                    .unwrap_or(0)
99            } else {
100                pagination.page.unwrap_or(0)
101            };
102
103            Ok(Pagination { page, size })
104        }
105    }
106
107    #[cfg(feature = "with-web-openapi")]
108    impl spring_web::aide::OperationInput for Pagination {
109        fn operation_input(
110            ctx: &mut spring_web::aide::generate::GenContext,
111            operation: &mut spring_web::aide::openapi::Operation,
112        ) {
113            <Query<OptionalPagination> as spring_web::aide::OperationInput>::operation_input(
114                ctx, operation,
115            );
116        }
117
118        fn inferred_early_responses(
119            ctx: &mut spring_web::aide::generate::GenContext,
120            operation: &mut spring_web::aide::openapi::Operation,
121        ) -> Vec<(
122            Option<spring_web::aide::openapi::StatusCode>,
123            spring_web::aide::openapi::Response,
124        )> {
125            <Query<OptionalPagination> as spring_web::aide::OperationInput>::inferred_early_responses(ctx, operation)
126        }
127    }
128}
129
130/// A page is a sublist of a list of objects.
131/// It allows gain information about the position of it in the containing entire list.
132
133#[derive(Debug, Serialize, JsonSchema)]
134pub struct Page<T> {
135    pub content: Vec<T>,
136    pub size: u64,
137    pub page: u64,
138    /// the total amount of elements.
139    pub total_elements: u64,
140    /// the number of total pages.
141    pub total_pages: u64,
142}
143
144impl<T> Page<T> {
145    pub fn new(content: Vec<T>, pagination: &Pagination, total: u64) -> Self {
146        Self {
147            content,
148            size: pagination.size,
149            page: pagination.page,
150            total_elements: total,
151            total_pages: Self::total_pages(total, pagination.size),
152        }
153    }
154
155    /// Compute the number of pages for the current page
156    fn total_pages(total: u64, size: u64) -> u64 {
157        if size == 0 {
158            return 0;
159        }
160        (total / size) + u64::from(!total.is_multiple_of(size))
161    }
162
163    /// iterator for content
164    pub fn iter(&self) -> std::slice::Iter<'_, T> {
165        self.content.iter()
166    }
167
168    /// Returns a new Page with the content of the current one mapped by the given Function
169    pub fn map<F, R>(self, func: F) -> Page<R>
170    where
171        F: FnMut(T) -> R,
172    {
173        let Page {
174            content,
175            size,
176            page,
177            total_elements,
178            total_pages,
179        } = self;
180        let content = content.into_iter().map(func).collect();
181        Page {
182            content,
183            size,
184            page,
185            total_elements,
186            total_pages,
187        }
188    }
189
190    #[inline]
191    pub fn is_empty(&self) -> bool {
192        self.content.is_empty()
193    }
194
195    #[inline]
196    pub fn is_first(&self) -> bool {
197        self.page == 0
198    }
199
200    #[inline]
201    pub fn is_last(&self) -> bool {
202        self.page + 1 >= self.total_pages
203    }
204}
205
206#[derive(Debug, Error)]
207pub enum OrmError {
208    #[error(transparent)]
209    DbErr(#[from] sea_orm::DbErr),
210}
211
212pub type PageResult<T> = std::result::Result<Page<T>, OrmError>;
213
214#[async_trait]
215/// A Trait for any type that can paginate results
216pub trait PaginationExt<'db, C, M>
217where
218    C: ConnectionTrait,
219{
220    /// pagination
221    async fn page(self, db: &'db C, pagination: &Pagination) -> PageResult<M>;
222}
223
224#[async_trait]
225impl<'db, C, M, E> PaginationExt<'db, C, M> for Select<E>
226where
227    C: ConnectionTrait,
228    E: EntityTrait<Model = M>,
229    M: FromQueryResult + Sized + Send + Sync + 'db,
230{
231    async fn page(self, db: &'db C, pagination: &Pagination) -> PageResult<M> {
232        let paginator = self.paginate(db, pagination.size);
233        let total = paginator.num_items().await?;
234        let content = paginator.fetch_page(pagination.page).await?;
235        Ok(Page::new(content, pagination, total))
236    }
237}
238
239#[async_trait]
240impl<'db, C, S> PaginationExt<'db, C, S::Item> for Selector<S>
241where
242    C: ConnectionTrait,
243    S: SelectorTrait + Send + Sync + 'db,
244{
245    async fn page(self, db: &'db C, pagination: &Pagination) -> PageResult<S::Item> {
246        let paginator = self.paginate(db, pagination.size);
247        let total = paginator.num_items().await?;
248        let content = paginator.fetch_page(pagination.page).await?;
249        Ok(Page::new(content, pagination, total))
250    }
251}