uni_stream/
udp.rs

1//! Provides a `tokio::TcpStream` like UdpStream implementation based on `tokio::UdpSocket`.
2
3use std::fmt::Debug;
4use std::io::{self};
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9
10use bytes::{Buf, Bytes, BytesMut};
11use hashbrown::HashMap;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio::net::UdpSocket;
14#[cfg(feature = "udp-timeout")]
15use tokio::time::Sleep;
16
17use self::impl_inner::{UdpStreamReadContext, UdpStreamWriteContext};
18use super::addr::{each_addr, ToSocketAddrs};
19#[cfg(feature = "udp-timeout")]
20use crate::udp::impl_inner::get_sleep;
21
22const UDP_CHANNEL_LEN: usize = 100;
23const UDP_BUFFER_SIZE: usize = 16 * 1024;
24
25type Result<T, E = std::io::Error> = std::result::Result<T, E>;
26
27macro_rules! error_get_or_continue {
28    ($func_call:expr, $msg:expr) => {
29        match $func_call {
30            Ok(v) => v,
31            Err(e) => {
32                tracing::error!("{}, detail:{e}", $msg);
33                continue;
34            }
35        }
36    };
37}
38
39mod impl_inner {
40
41    #[cfg(feature = "udp-timeout")]
42    use std::time::Duration;
43
44    #[cfg(feature = "udp-timeout")]
45    use futures::FutureExt;
46    use futures::StreamExt;
47    #[cfg(feature = "udp-timeout")]
48    use once_cell::sync::Lazy;
49    #[cfg(feature = "udp-timeout")]
50    use tokio::time::{sleep, Instant};
51
52    use super::*;
53
54    pub(super) trait UdpStreamReadContext {
55        fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes>;
56        fn get_receiver_stream(&mut self) -> &mut flume::r#async::RecvStream<'static, Bytes>;
57        #[cfg(feature = "udp-timeout")]
58        fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>>;
59    }
60
61    pub(super) trait UdpStreamWriteContext {
62        fn is_connect(&self) -> bool;
63        fn get_socket(&self) -> &tokio::net::UdpSocket;
64        fn get_peer_addr(&self) -> &SocketAddr;
65    }
66
67    pub(super) fn poll_read<T: UdpStreamReadContext>(
68        mut read_ctx: T,
69        cx: &mut Context,
70        buf: &mut ReadBuf,
71    ) -> Poll<Result<()>> {
72        // timeout
73        #[cfg(feature = "udp-timeout")]
74        if read_ctx.get_timeout().poll_unpin(cx).is_ready() {
75            buf.clear();
76            return Poll::Ready(Err(io::Error::new(
77                io::ErrorKind::TimedOut,
78                format!(
79                    "UdpStream timeout with duration:{:?}",
80                    get_timeout_duration()
81                ),
82            )));
83        }
84
85        #[cfg(feature = "udp-timeout")]
86        #[inline]
87        fn update_timeout(timeout: &mut Pin<Box<Sleep>>) {
88            timeout
89                .as_mut()
90                .reset(Instant::now() + get_timeout_duration())
91        }
92
93        let is_consume_remaining = if let Some(remaining) = read_ctx.get_mut_remaining_bytes() {
94            if buf.remaining() < remaining.len() {
95                buf.put_slice(&remaining.split_to(buf.remaining())[..]);
96            } else {
97                buf.put_slice(&remaining[..]);
98                *read_ctx.get_mut_remaining_bytes() = None;
99            }
100            true
101        } else {
102            false
103        };
104
105        if is_consume_remaining {
106            #[cfg(feature = "udp-timeout")]
107            update_timeout(read_ctx.get_timeout());
108            return Poll::Ready(Ok(()));
109        }
110
111        let remaining = match read_ctx.get_receiver_stream().poll_next_unpin(cx) {
112            Poll::Ready(Some(mut inner_buf)) => {
113                let remaining = if buf.remaining() < inner_buf.len() {
114                    Some(inner_buf.split_off(buf.remaining()))
115                } else {
116                    None
117                };
118                buf.put_slice(&inner_buf[..]);
119                remaining
120            }
121            Poll::Ready(None) => {
122                return Poll::Ready(Err(io::Error::new(
123                    io::ErrorKind::BrokenPipe,
124                    "Broken pipe",
125                )));
126            }
127            Poll::Pending => return Poll::Pending,
128        };
129        #[cfg(feature = "udp-timeout")]
130        update_timeout(read_ctx.get_timeout());
131        *read_ctx.get_mut_remaining_bytes() = remaining;
132        Poll::Ready(Ok(()))
133    }
134
135    pub(super) fn poll_write<T: UdpStreamWriteContext>(
136        write_ctx: T,
137        cx: &mut Context,
138        buf: &[u8],
139    ) -> Poll<Result<usize>> {
140        if write_ctx.is_connect() {
141            write_ctx.get_socket().poll_send(cx, buf)
142        } else {
143            write_ctx
144                .get_socket()
145                .poll_send_to(cx, buf, *write_ctx.get_peer_addr())
146        }
147    }
148
149    #[cfg(feature = "udp-timeout")]
150    const DEFAULT_TIMEOUT: Duration = Duration::from_secs(20);
151
152    #[cfg(feature = "udp-timeout")]
153    static mut CUSTOM_TIMEOUT: Option<Duration> = None;
154
155    /// Set custom timeout.
156    /// Note that this function can only be called before the [`TIMEOUT`] lazy variable is created.
157    #[cfg(feature = "udp-timeout")]
158    pub fn set_custom_timeout(timeout: Duration) {
159        unsafe { CUSTOM_TIMEOUT = Some(timeout) }
160    }
161
162    #[cfg(feature = "udp-timeout")]
163    static TIMEOUT: Lazy<Duration> = Lazy::new(|| match unsafe { CUSTOM_TIMEOUT } {
164        Some(dur) => dur,
165        None => DEFAULT_TIMEOUT,
166    });
167
168    #[cfg(feature = "udp-timeout")]
169    #[inline]
170    pub(super) fn get_timeout_duration() -> Duration {
171        *TIMEOUT
172    }
173
174    #[cfg(feature = "udp-timeout")]
175    #[inline]
176    pub(super) fn get_sleep() -> Sleep {
177        sleep(get_timeout_duration())
178    }
179}
180
181#[cfg(feature = "udp-timeout")]
182pub use impl_inner::set_custom_timeout;
183
184/// An I/O object representing a UDP socket listening for incoming connections.
185///
186/// This object can be converted into a stream of incoming connections for
187/// various forms of processing.
188///
189/// # Examples
190///
191/// ```no_run
192/// async fn process_socket<T>(_socket: T) {}
193///
194/// #[tokio::main]
195/// async fn main() -> Result<(), Box<dyn Error>> {
196///     let mut listener = UdpListener::bind(SocketAddr::from_str("127.0.0.1:8080")?).await?;
197///
198///     loop {
199///         let (socket, _) = listener.accept().await?;
200///         process_socket(socket).await;
201///     }
202/// }
203/// ```
204pub struct UdpListener {
205    handler: tokio::task::JoinHandle<()>,
206    receiver: flume::Receiver<(UdpStream, SocketAddr)>,
207    local_addr: SocketAddr,
208}
209
210impl Drop for UdpListener {
211    fn drop(&mut self) {
212        self.handler.abort();
213    }
214}
215
216impl UdpListener {
217    /// Usage is exactly the same as [`tokio::net::TcpListener::bind`]
218    pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
219        each_addr(addr, UdpListener::bind_inner).await
220    }
221
222    async fn bind_inner(local_addr: SocketAddr) -> Result<Self> {
223        let (listener_tx, listener_rx) = flume::bounded(UDP_CHANNEL_LEN);
224        let udp_socket = UdpSocket::bind(local_addr).await?;
225        let local_addr = udp_socket.local_addr()?;
226
227        let handler = tokio::spawn(async move {
228            let mut streams: HashMap<SocketAddr, flume::Sender<Bytes>> = HashMap::new();
229            let socket = Arc::new(udp_socket);
230            let (drop_tx, drop_rx) = flume::bounded(10);
231
232            let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE * 3);
233            loop {
234                if buf.capacity() < UDP_BUFFER_SIZE {
235                    buf.reserve(UDP_BUFFER_SIZE * 3);
236                }
237                tokio::select! {
238                    ret = drop_rx.recv_async() => {
239                        let peer_addr = error_get_or_continue!(ret,"UDPListener clean conn");
240                        streams.remove(&peer_addr);
241                    }
242                    ret = socket.recv_buf_from(&mut buf) => {
243                        let (len,peer_addr) = error_get_or_continue!(ret,"UdpListener `recv_buf_from`");
244                        match streams.get(&peer_addr) {
245                            Some(tx) => {
246                                if let Err(err) =  tx.send_async(buf.copy_to_bytes(len)).await{
247                                    tracing::error!("UDPListener send msg to conn, detail:{err}");
248                                    streams.remove(&peer_addr);
249                                    continue;
250                                }
251                            }
252                            None => {
253                                let (child_tx, child_rx) = flume::bounded(UDP_CHANNEL_LEN);
254                                // pre send msg
255                                error_get_or_continue!(
256                                    child_tx.send_async(buf.copy_to_bytes(len)).await,
257                                    "new conn pre send msg"
258                                );
259
260                                let udp_stream = UdpStream {
261                                    is_connect:false,
262                                    local_addr,
263                                    peer_addr,
264                                    #[cfg(feature = "udp-timeout")]
265                                    timeout: Box::pin(get_sleep()),
266                                    recv_stream: child_rx.into_stream(),
267                                    socket: socket.clone(),
268                                    _handler_guard: None,
269                                    _listener_guard: Some(ListenerCleanGuard{sender:drop_tx.clone(),peer_addr}),
270                                    remaining: None,
271                                };
272                                error_get_or_continue!(
273                                    listener_tx.send_async((udp_stream, peer_addr)).await,
274                                    "register UDPStream"
275                                );
276                                streams.insert(peer_addr, child_tx);
277                            }
278                        }
279                    }
280                }
281            }
282        });
283        Ok(Self {
284            handler,
285            receiver: listener_rx,
286            local_addr,
287        })
288    }
289
290    /// Returns the local address that this socket is bound to.
291    pub fn local_addr(&self) -> io::Result<SocketAddr> {
292        Ok(self.local_addr)
293    }
294
295    /// Accepts a new incoming UDP connection.
296    pub async fn accept(&self) -> io::Result<(UdpStream, SocketAddr)> {
297        self.receiver
298            .recv_async()
299            .await
300            .map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))
301    }
302}
303
304#[derive(Debug)]
305struct TaskJoinHandleGuard(tokio::task::JoinHandle<()>);
306
307#[derive(Debug, Clone)]
308struct ListenerCleanGuard {
309    sender: flume::Sender<SocketAddr>,
310    peer_addr: SocketAddr,
311}
312
313impl Drop for ListenerCleanGuard {
314    fn drop(&mut self) {
315        _ = self.sender.try_send(self.peer_addr);
316    }
317}
318
319impl Drop for TaskJoinHandleGuard {
320    fn drop(&mut self) {
321        self.0.abort();
322    }
323}
324
325/// An I/O object representing a UDP stream connected to a remote endpoint.
326///
327/// A UDP stream can either be created by connecting to an endpoint, via the
328/// [`UdpStream::connect`] method, or by [UdpListener::accept] a connection from a listener.
329pub struct UdpStream {
330    is_connect: bool,
331    local_addr: SocketAddr,
332    peer_addr: SocketAddr,
333    socket: Arc<tokio::net::UdpSocket>,
334    #[cfg(feature = "udp-timeout")]
335    timeout: Pin<Box<Sleep>>,
336    recv_stream: flume::r#async::RecvStream<'static, Bytes>,
337    remaining: Option<Bytes>,
338    _handler_guard: Option<TaskJoinHandleGuard>,
339    _listener_guard: Option<ListenerCleanGuard>,
340}
341
342impl UdpStream {
343    /// Create a new UDP stream connected to the specified address.
344    ///
345    /// This function will create a new UDP socket and attempt to connect it to
346    /// the `addr` provided. The returned future will be resolved once the
347    /// stream has successfully connected, or it will return an error if one
348    /// occurs.
349    pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<Self> {
350        each_addr(addr, UdpStream::connect_inner).await
351    }
352
353    async fn connect_inner(addr: SocketAddr) -> Result<Self> {
354        let local_addr: SocketAddr = if addr.is_ipv4() {
355            SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
356        } else {
357            SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
358        };
359        let socket = UdpSocket::bind(local_addr).await?;
360        socket.connect(&addr).await?;
361        Self::from_tokio(socket, true).await
362    }
363
364    /// Creates a new UdpStream from a tokio::net::UdpSocket.
365    /// This function is intended to be used to wrap a UDP socket from the tokio library.
366    /// Note: The UdpSocket must have the UdpSocket::connect method called before invoking this
367    /// function.
368    async fn from_tokio(socket: UdpSocket, is_connect: bool) -> Result<Self> {
369        let socket = Arc::new(socket);
370
371        let local_addr = socket.local_addr()?;
372        let peer_addr = socket.peer_addr()?;
373
374        let (tx, rx) = flume::bounded(UDP_CHANNEL_LEN);
375
376        let socket_inner = socket.clone();
377
378        let handler = tokio::spawn(async move {
379            let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE);
380            while let Ok((len, received_addr)) = socket_inner.recv_buf_from(&mut buf).await {
381                if received_addr != peer_addr {
382                    continue;
383                }
384                if tx.send_async(buf.copy_to_bytes(len)).await.is_err() {
385                    drop(tx);
386                    break;
387                }
388
389                if buf.capacity() < UDP_BUFFER_SIZE {
390                    buf.reserve(UDP_BUFFER_SIZE * 3);
391                }
392            }
393        });
394
395        Ok(UdpStream {
396            local_addr,
397            peer_addr,
398            #[cfg(feature = "udp-timeout")]
399            timeout: Box::pin(get_sleep()),
400            recv_stream: rx.into_stream(),
401            socket,
402            _handler_guard: Some(TaskJoinHandleGuard(handler)),
403            _listener_guard: None,
404            remaining: None,
405            is_connect,
406        })
407    }
408
409    /// Return peer address
410    pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
411        Ok(self.peer_addr)
412    }
413
414    /// Return local address
415    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
416        Ok(self.local_addr)
417    }
418
419    /// Split into read side and write side to avoid borrow check, note that ownership is not
420    /// transferred
421    pub fn split(&self) -> (UdpStreamReadHalf<'static>, UdpStreamWriteHalf) {
422        (
423            UdpStreamReadHalf {
424                recv_stream: self.recv_stream.clone(),
425                remaining: self.remaining.clone(),
426                #[cfg(feature = "udp-timeout")]
427                timeout: Box::pin(get_sleep()),
428            },
429            UdpStreamWriteHalf {
430                is_connect: self.is_connect,
431                socket: &self.socket,
432                peer_addr: self.peer_addr,
433            },
434        )
435    }
436}
437
438impl UdpStreamReadContext for std::pin::Pin<&mut UdpStream> {
439    fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes> {
440        &mut self.remaining
441    }
442
443    fn get_receiver_stream(&mut self) -> &mut flume::r#async::RecvStream<'static, Bytes> {
444        &mut self.recv_stream
445    }
446
447    #[cfg(feature = "udp-timeout")]
448    fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>> {
449        &mut self.timeout
450    }
451}
452
453impl UdpStreamWriteContext for std::pin::Pin<&mut UdpStream> {
454    fn get_socket(&self) -> &tokio::net::UdpSocket {
455        &self.socket
456    }
457
458    fn get_peer_addr(&self) -> &SocketAddr {
459        &self.peer_addr
460    }
461
462    fn is_connect(&self) -> bool {
463        self.is_connect
464    }
465}
466
467impl AsyncRead for UdpStream {
468    fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<Result<()>> {
469        impl_inner::poll_read(self, cx, buf)
470    }
471}
472
473impl AsyncWrite for UdpStream {
474    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
475        impl_inner::poll_write(self, cx, buf)
476    }
477
478    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
479        Poll::Ready(Ok(()))
480    }
481
482    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
483        Poll::Ready(Ok(()))
484    }
485}
486
487/// [`UdpStream`] read-side implementation
488pub struct UdpStreamReadHalf<'a> {
489    recv_stream: flume::r#async::RecvStream<'a, Bytes>,
490    remaining: Option<Bytes>,
491    #[cfg(feature = "udp-timeout")]
492    timeout: Pin<Box<Sleep>>,
493}
494
495impl UdpStreamReadContext for std::pin::Pin<&mut UdpStreamReadHalf<'static>> {
496    fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes> {
497        &mut self.remaining
498    }
499
500    fn get_receiver_stream(&mut self) -> &mut flume::r#async::RecvStream<'static, Bytes> {
501        &mut self.recv_stream
502    }
503
504    #[cfg(feature = "udp-timeout")]
505    fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>> {
506        &mut self.timeout
507    }
508}
509
510impl AsyncRead for UdpStreamReadHalf<'static> {
511    fn poll_read(
512        self: Pin<&mut Self>,
513        cx: &mut Context<'_>,
514        buf: &mut ReadBuf<'_>,
515    ) -> Poll<Result<()>> {
516        impl_inner::poll_read(self, cx, buf)
517    }
518}
519
520/// [`UdpStream`] write-side implementation
521pub struct UdpStreamWriteHalf<'a> {
522    is_connect: bool,
523    socket: &'a tokio::net::UdpSocket,
524    peer_addr: SocketAddr,
525}
526
527impl UdpStreamWriteContext for std::pin::Pin<&mut UdpStreamWriteHalf<'_>> {
528    fn get_socket(&self) -> &tokio::net::UdpSocket {
529        self.socket
530    }
531
532    fn get_peer_addr(&self) -> &SocketAddr {
533        &self.peer_addr
534    }
535
536    fn is_connect(&self) -> bool {
537        self.is_connect
538    }
539}
540
541impl AsyncWrite for UdpStreamWriteHalf<'_> {
542    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
543        impl_inner::poll_write(self, cx, buf)
544    }
545
546    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
547        Poll::Ready(Ok(()))
548    }
549
550    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
551        Poll::Ready(Ok(()))
552    }
553}