serde_querystring_axum/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::ops::Deref;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use axum_core::{
8    extract::FromRequestParts,
9    response::{IntoResponse, Response},
10};
11use http::{request::Parts, StatusCode};
12use serde::de::DeserializeOwned;
13use serde_querystring::de::Error;
14
15pub use serde_querystring::de::ParseMode;
16
17/// Axum's Query extractor, modified to use serde-querystring.
18///
19/// `T` is expected to implement [`serde::Deserialize`].
20///
21/// # Example
22///
23/// ```rust,no_run
24/// use axum::{
25///     routing::get,
26///     Router,
27/// };
28/// use serde::Deserialize;
29/// use serde_querystring_axum::QueryString;
30///
31/// #[derive(Deserialize)]
32/// struct Pagination {
33///     page: usize,
34///     per_page: usize,
35/// }
36///
37/// // This will parse query strings like `?page=2&per_page=30` into `Pagination`
38/// // structs.
39/// async fn list_things(pagination: QueryString<Pagination>) {
40///     let pagination: Pagination = pagination.0;
41///
42///     // ...
43/// }
44///
45/// let app = Router::new().route("/list_things", get(list_things));
46/// # async {
47/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
48/// # };
49/// ```
50///
51/// If the query string cannot be parsed it will reject the request with a `422
52/// Unprocessable Entity` response.
53///
54/// To change the default error and the parsing mode, add `QueryStringConfig` to your extensions.
55///
56/// ```rust,no_run
57/// use axum::{Router, Extension, http::StatusCode};
58/// use serde_querystring_axum::{ParseMode, QueryStringConfig};
59///
60/// let app = Router::new().layer(Extension(
61///     QueryStringConfig::new(ParseMode::Brackets).ehandler(|err| {
62///         (StatusCode::BAD_REQUEST, err.to_string()) // return type should impl IntoResponse
63///     }),
64/// ));
65/// # async {
66/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
67/// # };
68/// ```
69///
70#[derive(Debug, Clone, Copy, Default)]
71pub struct QueryString<T>(pub T);
72
73#[async_trait]
74impl<T, S> FromRequestParts<S> for QueryString<T>
75where
76    T: DeserializeOwned,
77    S: Send + Sync,
78{
79    type Rejection = Response;
80
81    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
82        let QueryStringConfig { mode, ehandler } = parts
83            .extensions
84            .get::<QueryStringConfig>()
85            .cloned()
86            .unwrap_or_default();
87
88        let query = parts.uri.query().unwrap_or_default();
89        let value = serde_querystring::from_str(query, mode).map_err(|e| {
90            if let Some(ehandler) = ehandler {
91                ehandler(e)
92            } else {
93                QueryStringError::default().into_response()
94            }
95        })?;
96        Ok(QueryString(value))
97    }
98}
99
100impl<T> Deref for QueryString<T> {
101    type Target = T;
102
103    fn deref(&self) -> &Self::Target {
104        &self.0
105    }
106}
107
108/// QueryString extractor configuration
109///
110/// ```rust,no_run
111/// use axum::{Router, Extension, http::StatusCode};
112/// use serde_querystring_axum::{ParseMode, QueryStringConfig};
113///
114/// let app = Router::new().layer(Extension(
115///     QueryStringConfig::new(ParseMode::Brackets)
116///     .ehandler(|err| {
117///         (StatusCode::BAD_REQUEST, err.to_string()) // return type should impl IntoResponse
118///     }),
119/// ));
120/// # async {
121/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
122/// # };
123/// ```
124///
125#[derive(Clone)]
126pub struct QueryStringConfig {
127    mode: ParseMode,
128    ehandler: Option<Arc<dyn Fn(Error) -> Response + Send + Sync>>,
129}
130
131impl Default for QueryStringConfig {
132    fn default() -> Self {
133        Self {
134            mode: ParseMode::Duplicate,
135            ehandler: None,
136        }
137    }
138}
139
140impl QueryStringConfig {
141    pub fn new(mode: ParseMode) -> Self {
142        Self {
143            mode,
144            ehandler: None,
145        }
146    }
147
148    pub fn mode(mut self, mode: ParseMode) -> Self {
149        self.mode = mode;
150        self
151    }
152
153    pub fn ehandler<F, R>(mut self, ehandler: F) -> Self
154    where
155        F: Fn(Error) -> R + Send + Sync + 'static,
156        R: IntoResponse,
157    {
158        self.ehandler = Some(Arc::new(move |e| ehandler(e).into_response()));
159        self
160    }
161}
162
163#[derive(Debug)]
164struct QueryStringError {
165    status: StatusCode,
166    body: String,
167}
168
169impl Default for QueryStringError {
170    fn default() -> Self {
171        Self {
172            status: StatusCode::BAD_REQUEST,
173            body: String::from("Failed to deserialize query string"),
174        }
175    }
176}
177
178impl IntoResponse for QueryStringError {
179    fn into_response(self) -> Response {
180        (self.status, self.body).into_response()
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use std::fmt::Debug;
187
188    use axum::{
189        body::{Body, HttpBody},
190        extract::FromRequest,
191        routing::get,
192        Extension, Router,
193    };
194    use http::{Request, StatusCode};
195    use serde::Deserialize;
196    use tower::ServiceExt;
197
198    use super::*;
199
200    async fn check<T>(uri: impl AsRef<str>, value: T)
201    where
202        T: DeserializeOwned + PartialEq + Debug,
203    {
204        let req = Request::builder().uri(uri.as_ref()).body(()).unwrap();
205        assert_eq!(
206            QueryString::<T>::from_request(req, &()).await.unwrap().0,
207            value
208        );
209    }
210
211    #[tokio::test]
212    async fn test_query() {
213        #[derive(Debug, PartialEq, Deserialize)]
214        struct Pagination {
215            size: Option<u64>,
216            pages: Option<Vec<u64>>,
217        }
218
219        check(
220            "http://example.com/test",
221            Pagination {
222                size: None,
223                pages: None,
224            },
225        )
226        .await;
227
228        check(
229            "http://example.com/test?size=10",
230            Pagination {
231                size: Some(10),
232                pages: None,
233            },
234        )
235        .await;
236
237        check(
238            "http://example.com/test?size=10&pages=20",
239            Pagination {
240                size: Some(10),
241                pages: Some(vec![20]),
242            },
243        )
244        .await;
245
246        check(
247            "http://example.com/test?size=10&pages=20&pages=21&pages=22",
248            Pagination {
249                size: Some(10),
250                pages: Some(vec![20, 21, 22]),
251            },
252        )
253        .await;
254    }
255
256    #[tokio::test]
257    async fn test_config_mode() {
258        #[derive(Deserialize)]
259        #[allow(dead_code)]
260        struct Params {
261            n: Vec<i32>,
262        }
263
264        async fn handler(q: QueryString<Params>) -> String {
265            format!("{}-{}", q.n.get(0).unwrap(), q.n.get(2).unwrap())
266        }
267
268        let app = Router::new()
269            .route("/", get(handler))
270            .layer(Extension(QueryStringConfig::new(ParseMode::Brackets)));
271        let res = app
272            .oneshot(
273                Request::builder()
274                    .uri("/?n[3]=300&n[2]=200&n[1]=100")
275                    .body(Body::empty())
276                    .unwrap(),
277            )
278            .await
279            .unwrap();
280
281        let (parts, mut body) = res.into_parts();
282
283        assert_eq!(parts.status, StatusCode::OK);
284        assert_eq!(body.data().await.unwrap().unwrap(), "100-300")
285    }
286
287    #[tokio::test]
288    async fn correct_rejection_default() {
289        #[derive(Deserialize)]
290        #[allow(dead_code)]
291        struct Params {
292            n: i32,
293        }
294
295        async fn handler(_: QueryString<Params>) {}
296
297        let app = Router::new().route("/", get(handler));
298        let res = app
299            .oneshot(
300                Request::builder()
301                    .uri("/?n=string")
302                    .body(Body::empty())
303                    .unwrap(),
304            )
305            .await
306            .unwrap();
307
308        let (parts, mut body) = res.into_parts();
309
310        assert_eq!(parts.status, StatusCode::BAD_REQUEST);
311        assert_eq!(
312            body.data().await.unwrap().unwrap(),
313            "Failed to deserialize query string"
314        );
315    }
316
317    #[tokio::test]
318    async fn correct_rejection_custom() {
319        #[derive(Deserialize)]
320        #[allow(dead_code)]
321        struct Params {
322            n: i32,
323        }
324
325        async fn handler(_: QueryString<Params>) {}
326
327        let app = Router::new().route("/", get(handler)).layer(Extension(
328            QueryStringConfig::default().ehandler(|_err| {
329                (
330                    StatusCode::BAD_GATEWAY,
331                    String::from("Something went wrong"),
332                )
333            }),
334        ));
335
336        let res = app
337            .oneshot(
338                Request::builder()
339                    .uri("/?n=string")
340                    .body(Body::empty())
341                    .unwrap(),
342            )
343            .await
344            .unwrap();
345
346        let (parts, mut body) = res.into_parts();
347
348        assert_eq!(parts.status, StatusCode::BAD_GATEWAY);
349        assert_eq!(body.data().await.unwrap().unwrap(), "Something went wrong");
350    }
351}