1use std::{
2 fmt,
3 future::Future,
4 io,
5 net::SocketAddr,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll}
9};
10
11use futures_core::Stream;
12use tokio::net::UdpSocket;
13
14pub(crate) type RecvFuture = Pin<Box<dyn Future<Output = io::Result<(Vec<u8>, usize, SocketAddr)>> + Send + Sync>>;
15
16pub(crate) struct UDPSocketStream {
17 pub(crate) socket: Arc<UdpSocket>,
18 future: Option<RecvFuture>,
19 buf: Option<Vec<u8>>
20}
21
22unsafe impl Send for UDPSocketStream {}
23
24impl Clone for UDPSocketStream {
25 fn clone(&self) -> Self {
26 Self::from_arc(self.socket.clone())
27 }
28}
29
30impl fmt::Debug for UDPSocketStream {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 f.debug_struct("UdpSocketStream").field("socket", &*self.socket).finish()
33 }
34}
35
36impl UDPSocketStream {
37 pub fn new(socket: UdpSocket) -> Self {
38 let socket = Arc::new(socket);
39 Self::from_arc(socket)
40 }
41
42 pub fn from_arc(socket: Arc<UdpSocket>) -> Self {
43 let buf = vec![0u8; 1024 * 64];
44 Self { socket, future: None, buf: Some(buf) }
45 }
46
47 pub fn get_ref(&self) -> &UdpSocket {
48 &self.socket
49 }
50
51 pub fn clone_inner(&self) -> Arc<UdpSocket> {
52 Arc::clone(&self.socket)
53 }
54}
55
56impl Stream for UDPSocketStream {
57 type Item = io::Result<(Vec<u8>, SocketAddr)>;
58
59 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
60 loop {
61 if self.future.is_none() {
62 let buf = self.buf.take().unwrap();
63 let future = recv_next(Arc::clone(&self.socket), buf);
64 self.future = Some(Box::pin(future));
65 }
66
67 if let Some(f) = &mut self.future {
68 let res = match f.as_mut().poll(cx) {
69 Poll::Ready(t) => t,
70 Poll::Pending => return Poll::Pending
71 };
72 self.future = None;
73 return match res {
74 Err(e) => Poll::Ready(Some(Err(e))),
75 Ok((buf, n, addr)) => {
76 let res_buf = buf[..n].to_vec();
77 self.buf = Some(buf);
78 Poll::Ready(Some(Ok((res_buf, addr))))
79 }
80 };
81 }
82 }
83 }
84}
85
86async fn recv_next(socket: Arc<UdpSocket>, mut buf: Vec<u8>) -> io::Result<(Vec<u8>, usize, SocketAddr)> {
87 let res = socket.recv_from(&mut buf).await;
88 match res {
89 Err(e) => Err(e),
90 Ok((n, addr)) => Ok((buf, n, addr))
91 }
92}