tsukuyomi_juniper/
request.rs

1use {
2    crate::{error::GraphQLParseError, Schema},
3    futures::Future,
4    http::{Method, Response, StatusCode},
5    juniper::{DefaultScalarValue, InputValue, ScalarRefValue, ScalarValue},
6    percent_encoding::percent_decode,
7    serde::Deserialize,
8    tsukuyomi::{
9        error::Error,
10        extractor::Extractor,
11        future::{Async, Poll, TryFuture},
12        input::{
13            body::{ReadAll, RequestBody},
14            header::ContentType,
15            Input,
16        },
17        responder::Responder,
18    },
19};
20
21/// Create an `Extractor` that parses the incoming request as GraphQL query.
22pub fn request<S>() -> impl Extractor<
23    Output = (GraphQLRequest<S>,), //
24    Error = Error,
25    Extract = impl TryFuture<Ok = (GraphQLRequest<S>,), Error = Error> + Send + 'static,
26>
27where
28    S: ScalarValue + Send + 'static,
29    for<'a> &'a S: ScalarRefValue<'a>,
30{
31    #[allow(missing_debug_implementations)]
32    #[derive(Copy, Clone)]
33    enum RequestKind {
34        Json,
35        GraphQL,
36    }
37
38    #[allow(missing_debug_implementations)]
39    enum State {
40        Init,
41        Receive(ReadAll, RequestKind),
42    }
43
44    tsukuyomi::extractor::extract(|| {
45        let mut state = State::Init;
46        tsukuyomi::future::poll_fn(move |input| loop {
47            state = match state {
48                State::Init => {
49                    if input.request.method() == Method::GET {
50                        return parse_query_request(input).map(|request| Async::Ready((request,)));
51                    } else if input.request.method() == Method::POST {
52                        let kind = match tsukuyomi::input::header::parse::<ContentType>(input) {
53                            Ok(Some(mime)) if *mime == mime::APPLICATION_JSON => RequestKind::Json,
54                            Ok(Some(mime)) if *mime == "application/graphql" => {
55                                RequestKind::GraphQL
56                            }
57                            Ok(Some(..)) => return Err(GraphQLParseError::InvalidMime.into()),
58                            Ok(None) => return Err(GraphQLParseError::MissingMime.into()),
59                            Err(err) => return Err(err),
60                        };
61
62                        let read_all = match input.locals.remove(&RequestBody::KEY) {
63                            Some(body) => body.read_all(),
64                            None => {
65                                return Err(tsukuyomi::error::internal_server_error(
66                                    "the payload has already stolen by another extractor",
67                                ))
68                            }
69                        };
70                        State::Receive(read_all, kind)
71                    } else {
72                        return Err(GraphQLParseError::InvalidRequestMethod.into());
73                    }
74                }
75                State::Receive(ref mut read_all, kind) => {
76                    let data = futures::try_ready!(read_all.poll());
77                    match kind {
78                        RequestKind::Json => {
79                            let request = serde_json::from_slice(&*data)
80                                .map_err(GraphQLParseError::ParseJson)?;
81                            return Ok(Async::Ready((request,)));
82                        }
83                        RequestKind::GraphQL => {
84                            return String::from_utf8(data.to_vec())
85                                .map(|query| {
86                                    Async::Ready((GraphQLRequest::single(query, None, None),))
87                                })
88                                .map_err(|e| GraphQLParseError::DecodeUtf8(e.utf8_error()).into())
89                        }
90                    }
91                }
92            };
93        })
94    })
95}
96
97fn parse_query_request<S>(input: &mut Input<'_>) -> tsukuyomi::Result<GraphQLRequest<S>>
98where
99    S: ScalarValue,
100    for<'a> &'a S: ScalarRefValue<'a>,
101{
102    let query_str = input
103        .request
104        .uri()
105        .query()
106        .ok_or_else(|| GraphQLParseError::MissingQuery)?;
107    parse_query_str(query_str).map_err(Into::into)
108}
109
110fn parse_query_str<S>(s: &str) -> Result<GraphQLRequest<S>, GraphQLParseError>
111where
112    S: ScalarValue,
113    for<'a> &'a S: ScalarRefValue<'a>,
114{
115    #[derive(Debug, serde::Deserialize)]
116    struct ParsedQuery {
117        query: String,
118        operation_name: Option<String>,
119        variables: Option<String>,
120    }
121    let parsed: ParsedQuery =
122        serde_urlencoded::from_str(s).map_err(GraphQLParseError::ParseQuery)?;
123
124    let query = percent_decode(parsed.query.as_ref())
125        .decode_utf8()
126        .map_err(GraphQLParseError::DecodeUtf8)?
127        .into_owned();
128
129    let operation_name = parsed.operation_name.map_or(Ok(None), |s| {
130        percent_decode(s.as_ref())
131            .decode_utf8()
132            .map_err(GraphQLParseError::DecodeUtf8)
133            .map(|s| s.into_owned())
134            .map(Some)
135    })?;
136
137    let variables = parsed
138        .variables
139        .map_or(Ok(None), |s| -> Result<_, GraphQLParseError> {
140            let decoded = percent_decode(s.as_ref())
141                .decode_utf8()
142                .map_err(GraphQLParseError::DecodeUtf8)?;
143            let variables = serde_json::from_str(&*decoded)
144                .map(Some)
145                .map_err(GraphQLParseError::ParseJson)?;
146            Ok(variables)
147        })?;
148
149    Ok(GraphQLRequest::single(query, operation_name, variables))
150}
151
152/// The type representing a GraphQL request from the client.
153#[derive(Debug, serde::Deserialize)]
154#[serde(bound = "InputValue<S>: Deserialize<'de>")]
155pub struct GraphQLRequest<S: ScalarValue = DefaultScalarValue>(GraphQLRequestKind<S>);
156
157#[derive(Debug, Deserialize)]
158#[serde(untagged, bound = "InputValue<S>: Deserialize<'de>")]
159enum GraphQLRequestKind<S: ScalarValue> {
160    Single(juniper::http::GraphQLRequest<S>),
161    Batch(Vec<juniper::http::GraphQLRequest<S>>),
162}
163
164impl<S> GraphQLRequest<S>
165where
166    S: ScalarValue,
167    for<'a> &'a S: ScalarRefValue<'a>,
168{
169    fn single(
170        query: String,
171        operation_name: Option<String>,
172        variables: Option<InputValue<S>>,
173    ) -> Self {
174        GraphQLRequest(GraphQLRequestKind::Single(
175            juniper::http::GraphQLRequest::new(query, operation_name, variables),
176        ))
177    }
178
179    /// Creates a `Responder` that executes this request using the specified schema and context.
180    pub fn execute<T, CtxT>(self, schema: T, context: CtxT) -> GraphQLResponse<T, CtxT, S>
181    where
182        T: Schema<S> + Send + 'static,
183        CtxT: AsRef<T::Context> + Send + 'static,
184        S: Send + 'static,
185    {
186        GraphQLResponse {
187            request: self,
188            schema,
189            context,
190        }
191    }
192}
193
194/// The type representing the result from the executing a GraphQL request.
195#[derive(Debug)]
196pub struct GraphQLResponse<T, CtxT, S: ScalarValue = DefaultScalarValue> {
197    request: GraphQLRequest<S>,
198    schema: T,
199    context: CtxT,
200}
201
202impl<T, CtxT, S> Responder for GraphQLResponse<T, CtxT, S>
203where
204    T: Schema<S> + Send + 'static,
205    CtxT: AsRef<T::Context> + Send + 'static,
206    S: ScalarValue + Send + 'static,
207    for<'a> &'a S: ScalarRefValue<'a>,
208{
209    type Response = Response<Vec<u8>>;
210    type Error = Error;
211    type Respond = GraphQLRespond;
212
213    fn respond(self) -> Self::Respond {
214        let Self {
215            request,
216            schema,
217            context,
218        } = self;
219        let handle = tsukuyomi_server::rt::spawn_fn(move || -> tsukuyomi::Result<_> {
220            use self::GraphQLRequestKind::*;
221            match request.0 {
222                Single(request) => {
223                    let response = request.execute(schema.as_root_node(), context.as_ref());
224                    let status = if response.is_ok() {
225                        StatusCode::OK
226                    } else {
227                        StatusCode::BAD_REQUEST
228                    };
229                    let body = serde_json::to_vec(&response)
230                        .map_err(tsukuyomi::error::internal_server_error)?;
231                    Ok(Response::builder()
232                        .status(status)
233                        .header("content-type", "application/json")
234                        .body(body)
235                        .expect("should be a valid response"))
236                }
237                Batch(requests) => {
238                    let responses: Vec<_> = requests
239                        .iter()
240                        .map(|request| request.execute(schema.as_root_node(), context.as_ref()))
241                        .collect();
242                    let status = if responses.iter().all(|response| response.is_ok()) {
243                        StatusCode::OK
244                    } else {
245                        StatusCode::BAD_REQUEST
246                    };
247                    let body = serde_json::to_vec(&responses)
248                        .map_err(tsukuyomi::error::internal_server_error)?;
249                    Ok(Response::builder()
250                        .status(status)
251                        .header("content-type", "application/json")
252                        .body(body)
253                        .expect("should be a valid response"))
254                }
255            }
256        });
257
258        GraphQLRespond { handle }
259    }
260}
261
262#[doc(hidden)]
263#[allow(missing_debug_implementations)]
264pub struct GraphQLRespond {
265    handle: tsukuyomi_server::rt::SpawnHandle<
266        tsukuyomi::Result<Response<Vec<u8>>>,
267        tsukuyomi_server::rt::BlockingError,
268    >,
269}
270
271impl TryFuture for GraphQLRespond {
272    type Ok = Response<Vec<u8>>;
273    type Error = Error;
274
275    #[inline]
276    fn poll_ready(&mut self, _: &mut Input<'_>) -> Poll<Self::Ok, Self::Error> {
277        futures::try_ready!(self
278            .handle
279            .poll()
280            .map_err(tsukuyomi::error::internal_server_error))
281        .map(Into::into)
282    }
283}