1use std::fmt::Debug;
4use std::future::Future;
5use std::io::{self};
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use bytes::{Buf, Bytes, BytesMut};
12use futures::future::poll_fn;
13use futures::Stream;
14use hashbrown::HashMap;
15use kanal_plus::{AsyncReceiver, AsyncSender, ReceiveStreamOwned};
16use socket2::SockRef;
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use tokio::net::UdpSocket;
19#[cfg(feature = "udp-timeout")]
20use tokio::time::Instant;
21#[cfg(feature = "udp-timeout")]
22use tokio::time::Sleep;
23
24use self::impl_inner::{UdpStreamReadContext, UdpStreamWriteContext};
25use super::addr::{each_addr, ToSocketAddrs};
26#[cfg(feature = "udp-timeout")]
27use crate::udp::impl_inner::get_sleep;
28
29const UDP_CHANNEL_LEN: usize = 100;
30const UDP_BUFFER_SIZE: usize = 65_507;
31const UDP_SOCKET_BUFFER_BYTES: usize = 4 * 1024 * 1024;
32
33type Result<T, E = std::io::Error> = std::result::Result<T, E>;
34
35fn receiver_stream<T: Send + 'static>(
36 receiver: AsyncReceiver<T>,
37) -> Pin<Box<ReceiveStreamOwned<T>>> {
38 Box::pin(receiver.into_stream())
39}
40
41#[cfg(not(target_os = "windows"))]
42pub fn tune_udp_socket(socket: &UdpSocket) {
44 let sock_ref = SockRef::from(socket);
45 if let Err(err) = sock_ref.set_recv_buffer_size(UDP_SOCKET_BUFFER_BYTES) {
46 tracing::warn!("failed to set udp recv buffer size: {err}");
47 }
48 if let Err(err) = sock_ref.set_send_buffer_size(UDP_SOCKET_BUFFER_BYTES) {
49 tracing::warn!("failed to set udp send buffer size: {err}");
50 }
51}
52
53#[cfg(target_os = "windows")]
54pub fn tune_udp_socket(socket: &UdpSocket) {
56 let sock_ref = SockRef::from(socket);
57 if let Err(err) = sock_ref.set_recv_buffer_size(UDP_SOCKET_BUFFER_BYTES) {
58 tracing::warn!("failed to set udp recv buffer size: {err}");
59 }
60 if let Err(err) = sock_ref.set_send_buffer_size(UDP_SOCKET_BUFFER_BYTES) {
61 tracing::warn!("failed to set udp send buffer size: {err}");
62 }
63}
64
65macro_rules! error_get_or_continue {
66 ($func_call:expr, $msg:expr) => {
67 match $func_call {
68 Ok(v) => v,
69 Err(e) => {
70 tracing::error!("{}, detail:{e}", $msg);
71 continue;
72 }
73 }
74 };
75}
76
77mod impl_inner {
78 #[cfg(feature = "udp-timeout")]
79 use std::time::Duration;
80
81 #[cfg(feature = "udp-timeout")]
82 use futures::FutureExt;
83 #[cfg(feature = "udp-timeout")]
84 use once_cell::sync::Lazy;
85 #[cfg(feature = "udp-timeout")]
86 use tokio::time::{sleep, Instant};
87
88 use super::*;
89
90 pub(super) trait UdpStreamReadContext {
91 fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes>;
92 fn get_receiver_stream(&mut self) -> &mut Pin<Box<ReceiveStreamOwned<Bytes>>>;
93 #[cfg(feature = "udp-timeout")]
94 fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>>;
95 }
96
97 pub(super) trait UdpStreamWriteContext {
98 fn is_connect(&self) -> bool;
99 fn get_socket(&self) -> &tokio::net::UdpSocket;
100 fn get_peer_addr(&self) -> &SocketAddr;
101 }
102
103 pub(super) fn poll_read<T: UdpStreamReadContext>(
104 mut read_ctx: T,
105 cx: &mut Context,
106 buf: &mut ReadBuf,
107 ) -> Poll<Result<()>> {
108 #[cfg(feature = "udp-timeout")]
110 if read_ctx.get_timeout().poll_unpin(cx).is_ready() {
111 buf.clear();
112 return Poll::Ready(Err(io::Error::new(
113 io::ErrorKind::TimedOut,
114 format!(
115 "UdpStream timeout with duration:{:?}",
116 get_timeout_duration()
117 ),
118 )));
119 }
120
121 #[cfg(feature = "udp-timeout")]
122 #[inline]
123 fn update_timeout(timeout: &mut Pin<Box<Sleep>>) {
124 timeout
125 .as_mut()
126 .reset(Instant::now() + get_timeout_duration())
127 }
128
129 let is_consume_remaining = if let Some(remaining) = read_ctx.get_mut_remaining_bytes() {
130 if buf.remaining() < remaining.len() {
131 buf.put_slice(&remaining.split_to(buf.remaining())[..]);
132 } else {
133 buf.put_slice(&remaining[..]);
134 *read_ctx.get_mut_remaining_bytes() = None;
135 }
136 true
137 } else {
138 false
139 };
140
141 if is_consume_remaining {
142 #[cfg(feature = "udp-timeout")]
143 update_timeout(read_ctx.get_timeout());
144 return Poll::Ready(Ok(()));
145 }
146
147 let remaining = match read_ctx.get_receiver_stream().as_mut().poll_next(cx) {
148 Poll::Ready(Some(mut inner_buf)) => {
149 let remaining = if buf.remaining() < inner_buf.len() {
150 Some(inner_buf.split_off(buf.remaining()))
151 } else {
152 None
153 };
154 buf.put_slice(&inner_buf[..]);
155 remaining
156 }
157 Poll::Ready(None) => {
158 return Poll::Ready(Err(io::Error::new(
159 io::ErrorKind::BrokenPipe,
160 "Broken pipe",
161 )));
162 }
163 Poll::Pending => return Poll::Pending,
164 };
165 #[cfg(feature = "udp-timeout")]
166 update_timeout(read_ctx.get_timeout());
167 *read_ctx.get_mut_remaining_bytes() = remaining;
168 Poll::Ready(Ok(()))
169 }
170
171 pub(super) fn poll_write<T: UdpStreamWriteContext>(
172 write_ctx: T,
173 cx: &mut Context,
174 buf: &[u8],
175 ) -> Poll<Result<usize>> {
176 if write_ctx.is_connect() {
177 write_ctx.get_socket().poll_send(cx, buf)
178 } else {
179 write_ctx
180 .get_socket()
181 .poll_send_to(cx, buf, *write_ctx.get_peer_addr())
182 }
183 }
184
185 #[cfg(feature = "udp-timeout")]
186 const DEFAULT_TIMEOUT: Duration = Duration::from_secs(20);
187
188 #[cfg(feature = "udp-timeout")]
189 static mut CUSTOM_TIMEOUT: Option<Duration> = None;
190
191 #[cfg(feature = "udp-timeout")]
194 pub fn set_custom_timeout(timeout: Duration) {
195 unsafe { CUSTOM_TIMEOUT = Some(timeout) }
196 }
197
198 #[cfg(feature = "udp-timeout")]
199 static TIMEOUT: Lazy<Duration> = Lazy::new(|| match unsafe { CUSTOM_TIMEOUT } {
200 Some(dur) => dur,
201 None => DEFAULT_TIMEOUT,
202 });
203
204 #[cfg(feature = "udp-timeout")]
205 #[inline]
206 pub(super) fn get_timeout_duration() -> Duration {
207 *TIMEOUT
208 }
209
210 #[cfg(feature = "udp-timeout")]
211 #[inline]
212 pub(super) fn get_sleep() -> Sleep {
213 sleep(get_timeout_duration())
214 }
215}
216
217#[cfg(feature = "udp-timeout")]
218pub use impl_inner::set_custom_timeout;
219
220pub struct UdpListener {
225 handler: tokio::task::JoinHandle<()>,
226 receiver: AsyncReceiver<(UdpStream, SocketAddr)>,
227 local_addr: SocketAddr,
228}
229
230impl Drop for UdpListener {
231 fn drop(&mut self) {
232 self.handler.abort();
233 }
234}
235
236impl UdpListener {
237 pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
239 each_addr(addr, UdpListener::bind_inner).await
240 }
241
242 async fn bind_inner(local_addr: SocketAddr) -> Result<Self> {
243 let (listener_tx, listener_rx) = kanal_plus::bounded_async(UDP_CHANNEL_LEN);
244 let udp_socket = UdpSocket::bind(local_addr).await?;
245 tune_udp_socket(&udp_socket);
246 let local_addr = udp_socket.local_addr()?;
247
248 let handler = tokio::spawn(async move {
249 let mut streams: HashMap<SocketAddr, AsyncSender<Bytes>> = HashMap::new();
250 let socket = Arc::new(udp_socket);
251 let (drop_tx, drop_rx) = kanal_plus::bounded_async(10);
252 let mut drop_buf = Vec::with_capacity(10);
253
254 let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE * 3);
255 loop {
256 if buf.capacity() < UDP_BUFFER_SIZE {
257 buf.reserve(UDP_BUFFER_SIZE * 3);
258 }
259 buf.clear();
260 tokio::select! {
261 result = drop_rx.drain_into_blocking(&mut drop_buf) => {
262 match result {
263 Ok(_) => {
264 for peer_addr in drop_buf.drain(..) {
265 streams.remove(&peer_addr);
266 }
267 }
268 Err(err) => {
269 tracing::error!("UdpListener cleanup recv error: {err}");
270 drop_buf.clear();
271 }
272 }
273 }
274 ret = socket.recv_buf_from(&mut buf) => {
275 let (len,peer_addr) = error_get_or_continue!(ret,"UdpListener `recv_buf_from`");
276 tracing::debug!("udp listener recv {len} bytes from {peer_addr}");
277 match streams.get(&peer_addr) {
278 Some(tx) => {
279 if let Err(err) = tx.send(buf.copy_to_bytes(len)).await{
280 tracing::error!("UDPListener send msg to conn, detail:{err}");
281 streams.remove(&peer_addr);
282 continue;
283 }
284 }
285 None => {
286 let (child_tx, child_rx) = kanal_plus::bounded_async(UDP_CHANNEL_LEN);
287 error_get_or_continue!(
289 child_tx.send(buf.copy_to_bytes(len)).await,
290 "new conn pre send msg"
291 );
292
293 let udp_stream = UdpStream {
294 is_connect: false,
295 local_addr,
296 peer_addr,
297 #[cfg(feature = "udp-timeout")]
298 timeout: Box::pin(get_sleep()),
299 recv_stream: receiver_stream(child_rx.clone()),
300 receiver: child_rx,
301 socket: socket.clone(),
302 _handler_guard: None,
303 _listener_guard: Some(ListenerCleanGuard {
304 sender: drop_tx.clone(),
305 peer_addr,
306 }),
307 remaining: None,
308 };
309 error_get_or_continue!(
310 listener_tx.send((udp_stream, peer_addr)).await,
311 "register UDPStream"
312 );
313 streams.insert(peer_addr, child_tx);
314 }
315 }
316 }
317 }
318 }
319 });
320 Ok(Self {
321 handler,
322 receiver: listener_rx,
323 local_addr,
324 })
325 }
326
327 pub fn local_addr(&self) -> io::Result<SocketAddr> {
329 Ok(self.local_addr)
330 }
331
332 pub async fn accept(&self) -> io::Result<(UdpStream, SocketAddr)> {
334 self.receiver
335 .recv()
336 .await
337 .map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))
338 }
339}
340
341#[derive(Debug)]
342struct TaskJoinHandleGuard(tokio::task::JoinHandle<()>);
343
344#[derive(Debug, Clone)]
345struct ListenerCleanGuard {
346 sender: AsyncSender<SocketAddr>,
347 peer_addr: SocketAddr,
348}
349
350impl Drop for ListenerCleanGuard {
351 fn drop(&mut self) {
352 let _ = self.sender.try_send(self.peer_addr);
353 }
354}
355
356impl Drop for TaskJoinHandleGuard {
357 fn drop(&mut self) {
358 self.0.abort();
359 }
360}
361
362pub struct UdpStream {
367 is_connect: bool,
368 local_addr: SocketAddr,
369 peer_addr: SocketAddr,
370 socket: Arc<tokio::net::UdpSocket>,
371 receiver: AsyncReceiver<Bytes>,
372 #[cfg(feature = "udp-timeout")]
373 timeout: Pin<Box<Sleep>>,
374 recv_stream: Pin<Box<ReceiveStreamOwned<Bytes>>>,
375 remaining: Option<Bytes>,
376 _handler_guard: Option<TaskJoinHandleGuard>,
377 _listener_guard: Option<ListenerCleanGuard>,
378}
379
380impl UdpStream {
381 pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<Self> {
388 each_addr(addr, UdpStream::connect_inner).await
389 }
390
391 async fn connect_inner(addr: SocketAddr) -> Result<Self> {
392 let local_addr: SocketAddr = if addr.is_ipv4() {
393 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
394 } else {
395 SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
396 };
397 let socket = UdpSocket::bind(local_addr).await?;
398 tune_udp_socket(&socket);
399 socket.connect(&addr).await?;
400 Self::from_tokio(socket, true).await
401 }
402
403 async fn from_tokio(socket: UdpSocket, is_connect: bool) -> Result<Self> {
408 tune_udp_socket(&socket);
409 let socket = Arc::new(socket);
410
411 let local_addr = socket.local_addr()?;
412 let peer_addr = socket.peer_addr()?;
413
414 let (tx, rx) = kanal_plus::bounded_async(UDP_CHANNEL_LEN);
415
416 let socket_inner = socket.clone();
417
418 let handler = tokio::spawn(async move {
419 let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE);
420 loop {
421 if buf.capacity() < UDP_BUFFER_SIZE {
422 buf.reserve(UDP_BUFFER_SIZE * 3);
423 }
424 buf.clear();
425 let (len, received_addr) = match socket_inner.recv_buf_from(&mut buf).await {
426 Ok(v) => v,
427 Err(_) => break,
428 };
429 if received_addr != peer_addr {
430 continue;
431 }
432 if tx.send(buf.copy_to_bytes(len)).await.is_err() {
433 drop(tx);
434 break;
435 }
436 }
437 });
438
439 Ok(UdpStream {
440 local_addr,
441 peer_addr,
442 #[cfg(feature = "udp-timeout")]
443 timeout: Box::pin(get_sleep()),
444 recv_stream: receiver_stream(rx.clone()),
445 receiver: rx,
446 socket,
447 _handler_guard: Some(TaskJoinHandleGuard(handler)),
448 _listener_guard: None,
449 remaining: None,
450 is_connect,
451 })
452 }
453
454 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
456 Ok(self.peer_addr)
457 }
458
459 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
461 Ok(self.local_addr)
462 }
463
464 pub fn split(&self) -> (UdpStreamReadHalf, UdpStreamWriteHalf<'_>) {
467 (
468 UdpStreamReadHalf {
469 recv_stream: receiver_stream(self.receiver.clone()),
470 remaining: self.remaining.clone(),
471 #[cfg(feature = "udp-timeout")]
472 timeout: Box::pin(get_sleep()),
473 },
474 UdpStreamWriteHalf {
475 is_connect: self.is_connect,
476 socket: &self.socket,
477 peer_addr: self.peer_addr,
478 },
479 )
480 }
481
482 pub fn into_split(self) -> (UdpStreamOwnedReadHalf, UdpStreamOwnedWriteHalf) {
484 let guard = Arc::new(UdpStreamGuard {
485 _handler_guard: self._handler_guard,
486 _listener_guard: self._listener_guard,
487 });
488 (
489 UdpStreamOwnedReadHalf {
490 recv_stream: self.recv_stream,
491 remaining: self.remaining,
492 #[cfg(feature = "udp-timeout")]
493 timeout: self.timeout,
494 _guard: guard.clone(),
495 },
496 UdpStreamOwnedWriteHalf {
497 is_connect: self.is_connect,
498 socket: self.socket,
499 peer_addr: self.peer_addr,
500 _guard: guard,
501 },
502 )
503 }
504
505 pub async fn send_datagram(&self, data: &[u8]) -> io::Result<()> {
507 let sent = if self.is_connect {
508 self.socket.send(data).await?
509 } else {
510 self.socket.send_to(data, self.peer_addr).await?
511 };
512 if sent != data.len() {
513 return Err(io::Error::new(
514 io::ErrorKind::WriteZero,
515 "udp datagram truncated",
516 ));
517 }
518 Ok(())
519 }
520}
521
522impl UdpStreamReadContext for std::pin::Pin<&mut UdpStream> {
523 fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes> {
524 &mut self.remaining
525 }
526
527 fn get_receiver_stream(&mut self) -> &mut Pin<Box<ReceiveStreamOwned<Bytes>>> {
528 &mut self.recv_stream
529 }
530
531 #[cfg(feature = "udp-timeout")]
532 fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>> {
533 &mut self.timeout
534 }
535}
536
537impl UdpStreamWriteContext for std::pin::Pin<&mut UdpStream> {
538 fn get_socket(&self) -> &tokio::net::UdpSocket {
539 &self.socket
540 }
541
542 fn get_peer_addr(&self) -> &SocketAddr {
543 &self.peer_addr
544 }
545
546 fn is_connect(&self) -> bool {
547 self.is_connect
548 }
549}
550
551impl AsyncRead for UdpStream {
552 fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<Result<()>> {
553 impl_inner::poll_read(self, cx, buf)
554 }
555}
556
557impl AsyncWrite for UdpStream {
558 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
559 impl_inner::poll_write(self, cx, buf)
560 }
561
562 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
563 Poll::Ready(Ok(()))
564 }
565
566 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
567 Poll::Ready(Ok(()))
568 }
569}
570
571pub struct UdpStreamReadHalf {
573 recv_stream: Pin<Box<ReceiveStreamOwned<Bytes>>>,
574 remaining: Option<Bytes>,
575 #[cfg(feature = "udp-timeout")]
576 timeout: Pin<Box<Sleep>>,
577}
578
579impl UdpStreamReadContext for std::pin::Pin<&mut UdpStreamReadHalf> {
580 fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes> {
581 &mut self.remaining
582 }
583
584 fn get_receiver_stream(&mut self) -> &mut Pin<Box<ReceiveStreamOwned<Bytes>>> {
585 &mut self.recv_stream
586 }
587
588 #[cfg(feature = "udp-timeout")]
589 fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>> {
590 &mut self.timeout
591 }
592}
593
594impl UdpStreamReadContext for std::pin::Pin<&mut UdpStreamOwnedReadHalf> {
595 fn get_mut_remaining_bytes(&mut self) -> &mut Option<Bytes> {
596 &mut self.remaining
597 }
598
599 fn get_receiver_stream(&mut self) -> &mut Pin<Box<ReceiveStreamOwned<Bytes>>> {
600 &mut self.recv_stream
601 }
602
603 #[cfg(feature = "udp-timeout")]
604 fn get_timeout(&mut self) -> &mut Pin<Box<Sleep>> {
605 &mut self.timeout
606 }
607}
608
609impl AsyncRead for UdpStreamReadHalf {
610 fn poll_read(
611 self: Pin<&mut Self>,
612 cx: &mut Context<'_>,
613 buf: &mut ReadBuf<'_>,
614 ) -> Poll<Result<()>> {
615 impl_inner::poll_read(self, cx, buf)
616 }
617}
618
619impl AsyncRead for UdpStreamOwnedReadHalf {
620 fn poll_read(
621 self: Pin<&mut Self>,
622 cx: &mut Context<'_>,
623 buf: &mut ReadBuf<'_>,
624 ) -> Poll<Result<()>> {
625 impl_inner::poll_read(self, cx, buf)
626 }
627}
628
629impl UdpStreamReadHalf {
630 pub async fn recv_datagram(&mut self) -> io::Result<Bytes> {
632 if self.remaining.is_some() {
633 return Err(io::Error::new(
634 io::ErrorKind::InvalidData,
635 "udp stream has buffered bytes; cannot recv datagram",
636 ));
637 }
638
639 #[cfg(feature = "udp-timeout")]
640 let result = poll_fn(|cx| {
641 if self.timeout.as_mut().poll(cx).is_ready() {
642 return Poll::Ready(Err(io::Error::new(
643 io::ErrorKind::TimedOut,
644 format!(
645 "UdpStream timeout with duration:{:?}",
646 impl_inner::get_timeout_duration()
647 ),
648 )));
649 }
650 match self.recv_stream.as_mut().poll_next(cx) {
651 Poll::Ready(Some(msg)) => Poll::Ready(Ok(msg)),
652 Poll::Ready(None) => Poll::Ready(Err(io::Error::new(
653 io::ErrorKind::BrokenPipe,
654 "Broken pipe",
655 ))),
656 Poll::Pending => Poll::Pending,
657 }
658 })
659 .await;
660
661 #[cfg(not(feature = "udp-timeout"))]
662 let result = poll_fn(|cx| match self.recv_stream.as_mut().poll_next(cx) {
663 Poll::Ready(Some(msg)) => Poll::Ready(Ok(msg)),
664 Poll::Ready(None) => Poll::Ready(Err(io::Error::new(
665 io::ErrorKind::BrokenPipe,
666 "Broken pipe",
667 ))),
668 Poll::Pending => Poll::Pending,
669 })
670 .await;
671
672 #[cfg(feature = "udp-timeout")]
673 if result.is_ok() {
674 self.timeout
675 .as_mut()
676 .reset(Instant::now() + impl_inner::get_timeout_duration());
677 }
678
679 result
680 }
681}
682
683pub struct UdpStreamOwnedReadHalf {
685 recv_stream: Pin<Box<ReceiveStreamOwned<Bytes>>>,
686 remaining: Option<Bytes>,
687 #[cfg(feature = "udp-timeout")]
688 timeout: Pin<Box<Sleep>>,
689 _guard: Arc<UdpStreamGuard>,
690}
691
692impl UdpStreamOwnedReadHalf {
693 pub async fn recv_datagram(&mut self) -> io::Result<Bytes> {
695 if self.remaining.is_some() {
696 return Err(io::Error::new(
697 io::ErrorKind::InvalidData,
698 "udp stream has buffered bytes; cannot recv datagram",
699 ));
700 }
701
702 #[cfg(feature = "udp-timeout")]
703 let result = poll_fn(|cx| {
704 if self.timeout.as_mut().poll(cx).is_ready() {
705 return Poll::Ready(Err(io::Error::new(
706 io::ErrorKind::TimedOut,
707 format!(
708 "UdpStream timeout with duration:{:?}",
709 impl_inner::get_timeout_duration()
710 ),
711 )));
712 }
713 match self.recv_stream.as_mut().poll_next(cx) {
714 Poll::Ready(Some(msg)) => Poll::Ready(Ok(msg)),
715 Poll::Ready(None) => Poll::Ready(Err(io::Error::new(
716 io::ErrorKind::BrokenPipe,
717 "Broken pipe",
718 ))),
719 Poll::Pending => Poll::Pending,
720 }
721 })
722 .await;
723
724 #[cfg(not(feature = "udp-timeout"))]
725 let result = poll_fn(|cx| match self.recv_stream.as_mut().poll_next(cx) {
726 Poll::Ready(Some(msg)) => Poll::Ready(Ok(msg)),
727 Poll::Ready(None) => Poll::Ready(Err(io::Error::new(
728 io::ErrorKind::BrokenPipe,
729 "Broken pipe",
730 ))),
731 Poll::Pending => Poll::Pending,
732 })
733 .await;
734
735 #[cfg(feature = "udp-timeout")]
736 if result.is_ok() {
737 self.timeout
738 .as_mut()
739 .reset(Instant::now() + impl_inner::get_timeout_duration());
740 }
741
742 result
743 }
744}
745
746pub struct UdpStreamWriteHalf<'a> {
748 is_connect: bool,
749 socket: &'a tokio::net::UdpSocket,
750 peer_addr: SocketAddr,
751}
752
753impl UdpStreamWriteContext for std::pin::Pin<&mut UdpStreamWriteHalf<'_>> {
754 fn get_socket(&self) -> &tokio::net::UdpSocket {
755 self.socket
756 }
757
758 fn get_peer_addr(&self) -> &SocketAddr {
759 &self.peer_addr
760 }
761
762 fn is_connect(&self) -> bool {
763 self.is_connect
764 }
765}
766
767impl UdpStreamWriteContext for std::pin::Pin<&mut UdpStreamOwnedWriteHalf> {
768 fn get_socket(&self) -> &tokio::net::UdpSocket {
769 &self.socket
770 }
771
772 fn get_peer_addr(&self) -> &SocketAddr {
773 &self.peer_addr
774 }
775
776 fn is_connect(&self) -> bool {
777 self.is_connect
778 }
779}
780
781impl AsyncWrite for UdpStreamWriteHalf<'_> {
782 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
783 impl_inner::poll_write(self, cx, buf)
784 }
785
786 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
787 Poll::Ready(Ok(()))
788 }
789
790 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
791 Poll::Ready(Ok(()))
792 }
793}
794
795impl AsyncWrite for UdpStreamOwnedWriteHalf {
796 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
797 impl_inner::poll_write(self, cx, buf)
798 }
799
800 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
801 Poll::Ready(Ok(()))
802 }
803
804 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
805 Poll::Ready(Ok(()))
806 }
807}
808
809impl UdpStreamWriteHalf<'_> {
810 pub async fn send_datagram(&self, data: &[u8]) -> io::Result<()> {
812 let sent = if self.is_connect {
813 self.socket.send(data).await?
814 } else {
815 self.socket.send_to(data, self.peer_addr).await?
816 };
817 if sent != data.len() {
818 return Err(io::Error::new(
819 io::ErrorKind::WriteZero,
820 "udp datagram truncated",
821 ));
822 }
823 Ok(())
824 }
825}
826
827pub struct UdpStreamOwnedWriteHalf {
829 is_connect: bool,
830 socket: Arc<tokio::net::UdpSocket>,
831 peer_addr: SocketAddr,
832 _guard: Arc<UdpStreamGuard>,
833}
834
835impl UdpStreamOwnedWriteHalf {
836 pub async fn send_datagram(&self, data: &[u8]) -> io::Result<()> {
838 let sent = if self.is_connect {
839 self.socket.send(data).await?
840 } else {
841 self.socket.send_to(data, self.peer_addr).await?
842 };
843 if sent != data.len() {
844 return Err(io::Error::new(
845 io::ErrorKind::WriteZero,
846 "udp datagram truncated",
847 ));
848 }
849 Ok(())
850 }
851}
852
853#[derive(Debug)]
854struct UdpStreamGuard {
855 _handler_guard: Option<TaskJoinHandleGuard>,
856 _listener_guard: Option<ListenerCleanGuard>,
857}