1#![warn(missing_debug_implementations, rust_2018_idioms)]
2
3use async_channel::{unbounded, Receiver, Sender, TrySendError};
4use async_mutex::Mutex;
5use std::collections::HashMap;
6use std::io;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::fmt;
10use tokio::net::{udp, ToSocketAddrs, UdpSocket};
11
12type Packet = Vec<u8>;
13
14fn other<E: std::error::Error + Send + Sync + 'static>(e: E) -> io::Error {
15 io::Error::new(io::ErrorKind::Other, e)
16}
17
18struct Inner {
19 sender: Sender<UdpStream>,
20 rx: Mutex<udp::RecvHalf>,
21 tx: Mutex<udp::SendHalf>,
22 children: Mutex<HashMap<SocketAddr, Sender<Packet>>>,
23}
24
25impl Inner {
26 async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
27 self.tx.lock().await.send_to(buf, target).await
28 }
29 async fn serve(self: Arc<Inner>) -> io::Result<()> {
30 let socket = &mut self.rx.lock().await;
31 loop {
32 let mut buf = vec![0u8; 65536];
33 let (size, addr) = socket.recv_from(&mut buf).await?;
34 buf.truncate(size);
35
36 let mut children = self.children.lock().await;
37 let sender = match children.get(&addr) {
38 Some(sender) => sender.clone(),
39 None => {
40 let (tx, rx) = unbounded();
41 let stream = UdpStream::new(self.clone(), addr, rx);
42 children.insert(addr, tx.clone());
43 self.sender.try_send(stream).map_err(other)?;
44 tx
45 }
46 };
47 match sender.try_send(buf) {
48 Ok(_) => {}
49 Err(TrySendError::Closed(_)) => {
50 children.remove(&addr);
51 }
52 _ => unreachable!(),
53 };
54 }
55 }
56}
57
58pub struct SendHalf {
59 inner: Arc<Inner>,
60 target: SocketAddr,
61}
62
63impl fmt::Debug for SendHalf {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("SendHalf")
66 .field("target", &self.target)
67 .finish()
68 }
69}
70
71#[derive(Debug)]
72pub struct RecvHalf {
73 receiver: Receiver<Packet>,
74}
75
76impl SendHalf {
77 pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
78 self.inner.send_to(buf, &self.target).await
79 }
80}
81
82impl RecvHalf {
83 pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
84 let p = self.receiver.recv().await.map_err(other)?;
85 let len = std::cmp::min(buf.len(), p.len());
86 buf.copy_from_slice(&p[..len]);
87 Ok(len)
88 }
89}
90
91pub struct UdpStream {
92 tx: SendHalf,
93 rx: RecvHalf,
94}
95
96impl fmt::Debug for UdpStream {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 f.debug_struct("UdpStream")
99 .field("target", &self.tx.target)
100 .finish()
101 }
102}
103
104impl UdpStream {
105 fn new(inner: Arc<Inner>, target: SocketAddr, receiver: Receiver<Packet>) -> UdpStream {
106 UdpStream {
107 tx: SendHalf { inner, target },
108 rx: RecvHalf { receiver },
109 }
110 }
111 pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
112 self.tx.send(buf).await
113 }
114 pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
115 self.rx.recv(buf).await
116 }
117 pub fn split(self) -> (RecvHalf, SendHalf) {
118 (self.rx, self.tx)
119 }
120}
121
122
123pub struct UdpListener {
124 receiver: Receiver<UdpStream>,
125}
126
127impl fmt::Debug for UdpListener {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 f.debug_struct("UdpListener")
130 .finish()
131 }
132}
133
134impl UdpListener {
135 pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpListener> {
136 Self::from_tokio(UdpSocket::bind(addr).await?)
137 }
138 pub fn from_tokio(udp: UdpSocket) -> io::Result<UdpListener> {
139 let (rx, tx) = udp.split();
140 let (sender, receiver) = unbounded();
141 let inner = Arc::new(Inner {
142 sender,
143 rx: Mutex::new(rx),
144 tx: Mutex::new(tx),
145 children: Mutex::new(HashMap::new()),
146 });
147 tokio::spawn(inner.clone().serve());
148 Ok(UdpListener { receiver })
149 }
150 pub fn from_std(socket: std::net::UdpSocket) -> io::Result<UdpListener> {
151 Self::from_tokio(UdpSocket::from_std(socket)?)
152 }
153 pub async fn next(&mut self) -> io::Result<UdpStream> {
154 self.receiver.recv().await.map_err(other)
155 }
156}