1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
use tokio::net::{ToSocketAddrs, UdpSocket, udp::{SendHalf, RecvHalf}};
use std::io;
use std::sync::{Arc};
use std::net::SocketAddr;
use std::collections::HashMap;
use async_mutex::Mutex;
use async_channel::{unbounded, Sender, Receiver, TrySendError};

type Packet = Vec<u8>;

fn other<E: std::error::Error + Send + Sync + 'static>(e: E) -> io::Error {
    io::Error::new(io::ErrorKind::Other, e)
}

struct Inner {
    sender: Sender<UdpStream>,
    rx: Mutex<RecvHalf>,
    tx: Mutex<SendHalf>,
    children: Mutex<HashMap<SocketAddr, Sender<Packet>>>,
}

impl Inner {
    async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
        self.tx.lock().await.send_to(buf, target).await
    }
    async fn serve(self: Arc<Inner>) -> io::Result<()> {
        let socket = &mut self.rx.lock().await;
        loop {
            let mut buf = vec![0u8; 65536];
            let (size, addr) = socket.recv_from(&mut buf).await?;
            buf.truncate(size);

            let mut children = self.children.lock().await;
            let sender = match children.get(&addr) {
                Some(sender) => {
                    sender.clone()
                }
                None => {
                    let (tx, rx) = unbounded();
                    let stream = UdpStream {
                        receiver: rx,
                        inner: self.clone(),
                        target: addr,
                    };
                    children.insert(addr, tx.clone());
                    self.sender.try_send(stream).map_err(other)?;
                    tx
                }
            };
            match sender.try_send(buf) {
                Ok(_) => {}
                Err(TrySendError::Closed(_)) => {
                    children.remove(&addr);
                }
                _ => unreachable!()
            };
        }
    }
}

pub struct UdpStream {
    receiver: Receiver<Packet>,
    inner: Arc<Inner>,
    target: SocketAddr,
}

impl UdpStream {
    pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
        self.inner.send_to(buf, &self.target).await
    }
    pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let p = self.receiver.recv().await.map_err(other)?;
        let len = std::cmp::min(buf.len(), p.len());
        buf.copy_from_slice(&p[..len]);
        Ok(len)
    }
}

pub struct UdpListener {
    receiver: Receiver<UdpStream>,
}

impl UdpListener {
    pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpListener> {
        let (rx, tx) = UdpSocket::bind(addr).await?.split();
        let (sender, receiver) = unbounded();
        let inner = Arc::new(Inner {
            sender,
            rx: Mutex::new(rx),
            tx: Mutex::new(tx),
            children: Mutex::new(HashMap::new()),
        });
        tokio::spawn(inner.clone().serve());
        Ok(UdpListener {
            receiver,
        })
    }
    pub async fn next(&mut self) -> io::Result<UdpStream> {
        self.receiver.recv().await.map_err(other)
    }
}