1use std::fmt::Debug;
4use std::io::{self};
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9
10use bytes::{Buf, Bytes, BytesMut};
11use hashbrown::HashMap;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio::net::UdpSocket;
14#[cfg(feature = "udp-timeout")]
15use tokio::time::Sleep;
16
17use self::impl_inner::{UdpStreamReadContext, UdpStreamWriteContext};
18use super::addr::{each_addr, ToSocketAddrs};
19#[cfg(feature = "udp-timeout")]
20use crate::udp::impl_inner::get_sleep;
21
22const UDP_CHANNEL_LEN: usize = 100;
23const UDP_BUFFER_SIZE: usize = 16 * 1024;
24
25type Result<T, E = std::io::Error> = std::result::Result<T, E>;
26
27macro_rules! error_get_or_continue {
28 ($func_call:expr, $msg:expr) => {
29 match $func_call {
30 Ok(v) => v,
31 Err(e) => {
32 tracing::error!("{}, detail:{e}", $msg);
33 continue;
34 }
35 }
36 };
37}
38
39mod impl_inner {
40
41 #[cfg(feature = "udp-timeout")]
42 use std::time::Duration;
43
44 #[cfg(feature = "udp-timeout")]
45 use futures::FutureExt;
46 use futures::StreamExt;
47 #[cfg(feature = "udp-timeout")]
48 use once_cell::sync::Lazy;
49 #[cfg(feature = "udp-timeout")]
50 use tokio::time::{sleep, Instant};
51
52 use super::*;
53
54 pub(super) trait UdpStreamReadContext {
55 fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes>;
56 fn get_receiver_stream(&mut self) -> &mut flume::r#async::RecvStream<'static, Bytes>;
57 #[cfg(feature = "udp-timeout")]
58 fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>>;
59 }
60
61 pub(super) trait UdpStreamWriteContext {
62 fn is_connect(&self) -> bool;
63 fn get_socket(&self) -> &tokio::net::UdpSocket;
64 fn get_peer_addr(&self) -> &SocketAddr;
65 }
66
67 pub(super) fn poll_read<T: UdpStreamReadContext>(
68 mut read_ctx: T,
69 cx: &mut Context,
70 buf: &mut ReadBuf,
71 ) -> Poll<Result<()>> {
72 #[cfg(feature = "udp-timeout")]
74 if read_ctx.get_timeout().poll_unpin(cx).is_ready() {
75 buf.clear();
76 return Poll::Ready(Err(io::Error::new(
77 io::ErrorKind::TimedOut,
78 format!(
79 "UdpStream timeout with duration:{:?}",
80 get_timeout_duration()
81 ),
82 )));
83 }
84
85 #[cfg(feature = "udp-timeout")]
86 #[inline]
87 fn update_timeout(timeout: &mut Pin<Box<Sleep>>) {
88 timeout
89 .as_mut()
90 .reset(Instant::now() + get_timeout_duration())
91 }
92
93 let is_consume_remaining = if let Some(remaining) = read_ctx.get_mut_remaining_bytes() {
94 if buf.remaining() < remaining.len() {
95 buf.put_slice(&remaining.split_to(buf.remaining())[..]);
96 } else {
97 buf.put_slice(&remaining[..]);
98 *read_ctx.get_mut_remaining_bytes() = None;
99 }
100 true
101 } else {
102 false
103 };
104
105 if is_consume_remaining {
106 #[cfg(feature = "udp-timeout")]
107 update_timeout(read_ctx.get_timeout());
108 return Poll::Ready(Ok(()));
109 }
110
111 let remaining = match read_ctx.get_receiver_stream().poll_next_unpin(cx) {
112 Poll::Ready(Some(mut inner_buf)) => {
113 let remaining = if buf.remaining() < inner_buf.len() {
114 Some(inner_buf.split_off(buf.remaining()))
115 } else {
116 None
117 };
118 buf.put_slice(&inner_buf[..]);
119 remaining
120 }
121 Poll::Ready(None) => {
122 return Poll::Ready(Err(io::Error::new(
123 io::ErrorKind::BrokenPipe,
124 "Broken pipe",
125 )));
126 }
127 Poll::Pending => return Poll::Pending,
128 };
129 #[cfg(feature = "udp-timeout")]
130 update_timeout(read_ctx.get_timeout());
131 *read_ctx.get_mut_remaining_bytes() = remaining;
132 Poll::Ready(Ok(()))
133 }
134
135 pub(super) fn poll_write<T: UdpStreamWriteContext>(
136 write_ctx: T,
137 cx: &mut Context,
138 buf: &[u8],
139 ) -> Poll<Result<usize>> {
140 if write_ctx.is_connect() {
141 write_ctx.get_socket().poll_send(cx, buf)
142 } else {
143 write_ctx
144 .get_socket()
145 .poll_send_to(cx, buf, *write_ctx.get_peer_addr())
146 }
147 }
148
149 #[cfg(feature = "udp-timeout")]
150 const DEFAULT_TIMEOUT: Duration = Duration::from_secs(20);
151
152 #[cfg(feature = "udp-timeout")]
153 static mut CUSTOM_TIMEOUT: Option<Duration> = None;
154
155 #[cfg(feature = "udp-timeout")]
158 pub fn set_custom_timeout(timeout: Duration) {
159 unsafe { CUSTOM_TIMEOUT = Some(timeout) }
160 }
161
162 #[cfg(feature = "udp-timeout")]
163 static TIMEOUT: Lazy<Duration> = Lazy::new(|| match unsafe { CUSTOM_TIMEOUT } {
164 Some(dur) => dur,
165 None => DEFAULT_TIMEOUT,
166 });
167
168 #[cfg(feature = "udp-timeout")]
169 #[inline]
170 pub(super) fn get_timeout_duration() -> Duration {
171 *TIMEOUT
172 }
173
174 #[cfg(feature = "udp-timeout")]
175 #[inline]
176 pub(super) fn get_sleep() -> Sleep {
177 sleep(get_timeout_duration())
178 }
179}
180
181#[cfg(feature = "udp-timeout")]
182pub use impl_inner::set_custom_timeout;
183
184pub struct UdpListener {
205 handler: tokio::task::JoinHandle<()>,
206 receiver: flume::Receiver<(UdpStream, SocketAddr)>,
207 local_addr: SocketAddr,
208}
209
210impl Drop for UdpListener {
211 fn drop(&mut self) {
212 self.handler.abort();
213 }
214}
215
216impl UdpListener {
217 pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
219 each_addr(addr, UdpListener::bind_inner).await
220 }
221
222 async fn bind_inner(local_addr: SocketAddr) -> Result<Self> {
223 let (listener_tx, listener_rx) = flume::bounded(UDP_CHANNEL_LEN);
224 let udp_socket = UdpSocket::bind(local_addr).await?;
225 let local_addr = udp_socket.local_addr()?;
226
227 let handler = tokio::spawn(async move {
228 let mut streams: HashMap<SocketAddr, flume::Sender<Bytes>> = HashMap::new();
229 let socket = Arc::new(udp_socket);
230 let (drop_tx, drop_rx) = flume::bounded(10);
231
232 let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE * 3);
233 loop {
234 if buf.capacity() < UDP_BUFFER_SIZE {
235 buf.reserve(UDP_BUFFER_SIZE * 3);
236 }
237 tokio::select! {
238 ret = drop_rx.recv_async() => {
239 let peer_addr = error_get_or_continue!(ret,"UDPListener clean conn");
240 streams.remove(&peer_addr);
241 }
242 ret = socket.recv_buf_from(&mut buf) => {
243 let (len,peer_addr) = error_get_or_continue!(ret,"UdpListener `recv_buf_from`");
244 match streams.get(&peer_addr) {
245 Some(tx) => {
246 if let Err(err) = tx.send_async(buf.copy_to_bytes(len)).await{
247 tracing::error!("UDPListener send msg to conn, detail:{err}");
248 streams.remove(&peer_addr);
249 continue;
250 }
251 }
252 None => {
253 let (child_tx, child_rx) = flume::bounded(UDP_CHANNEL_LEN);
254 error_get_or_continue!(
256 child_tx.send_async(buf.copy_to_bytes(len)).await,
257 "new conn pre send msg"
258 );
259
260 let udp_stream = UdpStream {
261 is_connect:false,
262 local_addr,
263 peer_addr,
264 #[cfg(feature = "udp-timeout")]
265 timeout: Box::pin(get_sleep()),
266 recv_stream: child_rx.into_stream(),
267 socket: socket.clone(),
268 _handler_guard: None,
269 _listener_guard: Some(ListenerCleanGuard{sender:drop_tx.clone(),peer_addr}),
270 remaining: None,
271 };
272 error_get_or_continue!(
273 listener_tx.send_async((udp_stream, peer_addr)).await,
274 "register UDPStream"
275 );
276 streams.insert(peer_addr, child_tx);
277 }
278 }
279 }
280 }
281 }
282 });
283 Ok(Self {
284 handler,
285 receiver: listener_rx,
286 local_addr,
287 })
288 }
289
290 pub fn local_addr(&self) -> io::Result<SocketAddr> {
292 Ok(self.local_addr)
293 }
294
295 pub async fn accept(&self) -> io::Result<(UdpStream, SocketAddr)> {
297 self.receiver
298 .recv_async()
299 .await
300 .map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))
301 }
302}
303
304#[derive(Debug)]
305struct TaskJoinHandleGuard(tokio::task::JoinHandle<()>);
306
307#[derive(Debug, Clone)]
308struct ListenerCleanGuard {
309 sender: flume::Sender<SocketAddr>,
310 peer_addr: SocketAddr,
311}
312
313impl Drop for ListenerCleanGuard {
314 fn drop(&mut self) {
315 _ = self.sender.try_send(self.peer_addr);
316 }
317}
318
319impl Drop for TaskJoinHandleGuard {
320 fn drop(&mut self) {
321 self.0.abort();
322 }
323}
324
325pub struct UdpStream {
330 is_connect: bool,
331 local_addr: SocketAddr,
332 peer_addr: SocketAddr,
333 socket: Arc<tokio::net::UdpSocket>,
334 #[cfg(feature = "udp-timeout")]
335 timeout: Pin<Box<Sleep>>,
336 recv_stream: flume::r#async::RecvStream<'static, Bytes>,
337 remaining: Option<Bytes>,
338 _handler_guard: Option<TaskJoinHandleGuard>,
339 _listener_guard: Option<ListenerCleanGuard>,
340}
341
342impl UdpStream {
343 pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<Self> {
350 each_addr(addr, UdpStream::connect_inner).await
351 }
352
353 async fn connect_inner(addr: SocketAddr) -> Result<Self> {
354 let local_addr: SocketAddr = if addr.is_ipv4() {
355 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
356 } else {
357 SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
358 };
359 let socket = UdpSocket::bind(local_addr).await?;
360 socket.connect(&addr).await?;
361 Self::from_tokio(socket, true).await
362 }
363
364 async fn from_tokio(socket: UdpSocket, is_connect: bool) -> Result<Self> {
369 let socket = Arc::new(socket);
370
371 let local_addr = socket.local_addr()?;
372 let peer_addr = socket.peer_addr()?;
373
374 let (tx, rx) = flume::bounded(UDP_CHANNEL_LEN);
375
376 let socket_inner = socket.clone();
377
378 let handler = tokio::spawn(async move {
379 let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE);
380 while let Ok((len, received_addr)) = socket_inner.recv_buf_from(&mut buf).await {
381 if received_addr != peer_addr {
382 continue;
383 }
384 if tx.send_async(buf.copy_to_bytes(len)).await.is_err() {
385 drop(tx);
386 break;
387 }
388
389 if buf.capacity() < UDP_BUFFER_SIZE {
390 buf.reserve(UDP_BUFFER_SIZE * 3);
391 }
392 }
393 });
394
395 Ok(UdpStream {
396 local_addr,
397 peer_addr,
398 #[cfg(feature = "udp-timeout")]
399 timeout: Box::pin(get_sleep()),
400 recv_stream: rx.into_stream(),
401 socket,
402 _handler_guard: Some(TaskJoinHandleGuard(handler)),
403 _listener_guard: None,
404 remaining: None,
405 is_connect,
406 })
407 }
408
409 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
411 Ok(self.peer_addr)
412 }
413
414 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
416 Ok(self.local_addr)
417 }
418
419 pub fn split(&self) -> (UdpStreamReadHalf<'static>, UdpStreamWriteHalf) {
422 (
423 UdpStreamReadHalf {
424 recv_stream: self.recv_stream.clone(),
425 remaining: self.remaining.clone(),
426 #[cfg(feature = "udp-timeout")]
427 timeout: Box::pin(get_sleep()),
428 },
429 UdpStreamWriteHalf {
430 is_connect: self.is_connect,
431 socket: &self.socket,
432 peer_addr: self.peer_addr,
433 },
434 )
435 }
436}
437
438impl UdpStreamReadContext for std::pin::Pin<&mut UdpStream> {
439 fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes> {
440 &mut self.remaining
441 }
442
443 fn get_receiver_stream(&mut self) -> &mut flume::r#async::RecvStream<'static, Bytes> {
444 &mut self.recv_stream
445 }
446
447 #[cfg(feature = "udp-timeout")]
448 fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>> {
449 &mut self.timeout
450 }
451}
452
453impl UdpStreamWriteContext for std::pin::Pin<&mut UdpStream> {
454 fn get_socket(&self) -> &tokio::net::UdpSocket {
455 &self.socket
456 }
457
458 fn get_peer_addr(&self) -> &SocketAddr {
459 &self.peer_addr
460 }
461
462 fn is_connect(&self) -> bool {
463 self.is_connect
464 }
465}
466
467impl AsyncRead for UdpStream {
468 fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<Result<()>> {
469 impl_inner::poll_read(self, cx, buf)
470 }
471}
472
473impl AsyncWrite for UdpStream {
474 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
475 impl_inner::poll_write(self, cx, buf)
476 }
477
478 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
479 Poll::Ready(Ok(()))
480 }
481
482 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
483 Poll::Ready(Ok(()))
484 }
485}
486
487pub struct UdpStreamReadHalf<'a> {
489 recv_stream: flume::r#async::RecvStream<'a, Bytes>,
490 remaining: Option<Bytes>,
491 #[cfg(feature = "udp-timeout")]
492 timeout: Pin<Box<Sleep>>,
493}
494
495impl UdpStreamReadContext for std::pin::Pin<&mut UdpStreamReadHalf<'static>> {
496 fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes> {
497 &mut self.remaining
498 }
499
500 fn get_receiver_stream(&mut self) -> &mut flume::r#async::RecvStream<'static, Bytes> {
501 &mut self.recv_stream
502 }
503
504 #[cfg(feature = "udp-timeout")]
505 fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>> {
506 &mut self.timeout
507 }
508}
509
510impl AsyncRead for UdpStreamReadHalf<'static> {
511 fn poll_read(
512 self: Pin<&mut Self>,
513 cx: &mut Context<'_>,
514 buf: &mut ReadBuf<'_>,
515 ) -> Poll<Result<()>> {
516 impl_inner::poll_read(self, cx, buf)
517 }
518}
519
520pub struct UdpStreamWriteHalf<'a> {
522 is_connect: bool,
523 socket: &'a tokio::net::UdpSocket,
524 peer_addr: SocketAddr,
525}
526
527impl UdpStreamWriteContext for std::pin::Pin<&mut UdpStreamWriteHalf<'_>> {
528 fn get_socket(&self) -> &tokio::net::UdpSocket {
529 self.socket
530 }
531
532 fn get_peer_addr(&self) -> &SocketAddr {
533 &self.peer_addr
534 }
535
536 fn is_connect(&self) -> bool {
537 self.is_connect
538 }
539}
540
541impl AsyncWrite for UdpStreamWriteHalf<'_> {
542 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
543 impl_inner::poll_write(self, cx, buf)
544 }
545
546 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
547 Poll::Ready(Ok(()))
548 }
549
550 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
551 Poll::Ready(Ok(()))
552 }
553}