udp_listener/
lib.rs

1#![warn(missing_debug_implementations, rust_2018_idioms)]
2
3use async_channel::{unbounded, Receiver, Sender, TrySendError};
4use async_mutex::Mutex;
5use std::collections::HashMap;
6use std::io;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::fmt;
10use tokio::net::{udp, ToSocketAddrs, UdpSocket};
11
12type Packet = Vec<u8>;
13
14fn other<E: std::error::Error + Send + Sync + 'static>(e: E) -> io::Error {
15    io::Error::new(io::ErrorKind::Other, e)
16}
17
18struct Inner {
19    sender: Sender<UdpStream>,
20    rx: Mutex<udp::RecvHalf>,
21    tx: Mutex<udp::SendHalf>,
22    children: Mutex<HashMap<SocketAddr, Sender<Packet>>>,
23}
24
25impl Inner {
26    async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
27        self.tx.lock().await.send_to(buf, target).await
28    }
29    async fn serve(self: Arc<Inner>) -> io::Result<()> {
30        let socket = &mut self.rx.lock().await;
31        loop {
32            let mut buf = vec![0u8; 65536];
33            let (size, addr) = socket.recv_from(&mut buf).await?;
34            buf.truncate(size);
35
36            let mut children = self.children.lock().await;
37            let sender = match children.get(&addr) {
38                Some(sender) => sender.clone(),
39                None => {
40                    let (tx, rx) = unbounded();
41                    let stream = UdpStream::new(self.clone(), addr, rx);
42                    children.insert(addr, tx.clone());
43                    self.sender.try_send(stream).map_err(other)?;
44                    tx
45                }
46            };
47            match sender.try_send(buf) {
48                Ok(_) => {}
49                Err(TrySendError::Closed(_)) => {
50                    children.remove(&addr);
51                }
52                _ => unreachable!(),
53            };
54        }
55    }
56}
57
58pub struct SendHalf {
59    inner: Arc<Inner>,
60    target: SocketAddr,
61}
62
63impl fmt::Debug for SendHalf {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        f.debug_struct("SendHalf")
66            .field("target", &self.target)
67            .finish()
68    }
69}
70
71#[derive(Debug)]
72pub struct RecvHalf {
73    receiver: Receiver<Packet>,
74}
75
76impl SendHalf {
77    pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
78        self.inner.send_to(buf, &self.target).await
79    }
80}
81
82impl RecvHalf {
83    pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
84        let p = self.receiver.recv().await.map_err(other)?;
85        let len = std::cmp::min(buf.len(), p.len());
86        buf.copy_from_slice(&p[..len]);
87        Ok(len)
88    }
89}
90
91pub struct UdpStream {
92    tx: SendHalf,
93    rx: RecvHalf,
94}
95
96impl fmt::Debug for UdpStream {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.debug_struct("UdpStream")
99            .field("target", &self.tx.target)
100            .finish()
101    }
102}
103
104impl UdpStream {
105    fn new(inner: Arc<Inner>, target: SocketAddr, receiver: Receiver<Packet>) -> UdpStream {
106        UdpStream {
107            tx: SendHalf { inner, target },
108            rx: RecvHalf { receiver },
109        }
110    }
111    pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
112        self.tx.send(buf).await
113    }
114    pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
115        self.rx.recv(buf).await
116    }
117    pub fn split(self) -> (RecvHalf, SendHalf) {
118        (self.rx, self.tx)
119    }
120}
121
122
123pub struct UdpListener {
124    receiver: Receiver<UdpStream>,
125}
126
127impl fmt::Debug for UdpListener {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        f.debug_struct("UdpListener")
130            .finish()
131    }
132}
133
134impl UdpListener {
135    pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpListener> {
136        Self::from_tokio(UdpSocket::bind(addr).await?)
137    }
138    pub fn from_tokio(udp: UdpSocket) -> io::Result<UdpListener> {
139        let (rx, tx) = udp.split();
140        let (sender, receiver) = unbounded();
141        let inner = Arc::new(Inner {
142            sender,
143            rx: Mutex::new(rx),
144            tx: Mutex::new(tx),
145            children: Mutex::new(HashMap::new()),
146        });
147        tokio::spawn(inner.clone().serve());
148        Ok(UdpListener { receiver })
149    }
150    pub fn from_std(socket: std::net::UdpSocket) -> io::Result<UdpListener> {
151        Self::from_tokio(UdpSocket::from_std(socket)?)
152    }
153    pub async fn next(&mut self) -> io::Result<UdpStream> {
154        self.receiver.recv().await.map_err(other)
155    }
156}