1use schemars::JsonSchema;
2use sea_orm::{
3 ConnectionTrait, EntityTrait, FromQueryResult, PaginatorTrait, Select, Selector, SelectorTrait,
4};
5use serde::{Deserialize, Serialize};
6use spring::async_trait;
7use thiserror::Error;
8
9#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
11pub struct Pagination {
12 #[serde(default = "default_page")]
13 pub page: u64,
14 #[serde(default = "default_size")]
15 pub size: u64,
16}
17
18fn default_page() -> u64 {
19 0
20}
21fn default_size() -> u64 {
22 20
23}
24
25impl Pagination {
26 pub fn empty_page<T>(&self) -> Page<T> {
27 Page::new(vec![], self, 0)
28 }
29}
30
31#[cfg(feature = "with-web")]
32mod web {
33 use super::Pagination;
34 use crate::config::SeaOrmWebConfig;
35 use schemars::JsonSchema;
36 use serde::Deserialize;
37 use spring_web::axum::extract::rejection::QueryRejection;
38 use spring_web::axum::extract::{FromRequestParts, Query};
39 use spring_web::axum::http::request::Parts;
40 use spring_web::axum::response::IntoResponse;
41 use spring_web::extractor::RequestPartsExt;
42 use std::result::Result as StdResult;
43 use thiserror::Error;
44
45 #[derive(Debug, Error)]
46 pub enum SeaOrmWebErr {
47 #[error(transparent)]
48 QueryRejection(#[from] QueryRejection),
49
50 #[error(transparent)]
51 WebError(#[from] spring_web::error::WebError),
52 }
53
54 impl IntoResponse for SeaOrmWebErr {
55 fn into_response(self) -> spring_web::axum::response::Response {
56 match self {
57 Self::QueryRejection(e) => e.into_response(),
58 Self::WebError(e) => e.into_response(),
59 }
60 }
61 }
62
63 #[derive(Debug, Clone, Deserialize, JsonSchema)]
64 struct OptionalPagination {
65 page: Option<u64>,
66 size: Option<u64>,
67 }
68
69 impl<S> FromRequestParts<S> for Pagination
70 where
71 S: Sync,
72 {
73 type Rejection = SeaOrmWebErr;
74
75 async fn from_request_parts(
76 parts: &mut Parts,
77 _state: &S,
78 ) -> StdResult<Self, Self::Rejection> {
79 let Query(pagination) = Query::<OptionalPagination>::try_from_uri(&parts.uri)?;
80
81 let config = parts.get_config::<SeaOrmWebConfig>()?;
82
83 let size = match pagination.size {
84 Some(size) => {
85 if size > config.max_page_size {
86 config.max_page_size
87 } else {
88 size
89 }
90 }
91 None => config.default_page_size,
92 };
93
94 let page = if config.one_indexed {
95 pagination
96 .page
97 .map(|page| if page == 0 { 0 } else { page - 1 })
98 .unwrap_or(0)
99 } else {
100 pagination.page.unwrap_or(0)
101 };
102
103 Ok(Pagination { page, size })
104 }
105 }
106
107 #[cfg(feature = "with-web-openapi")]
108 impl spring_web::aide::OperationInput for Pagination {
109 fn operation_input(
110 ctx: &mut spring_web::aide::generate::GenContext,
111 operation: &mut spring_web::aide::openapi::Operation,
112 ) {
113 <Query<OptionalPagination> as spring_web::aide::OperationInput>::operation_input(
114 ctx, operation,
115 );
116 }
117
118 fn inferred_early_responses(
119 ctx: &mut spring_web::aide::generate::GenContext,
120 operation: &mut spring_web::aide::openapi::Operation,
121 ) -> Vec<(
122 Option<spring_web::aide::openapi::StatusCode>,
123 spring_web::aide::openapi::Response,
124 )> {
125 <Query<OptionalPagination> as spring_web::aide::OperationInput>::inferred_early_responses(ctx, operation)
126 }
127 }
128}
129
130#[derive(Debug, Serialize, JsonSchema)]
134pub struct Page<T> {
135 pub content: Vec<T>,
136 pub size: u64,
137 pub page: u64,
138 pub total_elements: u64,
140 pub total_pages: u64,
142}
143
144impl<T> Page<T> {
145 pub fn new(content: Vec<T>, pagination: &Pagination, total: u64) -> Self {
146 Self {
147 content,
148 size: pagination.size,
149 page: pagination.page,
150 total_elements: total,
151 total_pages: Self::total_pages(total, pagination.size),
152 }
153 }
154
155 fn total_pages(total: u64, size: u64) -> u64 {
157 if size == 0 {
158 return 0;
159 }
160 (total / size) + u64::from(!total.is_multiple_of(size))
161 }
162
163 pub fn iter(&self) -> std::slice::Iter<'_, T> {
165 self.content.iter()
166 }
167
168 pub fn map<F, R>(self, func: F) -> Page<R>
170 where
171 F: FnMut(T) -> R,
172 {
173 let Page {
174 content,
175 size,
176 page,
177 total_elements,
178 total_pages,
179 } = self;
180 let content = content.into_iter().map(func).collect();
181 Page {
182 content,
183 size,
184 page,
185 total_elements,
186 total_pages,
187 }
188 }
189
190 #[inline]
191 pub fn is_empty(&self) -> bool {
192 self.content.is_empty()
193 }
194
195 #[inline]
196 pub fn is_first(&self) -> bool {
197 self.page == 0
198 }
199
200 #[inline]
201 pub fn is_last(&self) -> bool {
202 self.page + 1 >= self.total_pages
203 }
204}
205
206#[derive(Debug, Error)]
207pub enum OrmError {
208 #[error(transparent)]
209 DbErr(#[from] sea_orm::DbErr),
210}
211
212pub type PageResult<T> = std::result::Result<Page<T>, OrmError>;
213
214#[async_trait]
215pub trait PaginationExt<'db, C, M>
217where
218 C: ConnectionTrait,
219{
220 async fn page(self, db: &'db C, pagination: &Pagination) -> PageResult<M>;
222}
223
224#[async_trait]
225impl<'db, C, M, E> PaginationExt<'db, C, M> for Select<E>
226where
227 C: ConnectionTrait,
228 E: EntityTrait<Model = M>,
229 M: FromQueryResult + Sized + Send + Sync + 'db,
230{
231 async fn page(self, db: &'db C, pagination: &Pagination) -> PageResult<M> {
232 let paginator = self.paginate(db, pagination.size);
233 let total = paginator.num_items().await?;
234 let content = paginator.fetch_page(pagination.page).await?;
235 Ok(Page::new(content, pagination, total))
236 }
237}
238
239#[async_trait]
240impl<'db, C, S> PaginationExt<'db, C, S::Item> for Selector<S>
241where
242 C: ConnectionTrait,
243 S: SelectorTrait + Send + Sync + 'db,
244{
245 async fn page(self, db: &'db C, pagination: &Pagination) -> PageResult<S::Item> {
246 let paginator = self.paginate(db, pagination.size);
247 let total = paginator.num_items().await?;
248 let content = paginator.fetch_page(pagination.page).await?;
249 Ok(Page::new(content, pagination, total))
250 }
251}