1use std::time::Duration;
5
6pub trait Listener: Send + 'static {
8 type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static;
10
11 type Addr: Clone + Send + Sync + 'static;
14
15 fn accept(&mut self) -> impl std::future::Future<Output = (Self::Io, Self::Addr)> + Send;
20
21 fn local_addr(&self) -> std::io::Result<Self::Addr>;
23}
24
25pub trait ListenerExt: Listener + Sized {
27 fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
47 where
48 F: FnMut(&mut Self::Io) + Send + 'static,
49 {
50 TapIo {
51 listener: self,
52 tap_fn,
53 }
54 }
55}
56
57impl<L: Listener> ListenerExt for L {}
58
59impl Listener for tokio::net::TcpListener {
60 type Io = tokio::net::TcpStream;
61 type Addr = std::net::SocketAddr;
62
63 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
64 let mut backoff = AcceptBackoff::new();
65 loop {
66 match Self::accept(self).await {
67 Ok(tup) => return tup,
68 Err(e) => backoff.handle_accept_error(e).await,
69 }
70 }
71 }
72
73 #[inline]
74 fn local_addr(&self) -> std::io::Result<Self::Addr> {
75 Self::local_addr(self)
76 }
77}
78
79#[derive(Debug)]
80pub struct TcpListenerWithOptions {
81 inner: tokio::net::TcpListener,
82 nodelay: bool,
83 keepalive: Option<Duration>,
84}
85
86impl TcpListenerWithOptions {
87 pub fn new<A: std::net::ToSocketAddrs>(
88 addr: A,
89 nodelay: bool,
90 keepalive: Option<Duration>,
91 ) -> Result<Self, crate::BoxError> {
92 let std_listener = std::net::TcpListener::bind(addr)?;
93 std_listener.set_nonblocking(true)?;
94 let listener = tokio::net::TcpListener::from_std(std_listener)?;
95
96 Ok(Self::from_listener(listener, nodelay, keepalive))
97 }
98
99 pub fn from_listener(
101 listener: tokio::net::TcpListener,
102 nodelay: bool,
103 keepalive: Option<Duration>,
104 ) -> Self {
105 Self {
106 inner: listener,
107 nodelay,
108 keepalive,
109 }
110 }
111
112 fn set_accepted_socket_options(&self, stream: &tokio::net::TcpStream) {
114 if self.nodelay && let Err(e) = stream.set_nodelay(true) {
115 tracing::warn!("error trying to set TCP nodelay: {}", e);
116 }
117
118 if let Some(timeout) = self.keepalive {
119 let sock_ref = socket2::SockRef::from(&stream);
120 let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);
121
122 if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
123 tracing::warn!("error trying to set TCP keepalive: {}", e);
124 }
125 }
126 }
127}
128
129impl Listener for TcpListenerWithOptions {
130 type Io = tokio::net::TcpStream;
131 type Addr = std::net::SocketAddr;
132
133 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
134 let (io, addr) = Listener::accept(&mut self.inner).await;
135 self.set_accepted_socket_options(&io);
136 (io, addr)
137 }
138
139 #[inline]
140 fn local_addr(&self) -> std::io::Result<Self::Addr> {
141 Listener::local_addr(&self.inner)
142 }
143}
144
145pub struct TapIo<L, F> {
170 listener: L,
171 tap_fn: F,
172}
173
174impl<L, F> std::fmt::Debug for TapIo<L, F>
175where
176 L: Listener + std::fmt::Debug,
177{
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("TapIo")
180 .field("listener", &self.listener)
181 .finish_non_exhaustive()
182 }
183}
184
185impl<L, F> Listener for TapIo<L, F>
186where
187 L: Listener,
188 F: FnMut(&mut L::Io) + Send + 'static,
189{
190 type Io = L::Io;
191 type Addr = L::Addr;
192
193 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
194 let (mut io, addr) = self.listener.accept().await;
195 (self.tap_fn)(&mut io);
196 (io, addr)
197 }
198
199 fn local_addr(&self) -> std::io::Result<Self::Addr> {
200 self.listener.local_addr()
201 }
202}
203
204struct AcceptBackoff {
217 next_delay: Duration,
218}
219
220impl AcceptBackoff {
221 const MIN: Duration = Duration::from_millis(5);
222 const MAX: Duration = Duration::from_secs(1);
223
224 fn new() -> Self {
225 Self {
226 next_delay: Self::MIN,
227 }
228 }
229
230 async fn handle_accept_error(&mut self, e: std::io::Error) {
231 if is_connection_error(&e) {
232 return;
233 }
234
235 tracing::error!(backoff = ?self.next_delay, "accept error: {e}");
236 tokio::time::sleep(self.next_delay).await;
237 self.next_delay = (self.next_delay * 2).min(Self::MAX);
238 }
239}
240
241fn is_connection_error(e: &std::io::Error) -> bool {
242 use std::io::ErrorKind;
243
244 matches!(
245 e.kind(),
246 ErrorKind::ConnectionRefused
247 | ErrorKind::ConnectionAborted
248 | ErrorKind::ConnectionReset
249 | ErrorKind::BrokenPipe
250 | ErrorKind::Interrupted
251 | ErrorKind::WouldBlock
252 | ErrorKind::TimedOut
253 )
254}