tonic/transport/server/
io_stream.rs1#[cfg(feature = "_tls-any")]
2use std::future::Future;
3#[cfg(feature = "_tls-any")]
4use std::pin::pin;
5use std::{
6 io,
7 ops::ControlFlow,
8 pin::Pin,
9 task::{Context, Poll, ready},
10};
11
12use pin_project::pin_project;
13use tokio::io::{AsyncRead, AsyncWrite};
14#[cfg(feature = "_tls-any")]
15use tokio::task::JoinSet;
16use tokio_stream::Stream;
17#[cfg(feature = "_tls-any")]
18use tokio_stream::StreamExt as _;
19
20use super::service::ServerIo;
21#[cfg(feature = "_tls-any")]
22use super::service::TlsAcceptor;
23
24#[cfg(feature = "_tls-any")]
25struct State<IO>(TlsAcceptor, JoinSet<Result<ServerIo<IO>, crate::BoxError>>);
26
27#[pin_project]
28pub(crate) struct ServerIoStream<S, IO, IE>
29where
30 S: Stream<Item = Result<IO, IE>>,
31{
32 #[pin]
33 inner: S,
34 #[cfg(feature = "_tls-any")]
35 state: Option<State<IO>>,
36}
37
38impl<S, IO, IE> ServerIoStream<S, IO, IE>
39where
40 S: Stream<Item = Result<IO, IE>>,
41{
42 pub(crate) fn new(incoming: S, #[cfg(feature = "_tls-any")] tls: Option<TlsAcceptor>) -> Self {
43 Self {
44 inner: incoming,
45 #[cfg(feature = "_tls-any")]
46 state: tls.map(|tls| State(tls, JoinSet::new())),
47 }
48 }
49
50 fn poll_next_without_tls(
51 mut self: Pin<&mut Self>,
52 cx: &mut Context<'_>,
53 ) -> Poll<Option<Result<ServerIo<IO>, crate::BoxError>>>
54 where
55 IE: Into<crate::BoxError>,
56 {
57 match ready!(self.as_mut().project().inner.poll_next(cx)) {
58 Some(Ok(io)) => Poll::Ready(Some(Ok(ServerIo::new_io(io)))),
59 Some(Err(e)) => match handle_tcp_accept_error(e) {
60 ControlFlow::Continue(()) => {
61 cx.waker().wake_by_ref();
62 Poll::Pending
63 }
64 ControlFlow::Break(e) => Poll::Ready(Some(Err(e))),
65 },
66 None => Poll::Ready(None),
67 }
68 }
69}
70
71impl<S, IO, IE> Stream for ServerIoStream<S, IO, IE>
72where
73 S: Stream<Item = Result<IO, IE>>,
74 IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
75 IE: Into<crate::BoxError>,
76{
77 type Item = Result<ServerIo<IO>, crate::BoxError>;
78
79 #[cfg(not(feature = "_tls-any"))]
80 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81 self.poll_next_without_tls(cx)
82 }
83
84 #[cfg(feature = "_tls-any")]
85 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
86 let mut projected = self.as_mut().project();
87
88 let Some(State(tls, tasks)) = projected.state else {
89 return self.poll_next_without_tls(cx);
90 };
91
92 let select_output = ready!(pin!(select(&mut projected.inner, tasks)).poll(cx));
93
94 match select_output {
95 SelectOutput::Incoming(stream) => {
96 let tls = tls.clone();
97 tasks.spawn(async move {
98 let io = tls.accept(stream).await?;
99 Ok(ServerIo::new_tls_io(io))
100 });
101 cx.waker().wake_by_ref();
102 Poll::Pending
103 }
104
105 SelectOutput::Io(io) => Poll::Ready(Some(Ok(io))),
106
107 SelectOutput::TcpErr(e) => match handle_tcp_accept_error(e) {
108 ControlFlow::Continue(()) => {
109 cx.waker().wake_by_ref();
110 Poll::Pending
111 }
112 ControlFlow::Break(e) => Poll::Ready(Some(Err(e))),
113 },
114
115 SelectOutput::TlsErr(e) => {
116 tracing::debug!(error = %e, "tls accept error");
117 cx.waker().wake_by_ref();
118 Poll::Pending
119 }
120
121 SelectOutput::Done => Poll::Ready(None),
122 }
123 }
124}
125
126fn handle_tcp_accept_error(e: impl Into<crate::BoxError>) -> ControlFlow<crate::BoxError> {
127 let e = e.into();
128 tracing::debug!(error = %e, "accept loop error");
129 if let Some(e) = e.downcast_ref::<io::Error>() {
130 if matches!(
131 e.kind(),
132 io::ErrorKind::ConnectionAborted
133 | io::ErrorKind::ConnectionReset
134 | io::ErrorKind::BrokenPipe
135 | io::ErrorKind::Interrupted
136 | io::ErrorKind::WouldBlock
137 | io::ErrorKind::TimedOut
138 ) {
139 return ControlFlow::Continue(());
140 }
141 }
142
143 ControlFlow::Break(e)
144}
145
146#[cfg(feature = "_tls-any")]
147async fn select<IO: 'static, IE>(
148 incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
149 tasks: &mut JoinSet<Result<ServerIo<IO>, crate::BoxError>>,
150) -> SelectOutput<IO>
151where
152 IE: Into<crate::BoxError>,
153{
154 let incoming_stream_future = async {
155 match incoming.try_next().await {
156 Ok(Some(stream)) => SelectOutput::Incoming(stream),
157 Ok(None) => SelectOutput::Done,
158 Err(e) => SelectOutput::TcpErr(e.into()),
159 }
160 };
161
162 if tasks.is_empty() {
163 return incoming_stream_future.await;
164 }
165
166 tokio::select! {
167 stream = incoming_stream_future => stream,
168 accept = tasks.join_next() => {
169 match accept.expect("JoinSet should never end") {
170 Ok(Ok(io)) => SelectOutput::Io(io),
171 Ok(Err(e)) => SelectOutput::TlsErr(e),
172 Err(e) => SelectOutput::TlsErr(e.into()),
173 }
174 }
175 }
176}
177
178#[cfg(feature = "_tls-any")]
179enum SelectOutput<A> {
180 Incoming(A),
181 Io(ServerIo<A>),
182 TcpErr(crate::BoxError),
183 TlsErr(crate::BoxError),
184 Done,
185}