surrealdb_async_graphql_axum/
extract.rs

1use std::{io::ErrorKind, marker::PhantomData};
2
3use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError};
4use axum::{
5    extract::{FromRequest, Request},
6    http::{self, Method},
7    response::IntoResponse,
8};
9use tokio_util::compat::TokioAsyncReadCompatExt;
10
11/// Extractor for GraphQL request.
12pub struct GraphQLRequest<R = rejection::GraphQLRejection>(
13    pub async_graphql::Request,
14    PhantomData<R>,
15);
16
17impl<R> GraphQLRequest<R> {
18    /// Unwraps the value to `async_graphql::Request`.
19    #[must_use]
20    pub fn into_inner(self) -> async_graphql::Request {
21        self.0
22    }
23}
24
25/// Rejection response types.
26pub mod rejection {
27    use async_graphql::ParseRequestError;
28    use axum::{
29        body::Body,
30        http,
31        http::StatusCode,
32        response::{IntoResponse, Response},
33    };
34
35    /// Rejection used for [`GraphQLRequest`](GraphQLRequest).
36    pub struct GraphQLRejection(pub ParseRequestError);
37
38    impl IntoResponse for GraphQLRejection {
39        fn into_response(self) -> Response {
40            match self.0 {
41                ParseRequestError::PayloadTooLarge => http::Response::builder()
42                    .status(StatusCode::PAYLOAD_TOO_LARGE)
43                    .body(Body::empty())
44                    .unwrap(),
45                bad_request => http::Response::builder()
46                    .status(StatusCode::BAD_REQUEST)
47                    .body(Body::from(format!("{:?}", bad_request)))
48                    .unwrap(),
49            }
50        }
51    }
52
53    impl From<ParseRequestError> for GraphQLRejection {
54        fn from(err: ParseRequestError) -> Self {
55            GraphQLRejection(err)
56        }
57    }
58}
59
60#[async_trait::async_trait]
61impl<S, R> FromRequest<S> for GraphQLRequest<R>
62where
63    S: Send + Sync,
64    R: IntoResponse + From<ParseRequestError>,
65{
66    type Rejection = R;
67
68    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
69        Ok(GraphQLRequest(
70            GraphQLBatchRequest::<R>::from_request(req, state)
71                .await?
72                .0
73                .into_single()?,
74            PhantomData,
75        ))
76    }
77}
78
79/// Extractor for GraphQL batch request.
80pub struct GraphQLBatchRequest<R = rejection::GraphQLRejection>(
81    pub async_graphql::BatchRequest,
82    PhantomData<R>,
83);
84
85impl<R> GraphQLBatchRequest<R> {
86    /// Unwraps the value to `async_graphql::BatchRequest`.
87    #[must_use]
88    pub fn into_inner(self) -> async_graphql::BatchRequest {
89        self.0
90    }
91}
92
93#[async_trait::async_trait]
94impl<S, R> FromRequest<S> for GraphQLBatchRequest<R>
95where
96    S: Send + Sync,
97    R: IntoResponse + From<ParseRequestError>,
98{
99    type Rejection = R;
100
101    async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
102        if req.method() == Method::GET {
103            let uri = req.uri();
104            let res = async_graphql::http::parse_query_string(uri.query().unwrap_or_default())
105                .map_err(|err| {
106                    ParseRequestError::Io(std::io::Error::new(
107                        ErrorKind::Other,
108                        format!("failed to parse graphql request from uri query: {}", err),
109                    ))
110                });
111            Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData))
112        } else {
113            let content_type = req
114                .headers()
115                .get(http::header::CONTENT_TYPE)
116                .and_then(|value| value.to_str().ok())
117                .map(ToString::to_string);
118            let body_stream = req
119                .into_body()
120                .into_data_stream()
121                .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
122            let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
123            Ok(Self(
124                async_graphql::http::receive_batch_body(
125                    content_type,
126                    body_reader,
127                    MultipartOptions::default(),
128                )
129                .await?,
130                PhantomData,
131            ))
132        }
133    }
134}