1use {
2 crate::{error::GraphQLParseError, Schema},
3 futures::Future,
4 http::{Method, Response, StatusCode},
5 juniper::{DefaultScalarValue, InputValue, ScalarRefValue, ScalarValue},
6 percent_encoding::percent_decode,
7 serde::Deserialize,
8 tsukuyomi::{
9 error::Error,
10 extractor::Extractor,
11 future::{Async, Poll, TryFuture},
12 input::{
13 body::{ReadAll, RequestBody},
14 header::ContentType,
15 Input,
16 },
17 responder::Responder,
18 },
19};
20
21pub fn request<S>() -> impl Extractor<
23 Output = (GraphQLRequest<S>,), Error = Error,
25 Extract = impl TryFuture<Ok = (GraphQLRequest<S>,), Error = Error> + Send + 'static,
26>
27where
28 S: ScalarValue + Send + 'static,
29 for<'a> &'a S: ScalarRefValue<'a>,
30{
31 #[allow(missing_debug_implementations)]
32 #[derive(Copy, Clone)]
33 enum RequestKind {
34 Json,
35 GraphQL,
36 }
37
38 #[allow(missing_debug_implementations)]
39 enum State {
40 Init,
41 Receive(ReadAll, RequestKind),
42 }
43
44 tsukuyomi::extractor::extract(|| {
45 let mut state = State::Init;
46 tsukuyomi::future::poll_fn(move |input| loop {
47 state = match state {
48 State::Init => {
49 if input.request.method() == Method::GET {
50 return parse_query_request(input).map(|request| Async::Ready((request,)));
51 } else if input.request.method() == Method::POST {
52 let kind = match tsukuyomi::input::header::parse::<ContentType>(input) {
53 Ok(Some(mime)) if *mime == mime::APPLICATION_JSON => RequestKind::Json,
54 Ok(Some(mime)) if *mime == "application/graphql" => {
55 RequestKind::GraphQL
56 }
57 Ok(Some(..)) => return Err(GraphQLParseError::InvalidMime.into()),
58 Ok(None) => return Err(GraphQLParseError::MissingMime.into()),
59 Err(err) => return Err(err),
60 };
61
62 let read_all = match input.locals.remove(&RequestBody::KEY) {
63 Some(body) => body.read_all(),
64 None => {
65 return Err(tsukuyomi::error::internal_server_error(
66 "the payload has already stolen by another extractor",
67 ))
68 }
69 };
70 State::Receive(read_all, kind)
71 } else {
72 return Err(GraphQLParseError::InvalidRequestMethod.into());
73 }
74 }
75 State::Receive(ref mut read_all, kind) => {
76 let data = futures::try_ready!(read_all.poll());
77 match kind {
78 RequestKind::Json => {
79 let request = serde_json::from_slice(&*data)
80 .map_err(GraphQLParseError::ParseJson)?;
81 return Ok(Async::Ready((request,)));
82 }
83 RequestKind::GraphQL => {
84 return String::from_utf8(data.to_vec())
85 .map(|query| {
86 Async::Ready((GraphQLRequest::single(query, None, None),))
87 })
88 .map_err(|e| GraphQLParseError::DecodeUtf8(e.utf8_error()).into())
89 }
90 }
91 }
92 };
93 })
94 })
95}
96
97fn parse_query_request<S>(input: &mut Input<'_>) -> tsukuyomi::Result<GraphQLRequest<S>>
98where
99 S: ScalarValue,
100 for<'a> &'a S: ScalarRefValue<'a>,
101{
102 let query_str = input
103 .request
104 .uri()
105 .query()
106 .ok_or_else(|| GraphQLParseError::MissingQuery)?;
107 parse_query_str(query_str).map_err(Into::into)
108}
109
110fn parse_query_str<S>(s: &str) -> Result<GraphQLRequest<S>, GraphQLParseError>
111where
112 S: ScalarValue,
113 for<'a> &'a S: ScalarRefValue<'a>,
114{
115 #[derive(Debug, serde::Deserialize)]
116 struct ParsedQuery {
117 query: String,
118 operation_name: Option<String>,
119 variables: Option<String>,
120 }
121 let parsed: ParsedQuery =
122 serde_urlencoded::from_str(s).map_err(GraphQLParseError::ParseQuery)?;
123
124 let query = percent_decode(parsed.query.as_ref())
125 .decode_utf8()
126 .map_err(GraphQLParseError::DecodeUtf8)?
127 .into_owned();
128
129 let operation_name = parsed.operation_name.map_or(Ok(None), |s| {
130 percent_decode(s.as_ref())
131 .decode_utf8()
132 .map_err(GraphQLParseError::DecodeUtf8)
133 .map(|s| s.into_owned())
134 .map(Some)
135 })?;
136
137 let variables = parsed
138 .variables
139 .map_or(Ok(None), |s| -> Result<_, GraphQLParseError> {
140 let decoded = percent_decode(s.as_ref())
141 .decode_utf8()
142 .map_err(GraphQLParseError::DecodeUtf8)?;
143 let variables = serde_json::from_str(&*decoded)
144 .map(Some)
145 .map_err(GraphQLParseError::ParseJson)?;
146 Ok(variables)
147 })?;
148
149 Ok(GraphQLRequest::single(query, operation_name, variables))
150}
151
152#[derive(Debug, serde::Deserialize)]
154#[serde(bound = "InputValue<S>: Deserialize<'de>")]
155pub struct GraphQLRequest<S: ScalarValue = DefaultScalarValue>(GraphQLRequestKind<S>);
156
157#[derive(Debug, Deserialize)]
158#[serde(untagged, bound = "InputValue<S>: Deserialize<'de>")]
159enum GraphQLRequestKind<S: ScalarValue> {
160 Single(juniper::http::GraphQLRequest<S>),
161 Batch(Vec<juniper::http::GraphQLRequest<S>>),
162}
163
164impl<S> GraphQLRequest<S>
165where
166 S: ScalarValue,
167 for<'a> &'a S: ScalarRefValue<'a>,
168{
169 fn single(
170 query: String,
171 operation_name: Option<String>,
172 variables: Option<InputValue<S>>,
173 ) -> Self {
174 GraphQLRequest(GraphQLRequestKind::Single(
175 juniper::http::GraphQLRequest::new(query, operation_name, variables),
176 ))
177 }
178
179 pub fn execute<T, CtxT>(self, schema: T, context: CtxT) -> GraphQLResponse<T, CtxT, S>
181 where
182 T: Schema<S> + Send + 'static,
183 CtxT: AsRef<T::Context> + Send + 'static,
184 S: Send + 'static,
185 {
186 GraphQLResponse {
187 request: self,
188 schema,
189 context,
190 }
191 }
192}
193
194#[derive(Debug)]
196pub struct GraphQLResponse<T, CtxT, S: ScalarValue = DefaultScalarValue> {
197 request: GraphQLRequest<S>,
198 schema: T,
199 context: CtxT,
200}
201
202impl<T, CtxT, S> Responder for GraphQLResponse<T, CtxT, S>
203where
204 T: Schema<S> + Send + 'static,
205 CtxT: AsRef<T::Context> + Send + 'static,
206 S: ScalarValue + Send + 'static,
207 for<'a> &'a S: ScalarRefValue<'a>,
208{
209 type Response = Response<Vec<u8>>;
210 type Error = Error;
211 type Respond = GraphQLRespond;
212
213 fn respond(self) -> Self::Respond {
214 let Self {
215 request,
216 schema,
217 context,
218 } = self;
219 let handle = tsukuyomi_server::rt::spawn_fn(move || -> tsukuyomi::Result<_> {
220 use self::GraphQLRequestKind::*;
221 match request.0 {
222 Single(request) => {
223 let response = request.execute(schema.as_root_node(), context.as_ref());
224 let status = if response.is_ok() {
225 StatusCode::OK
226 } else {
227 StatusCode::BAD_REQUEST
228 };
229 let body = serde_json::to_vec(&response)
230 .map_err(tsukuyomi::error::internal_server_error)?;
231 Ok(Response::builder()
232 .status(status)
233 .header("content-type", "application/json")
234 .body(body)
235 .expect("should be a valid response"))
236 }
237 Batch(requests) => {
238 let responses: Vec<_> = requests
239 .iter()
240 .map(|request| request.execute(schema.as_root_node(), context.as_ref()))
241 .collect();
242 let status = if responses.iter().all(|response| response.is_ok()) {
243 StatusCode::OK
244 } else {
245 StatusCode::BAD_REQUEST
246 };
247 let body = serde_json::to_vec(&responses)
248 .map_err(tsukuyomi::error::internal_server_error)?;
249 Ok(Response::builder()
250 .status(status)
251 .header("content-type", "application/json")
252 .body(body)
253 .expect("should be a valid response"))
254 }
255 }
256 });
257
258 GraphQLRespond { handle }
259 }
260}
261
262#[doc(hidden)]
263#[allow(missing_debug_implementations)]
264pub struct GraphQLRespond {
265 handle: tsukuyomi_server::rt::SpawnHandle<
266 tsukuyomi::Result<Response<Vec<u8>>>,
267 tsukuyomi_server::rt::BlockingError,
268 >,
269}
270
271impl TryFuture for GraphQLRespond {
272 type Ok = Response<Vec<u8>>;
273 type Error = Error;
274
275 #[inline]
276 fn poll_ready(&mut self, _: &mut Input<'_>) -> Poll<Self::Ok, Self::Error> {
277 futures::try_ready!(self
278 .handle
279 .poll()
280 .map_err(tsukuyomi::error::internal_server_error))
281 .map(Into::into)
282 }
283}