1#![cfg(feature = "async-graphql")]
9#![cfg_attr(docsrs, doc(cfg(feature = "async-graphql")))]
10
11use std::{future::Future, str::FromStr, time::Duration};
12
13use async_graphql::{
14 BatchRequest as GqlBatchRequest, BatchResponse as GqlBatchResponse, Data, Executor,
15 Result as GqlResult,
16 http::{
17 DefaultOnConnInitType, DefaultOnPingType, MultipartOptions, WebSocket as GqlWebSocket,
18 WebSocketProtocols, WsMessage, default_on_connection_init, default_on_ping,
19 },
20};
21use futures_util::{Sink, SinkExt as _, Stream, StreamExt as _};
22use http::{HeaderValue, StatusCode, header};
23use http_body_util::BodyExt;
24use hyper_util::rt::TokioIo;
25use tokio_tungstenite::{WebSocketStream, tungstenite::protocol::Role};
26
27use crate::{
28 body::TakoBody,
29 extractors::{FromRequest, FromRequestParts},
30 responder::Responder,
31 types::{Request, Response},
32};
33
34#[cfg(feature = "graphiql")]
35pub use crate::graphiql::{GraphiQL, graphiql};
36
37pub struct GraphQLRequest(pub async_graphql::Request);
39
40impl GraphQLRequest {
41 pub fn into_inner(self) -> async_graphql::Request {
42 self.0
43 }
44}
45
46pub struct GraphQLBatchRequest(pub GqlBatchRequest);
48
49impl GraphQLBatchRequest {
50 pub fn into_inner(self) -> GqlBatchRequest {
51 self.0
52 }
53}
54
55#[derive(Debug)]
57pub enum GraphQLError {
58 MissingQuery,
59 BodyRead(String),
60 InvalidJson(String),
61 Parse(String),
62}
63
64#[derive(Clone)]
66pub struct GraphQLOptions {
67 pub multipart: MultipartOptions,
68}
69
70impl Default for GraphQLOptions {
71 fn default() -> Self {
72 Self {
73 multipart: MultipartOptions::default(),
74 }
75 }
76}
77
78impl Responder for GraphQLError {
79 fn into_response(self) -> Response {
80 match self {
81 GraphQLError::MissingQuery => {
82 (StatusCode::BAD_REQUEST, "Missing GraphQL query").into_response()
83 }
84 GraphQLError::BodyRead(e) => {
85 (StatusCode::BAD_REQUEST, format!("Failed to read body: {e}")).into_response()
86 }
87 GraphQLError::InvalidJson(e) => {
88 (StatusCode::BAD_REQUEST, format!("Invalid JSON: {e}")).into_response()
89 }
90 GraphQLError::Parse(e) => {
91 (StatusCode::BAD_REQUEST, format!("Invalid request: {e}")).into_response()
92 }
93 }
94 }
95}
96
97pub struct GraphQLProtocol(pub WebSocketProtocols);
99
100#[derive(Debug)]
101pub struct GraphQLProtocolRejection;
102
103impl Responder for GraphQLProtocolRejection {
104 fn into_response(self) -> Response {
105 (
106 StatusCode::BAD_REQUEST,
107 "Missing or invalid Sec-WebSocket-Protocol",
108 )
109 .into_response()
110 }
111}
112
113impl<'a> FromRequestParts<'a> for GraphQLProtocol {
114 type Error = GraphQLProtocolRejection;
115
116 fn from_request_parts(
117 parts: &'a mut http::request::Parts,
118 ) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
119 std::future::ready(
120 parts
121 .headers
122 .get(header::SEC_WEBSOCKET_PROTOCOL)
123 .and_then(|v| v.to_str().ok())
124 .and_then(|protocols| {
125 protocols
126 .split(',')
127 .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
128 })
129 .map(GraphQLProtocol)
130 .ok_or(GraphQLProtocolRejection),
131 )
132 }
133}
134
135impl<'a> FromRequest<'a> for GraphQLProtocol {
136 type Error = GraphQLProtocolRejection;
137
138 fn from_request(
139 req: &'a mut Request,
140 ) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
141 std::future::ready(
142 req
143 .headers()
144 .get(header::SEC_WEBSOCKET_PROTOCOL)
145 .and_then(|v| v.to_str().ok())
146 .and_then(|protocols| {
147 protocols
148 .split(',')
149 .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
150 })
151 .map(GraphQLProtocol)
152 .ok_or(GraphQLProtocolRejection),
153 )
154 }
155}
156
157#[inline]
158fn resolve_opts(req: &Request) -> MultipartOptions {
159 if let Some(opts) = req.extensions().get::<GraphQLOptions>() {
161 return opts.multipart.clone();
162 }
163 if let Some(global) = crate::state::get_state::<GraphQLOptions>() {
165 return global.as_ref().multipart.clone();
166 }
167 MultipartOptions::default()
168}
169
170fn parse_get_request(req: &Request) -> Result<async_graphql::Request, GraphQLError> {
171 let qs = req.uri().query().unwrap_or("");
172 async_graphql::http::parse_query_string(qs).map_err(|e| GraphQLError::Parse(e.to_string()))
173}
174
175async fn read_body_bytes(req: &mut Request) -> Result<bytes::Bytes, GraphQLError> {
176 req
177 .body_mut()
178 .collect()
179 .await
180 .map_err(|e| GraphQLError::BodyRead(e.to_string()))
181 .map(|collected| collected.to_bytes())
182}
183
184impl<'a> FromRequest<'a> for GraphQLRequest {
185 type Error = GraphQLError;
186
187 fn from_request(
188 req: &'a mut Request,
189 ) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
190 async move {
191 if req.method() == http::Method::GET {
192 return Ok(GraphQLRequest(parse_get_request(req)?));
193 }
194
195 let opts = resolve_opts(req);
197
198 let body = read_body_bytes(req).await?;
199 let content_type = req
200 .headers()
201 .get(http::header::CONTENT_TYPE)
202 .and_then(|v| v.to_str().ok())
203 .map(|s| s.to_string());
204
205 let reader = futures_util::io::Cursor::new(body.to_vec());
206 let req = async_graphql::http::receive_body(content_type.as_deref(), reader, opts)
207 .await
208 .map_err(|e| GraphQLError::Parse(e.to_string()))?;
209 Ok(GraphQLRequest(req))
210 }
211 }
212}
213
214pub fn attach_graphql_options(req: &mut Request, opts: GraphQLOptions) {
217 req.extensions_mut().insert(opts);
218}
219
220pub fn set_global_graphql_options(opts: GraphQLOptions) {
222 crate::state::set_state::<GraphQLOptions>(opts);
223}
224
225pub async fn receive_graphql(
226 req: &mut Request,
227 opts: MultipartOptions,
228) -> Result<async_graphql::Request, GraphQLError> {
229 if req.method() == http::Method::GET {
230 return parse_get_request(req);
231 }
232 let body = read_body_bytes(req).await?;
233 let content_type = req
234 .headers()
235 .get(http::header::CONTENT_TYPE)
236 .and_then(|v| v.to_str().ok())
237 .map(|s| s.to_string());
238 let reader = futures_util::io::Cursor::new(body.to_vec());
239 async_graphql::http::receive_body(content_type.as_deref(), reader, opts)
240 .await
241 .map_err(|e| GraphQLError::Parse(e.to_string()))
242}
243
244pub async fn receive_graphql_batch(
246 req: &mut Request,
247 opts: MultipartOptions,
248) -> Result<GqlBatchRequest, GraphQLError> {
249 if req.method() == http::Method::GET {
250 let single = parse_get_request(req)?;
251 return Ok(GqlBatchRequest::Single(single));
252 }
253 let body = read_body_bytes(req).await?;
254 let content_type = req
255 .headers()
256 .get(http::header::CONTENT_TYPE)
257 .and_then(|v| v.to_str().ok())
258 .map(|s| s.to_string());
259 let reader = futures_util::io::Cursor::new(body.to_vec());
260 async_graphql::http::receive_batch_body(content_type.as_deref(), reader, opts)
261 .await
262 .map_err(|e| GraphQLError::Parse(e.to_string()))
263}
264
265impl<'a> FromRequest<'a> for GraphQLBatchRequest {
266 type Error = GraphQLError;
267
268 fn from_request(
269 req: &'a mut Request,
270 ) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
271 async move {
272 if req.method() == http::Method::GET {
273 let single = parse_get_request(req)?;
275 return Ok(GraphQLBatchRequest(GqlBatchRequest::Single(single)));
276 }
277
278 let opts = resolve_opts(req);
280
281 let body = read_body_bytes(req).await?;
282 let content_type = req
283 .headers()
284 .get(http::header::CONTENT_TYPE)
285 .and_then(|v| v.to_str().ok())
286 .map(|s| s.to_string());
287 let reader = futures_util::io::Cursor::new(body.to_vec());
288 let batch = async_graphql::http::receive_batch_body(content_type.as_deref(), reader, opts)
289 .await
290 .map_err(|e| GraphQLError::Parse(e.to_string()))?;
291 Ok(GraphQLBatchRequest(batch))
292 }
293 }
294}
295
296pub struct GraphQLResponse(pub async_graphql::Response);
298
299impl From<async_graphql::Response> for GraphQLResponse {
300 fn from(value: async_graphql::Response) -> Self {
301 Self(value)
302 }
303}
304
305impl Responder for GraphQLResponse {
306 fn into_response(self) -> Response {
307 match serde_json::to_vec(&self.0) {
308 Ok(buf) => {
309 let mut res = Response::new(TakoBody::from(buf));
310 res.headers_mut().insert(
311 header::CONTENT_TYPE,
312 HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()),
313 );
314 res
315 }
316 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
317 }
318 }
319}
320
321pub struct GraphQLBatchResponse(pub GqlBatchResponse);
323
324impl From<GqlBatchResponse> for GraphQLBatchResponse {
325 fn from(value: GqlBatchResponse) -> Self {
326 Self(value)
327 }
328}
329
330impl Responder for GraphQLBatchResponse {
331 fn into_response(self) -> Response {
332 match serde_json::to_vec(&self.0) {
333 Ok(buf) => {
334 let mut res = Response::new(TakoBody::from(buf));
335 res.headers_mut().insert(
336 header::CONTENT_TYPE,
337 HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()),
338 );
339 res
340 }
341 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
342 }
343 }
344}
345
346pub struct GraphQLSubscription<E, OnConnInit = DefaultOnConnInitType, OnPing = DefaultOnPingType>
358where
359 E: Executor,
360{
361 request: Request,
362 executor: E,
363 data: Data,
364 on_connection_init: OnConnInit,
365 on_ping: OnPing,
366 keepalive_timeout: Option<Duration>,
367}
368
369impl<E> GraphQLSubscription<E, DefaultOnConnInitType, DefaultOnPingType>
370where
371 E: Executor,
372{
373 pub fn new(request: Request, executor: E) -> Self {
374 Self {
375 request,
376 executor,
377 data: Data::default(),
378 on_connection_init: default_on_connection_init,
379 on_ping: default_on_ping,
380 keepalive_timeout: None,
381 }
382 }
383}
384
385impl<E, OnConnInit, OnPing> GraphQLSubscription<E, OnConnInit, OnPing>
386where
387 E: Executor,
388{
389 pub fn with_data(mut self, data: Data) -> Self {
390 self.data = data;
391 self
392 }
393
394 pub fn keepalive_timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
395 self.keepalive_timeout = timeout.into();
396 self
397 }
398
399 pub fn on_connection_init<F, Fut>(self, f: F) -> GraphQLSubscription<E, F, OnPing>
400 where
401 F: FnOnce(serde_json::Value) -> Fut + Send + 'static,
402 Fut: Future<Output = GqlResult<Data>> + Send + 'static,
403 {
404 GraphQLSubscription {
405 request: self.request,
406 executor: self.executor,
407 data: self.data,
408 on_connection_init: f,
409 on_ping: self.on_ping,
410 keepalive_timeout: self.keepalive_timeout,
411 }
412 }
413
414 pub fn on_ping<F, Fut>(self, f: F) -> GraphQLSubscription<E, OnConnInit, F>
415 where
416 F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> Fut + Clone + Send + 'static,
417 Fut: Future<Output = GqlResult<Option<serde_json::Value>>> + Send + 'static,
418 {
419 GraphQLSubscription {
420 request: self.request,
421 executor: self.executor,
422 data: self.data,
423 on_connection_init: self.on_connection_init,
424 on_ping: f,
425 keepalive_timeout: self.keepalive_timeout,
426 }
427 }
428}
429
430impl<E, OnConnInit, OnConnInitFut, OnPing, OnPingFut> Responder
431 for GraphQLSubscription<E, OnConnInit, OnPing>
432where
433 E: Executor + Send + Sync + Clone + 'static,
434 OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
435 OnConnInitFut: Future<Output = GqlResult<Data>> + Send + 'static,
436 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
437 OnPingFut: Future<Output = GqlResult<Option<serde_json::Value>>> + Send + 'static,
438{
439 fn into_response(self) -> Response {
440 let (parts, body) = self.request.into_parts();
442 let req = http::Request::from_parts(parts, body);
443
444 let selected_protocol = req
446 .headers()
447 .get(header::SEC_WEBSOCKET_PROTOCOL)
448 .and_then(|v| v.to_str().ok())
449 .and_then(|protocols| {
450 protocols
451 .split(',')
452 .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
453 });
454
455 let Some(protocol) = selected_protocol else {
456 return (
457 StatusCode::BAD_REQUEST,
458 "Missing or invalid Sec-WebSocket-Protocol",
459 )
460 .into_response();
461 };
462
463 let key = match req.headers().get("Sec-WebSocket-Key") {
465 Some(k) => k,
466 None => {
467 return (
468 StatusCode::BAD_REQUEST,
469 "Missing Sec-WebSocket-Key for WebSocket upgrade",
470 )
471 .into_response();
472 }
473 };
474
475 let accept = {
476 use base64::{Engine as _, engine::general_purpose::STANDARD};
477 use sha1::{Digest, Sha1};
478 let mut sha1 = Sha1::new();
479 sha1.update(key.as_bytes());
480 sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
481 STANDARD.encode(sha1.finalize())
482 };
483
484 let builder = http::Response::builder()
486 .status(StatusCode::SWITCHING_PROTOCOLS)
487 .header(header::UPGRADE, "websocket")
488 .header(header::CONNECTION, "Upgrade")
489 .header("Sec-WebSocket-Accept", accept)
490 .header(
491 header::SEC_WEBSOCKET_PROTOCOL,
492 HeaderValue::from_static(protocol.sec_websocket_protocol()),
493 );
494
495 let response = builder.body(TakoBody::empty()).unwrap();
496
497 if let Some(on_upgrade) = req.extensions().get::<hyper::upgrade::OnUpgrade>().cloned() {
499 let executor = self.executor.clone();
500 let data = self.data;
501 let on_conn_init = self.on_connection_init;
502 let on_ping = self.on_ping;
503 let keepalive = self.keepalive_timeout;
504
505 tokio::spawn(async move {
506 if let Ok(upgraded) = on_upgrade.await {
507 let upgraded = TokioIo::new(upgraded);
508 let ws = WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
509 let (mut sink, stream) = ws.split();
510
511 let input = stream
512 .take_while(|res| futures_util::future::ready(res.is_ok()))
513 .map(Result::unwrap)
514 .filter_map(|msg| match msg {
515 tokio_tungstenite::tungstenite::Message::Text(_)
516 | tokio_tungstenite::tungstenite::Message::Binary(_) => {
517 futures_util::future::ready(Some(msg))
518 }
519 _ => futures_util::future::ready(None),
520 })
521 .map(|msg| msg.into_data());
522
523 let mut stream = GqlWebSocket::new(executor, input, protocol)
524 .connection_data(data)
525 .on_connection_init(on_conn_init)
526 .on_ping(on_ping.clone())
527 .keepalive_timeout(keepalive)
528 .map(|msg| match msg {
529 WsMessage::Text(text) => tokio_tungstenite::tungstenite::Message::Text(text.into()),
530 WsMessage::Close(_code, _status) => {
531 tokio_tungstenite::tungstenite::Message::Close(None)
533 }
534 });
535
536 while let Some(item) = stream.next().await {
537 if sink.send(item).await.is_err() {
538 break;
539 }
540 }
541 }
542 });
543 }
544
545 response
546 }
547}
548
549pub struct GraphQLWebSocket<SinkT, StreamT, E, OnConnInit, OnPing>
554where
555 E: Executor,
556{
557 sink: SinkT,
558 stream: StreamT,
559 executor: E,
560 data: Data,
561 on_connection_init: OnConnInit,
562 on_ping: OnPing,
563 protocol: WebSocketProtocols,
564 keepalive_timeout: Option<Duration>,
565}
566
567impl<S, E>
568 GraphQLWebSocket<
569 futures_util::stream::SplitSink<S, tokio_tungstenite::tungstenite::Message>,
570 futures_util::stream::SplitStream<S>,
571 E,
572 DefaultOnConnInitType,
573 DefaultOnPingType,
574 >
575where
576 S: Stream<
577 Item = Result<tokio_tungstenite::tungstenite::Message, tokio_tungstenite::tungstenite::Error>,
578 > + Sink<tokio_tungstenite::tungstenite::Message>,
579 E: Executor,
580{
581 pub fn new(stream: S, executor: E, protocol: WebSocketProtocols) -> Self {
583 let (sink, stream) = stream.split();
584 GraphQLWebSocket::new_with_pair(sink, stream, executor, protocol)
585 }
586}
587
588impl<SinkT, StreamT, E>
589 GraphQLWebSocket<SinkT, StreamT, E, DefaultOnConnInitType, DefaultOnPingType>
590where
591 SinkT: Sink<tokio_tungstenite::tungstenite::Message>,
592 StreamT: Stream<
593 Item = Result<tokio_tungstenite::tungstenite::Message, tokio_tungstenite::tungstenite::Error>,
594 >,
595 E: Executor,
596{
597 pub fn new_with_pair(
599 sink: SinkT,
600 stream: StreamT,
601 executor: E,
602 protocol: WebSocketProtocols,
603 ) -> Self {
604 Self {
605 sink,
606 stream,
607 executor,
608 data: Data::default(),
609 on_connection_init: default_on_connection_init,
610 on_ping: default_on_ping,
611 protocol,
612 keepalive_timeout: None,
613 }
614 }
615}
616
617impl<SinkT, StreamT, E, OnConnInit, OnPing> GraphQLWebSocket<SinkT, StreamT, E, OnConnInit, OnPing>
618where
619 SinkT: Sink<tokio_tungstenite::tungstenite::Message>,
620 StreamT: Stream<
621 Item = Result<tokio_tungstenite::tungstenite::Message, tokio_tungstenite::tungstenite::Error>,
622 >,
623 E: Executor,
624{
625 pub fn with_data(self, data: Data) -> Self {
626 Self { data, ..self }
627 }
628
629 pub fn keepalive_timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
630 Self {
631 keepalive_timeout: timeout.into(),
632 ..self
633 }
634 }
635}
636
637impl<SinkT, StreamT, E, OnConnInit, OnConnInitFut, OnPing, OnPingFut>
638 GraphQLWebSocket<SinkT, StreamT, E, OnConnInit, OnPing>
639where
640 SinkT: Sink<tokio_tungstenite::tungstenite::Message> + Unpin,
641 StreamT: Stream<
642 Item = Result<tokio_tungstenite::tungstenite::Message, tokio_tungstenite::tungstenite::Error>,
643 > + Unpin,
644 E: Executor,
645 OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static,
646 OnConnInitFut: Future<Output = GqlResult<Data>> + Send + 'static,
647 OnPing: FnOnce(Option<&Data>, Option<serde_json::Value>) -> OnPingFut + Clone + Send + 'static,
648 OnPingFut: Future<Output = GqlResult<Option<serde_json::Value>>> + Send + 'static,
649{
650 pub fn on_connection_init<F, Fut>(
651 self,
652 callback: F,
653 ) -> GraphQLWebSocket<SinkT, StreamT, E, F, OnPing>
654 where
655 F: FnOnce(serde_json::Value) -> Fut + Send + 'static,
656 Fut: Future<Output = GqlResult<Data>> + Send + 'static,
657 {
658 GraphQLWebSocket {
659 sink: self.sink,
660 stream: self.stream,
661 executor: self.executor,
662 data: self.data,
663 on_connection_init: callback,
664 on_ping: self.on_ping,
665 protocol: self.protocol,
666 keepalive_timeout: self.keepalive_timeout,
667 }
668 }
669
670 pub fn on_ping<F, Fut>(self, callback: F) -> GraphQLWebSocket<SinkT, StreamT, E, OnConnInit, F>
671 where
672 F: FnOnce(Option<&Data>, Option<serde_json::Value>) -> Fut + Clone + Send + 'static,
673 Fut: Future<Output = GqlResult<Option<serde_json::Value>>> + Send + 'static,
674 {
675 GraphQLWebSocket {
676 sink: self.sink,
677 stream: self.stream,
678 executor: self.executor,
679 data: self.data,
680 on_connection_init: self.on_connection_init,
681 on_ping: callback,
682 protocol: self.protocol,
683 keepalive_timeout: self.keepalive_timeout,
684 }
685 }
686
687 pub async fn serve(mut self) {
689 let input = self
690 .stream
691 .take_while(|res| futures_util::future::ready(res.is_ok()))
692 .map(Result::unwrap)
693 .filter_map(|msg| match msg {
694 tokio_tungstenite::tungstenite::Message::Text(_)
695 | tokio_tungstenite::tungstenite::Message::Binary(_) => {
696 futures_util::future::ready(Some(msg))
697 }
698 _ => futures_util::future::ready(None),
699 })
700 .map(|msg| msg.into_data());
701
702 let mut out_stream = GqlWebSocket::new(self.executor, input, self.protocol)
703 .connection_data(self.data)
704 .on_connection_init(self.on_connection_init)
705 .on_ping(self.on_ping.clone())
706 .keepalive_timeout(self.keepalive_timeout)
707 .map(|msg| match msg {
708 WsMessage::Text(text) => tokio_tungstenite::tungstenite::Message::Text(text.into()),
709 WsMessage::Close(_code, _status) => tokio_tungstenite::tungstenite::Message::Close(None),
710 });
711
712 while let Some(item) = out_stream.next().await {
713 if self.sink.send(item).await.is_err() {
714 break;
715 }
716 }
717 }
718}