udp_stream/
lib.rs

1use bytes::{Buf, Bytes, BytesMut};
2use std::{
3    collections::HashMap,
4    future::Future,
5    io,
6    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
7    pin::Pin,
8    sync::Arc,
9    task::{Context, Poll},
10};
11use tokio::{
12    io::{AsyncRead, AsyncWrite, ReadBuf},
13    net::UdpSocket,
14    sync::{mpsc, Mutex},
15};
16
17const UDP_BUFFER_SIZE: usize = 17480; // 17kb
18                                      // const UDP_TIMEOUT: u64 = 10 * 1000; // 10sec
19const CHANNEL_LEN: usize = 100;
20
21/// An I/O object representing a UDP socket listening for incoming connections.
22///
23/// This object can be converted into a stream of incoming connections for
24/// various forms of processing.
25///
26/// # Examples
27///
28/// ```no_run
29/// use udp_stream::UdpListener;
30///
31/// use std::{io, net::SocketAddr, error::Error, str::FromStr};
32/// # async fn process_socket<T>(_socket: T) {}
33///
34/// #[tokio::main]
35/// async fn main() -> Result<(), Box<dyn Error>> {
36///     let mut listener = UdpListener::bind(SocketAddr::from_str("127.0.0.1:8080")?).await?;
37///
38///     loop {
39///         let (socket, _) = listener.accept().await?;
40///         process_socket(socket).await;
41///     }
42/// }
43/// ```
44pub struct UdpListener {
45    handler: tokio::task::JoinHandle<()>,
46    receiver: Arc<Mutex<mpsc::Receiver<(UdpStream, SocketAddr)>>>,
47    local_addr: SocketAddr,
48}
49
50impl Drop for UdpListener {
51    fn drop(&mut self) {
52        self.handler.abort();
53    }
54}
55
56impl UdpListener {
57    pub async fn bind(local_addr: SocketAddr) -> io::Result<Self> {
58        let (tx, rx) = mpsc::channel(CHANNEL_LEN);
59        let udp_socket = UdpSocket::bind(local_addr).await?;
60        let local_addr = udp_socket.local_addr()?;
61
62        let handler = tokio::spawn(async move {
63            let mut streams: HashMap<SocketAddr, mpsc::Sender<Bytes>> = HashMap::new();
64            let socket = Arc::new(udp_socket);
65            let (drop_tx, mut drop_rx) = mpsc::channel(1);
66
67            let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE * 3);
68            loop {
69                if buf.capacity() < UDP_BUFFER_SIZE {
70                    buf.reserve(UDP_BUFFER_SIZE * 3);
71                }
72                tokio::select! {
73                    Some(peer_addr) = drop_rx.recv() => {
74                        streams.remove(&peer_addr);
75                    }
76                    Ok((len, peer_addr)) = socket.recv_buf_from(&mut buf) => {
77                        match streams.get_mut(&peer_addr) {
78                            Some(child_tx) => {
79                                if let Err(err) = child_tx.send(buf.copy_to_bytes(len)).await {
80                                    log::error!("child_tx.send {:?}", err);
81                                    child_tx.closed().await;
82                                    streams.remove(&peer_addr);
83                                    continue;
84                                }
85                            }
86                            None => {
87                                let (child_tx, child_rx) = mpsc::channel(CHANNEL_LEN);
88                                if let Err(err) = child_tx.send(buf.copy_to_bytes(len)).await {
89                                    log::error!("child_tx.send {:?}", err);
90                                    continue;
91                                }
92                                let udp_stream = UdpStream {
93                                    local_addr,
94                                    peer_addr,
95                                    receiver: Arc::new(Mutex::new(child_rx)),
96                                    socket: socket.clone(),
97                                    handler: None,
98                                    drop: Some(drop_tx.clone()),
99                                    remaining: None,
100                                };
101                                if let Err(err) = tx.send((udp_stream, peer_addr)).await {
102                                    log::error!("tx.send {:?}", err);
103                                    continue;
104                                }
105                                streams.insert(peer_addr, child_tx.clone());
106                            }
107                        }
108                    }
109                }
110            }
111        });
112        Ok(Self {
113            handler,
114            receiver: Arc::new(Mutex::new(rx)),
115            local_addr,
116        })
117    }
118
119    ///Returns the local address that this socket is bound to.
120    pub fn local_addr(&self) -> io::Result<SocketAddr> {
121        Ok(self.local_addr)
122    }
123
124    /// Accepts a new incoming UDP connection.
125    pub async fn accept(&self) -> io::Result<(UdpStream, SocketAddr)> {
126        self.receiver
127            .lock()
128            .await
129            .recv()
130            .await
131            .ok_or(io::Error::from(io::ErrorKind::BrokenPipe))
132    }
133}
134
135/// An I/O object representing a UDP stream connected to a remote endpoint.
136///
137/// A UDP stream can either be created by connecting to an endpoint, via the
138/// [`connect`] method, or by [accepting] a connection from a [listener].
139///
140/// [`connect`]: struct.UdpStream.html#method.connect
141/// [accepting]: struct.UdpListener.html#method.accept
142/// [listener]: struct.UdpListener.html
143#[derive(Debug)]
144pub struct UdpStream {
145    local_addr: SocketAddr,
146    peer_addr: SocketAddr,
147    receiver: Arc<Mutex<mpsc::Receiver<Bytes>>>,
148    socket: Arc<tokio::net::UdpSocket>,
149    handler: Option<tokio::task::JoinHandle<()>>,
150    drop: Option<mpsc::Sender<SocketAddr>>,
151    remaining: Option<Bytes>,
152}
153
154impl Drop for UdpStream {
155    fn drop(&mut self) {
156        if let Some(handler) = &self.handler {
157            handler.abort()
158        }
159
160        if let Some(drop) = &self.drop {
161            let _ = drop.try_send(self.peer_addr);
162        };
163    }
164}
165
166impl UdpStream {
167    /// Create a new UDP stream connected to the specified address.
168    ///
169    /// This function will create a new UDP socket and attempt to connect it to
170    /// the `addr` provided. The returned future will be resolved once the
171    /// stream has successfully connected, or it will return an error if one
172    /// occurs.
173    pub async fn connect(addr: SocketAddr) -> Result<Self, tokio::io::Error> {
174        let local_addr: SocketAddr = if addr.is_ipv4() {
175            SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
176        } else {
177            SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
178        };
179        let socket = UdpSocket::bind(local_addr).await?;
180        Self::from_tokio(socket, addr).await
181    }
182    /// Creates a new UdpStream from a tokio::net::UdpSocket.
183    /// This function is intended to be used to wrap a UDP socket from the tokio library.
184    /// Note: The UdpSocket must have the UdpSocket::connect method called before invoking this function.
185    pub async fn from_tokio(
186        socket: UdpSocket,
187        peer_addr: SocketAddr,
188    ) -> Result<Self, tokio::io::Error> {
189        let socket = Arc::new(socket);
190
191        let local_addr = socket.local_addr()?;
192
193        let (child_tx, child_rx) = mpsc::channel(CHANNEL_LEN);
194
195        let socket_inner = socket.clone();
196
197        let handler = tokio::spawn(async move {
198            let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE);
199            while let Ok((len, received_addr)) = socket_inner.clone().recv_buf_from(&mut buf).await
200            {
201                if received_addr != peer_addr {
202                    continue;
203                }
204                if child_tx.send(buf.copy_to_bytes(len)).await.is_err() {
205                    child_tx.closed().await;
206                    break;
207                }
208
209                if buf.capacity() < UDP_BUFFER_SIZE {
210                    buf.reserve(UDP_BUFFER_SIZE * 3);
211                }
212            }
213        });
214
215        Ok(UdpStream {
216            local_addr,
217            peer_addr,
218            receiver: Arc::new(Mutex::new(child_rx)),
219            socket: socket.clone(),
220            handler: Some(handler),
221            drop: None,
222            remaining: None,
223        })
224    }
225
226    pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
227        Ok(self.peer_addr)
228    }
229    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
230        Ok(self.local_addr)
231    }
232    pub fn shutdown(&self) {
233        if let Some(drop) = &self.drop {
234            let _ = drop.try_send(self.peer_addr);
235        };
236    }
237}
238
239impl AsyncRead for UdpStream {
240    fn poll_read(
241        mut self: Pin<&mut Self>,
242        cx: &mut Context,
243        buf: &mut ReadBuf,
244    ) -> Poll<io::Result<()>> {
245        if let Some(remaining) = self.remaining.as_mut() {
246            if buf.remaining() < remaining.len() {
247                buf.put_slice(&remaining.split_to(buf.remaining())[..]);
248            } else {
249                buf.put_slice(&remaining[..]);
250                self.remaining = None;
251            }
252            return Poll::Ready(Ok(()));
253        }
254
255        let receiver = self.receiver.clone();
256        let mut socket = match Pin::new(&mut Box::pin(receiver.lock())).poll(cx) {
257            Poll::Ready(socket) => socket,
258            Poll::Pending => return Poll::Pending,
259        };
260
261        match socket.poll_recv(cx) {
262            Poll::Ready(Some(mut inner_buf)) => {
263                if buf.remaining() < inner_buf.len() {
264                    self.remaining = Some(inner_buf.split_off(buf.remaining()));
265                };
266                buf.put_slice(&inner_buf[..]);
267                Poll::Ready(Ok(()))
268            }
269            Poll::Ready(None) => Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe))),
270            Poll::Pending => Poll::Pending,
271        }
272    }
273}
274
275impl AsyncWrite for UdpStream {
276    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
277        match self.socket.poll_send_to(cx, buf, self.peer_addr) {
278            Poll::Ready(Ok(r)) => Poll::Ready(Ok(r)),
279            Poll::Ready(Err(e)) => {
280                if let Some(drop) = &self.drop {
281                    let _ = drop.try_send(self.peer_addr);
282                };
283                Poll::Ready(Err(e))
284            }
285            Poll::Pending => Poll::Pending,
286        }
287    }
288    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
289        Poll::Ready(Ok(()))
290    }
291    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
292        Poll::Ready(Ok(()))
293    }
294}