tibba_util/
request.rs

1// Copyright 2025 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use axum::Json;
16use axum::body::Body;
17use axum::extract::{FromRequest, FromRequestParts};
18use axum::http::header::HeaderMap;
19use axum::http::request::Parts;
20use axum::http::{Request, header};
21use serde::de::DeserializeOwned;
22use tibba_error::Error;
23use validator::Validate;
24
25fn map_err(err: impl ToString, sub_category: &str) -> Error {
26    Error::new(err)
27        .with_category("params")
28        .with_sub_category(sub_category)
29}
30
31#[derive(Debug, Clone, Copy, Default)]
32pub struct JsonParams<T>(pub T);
33
34impl<T, S> FromRequest<S> for JsonParams<T>
35where
36    T: DeserializeOwned + Validate,
37    S: Send + Sync,
38{
39    type Rejection = Error;
40
41    async fn from_request(req: Request<Body>, state: &S) -> Result<Self, Self::Rejection> {
42        if json_content_type(req.headers()) {
43            let Json(value) = Json::<T>::from_request(req, state)
44                .await
45                .map_err(|err| map_err(err, "from_json"))?;
46            value.validate().map_err(|e| map_err(e, "validate"))?;
47
48            Ok(JsonParams(value))
49        } else {
50            Err(map_err("Missing json content type", "from_json"))
51        }
52    }
53}
54
55fn json_content_type(headers: &HeaderMap) -> bool {
56    let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
57        content_type
58    } else {
59        return false;
60    };
61
62    let content_type = if let Ok(content_type) = content_type.to_str() {
63        content_type
64    } else {
65        return false;
66    };
67
68    content_type.contains("application/json")
69}
70
71#[derive(Debug, Clone, Copy, Default)]
72pub struct QueryParams<T>(pub T);
73
74impl<T, S> FromRequestParts<S> for QueryParams<T>
75where
76    T: DeserializeOwned + Validate,
77    S: Send + Sync,
78{
79    type Rejection = Error;
80
81    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
82        let query = parts.uri.query().unwrap_or_default();
83        let params: T =
84            serde_urlencoded::from_str(query).map_err(|err| map_err(err, "from_query"))?;
85        params.validate().map_err(|e| map_err(e, "validate"))?;
86        Ok(QueryParams(params))
87    }
88}