surrealdb_async_graphql_axum/
extract.rs1use std::{io::ErrorKind, marker::PhantomData};
2
3use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError};
4use axum::{
5 extract::{FromRequest, Request},
6 http::{self, Method},
7 response::IntoResponse,
8};
9use tokio_util::compat::TokioAsyncReadCompatExt;
10
11pub struct GraphQLRequest<R = rejection::GraphQLRejection>(
13 pub async_graphql::Request,
14 PhantomData<R>,
15);
16
17impl<R> GraphQLRequest<R> {
18 #[must_use]
20 pub fn into_inner(self) -> async_graphql::Request {
21 self.0
22 }
23}
24
25pub mod rejection {
27 use async_graphql::ParseRequestError;
28 use axum::{
29 body::Body,
30 http,
31 http::StatusCode,
32 response::{IntoResponse, Response},
33 };
34
35 pub struct GraphQLRejection(pub ParseRequestError);
37
38 impl IntoResponse for GraphQLRejection {
39 fn into_response(self) -> Response {
40 match self.0 {
41 ParseRequestError::PayloadTooLarge => http::Response::builder()
42 .status(StatusCode::PAYLOAD_TOO_LARGE)
43 .body(Body::empty())
44 .unwrap(),
45 bad_request => http::Response::builder()
46 .status(StatusCode::BAD_REQUEST)
47 .body(Body::from(format!("{:?}", bad_request)))
48 .unwrap(),
49 }
50 }
51 }
52
53 impl From<ParseRequestError> for GraphQLRejection {
54 fn from(err: ParseRequestError) -> Self {
55 GraphQLRejection(err)
56 }
57 }
58}
59
60#[async_trait::async_trait]
61impl<S, R> FromRequest<S> for GraphQLRequest<R>
62where
63 S: Send + Sync,
64 R: IntoResponse + From<ParseRequestError>,
65{
66 type Rejection = R;
67
68 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
69 Ok(GraphQLRequest(
70 GraphQLBatchRequest::<R>::from_request(req, state)
71 .await?
72 .0
73 .into_single()?,
74 PhantomData,
75 ))
76 }
77}
78
79pub struct GraphQLBatchRequest<R = rejection::GraphQLRejection>(
81 pub async_graphql::BatchRequest,
82 PhantomData<R>,
83);
84
85impl<R> GraphQLBatchRequest<R> {
86 #[must_use]
88 pub fn into_inner(self) -> async_graphql::BatchRequest {
89 self.0
90 }
91}
92
93#[async_trait::async_trait]
94impl<S, R> FromRequest<S> for GraphQLBatchRequest<R>
95where
96 S: Send + Sync,
97 R: IntoResponse + From<ParseRequestError>,
98{
99 type Rejection = R;
100
101 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
102 if req.method() == Method::GET {
103 let uri = req.uri();
104 let res = async_graphql::http::parse_query_string(uri.query().unwrap_or_default())
105 .map_err(|err| {
106 ParseRequestError::Io(std::io::Error::new(
107 ErrorKind::Other,
108 format!("failed to parse graphql request from uri query: {}", err),
109 ))
110 });
111 Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData))
112 } else {
113 let content_type = req
114 .headers()
115 .get(http::header::CONTENT_TYPE)
116 .and_then(|value| value.to_str().ok())
117 .map(ToString::to_string);
118 let body_stream = req
119 .into_body()
120 .into_data_stream()
121 .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
122 let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
123 Ok(Self(
124 async_graphql::http::receive_batch_body(
125 content_type,
126 body_reader,
127 MultipartOptions::default(),
128 )
129 .await?,
130 PhantomData,
131 ))
132 }
133 }
134}