1use crate::{
34 Code, Encoding, Status,
35 encoding::DEFAULT_MAX_MESSAGE_SIZE,
36 frame::{
37 reader::{ReadState, poll_read_message},
38 writer::encode_payload,
39 },
40 metadata::Metadata,
41 timeout::format_grpc_timeout,
42};
43use bytes::Bytes;
44use futures_lite::{AsyncWriteExt, future::poll_fn};
45use std::{
46 future::Future,
47 marker::PhantomData,
48 pin::Pin,
49 task::Poll,
50 time::{Duration, Instant},
51};
52use trillium::{Headers, KnownHeaderName, Status as HttpStatus, Transport};
53use trillium_client::{Body, Client, Conn, ConnExt, Version};
54use trillium_http::Upgrade as HttpUpgrade;
55use trillium_server_common::Runtime;
56
57type Upgrade = HttpUpgrade<Box<dyn Transport>>;
58
59struct Pending {
62 client: Client,
63 path: String,
64 content_type: String,
65 request_metadata: Metadata,
66 body: Vec<u8>,
67 send_closed: bool,
68}
69
70struct Live<R> {
74 reader: R,
75 response_headers: Headers,
76 read_state: ReadState,
77 response_encoding: Encoding,
78 head_status: Option<Result<(), Status>>,
83}
84
85enum Inner {
88 Pending(Pending),
90 Reading(Box<Live<Conn>>),
92 Duplex(Box<Live<Upgrade>>),
94 Done {
96 response_headers: Headers,
97 trailers: Headers,
98 },
99 Failed(Status),
103}
104
105#[derive(Clone)]
110pub struct CancelHandle(async_channel::Sender<()>);
111
112impl CancelHandle {
113 pub fn cancel(&self) {
115 self.0.close();
117 }
118}
119
120pub struct GrpcClientConn<Req, Resp> {
128 inner: Inner,
129 decode: fn(&[u8]) -> Result<Resp, Status>,
130 encode: fn(&Req) -> Result<Bytes, Status>,
131 outbound_encoding: Encoding,
132 max_message_size: usize,
133 full_duplex: bool,
134 deadline: Option<Instant>,
135 runtime: Runtime,
136 cancel_rx: async_channel::Receiver<()>,
137 cancel_tx: async_channel::Sender<()>,
138 init_error: Option<Status>,
143 _marker: PhantomData<fn() -> (Req, Resp)>,
144}
145
146impl<Req, Resp> GrpcClientConn<Req, Resp>
147where
148 Req: Send + 'static,
149 Resp: Send + 'static,
150{
151 #[allow(clippy::too_many_arguments)]
156 pub(crate) fn open(
157 client: &Client,
158 path: &str,
159 content_type: String,
160 request_metadata: Metadata,
161 timeout: Option<Duration>,
162 encode: fn(&Req) -> Result<Bytes, Status>,
163 decode: fn(&[u8]) -> Result<Resp, Status>,
164 outbound_encoding: Encoding,
165 full_duplex: bool,
166 ) -> Self {
167 let (cancel_tx, cancel_rx) = async_channel::bounded(1);
168 Self {
169 inner: Inner::Pending(Pending {
170 client: client.clone(),
171 path: path.to_string(),
172 content_type,
173 request_metadata,
174 body: Vec::new(),
175 send_closed: false,
176 }),
177 decode,
178 encode,
179 outbound_encoding,
180 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
181 full_duplex,
182 deadline: timeout.map(|d| Instant::now() + d),
183 runtime: client.connector().runtime(),
184 cancel_rx,
185 cancel_tx,
186 init_error: None,
187 _marker: PhantomData,
188 }
189 }
190
191 pub(crate) fn add_ascii_metadata(&mut self, key: &str, value: &str) {
196 let result = match &mut self.inner {
197 Inner::Pending(pending) => pending.request_metadata.insert_ascii(key, value),
198 _ => Ok(()),
199 };
200 if let Err(e) = result {
201 self.init_error
202 .get_or_insert_with(|| Status::invalid_argument(format!("invalid metadata: {e}")));
203 }
204 }
205
206 pub(crate) fn add_binary_metadata(&mut self, key: &str, value: Vec<u8>) {
209 let result = match &mut self.inner {
210 Inner::Pending(pending) => pending.request_metadata.insert_binary(key, value),
211 _ => Ok(()),
212 };
213 if let Err(e) = result {
214 self.init_error
215 .get_or_insert_with(|| Status::invalid_argument(format!("invalid metadata: {e}")));
216 }
217 }
218
219 pub(crate) fn set_deadline_from_now(&mut self, timeout: Duration) {
221 self.deadline = Some(Instant::now() + timeout);
222 }
223
224 pub(crate) fn buffer_request(&mut self, message: Req) {
229 match (self.encode)(&message)
230 .and_then(|payload| encode_payload(&payload, self.outbound_encoding))
231 {
232 Ok(frame) => {
233 if let Inner::Pending(pending) = &mut self.inner {
234 pending.body.extend_from_slice(&frame);
235 }
236 }
237 Err(status) => {
238 self.init_error.get_or_insert(status);
239 }
240 }
241 }
242
243 pub(crate) async fn open_head(&mut self) -> Result<(), Status> {
247 if matches!(self.inner, Inner::Pending(_)) {
248 if self.full_duplex {
249 self.materialize_duplex().await
250 } else {
251 self.materialize_reading().await
252 }
253 } else {
254 Ok(())
255 }
256 }
257
258 pub fn new<C>(
263 client: &Client,
264 path: &str,
265 metadata: Metadata,
266 timeout: Option<Duration>,
267 full_duplex: bool,
268 ) -> Self
269 where
270 C: crate::Codec<Req> + crate::Codec<Resp>,
271 {
272 let content_type = format!(
273 "application/grpc+{}",
274 <C as crate::Codec<Req>>::content_type_suffix()
275 );
276 let outbound_encoding = client
277 .default_headers()
278 .get_str("grpc-encoding")
279 .and_then(Encoding::from_grpc_encoding)
280 .unwrap_or(Encoding::Identity);
281 Self::open(
282 client,
283 path,
284 content_type,
285 metadata,
286 timeout,
287 <C as crate::Codec<Req>>::encode,
288 decode_response::<C, Resp>,
289 outbound_encoding,
290 full_duplex,
291 )
292 }
293
294 pub fn cancel_handle(&self) -> CancelHandle {
297 CancelHandle(self.cancel_tx.clone())
298 }
299
300 pub fn headers(&self) -> Option<&Headers> {
304 match &self.inner {
305 Inner::Reading(live) => Some(&live.response_headers),
306 Inner::Duplex(live) => Some(&live.response_headers),
307 Inner::Done {
308 response_headers, ..
309 } => Some(response_headers),
310 Inner::Pending(_) | Inner::Failed(_) => None,
311 }
312 }
313
314 pub fn trailers(&self) -> Option<&Headers> {
317 match &self.inner {
318 Inner::Done { trailers, .. } => Some(trailers),
319 _ => None,
320 }
321 }
322
323 pub async fn send(&mut self, message: Req) -> Result<(), Status> {
329 let frame = (self.encode)(&message)
330 .and_then(|payload| encode_payload(&payload, self.outbound_encoding))?;
331
332 match &mut self.inner {
333 Inner::Pending(pending) => {
334 if pending.send_closed {
335 return Err(Status::internal("send after close_send"));
336 }
337 pending.body.extend_from_slice(&frame);
338 Ok(())
339 }
340 Inner::Duplex(live) => {
341 let (deadline, runtime, cancel_rx) =
342 (self.deadline, self.runtime.clone(), self.cancel_rx.clone());
343 let write = async {
344 live.reader
345 .write_all(&frame)
346 .await
347 .map_err(|e| Status::unavailable(format!("write error: {e}")))
348 };
349 race(deadline, &runtime, &cancel_rx, write).await
350 }
351 _ => Err(Status::internal("send after response started")),
352 }
353 }
354
355 pub async fn close_send(&mut self) -> Result<(), Status> {
362 match &mut self.inner {
363 Inner::Pending(pending) => {
364 pending.send_closed = true;
365 if self.full_duplex {
366 self.materialize_duplex().await?;
367 if let Inner::Duplex(live) = &mut self.inner {
368 live.reader
369 .close()
370 .await
371 .map_err(|e| Status::unavailable(format!("close error: {e}")))?;
372 }
373 Ok(())
374 } else {
375 self.materialize_reading().await
376 }
377 }
378 Inner::Duplex(live) => live
379 .reader
380 .close()
381 .await
382 .map_err(|e| Status::unavailable(format!("close error: {e}"))),
383 _ => Ok(()),
384 }
385 }
386
387 pub async fn recv(&mut self) -> Result<Option<Resp>, Status> {
392 if matches!(self.inner, Inner::Pending(_)) {
393 if self.full_duplex {
394 self.materialize_duplex().await?;
395 } else {
396 self.materialize_reading().await?;
399 }
400 }
401
402 match &self.inner {
403 Inner::Reading(_) | Inner::Duplex(_) => self.read_one().await,
404 Inner::Done { .. } => Ok(None),
405 Inner::Failed(_) => {
406 let Inner::Failed(status) = std::mem::replace(
407 &mut self.inner,
408 Inner::Done {
409 response_headers: Headers::new(),
410 trailers: Headers::new(),
411 },
412 ) else {
413 unreachable!()
414 };
415 Err(status)
416 }
417 Inner::Pending(_) => unreachable!("materialized above"),
418 }
419 }
420
421 async fn read_one(&mut self) -> Result<Option<Resp>, Status> {
423 let decode = self.decode;
424 let max = self.max_message_size;
425 let (deadline, runtime, cancel_rx) =
426 (self.deadline, self.runtime.clone(), self.cancel_rx.clone());
427
428 let read = poll_fn(|cx| match &mut self.inner {
429 Inner::Reading(live) => {
430 let enc = live.response_encoding;
433 let mut body = live.reader.response_body();
434 poll_read_message(
435 Pin::new(&mut body),
436 &mut live.read_state,
437 cx,
438 decode,
439 enc,
440 max,
441 )
442 }
443 Inner::Duplex(live) => poll_read_message(
444 Pin::new(&mut live.reader),
445 &mut live.read_state,
446 cx,
447 decode,
448 live.response_encoding,
449 max,
450 ),
451 _ => Poll::Ready(None),
452 });
453
454 match race(deadline, &runtime, &cancel_rx, async { Ok(read.await) }).await {
455 Ok(Some(Ok(msg))) => {
456 if self.is_trailers_only() {
457 let _ = self.finish_from_trailers();
461 Err(Status::internal(
462 "trailers-only response (grpc-status in headers) carried a message body",
463 ))
464 } else {
465 Ok(Some(msg))
466 }
467 }
468 Ok(Some(Err(status))) => {
469 let _ = self.finish_from_trailers();
470 Err(status)
471 }
472 Ok(None) => self.finish_from_trailers().map(|()| None),
473 Err(status) => Err(status), }
475 }
476
477 fn is_trailers_only(&self) -> bool {
480 match &self.inner {
481 Inner::Reading(live) => live.head_status.is_some(),
482 Inner::Duplex(live) => live.head_status.is_some(),
483 _ => false,
484 }
485 }
486
487 fn finish_from_trailers(&mut self) -> Result<(), Status> {
496 let (response_headers, head_status, mut trailers) = match &mut self.inner {
497 Inner::Reading(live) => (
498 live.response_headers.clone(),
499 live.head_status.clone(),
500 live.reader.response_trailers().cloned().unwrap_or_default(),
501 ),
502 Inner::Duplex(live) => (
503 live.response_headers.clone(),
504 live.head_status.clone(),
505 live.reader.received_trailers().cloned().unwrap_or_default(),
506 ),
507 _ => (Headers::new(), None, Headers::new()),
508 };
509
510 let status = if trailers.get_str("grpc-status").is_some() {
511 Status::from_trailers(&trailers)
512 } else if let Some(head_status) = head_status {
513 trailers = response_headers.clone();
515 head_status
516 } else {
517 Status::from_trailers(&trailers)
518 };
519
520 self.inner = Inner::Done {
521 response_headers,
522 trailers,
523 };
524 status
525 }
526
527 async fn materialize_reading(&mut self) -> Result<(), Status> {
530 if let Some(status) = self.init_error.take() {
531 return self.fail(status);
532 }
533 let body = self.take_body();
534 let request = self.build_request(Body::from(body));
535 let (deadline, runtime, cancel_rx) =
536 (self.deadline, self.runtime.clone(), self.cancel_rx.clone());
537 let conn = match race(deadline, &runtime, &cancel_rx, async {
538 request.await.map_err(transport_error)
539 })
540 .await
541 {
542 Ok(conn) => conn,
543 Err(status) => return self.fail(status),
544 };
545
546 match process_head(&conn) {
547 Ok(Head {
548 response_headers,
549 response_encoding,
550 head_status,
551 }) => {
552 self.inner = Inner::Reading(Box::new(Live {
553 reader: conn,
554 response_headers,
555 read_state: ReadState::new(),
556 response_encoding,
557 head_status,
558 }));
559 Ok(())
560 }
561 Err(status) => self.fail(status),
562 }
563 }
564
565 fn fail(&mut self, status: Status) -> Result<(), Status> {
568 self.inner = Inner::Failed(status.clone());
569 Err(status)
570 }
571
572 async fn materialize_duplex(&mut self) -> Result<(), Status> {
575 if let Some(status) = self.init_error.take() {
576 return self.fail(status);
577 }
578 let body = self.take_body();
579 let request = self.build_request(Body::from(body));
580 let (deadline, runtime, cancel_rx) =
581 (self.deadline, self.runtime.clone(), self.cancel_rx.clone());
582 let conn = match race(deadline, &runtime, &cancel_rx, async {
583 request.upgrade().await.map_err(transport_error)
584 })
585 .await
586 {
587 Ok(conn) => conn,
588 Err(status) => return self.fail(status),
589 };
590
591 match process_head(&conn) {
592 Ok(Head {
593 response_headers,
594 response_encoding,
595 head_status,
596 }) => {
597 self.inner = Inner::Duplex(Box::new(Live {
598 reader: conn.into(),
599 response_headers,
600 read_state: ReadState::new(),
601 response_encoding,
602 head_status,
603 }));
604 Ok(())
605 }
606 Err(status) => self.fail(status),
607 }
608 }
609
610 fn take_body(&mut self) -> Vec<u8> {
612 match &mut self.inner {
613 Inner::Pending(pending) => std::mem::take(&mut pending.body),
614 _ => Vec::new(),
615 }
616 }
617
618 fn build_request(&self, body: Body) -> Conn {
621 let Inner::Pending(pending) = &self.inner else {
622 unreachable!("build_request requires Pending");
623 };
624 let mut conn = pending
625 .client
626 .post(pending.path.as_str())
627 .with_http_version(Version::Http2)
628 .with_request_header(KnownHeaderName::ContentType, pending.content_type.clone())
629 .with_request_header(KnownHeaderName::Te, "trailers")
630 .with_request_header("grpc-accept-encoding", Encoding::accepted_encodings());
631
632 if !matches!(self.outbound_encoding, Encoding::Identity) {
633 conn.request_headers_mut()
634 .insert("grpc-encoding", self.outbound_encoding.as_grpc_encoding());
635 }
636 if let Some(deadline) = self.deadline {
637 let remaining = deadline.saturating_duration_since(Instant::now());
638 conn.request_headers_mut()
639 .insert("grpc-timeout", format_grpc_timeout(remaining));
640 }
641 pending
642 .request_metadata
643 .write_into(conn.request_headers_mut());
644 conn.with_body(body)
645 }
646}
647
648async fn race<T, F>(
650 deadline: Option<Instant>,
651 runtime: &Runtime,
652 cancel_rx: &async_channel::Receiver<()>,
653 fut: F,
654) -> Result<T, Status>
655where
656 F: Future<Output = Result<T, Status>>,
657{
658 let cancel = {
659 let rx = cancel_rx.clone();
660 async move {
661 let _ = rx.recv().await;
663 Err(Status::cancelled("call cancelled"))
664 }
665 };
666
667 match deadline {
672 None => futures_lite::future::or(cancel, fut).await,
673 Some(deadline) => {
674 let Some(remaining) = deadline.checked_duration_since(Instant::now()) else {
675 return Err(Status::deadline_exceeded("deadline elapsed"));
676 };
677 let runtime = runtime.clone();
678 let timer = async move {
679 runtime.delay(remaining).await;
680 Err(Status::deadline_exceeded("deadline elapsed"))
681 };
682 futures_lite::future::or(futures_lite::future::or(cancel, timer), fut).await
683 }
684 }
685}
686
687struct Head {
692 response_headers: Headers,
693 response_encoding: Encoding,
694 head_status: Option<Result<(), Status>>,
695}
696
697fn process_head(conn: &Conn) -> Result<Head, Status> {
700 let http_status = conn.status();
701 if http_status != Some(HttpStatus::Ok) {
702 return Err(http_to_grpc_status(
703 http_status.map(|s| s as u16).unwrap_or(0),
704 ));
705 }
706
707 let ct = conn
708 .response_headers()
709 .get_str(KnownHeaderName::ContentType);
710 if ct
711 .and_then(crate::content_type::parse_grpc_content_type)
712 .is_none()
713 {
714 return Err(Status::unknown(format!(
717 "unexpected response content-type: {ct:?}"
718 )));
719 }
720
721 let response_encoding = match conn.response_headers().get_str("grpc-encoding") {
722 None => Encoding::Identity,
723 Some(s) => Encoding::from_grpc_encoding(s)
724 .ok_or_else(|| Status::internal(format!("unsupported grpc-encoding {s:?}")))?,
725 };
726
727 let response_headers = conn.response_headers().clone();
728
729 let head_status = response_headers
735 .get_str("grpc-status")
736 .is_some()
737 .then(|| Status::from_trailers(&response_headers));
738
739 Ok(Head {
740 response_headers,
741 response_encoding,
742 head_status,
743 })
744}
745
746fn transport_error(err: trillium_client::Error) -> Status {
747 Status::unavailable(format!("transport error: {err}"))
748}
749
750fn decode_response<C, Resp>(bytes: &[u8]) -> Result<Resp, Status>
755where
756 C: crate::Codec<Resp>,
757{
758 <C as crate::Codec<Resp>>::decode(bytes)
759 .map_err(|status| Status::new(Code::Internal, status.message))
760}
761
762fn http_to_grpc_status(http: u16) -> Status {
764 let code = match http {
765 400 => Code::Internal,
766 401 => Code::Unauthenticated,
767 403 => Code::PermissionDenied,
768 404 => Code::Unimplemented,
769 429 | 502 | 503 | 504 => Code::Unavailable,
770 _ => Code::Unknown,
771 };
772 Status::new(code, format!("HTTP {http}"))
773}