1use bytes::{Buf, Bytes, BytesMut};
2use std::{
3 collections::HashMap,
4 future::Future,
5 io,
6 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
7 pin::Pin,
8 sync::Arc,
9 task::{Context, Poll},
10};
11use tokio::{
12 io::{AsyncRead, AsyncWrite, ReadBuf},
13 net::UdpSocket,
14 sync::{mpsc, Mutex},
15};
16
17const UDP_BUFFER_SIZE: usize = 17480; const CHANNEL_LEN: usize = 100;
20
21pub struct UdpListener {
45 handler: tokio::task::JoinHandle<()>,
46 receiver: Arc<Mutex<mpsc::Receiver<(UdpStream, SocketAddr)>>>,
47 local_addr: SocketAddr,
48}
49
50impl Drop for UdpListener {
51 fn drop(&mut self) {
52 self.handler.abort();
53 }
54}
55
56impl UdpListener {
57 pub async fn bind(local_addr: SocketAddr) -> io::Result<Self> {
58 let (tx, rx) = mpsc::channel(CHANNEL_LEN);
59 let udp_socket = UdpSocket::bind(local_addr).await?;
60 let local_addr = udp_socket.local_addr()?;
61
62 let handler = tokio::spawn(async move {
63 let mut streams: HashMap<SocketAddr, mpsc::Sender<Bytes>> = HashMap::new();
64 let socket = Arc::new(udp_socket);
65 let (drop_tx, mut drop_rx) = mpsc::channel(1);
66
67 let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE * 3);
68 loop {
69 if buf.capacity() < UDP_BUFFER_SIZE {
70 buf.reserve(UDP_BUFFER_SIZE * 3);
71 }
72 tokio::select! {
73 Some(peer_addr) = drop_rx.recv() => {
74 streams.remove(&peer_addr);
75 }
76 Ok((len, peer_addr)) = socket.recv_buf_from(&mut buf) => {
77 match streams.get_mut(&peer_addr) {
78 Some(child_tx) => {
79 if let Err(err) = child_tx.send(buf.copy_to_bytes(len)).await {
80 log::error!("child_tx.send {:?}", err);
81 child_tx.closed().await;
82 streams.remove(&peer_addr);
83 continue;
84 }
85 }
86 None => {
87 let (child_tx, child_rx) = mpsc::channel(CHANNEL_LEN);
88 if let Err(err) = child_tx.send(buf.copy_to_bytes(len)).await {
89 log::error!("child_tx.send {:?}", err);
90 continue;
91 }
92 let udp_stream = UdpStream {
93 local_addr,
94 peer_addr,
95 receiver: Arc::new(Mutex::new(child_rx)),
96 socket: socket.clone(),
97 handler: None,
98 drop: Some(drop_tx.clone()),
99 remaining: None,
100 };
101 if let Err(err) = tx.send((udp_stream, peer_addr)).await {
102 log::error!("tx.send {:?}", err);
103 continue;
104 }
105 streams.insert(peer_addr, child_tx.clone());
106 }
107 }
108 }
109 }
110 }
111 });
112 Ok(Self {
113 handler,
114 receiver: Arc::new(Mutex::new(rx)),
115 local_addr,
116 })
117 }
118
119 pub fn local_addr(&self) -> io::Result<SocketAddr> {
121 Ok(self.local_addr)
122 }
123
124 pub async fn accept(&self) -> io::Result<(UdpStream, SocketAddr)> {
126 self.receiver
127 .lock()
128 .await
129 .recv()
130 .await
131 .ok_or(io::Error::from(io::ErrorKind::BrokenPipe))
132 }
133}
134
135#[derive(Debug)]
144pub struct UdpStream {
145 local_addr: SocketAddr,
146 peer_addr: SocketAddr,
147 receiver: Arc<Mutex<mpsc::Receiver<Bytes>>>,
148 socket: Arc<tokio::net::UdpSocket>,
149 handler: Option<tokio::task::JoinHandle<()>>,
150 drop: Option<mpsc::Sender<SocketAddr>>,
151 remaining: Option<Bytes>,
152}
153
154impl Drop for UdpStream {
155 fn drop(&mut self) {
156 if let Some(handler) = &self.handler {
157 handler.abort()
158 }
159
160 if let Some(drop) = &self.drop {
161 let _ = drop.try_send(self.peer_addr);
162 };
163 }
164}
165
166impl UdpStream {
167 pub async fn connect(addr: SocketAddr) -> Result<Self, tokio::io::Error> {
174 let local_addr: SocketAddr = if addr.is_ipv4() {
175 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
176 } else {
177 SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
178 };
179 let socket = UdpSocket::bind(local_addr).await?;
180 Self::from_tokio(socket, addr).await
181 }
182 pub async fn from_tokio(
186 socket: UdpSocket,
187 peer_addr: SocketAddr,
188 ) -> Result<Self, tokio::io::Error> {
189 let socket = Arc::new(socket);
190
191 let local_addr = socket.local_addr()?;
192
193 let (child_tx, child_rx) = mpsc::channel(CHANNEL_LEN);
194
195 let socket_inner = socket.clone();
196
197 let handler = tokio::spawn(async move {
198 let mut buf = BytesMut::with_capacity(UDP_BUFFER_SIZE);
199 while let Ok((len, received_addr)) = socket_inner.clone().recv_buf_from(&mut buf).await
200 {
201 if received_addr != peer_addr {
202 continue;
203 }
204 if child_tx.send(buf.copy_to_bytes(len)).await.is_err() {
205 child_tx.closed().await;
206 break;
207 }
208
209 if buf.capacity() < UDP_BUFFER_SIZE {
210 buf.reserve(UDP_BUFFER_SIZE * 3);
211 }
212 }
213 });
214
215 Ok(UdpStream {
216 local_addr,
217 peer_addr,
218 receiver: Arc::new(Mutex::new(child_rx)),
219 socket: socket.clone(),
220 handler: Some(handler),
221 drop: None,
222 remaining: None,
223 })
224 }
225
226 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
227 Ok(self.peer_addr)
228 }
229 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
230 Ok(self.local_addr)
231 }
232 pub fn shutdown(&self) {
233 if let Some(drop) = &self.drop {
234 let _ = drop.try_send(self.peer_addr);
235 };
236 }
237}
238
239impl AsyncRead for UdpStream {
240 fn poll_read(
241 mut self: Pin<&mut Self>,
242 cx: &mut Context,
243 buf: &mut ReadBuf,
244 ) -> Poll<io::Result<()>> {
245 if let Some(remaining) = self.remaining.as_mut() {
246 if buf.remaining() < remaining.len() {
247 buf.put_slice(&remaining.split_to(buf.remaining())[..]);
248 } else {
249 buf.put_slice(&remaining[..]);
250 self.remaining = None;
251 }
252 return Poll::Ready(Ok(()));
253 }
254
255 let receiver = self.receiver.clone();
256 let mut socket = match Pin::new(&mut Box::pin(receiver.lock())).poll(cx) {
257 Poll::Ready(socket) => socket,
258 Poll::Pending => return Poll::Pending,
259 };
260
261 match socket.poll_recv(cx) {
262 Poll::Ready(Some(mut inner_buf)) => {
263 if buf.remaining() < inner_buf.len() {
264 self.remaining = Some(inner_buf.split_off(buf.remaining()));
265 };
266 buf.put_slice(&inner_buf[..]);
267 Poll::Ready(Ok(()))
268 }
269 Poll::Ready(None) => Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe))),
270 Poll::Pending => Poll::Pending,
271 }
272 }
273}
274
275impl AsyncWrite for UdpStream {
276 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
277 match self.socket.poll_send_to(cx, buf, self.peer_addr) {
278 Poll::Ready(Ok(r)) => Poll::Ready(Ok(r)),
279 Poll::Ready(Err(e)) => {
280 if let Some(drop) = &self.drop {
281 let _ = drop.try_send(self.peer_addr);
282 };
283 Poll::Ready(Err(e))
284 }
285 Poll::Pending => Poll::Pending,
286 }
287 }
288 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
289 Poll::Ready(Ok(()))
290 }
291 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
292 Poll::Ready(Ok(()))
293 }
294}