1#![doc = include_str!("../README.md")]
2
3use std::ops::Deref;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use axum_core::{
8 extract::FromRequestParts,
9 response::{IntoResponse, Response},
10};
11use http::{request::Parts, StatusCode};
12use serde::de::DeserializeOwned;
13use serde_querystring::de::Error;
14
15pub use serde_querystring::de::ParseMode;
16
17#[derive(Debug, Clone, Copy, Default)]
71pub struct QueryString<T>(pub T);
72
73#[async_trait]
74impl<T, S> FromRequestParts<S> for QueryString<T>
75where
76 T: DeserializeOwned,
77 S: Send + Sync,
78{
79 type Rejection = Response;
80
81 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
82 let QueryStringConfig { mode, ehandler } = parts
83 .extensions
84 .get::<QueryStringConfig>()
85 .cloned()
86 .unwrap_or_default();
87
88 let query = parts.uri.query().unwrap_or_default();
89 let value = serde_querystring::from_str(query, mode).map_err(|e| {
90 if let Some(ehandler) = ehandler {
91 ehandler(e)
92 } else {
93 QueryStringError::default().into_response()
94 }
95 })?;
96 Ok(QueryString(value))
97 }
98}
99
100impl<T> Deref for QueryString<T> {
101 type Target = T;
102
103 fn deref(&self) -> &Self::Target {
104 &self.0
105 }
106}
107
108#[derive(Clone)]
126pub struct QueryStringConfig {
127 mode: ParseMode,
128 ehandler: Option<Arc<dyn Fn(Error) -> Response + Send + Sync>>,
129}
130
131impl Default for QueryStringConfig {
132 fn default() -> Self {
133 Self {
134 mode: ParseMode::Duplicate,
135 ehandler: None,
136 }
137 }
138}
139
140impl QueryStringConfig {
141 pub fn new(mode: ParseMode) -> Self {
142 Self {
143 mode,
144 ehandler: None,
145 }
146 }
147
148 pub fn mode(mut self, mode: ParseMode) -> Self {
149 self.mode = mode;
150 self
151 }
152
153 pub fn ehandler<F, R>(mut self, ehandler: F) -> Self
154 where
155 F: Fn(Error) -> R + Send + Sync + 'static,
156 R: IntoResponse,
157 {
158 self.ehandler = Some(Arc::new(move |e| ehandler(e).into_response()));
159 self
160 }
161}
162
163#[derive(Debug)]
164struct QueryStringError {
165 status: StatusCode,
166 body: String,
167}
168
169impl Default for QueryStringError {
170 fn default() -> Self {
171 Self {
172 status: StatusCode::BAD_REQUEST,
173 body: String::from("Failed to deserialize query string"),
174 }
175 }
176}
177
178impl IntoResponse for QueryStringError {
179 fn into_response(self) -> Response {
180 (self.status, self.body).into_response()
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use std::fmt::Debug;
187
188 use axum::{
189 body::{Body, HttpBody},
190 extract::FromRequest,
191 routing::get,
192 Extension, Router,
193 };
194 use http::{Request, StatusCode};
195 use serde::Deserialize;
196 use tower::ServiceExt;
197
198 use super::*;
199
200 async fn check<T>(uri: impl AsRef<str>, value: T)
201 where
202 T: DeserializeOwned + PartialEq + Debug,
203 {
204 let req = Request::builder().uri(uri.as_ref()).body(()).unwrap();
205 assert_eq!(
206 QueryString::<T>::from_request(req, &()).await.unwrap().0,
207 value
208 );
209 }
210
211 #[tokio::test]
212 async fn test_query() {
213 #[derive(Debug, PartialEq, Deserialize)]
214 struct Pagination {
215 size: Option<u64>,
216 pages: Option<Vec<u64>>,
217 }
218
219 check(
220 "http://example.com/test",
221 Pagination {
222 size: None,
223 pages: None,
224 },
225 )
226 .await;
227
228 check(
229 "http://example.com/test?size=10",
230 Pagination {
231 size: Some(10),
232 pages: None,
233 },
234 )
235 .await;
236
237 check(
238 "http://example.com/test?size=10&pages=20",
239 Pagination {
240 size: Some(10),
241 pages: Some(vec![20]),
242 },
243 )
244 .await;
245
246 check(
247 "http://example.com/test?size=10&pages=20&pages=21&pages=22",
248 Pagination {
249 size: Some(10),
250 pages: Some(vec![20, 21, 22]),
251 },
252 )
253 .await;
254 }
255
256 #[tokio::test]
257 async fn test_config_mode() {
258 #[derive(Deserialize)]
259 #[allow(dead_code)]
260 struct Params {
261 n: Vec<i32>,
262 }
263
264 async fn handler(q: QueryString<Params>) -> String {
265 format!("{}-{}", q.n.get(0).unwrap(), q.n.get(2).unwrap())
266 }
267
268 let app = Router::new()
269 .route("/", get(handler))
270 .layer(Extension(QueryStringConfig::new(ParseMode::Brackets)));
271 let res = app
272 .oneshot(
273 Request::builder()
274 .uri("/?n[3]=300&n[2]=200&n[1]=100")
275 .body(Body::empty())
276 .unwrap(),
277 )
278 .await
279 .unwrap();
280
281 let (parts, mut body) = res.into_parts();
282
283 assert_eq!(parts.status, StatusCode::OK);
284 assert_eq!(body.data().await.unwrap().unwrap(), "100-300")
285 }
286
287 #[tokio::test]
288 async fn correct_rejection_default() {
289 #[derive(Deserialize)]
290 #[allow(dead_code)]
291 struct Params {
292 n: i32,
293 }
294
295 async fn handler(_: QueryString<Params>) {}
296
297 let app = Router::new().route("/", get(handler));
298 let res = app
299 .oneshot(
300 Request::builder()
301 .uri("/?n=string")
302 .body(Body::empty())
303 .unwrap(),
304 )
305 .await
306 .unwrap();
307
308 let (parts, mut body) = res.into_parts();
309
310 assert_eq!(parts.status, StatusCode::BAD_REQUEST);
311 assert_eq!(
312 body.data().await.unwrap().unwrap(),
313 "Failed to deserialize query string"
314 );
315 }
316
317 #[tokio::test]
318 async fn correct_rejection_custom() {
319 #[derive(Deserialize)]
320 #[allow(dead_code)]
321 struct Params {
322 n: i32,
323 }
324
325 async fn handler(_: QueryString<Params>) {}
326
327 let app = Router::new().route("/", get(handler)).layer(Extension(
328 QueryStringConfig::default().ehandler(|_err| {
329 (
330 StatusCode::BAD_GATEWAY,
331 String::from("Something went wrong"),
332 )
333 }),
334 ));
335
336 let res = app
337 .oneshot(
338 Request::builder()
339 .uri("/?n=string")
340 .body(Body::empty())
341 .unwrap(),
342 )
343 .await
344 .unwrap();
345
346 let (parts, mut body) = res.into_parts();
347
348 assert_eq!(parts.status, StatusCode::BAD_GATEWAY);
349 assert_eq!(body.data().await.unwrap().unwrap(), "Something went wrong");
350 }
351}