Skip to main content

uni_stream/
udp.rs

1//! Provides a `tokio::TcpStream`-like UDP stream implementation based on `tokio::UdpSocket`.
2
3use std::fmt::Debug;
4use std::future::Future;
5use std::io::{self};
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use bytes::{Buf, Bytes, BytesMut};
12use futures::future::poll_fn;
13use futures::Stream;
14use hashbrown::HashMap;
15use kanal_plus::{AsyncReceiver, AsyncSender, ReceiveStreamOwned};
16use socket2::SockRef;
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use tokio::net::UdpSocket;
19#[cfg(feature = "udp-timeout")]
20use tokio::time::Instant;
21#[cfg(feature = "udp-timeout")]
22use tokio::time::Sleep;
23
24use self::impl_inner::{UdpStreamReadContext, UdpStreamWriteContext};
25use super::addr::{each_addr, ToSocketAddrs};
26#[cfg(feature = "udp-timeout")]
27use crate::udp::impl_inner::get_sleep;
28
29const UDP_CHANNEL_LEN: usize = 100;
30const UDP_BUFFER_SIZE: usize = 65_507;
31const UDP_SOCKET_BUFFER_BYTES: usize = 4 * 1024 * 1024;
32
33type Result<T, E = std::io::Error> = std::result::Result<T, E>;
34
35fn receiver_stream<T: Send + 'static>(
36    receiver: AsyncReceiver<T>,
37) -> Pin<Box<ReceiveStreamOwned<T>>> {
38    Box::pin(receiver.into_stream())
39}
40
41#[cfg(not(target_os = "windows"))]
42/// Tune UDP socket buffer sizes for better throughput.
43pub fn tune_udp_socket(socket: &UdpSocket) {
44    let sock_ref = SockRef::from(socket);
45    if let Err(err) = sock_ref.set_recv_buffer_size(UDP_SOCKET_BUFFER_BYTES) {
46        tracing::warn!("failed to set udp recv buffer size: {err}");
47    }
48    if let Err(err) = sock_ref.set_send_buffer_size(UDP_SOCKET_BUFFER_BYTES) {
49        tracing::warn!("failed to set udp send buffer size: {err}");
50    }
51}
52
53#[cfg(target_os = "windows")]
54/// Tune UDP socket buffer sizes for better throughput.
55pub fn tune_udp_socket(socket: &UdpSocket) {
56    let sock_ref = SockRef::from(socket);
57    if let Err(err) = sock_ref.set_recv_buffer_size(UDP_SOCKET_BUFFER_BYTES) {
58        tracing::warn!("failed to set udp recv buffer size: {err}");
59    }
60    if let Err(err) = sock_ref.set_send_buffer_size(UDP_SOCKET_BUFFER_BYTES) {
61        tracing::warn!("failed to set udp send buffer size: {err}");
62    }
63}
64
65macro_rules! error_get_or_continue {
66    ($func_call:expr, $msg:expr) => {
67        match $func_call {
68            Ok(v) => v,
69            Err(e) => {
70                tracing::error!("{}, detail:{e}", $msg);
71                continue;
72            }
73        }
74    };
75}
76
77mod impl_inner {
78    #[cfg(feature = "udp-timeout")]
79    use std::time::Duration;
80
81    #[cfg(feature = "udp-timeout")]
82    use futures::FutureExt;
83    #[cfg(feature = "udp-timeout")]
84    use once_cell::sync::Lazy;
85    #[cfg(feature = "udp-timeout")]
86    use tokio::time::{sleep, Instant};
87
88    use super::*;
89
90    pub(super) trait UdpStreamReadContext {
91        fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes>;
92        fn get_receiver_stream(&mut self) -> &mut Pin<Box<ReceiveStreamOwned<Bytes>>>;
93        #[cfg(feature = "udp-timeout")]
94        fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>>;
95    }
96
97    pub(super) trait UdpStreamWriteContext {
98        fn is_connect(&self) -> bool;
99        fn get_socket(&self) -> &tokio::net::UdpSocket;
100        fn get_peer_addr(&self) -> &SocketAddr;
101    }
102
103    pub(super) fn poll_read<T: UdpStreamReadContext>(
104        mut read_ctx: T,
105        cx: &mut Context,
106        buf: &mut ReadBuf,
107    ) -> Poll<Result<()>> {
108        // timeout
109        #[cfg(feature = "udp-timeout")]
110        if read_ctx.get_timeout().poll_unpin(cx).is_ready() {
111            buf.clear();
112            return Poll::Ready(Err(io::Error::new(
113                io::ErrorKind::TimedOut,
114                format!(
115                    "UdpStream timeout with duration:{:?}",
116                    get_timeout_duration()
117                ),
118            )));
119        }
120
121        #[cfg(feature = "udp-timeout")]
122        #[inline]
123        fn update_timeout(timeout: &mut Pin<Box<Sleep>>) {
124            timeout
125                .as_mut()
126                .reset(Instant::now() + get_timeout_duration())
127        }
128
129        let is_consume_remaining = if let Some(remaining) = read_ctx.get_mut_remaining_bytes() {
130            if buf.remaining() < remaining.len() {
131                buf.put_slice(&remaining.split_to(buf.remaining())[..]);
132            } else {
133                buf.put_slice(&remaining[..]);
134                *read_ctx.get_mut_remaining_bytes() = None;
135            }
136            true
137        } else {
138            false
139        };
140
141        if is_consume_remaining {
142            #[cfg(feature = "udp-timeout")]
143            update_timeout(read_ctx.get_timeout());
144            return Poll::Ready(Ok(()));
145        }
146
147        let remaining = match read_ctx.get_receiver_stream().as_mut().poll_next(cx) {
148            Poll::Ready(Some(mut inner_buf)) => {
149                let remaining = if buf.remaining() < inner_buf.len() {
150                    Some(inner_buf.split_off(buf.remaining()))
151                } else {
152                    None
153                };
154                buf.put_slice(&inner_buf[..]);
155                remaining
156            }
157            Poll::Ready(None) => {
158                return Poll::Ready(Err(io::Error::new(
159                    io::ErrorKind::BrokenPipe,
160                    "Broken pipe",
161                )));
162            }
163            Poll::Pending => return Poll::Pending,
164        };
165        #[cfg(feature = "udp-timeout")]
166        update_timeout(read_ctx.get_timeout());
167        *read_ctx.get_mut_remaining_bytes() = remaining;
168        Poll::Ready(Ok(()))
169    }
170
171    pub(super) fn poll_write<T: UdpStreamWriteContext>(
172        write_ctx: T,
173        cx: &mut Context,
174        buf: &[u8],
175    ) -> Poll<Result<usize>> {
176        if write_ctx.is_connect() {
177            write_ctx.get_socket().poll_send(cx, buf)
178        } else {
179            write_ctx
180                .get_socket()
181                .poll_send_to(cx, buf, *write_ctx.get_peer_addr())
182        }
183    }
184
185    #[cfg(feature = "udp-timeout")]
186    const DEFAULT_TIMEOUT: Duration = Duration::from_secs(20);
187
188    #[cfg(feature = "udp-timeout")]
189    static mut CUSTOM_TIMEOUT: Option<Duration> = None;
190
191    /// Set custom timeout.
192    /// Note that this function can only be called before the [`TIMEOUT`] lazy variable is created.
193    #[cfg(feature = "udp-timeout")]
194    pub fn set_custom_timeout(timeout: Duration) {
195        unsafe { CUSTOM_TIMEOUT = Some(timeout) }
196    }
197
198    #[cfg(feature = "udp-timeout")]
199    static TIMEOUT: Lazy<Duration> = Lazy::new(|| match unsafe { CUSTOM_TIMEOUT } {
200        Some(dur) => dur,
201        None => DEFAULT_TIMEOUT,
202    });
203
204    #[cfg(feature = "udp-timeout")]
205    #[inline]
206    pub(super) fn get_timeout_duration() -> Duration {
207        *TIMEOUT
208    }
209
210    #[cfg(feature = "udp-timeout")]
211    #[inline]
212    pub(super) fn get_sleep() -> Sleep {
213        sleep(get_timeout_duration())
214    }
215}
216
217#[cfg(feature = "udp-timeout")]
218pub use impl_inner::set_custom_timeout;
219
220/// An I/O object representing a UDP socket listening for incoming connections.
221///
222/// This object can be converted into a stream of incoming connections for
223/// various forms of processing.
224pub struct UdpListener {
225    handler: tokio::task::JoinHandle<()>,
226    receiver: AsyncReceiver<(UdpStream, SocketAddr)>,
227    local_addr: SocketAddr,
228}
229
230impl Drop for UdpListener {
231    fn drop(&mut self) {
232        self.handler.abort();
233    }
234}
235
236impl UdpListener {
237    /// Usage is exactly the same as [`tokio::net::TcpListener::bind`]
238    pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
239        each_addr(addr, UdpListener::bind_inner).await
240    }
241
242    async fn bind_inner(local_addr: SocketAddr) -> Result<Self> {
243        let (listener_tx, listener_rx) = kanal_plus::bounded_async(UDP_CHANNEL_LEN);
244        let udp_socket = UdpSocket::bind(local_addr).await?;
245        tune_udp_socket(&udp_socket);
246        let local_addr = udp_socket.local_addr()?;
247
248        let handler = tokio::spawn(async move {
249            let mut streams: HashMap<SocketAddr, AsyncSender<Bytes>> = HashMap::new();
250            let socket = Arc::new(udp_socket);
251            let (drop_tx, drop_rx) = kanal_plus::bounded_async(10);
252            let mut drop_buf = Vec::with_capacity(10);
253
254            let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE * 3);
255            loop {
256                if buf.capacity() < UDP_BUFFER_SIZE {
257                    buf.reserve(UDP_BUFFER_SIZE * 3);
258                }
259                buf.clear();
260                tokio::select! {
261                    result = drop_rx.drain_into_blocking(&mut drop_buf) => {
262                        match result {
263                            Ok(_) => {
264                                for peer_addr in drop_buf.drain(..) {
265                                    streams.remove(&peer_addr);
266                                }
267                            }
268                            Err(err) => {
269                                tracing::error!("UdpListener cleanup recv error: {err}");
270                                drop_buf.clear();
271                            }
272                        }
273                    }
274                    ret = socket.recv_buf_from(&mut buf) => {
275                        let (len,peer_addr) = error_get_or_continue!(ret,"UdpListener `recv_buf_from`");
276                        tracing::debug!("udp listener recv {len} bytes from {peer_addr}");
277                        match streams.get(&peer_addr) {
278                            Some(tx) => {
279                                if let Err(err) =  tx.send(buf.copy_to_bytes(len)).await{
280                                    tracing::error!("UDPListener send msg to conn, detail:{err}");
281                                    streams.remove(&peer_addr);
282                                    continue;
283                                }
284                            }
285                            None => {
286                                let (child_tx, child_rx) = kanal_plus::bounded_async(UDP_CHANNEL_LEN);
287                                // pre send msg
288                                error_get_or_continue!(
289                                    child_tx.send(buf.copy_to_bytes(len)).await,
290                                    "new conn pre send msg"
291                                );
292
293                                let udp_stream = UdpStream {
294                                    is_connect: false,
295                                    local_addr,
296                                    peer_addr,
297                                    #[cfg(feature = "udp-timeout")]
298                                    timeout: Box::pin(get_sleep()),
299                                    recv_stream: receiver_stream(child_rx.clone()),
300                                    receiver: child_rx,
301                                    socket: socket.clone(),
302                                    _handler_guard: None,
303                                    _listener_guard: Some(ListenerCleanGuard {
304                                        sender: drop_tx.clone(),
305                                        peer_addr,
306                                    }),
307                                    remaining: None,
308                                };
309                                error_get_or_continue!(
310                                    listener_tx.send((udp_stream, peer_addr)).await,
311                                    "register UDPStream"
312                                );
313                                streams.insert(peer_addr, child_tx);
314                            }
315                        }
316                    }
317                }
318            }
319        });
320        Ok(Self {
321            handler,
322            receiver: listener_rx,
323            local_addr,
324        })
325    }
326
327    /// Returns the local address that this socket is bound to.
328    pub fn local_addr(&self) -> io::Result<SocketAddr> {
329        Ok(self.local_addr)
330    }
331
332    /// Accepts a new incoming UDP connection.
333    pub async fn accept(&self) -> io::Result<(UdpStream, SocketAddr)> {
334        self.receiver
335            .recv()
336            .await
337            .map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))
338    }
339}
340
341#[derive(Debug)]
342struct TaskJoinHandleGuard(tokio::task::JoinHandle<()>);
343
344#[derive(Debug, Clone)]
345struct ListenerCleanGuard {
346    sender: AsyncSender<SocketAddr>,
347    peer_addr: SocketAddr,
348}
349
350impl Drop for ListenerCleanGuard {
351    fn drop(&mut self) {
352        let _ = self.sender.try_send(self.peer_addr);
353    }
354}
355
356impl Drop for TaskJoinHandleGuard {
357    fn drop(&mut self) {
358        self.0.abort();
359    }
360}
361
362/// An I/O object representing a UDP stream connected to a remote endpoint.
363///
364/// A UDP stream can either be created by connecting to an endpoint, via the
365/// [`UdpStream::connect`] method, or by [UdpListener::accept] a connection from a listener.
366pub struct UdpStream {
367    is_connect: bool,
368    local_addr: SocketAddr,
369    peer_addr: SocketAddr,
370    socket: Arc<tokio::net::UdpSocket>,
371    receiver: AsyncReceiver<Bytes>,
372    #[cfg(feature = "udp-timeout")]
373    timeout: Pin<Box<Sleep>>,
374    recv_stream: Pin<Box<ReceiveStreamOwned<Bytes>>>,
375    remaining: Option<Bytes>,
376    _handler_guard: Option<TaskJoinHandleGuard>,
377    _listener_guard: Option<ListenerCleanGuard>,
378}
379
380impl UdpStream {
381    /// Create a new UDP stream connected to the specified address.
382    ///
383    /// This function will create a new UDP socket and attempt to connect it to
384    /// the `addr` provided. The returned future will be resolved once the
385    /// stream has successfully connected, or it will return an error if one
386    /// occurs.
387    pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<Self> {
388        each_addr(addr, UdpStream::connect_inner).await
389    }
390
391    async fn connect_inner(addr: SocketAddr) -> Result<Self> {
392        let local_addr: SocketAddr = if addr.is_ipv4() {
393            SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
394        } else {
395            SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
396        };
397        let socket = UdpSocket::bind(local_addr).await?;
398        tune_udp_socket(&socket);
399        socket.connect(&addr).await?;
400        Self::from_tokio(socket, true).await
401    }
402
403    /// Creates a new UdpStream from a tokio::net::UdpSocket.
404    /// This function is intended to be used to wrap a UDP socket from the tokio library.
405    /// Note: The UdpSocket must have the UdpSocket::connect method called before invoking this
406    /// function.
407    async fn from_tokio(socket: UdpSocket, is_connect: bool) -> Result<Self> {
408        tune_udp_socket(&socket);
409        let socket = Arc::new(socket);
410
411        let local_addr = socket.local_addr()?;
412        let peer_addr = socket.peer_addr()?;
413
414        let (tx, rx) = kanal_plus::bounded_async(UDP_CHANNEL_LEN);
415
416        let socket_inner = socket.clone();
417
418        let handler = tokio::spawn(async move {
419            let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE);
420            loop {
421                if buf.capacity() < UDP_BUFFER_SIZE {
422                    buf.reserve(UDP_BUFFER_SIZE * 3);
423                }
424                buf.clear();
425                let (len, received_addr) = match socket_inner.recv_buf_from(&mut buf).await {
426                    Ok(v) => v,
427                    Err(_) => break,
428                };
429                if received_addr != peer_addr {
430                    continue;
431                }
432                if tx.send(buf.copy_to_bytes(len)).await.is_err() {
433                    drop(tx);
434                    break;
435                }
436            }
437        });
438
439        Ok(UdpStream {
440            local_addr,
441            peer_addr,
442            #[cfg(feature = "udp-timeout")]
443            timeout: Box::pin(get_sleep()),
444            recv_stream: receiver_stream(rx.clone()),
445            receiver: rx,
446            socket,
447            _handler_guard: Some(TaskJoinHandleGuard(handler)),
448            _listener_guard: None,
449            remaining: None,
450            is_connect,
451        })
452    }
453
454    /// Return peer address
455    pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
456        Ok(self.peer_addr)
457    }
458
459    /// Return local address
460    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
461        Ok(self.local_addr)
462    }
463
464    /// Split into read side and write side to avoid borrow check, note that ownership is not
465    /// transferred
466    pub fn split(&self) -> (UdpStreamReadHalf, UdpStreamWriteHalf<'_>) {
467        (
468            UdpStreamReadHalf {
469                recv_stream: receiver_stream(self.receiver.clone()),
470                remaining: self.remaining.clone(),
471                #[cfg(feature = "udp-timeout")]
472                timeout: Box::pin(get_sleep()),
473            },
474            UdpStreamWriteHalf {
475                is_connect: self.is_connect,
476                socket: &self.socket,
477                peer_addr: self.peer_addr,
478            },
479        )
480    }
481
482    /// Split into owned read and write halves.
483    pub fn into_split(self) -> (UdpStreamOwnedReadHalf, UdpStreamOwnedWriteHalf) {
484        let guard = Arc::new(UdpStreamGuard {
485            _handler_guard: self._handler_guard,
486            _listener_guard: self._listener_guard,
487        });
488        (
489            UdpStreamOwnedReadHalf {
490                recv_stream: self.recv_stream,
491                remaining: self.remaining,
492                #[cfg(feature = "udp-timeout")]
493                timeout: self.timeout,
494                _guard: guard.clone(),
495            },
496            UdpStreamOwnedWriteHalf {
497                is_connect: self.is_connect,
498                socket: self.socket,
499                peer_addr: self.peer_addr,
500                _guard: guard,
501            },
502        )
503    }
504
505    /// Send a single UDP datagram to the connected peer.
506    pub async fn send_datagram(&self, data: &[u8]) -> io::Result<()> {
507        let sent = if self.is_connect {
508            self.socket.send(data).await?
509        } else {
510            self.socket.send_to(data, self.peer_addr).await?
511        };
512        if sent != data.len() {
513            return Err(io::Error::new(
514                io::ErrorKind::WriteZero,
515                "udp datagram truncated",
516            ));
517        }
518        Ok(())
519    }
520}
521
522impl UdpStreamReadContext for std::pin::Pin<&mut UdpStream> {
523    fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes> {
524        &mut self.remaining
525    }
526
527    fn get_receiver_stream(&mut self) -> &mut Pin<Box<ReceiveStreamOwned<Bytes>>> {
528        &mut self.recv_stream
529    }
530
531    #[cfg(feature = "udp-timeout")]
532    fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>> {
533        &mut self.timeout
534    }
535}
536
537impl UdpStreamWriteContext for std::pin::Pin<&mut UdpStream> {
538    fn get_socket(&self) -> &tokio::net::UdpSocket {
539        &self.socket
540    }
541
542    fn get_peer_addr(&self) -> &SocketAddr {
543        &self.peer_addr
544    }
545
546    fn is_connect(&self) -> bool {
547        self.is_connect
548    }
549}
550
551impl AsyncRead for UdpStream {
552    fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<Result<()>> {
553        impl_inner::poll_read(self, cx, buf)
554    }
555}
556
557impl AsyncWrite for UdpStream {
558    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
559        impl_inner::poll_write(self, cx, buf)
560    }
561
562    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
563        Poll::Ready(Ok(()))
564    }
565
566    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
567        Poll::Ready(Ok(()))
568    }
569}
570
571/// [`UdpStream`] read-side implementation
572pub struct UdpStreamReadHalf {
573    recv_stream: Pin<Box<ReceiveStreamOwned<Bytes>>>,
574    remaining: Option<Bytes>,
575    #[cfg(feature = "udp-timeout")]
576    timeout: Pin<Box<Sleep>>,
577}
578
579impl UdpStreamReadContext for std::pin::Pin<&mut UdpStreamReadHalf> {
580    fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes> {
581        &mut self.remaining
582    }
583
584    fn get_receiver_stream(&mut self) -> &mut Pin<Box<ReceiveStreamOwned<Bytes>>> {
585        &mut self.recv_stream
586    }
587
588    #[cfg(feature = "udp-timeout")]
589    fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>> {
590        &mut self.timeout
591    }
592}
593
594impl UdpStreamReadContext for std::pin::Pin<&mut UdpStreamOwnedReadHalf> {
595    fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes> {
596        &mut self.remaining
597    }
598
599    fn get_receiver_stream(&mut self) -> &mut Pin<Box<ReceiveStreamOwned<Bytes>>> {
600        &mut self.recv_stream
601    }
602
603    #[cfg(feature = "udp-timeout")]
604    fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>> {
605        &mut self.timeout
606    }
607}
608
609impl AsyncRead for UdpStreamReadHalf {
610    fn poll_read(
611        self: Pin<&mut Self>,
612        cx: &mut Context<'_>,
613        buf: &mut ReadBuf<'_>,
614    ) -> Poll<Result<()>> {
615        impl_inner::poll_read(self, cx, buf)
616    }
617}
618
619impl AsyncRead for UdpStreamOwnedReadHalf {
620    fn poll_read(
621        self: Pin<&mut Self>,
622        cx: &mut Context<'_>,
623        buf: &mut ReadBuf<'_>,
624    ) -> Poll<Result<()>> {
625        impl_inner::poll_read(self, cx, buf)
626    }
627}
628
629impl UdpStreamReadHalf {
630    /// Receive a single UDP datagram as an owned buffer.
631    pub async fn recv_datagram(&mut self) -> io::Result<Bytes> {
632        if self.remaining.is_some() {
633            return Err(io::Error::new(
634                io::ErrorKind::InvalidData,
635                "udp stream has buffered bytes; cannot recv datagram",
636            ));
637        }
638
639        #[cfg(feature = "udp-timeout")]
640        let result = poll_fn(|cx| {
641            if self.timeout.as_mut().poll(cx).is_ready() {
642                return Poll::Ready(Err(io::Error::new(
643                    io::ErrorKind::TimedOut,
644                    format!(
645                        "UdpStream timeout with duration:{:?}",
646                        impl_inner::get_timeout_duration()
647                    ),
648                )));
649            }
650            match self.recv_stream.as_mut().poll_next(cx) {
651                Poll::Ready(Some(msg)) => Poll::Ready(Ok(msg)),
652                Poll::Ready(None) => Poll::Ready(Err(io::Error::new(
653                    io::ErrorKind::BrokenPipe,
654                    "Broken pipe",
655                ))),
656                Poll::Pending => Poll::Pending,
657            }
658        })
659        .await;
660
661        #[cfg(not(feature = "udp-timeout"))]
662        let result = poll_fn(|cx| match self.recv_stream.as_mut().poll_next(cx) {
663            Poll::Ready(Some(msg)) => Poll::Ready(Ok(msg)),
664            Poll::Ready(None) => Poll::Ready(Err(io::Error::new(
665                io::ErrorKind::BrokenPipe,
666                "Broken pipe",
667            ))),
668            Poll::Pending => Poll::Pending,
669        })
670        .await;
671
672        #[cfg(feature = "udp-timeout")]
673        if result.is_ok() {
674            self.timeout
675                .as_mut()
676                .reset(Instant::now() + impl_inner::get_timeout_duration());
677        }
678
679        result
680    }
681}
682
683/// [`UdpStream`] owned read-side implementation.
684pub struct UdpStreamOwnedReadHalf {
685    recv_stream: Pin<Box<ReceiveStreamOwned<Bytes>>>,
686    remaining: Option<Bytes>,
687    #[cfg(feature = "udp-timeout")]
688    timeout: Pin<Box<Sleep>>,
689    _guard: Arc<UdpStreamGuard>,
690}
691
692impl UdpStreamOwnedReadHalf {
693    /// Receive a single UDP datagram as an owned buffer.
694    pub async fn recv_datagram(&mut self) -> io::Result<Bytes> {
695        if self.remaining.is_some() {
696            return Err(io::Error::new(
697                io::ErrorKind::InvalidData,
698                "udp stream has buffered bytes; cannot recv datagram",
699            ));
700        }
701
702        #[cfg(feature = "udp-timeout")]
703        let result = poll_fn(|cx| {
704            if self.timeout.as_mut().poll(cx).is_ready() {
705                return Poll::Ready(Err(io::Error::new(
706                    io::ErrorKind::TimedOut,
707                    format!(
708                        "UdpStream timeout with duration:{:?}",
709                        impl_inner::get_timeout_duration()
710                    ),
711                )));
712            }
713            match self.recv_stream.as_mut().poll_next(cx) {
714                Poll::Ready(Some(msg)) => Poll::Ready(Ok(msg)),
715                Poll::Ready(None) => Poll::Ready(Err(io::Error::new(
716                    io::ErrorKind::BrokenPipe,
717                    "Broken pipe",
718                ))),
719                Poll::Pending => Poll::Pending,
720            }
721        })
722        .await;
723
724        #[cfg(not(feature = "udp-timeout"))]
725        let result = poll_fn(|cx| match self.recv_stream.as_mut().poll_next(cx) {
726            Poll::Ready(Some(msg)) => Poll::Ready(Ok(msg)),
727            Poll::Ready(None) => Poll::Ready(Err(io::Error::new(
728                io::ErrorKind::BrokenPipe,
729                "Broken pipe",
730            ))),
731            Poll::Pending => Poll::Pending,
732        })
733        .await;
734
735        #[cfg(feature = "udp-timeout")]
736        if result.is_ok() {
737            self.timeout
738                .as_mut()
739                .reset(Instant::now() + impl_inner::get_timeout_duration());
740        }
741
742        result
743    }
744}
745
746/// [`UdpStream`] write-side implementation
747pub struct UdpStreamWriteHalf<'a> {
748    is_connect: bool,
749    socket: &'a tokio::net::UdpSocket,
750    peer_addr: SocketAddr,
751}
752
753impl UdpStreamWriteContext for std::pin::Pin<&mut UdpStreamWriteHalf<'_>> {
754    fn get_socket(&self) -> &tokio::net::UdpSocket {
755        self.socket
756    }
757
758    fn get_peer_addr(&self) -> &SocketAddr {
759        &self.peer_addr
760    }
761
762    fn is_connect(&self) -> bool {
763        self.is_connect
764    }
765}
766
767impl UdpStreamWriteContext for std::pin::Pin<&mut UdpStreamOwnedWriteHalf> {
768    fn get_socket(&self) -> &tokio::net::UdpSocket {
769        &self.socket
770    }
771
772    fn get_peer_addr(&self) -> &SocketAddr {
773        &self.peer_addr
774    }
775
776    fn is_connect(&self) -> bool {
777        self.is_connect
778    }
779}
780
781impl AsyncWrite for UdpStreamWriteHalf<'_> {
782    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
783        impl_inner::poll_write(self, cx, buf)
784    }
785
786    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
787        Poll::Ready(Ok(()))
788    }
789
790    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
791        Poll::Ready(Ok(()))
792    }
793}
794
795impl AsyncWrite for UdpStreamOwnedWriteHalf {
796    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
797        impl_inner::poll_write(self, cx, buf)
798    }
799
800    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
801        Poll::Ready(Ok(()))
802    }
803
804    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
805        Poll::Ready(Ok(()))
806    }
807}
808
809impl UdpStreamWriteHalf<'_> {
810    /// Send a single UDP datagram.
811    pub async fn send_datagram(&self, data: &[u8]) -> io::Result<()> {
812        let sent = if self.is_connect {
813            self.socket.send(data).await?
814        } else {
815            self.socket.send_to(data, self.peer_addr).await?
816        };
817        if sent != data.len() {
818            return Err(io::Error::new(
819                io::ErrorKind::WriteZero,
820                "udp datagram truncated",
821            ));
822        }
823        Ok(())
824    }
825}
826
827/// [`UdpStream`] owned write-side implementation.
828pub struct UdpStreamOwnedWriteHalf {
829    is_connect: bool,
830    socket: Arc<tokio::net::UdpSocket>,
831    peer_addr: SocketAddr,
832    _guard: Arc<UdpStreamGuard>,
833}
834
835impl UdpStreamOwnedWriteHalf {
836    /// Send a single UDP datagram.
837    pub async fn send_datagram(&self, data: &[u8]) -> io::Result<()> {
838        let sent = if self.is_connect {
839            self.socket.send(data).await?
840        } else {
841            self.socket.send_to(data, self.peer_addr).await?
842        };
843        if sent != data.len() {
844            return Err(io::Error::new(
845                io::ErrorKind::WriteZero,
846                "udp datagram truncated",
847            ));
848        }
849        Ok(())
850    }
851}
852
853#[derive(Debug)]
854struct UdpStreamGuard {
855    _handler_guard: Option<TaskJoinHandleGuard>,
856    _listener_guard: Option<ListenerCleanGuard>,
857}