Skip to main content

tonic/transport/server/
io_stream.rs

1#[cfg(feature = "_tls-any")]
2use std::future::Future;
3#[cfg(feature = "_tls-any")]
4use std::pin::pin;
5use std::{
6    io,
7    ops::ControlFlow,
8    pin::Pin,
9    task::{Context, Poll, ready},
10};
11
12use pin_project::pin_project;
13use tokio::io::{AsyncRead, AsyncWrite};
14#[cfg(feature = "_tls-any")]
15use tokio::task::JoinSet;
16use tokio_stream::Stream;
17#[cfg(feature = "_tls-any")]
18use tokio_stream::StreamExt as _;
19
20use super::service::ServerIo;
21#[cfg(feature = "_tls-any")]
22use super::service::TlsAcceptor;
23
24#[cfg(feature = "_tls-any")]
25struct State<IO>(TlsAcceptor, JoinSet<Result<ServerIo<IO>, crate::BoxError>>);
26
27#[pin_project]
28pub(crate) struct ServerIoStream<S, IO, IE>
29where
30    S: Stream<Item = Result<IO, IE>>,
31{
32    #[pin]
33    inner: S,
34    #[cfg(feature = "_tls-any")]
35    state: Option<State<IO>>,
36}
37
38impl<S, IO, IE> ServerIoStream<S, IO, IE>
39where
40    S: Stream<Item = Result<IO, IE>>,
41{
42    pub(crate) fn new(incoming: S, #[cfg(feature = "_tls-any")] tls: Option<TlsAcceptor>) -> Self {
43        Self {
44            inner: incoming,
45            #[cfg(feature = "_tls-any")]
46            state: tls.map(|tls| State(tls, JoinSet::new())),
47        }
48    }
49
50    fn poll_next_without_tls(
51        mut self: Pin<&mut Self>,
52        cx: &mut Context<'_>,
53    ) -> Poll<Option<Result<ServerIo<IO>, crate::BoxError>>>
54    where
55        IE: Into<crate::BoxError>,
56    {
57        match ready!(self.as_mut().project().inner.poll_next(cx)) {
58            Some(Ok(io)) => Poll::Ready(Some(Ok(ServerIo::new_io(io)))),
59            Some(Err(e)) => match handle_tcp_accept_error(e) {
60                ControlFlow::Continue(()) => {
61                    cx.waker().wake_by_ref();
62                    Poll::Pending
63                }
64                ControlFlow::Break(e) => Poll::Ready(Some(Err(e))),
65            },
66            None => Poll::Ready(None),
67        }
68    }
69}
70
71impl<S, IO, IE> Stream for ServerIoStream<S, IO, IE>
72where
73    S: Stream<Item = Result<IO, IE>>,
74    IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
75    IE: Into<crate::BoxError>,
76{
77    type Item = Result<ServerIo<IO>, crate::BoxError>;
78
79    #[cfg(not(feature = "_tls-any"))]
80    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81        self.poll_next_without_tls(cx)
82    }
83
84    #[cfg(feature = "_tls-any")]
85    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
86        let mut projected = self.as_mut().project();
87
88        let Some(State(tls, tasks)) = projected.state else {
89            return self.poll_next_without_tls(cx);
90        };
91
92        let select_output = ready!(pin!(select(&mut projected.inner, tasks)).poll(cx));
93
94        match select_output {
95            SelectOutput::Incoming(stream) => {
96                let tls = tls.clone();
97                tasks.spawn(async move {
98                    let io = tls.accept(stream).await?;
99                    Ok(ServerIo::new_tls_io(io))
100                });
101                cx.waker().wake_by_ref();
102                Poll::Pending
103            }
104
105            SelectOutput::Io(io) => Poll::Ready(Some(Ok(io))),
106
107            SelectOutput::TcpErr(e) => match handle_tcp_accept_error(e) {
108                ControlFlow::Continue(()) => {
109                    cx.waker().wake_by_ref();
110                    Poll::Pending
111                }
112                ControlFlow::Break(e) => Poll::Ready(Some(Err(e))),
113            },
114
115            SelectOutput::TlsErr(e) => {
116                tracing::debug!(error = %e, "tls accept error");
117                cx.waker().wake_by_ref();
118                Poll::Pending
119            }
120
121            SelectOutput::Done => Poll::Ready(None),
122        }
123    }
124}
125
126fn handle_tcp_accept_error(e: impl Into<crate::BoxError>) -> ControlFlow<crate::BoxError> {
127    let e = e.into();
128    tracing::debug!(error = %e, "accept loop error");
129    if let Some(e) = e.downcast_ref::<io::Error>() {
130        if matches!(
131            e.kind(),
132            io::ErrorKind::ConnectionAborted
133                | io::ErrorKind::ConnectionReset
134                | io::ErrorKind::BrokenPipe
135                | io::ErrorKind::Interrupted
136                | io::ErrorKind::WouldBlock
137                | io::ErrorKind::TimedOut
138        ) {
139            return ControlFlow::Continue(());
140        }
141    }
142
143    ControlFlow::Break(e)
144}
145
146#[cfg(feature = "_tls-any")]
147async fn select<IO: 'static, IE>(
148    incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
149    tasks: &mut JoinSet<Result<ServerIo<IO>, crate::BoxError>>,
150) -> SelectOutput<IO>
151where
152    IE: Into<crate::BoxError>,
153{
154    let incoming_stream_future = async {
155        match incoming.try_next().await {
156            Ok(Some(stream)) => SelectOutput::Incoming(stream),
157            Ok(None) => SelectOutput::Done,
158            Err(e) => SelectOutput::TcpErr(e.into()),
159        }
160    };
161
162    if tasks.is_empty() {
163        return incoming_stream_future.await;
164    }
165
166    tokio::select! {
167        stream = incoming_stream_future => stream,
168        accept = tasks.join_next() => {
169            match accept.expect("JoinSet should never end") {
170                Ok(Ok(io)) => SelectOutput::Io(io),
171                Ok(Err(e)) => SelectOutput::TlsErr(e),
172                Err(e) => SelectOutput::TlsErr(e.into()),
173            }
174        }
175    }
176}
177
178#[cfg(feature = "_tls-any")]
179enum SelectOutput<A> {
180    Incoming(A),
181    Io(ServerIo<A>),
182    TcpErr(crate::BoxError),
183    TlsErr(crate::BoxError),
184    Done,
185}