1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
use crate::proto::{RecvMeta, SocketType, Transmit, UdpCapabilities};
use async_io::Async;
use futures_lite::future::poll_fn;
use std::io::{IoSliceMut, Result};
use std::net::SocketAddr;
use std::task::{Context, Poll};

#[derive(Debug)]
pub struct UdpSocket {
    inner: Async<std::net::UdpSocket>,
    ty: SocketType,
}

impl UdpSocket {
    pub fn capabilities() -> Result<UdpCapabilities> {
        Ok(UdpCapabilities {
            max_gso_segments: if cfg!(unix) {
                crate::unix::max_gso_segments()?
            } else {
                1
            },
        })
    }

    pub fn bind(addr: SocketAddr) -> Result<Self> {
        let socket = std::net::UdpSocket::bind(addr)?;
        let ty = if cfg!(unix) {
            crate::unix::init(&socket)?
        } else if addr.is_ipv4() {
            SocketType::Ipv4
        } else {
            SocketType::Ipv6Only
        };
        Ok(Self {
            inner: Async::new(socket)?,
            ty,
        })
    }

    pub fn socket_type(&self) -> SocketType {
        self.ty
    }

    pub fn local_addr(&self) -> Result<SocketAddr> {
        self.inner.get_ref().local_addr()
    }

    pub fn ttl(&self) -> Result<u8> {
        let ttl = self.inner.get_ref().ttl()?;
        Ok(ttl as u8)
    }

    pub fn set_ttl(&self, ttl: u8) -> Result<()> {
        self.inner.get_ref().set_ttl(ttl as u32)
    }

    pub fn poll_send(&self, cx: &mut Context, transmits: &[Transmit]) -> Poll<Result<usize>> {
        match self.inner.poll_writable(cx) {
            Poll::Ready(Ok(())) => {}
            Poll::Pending => return Poll::Pending,
            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
        }
        let socket = self.inner.get_ref();
        let res = if cfg!(unix) {
            crate::unix::send(socket, transmits)
        } else {
            fallback_send(socket, transmits)
        };
        match res {
            Ok(len) => Poll::Ready(Ok(len)),
            Err(err) => Poll::Ready(Err(err)),
        }
    }

    pub fn poll_recv(
        &self,
        cx: &mut Context,
        buffers: &mut [IoSliceMut<'_>],
        meta: &mut [RecvMeta],
    ) -> Poll<Result<usize>> {
        match self.inner.poll_readable(cx) {
            Poll::Ready(Ok(())) => {}
            Poll::Pending => return Poll::Pending,
            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
        }
        let socket = self.inner.get_ref();
        let res = if cfg!(unix) {
            crate::unix::recv(socket, buffers, meta)
        } else {
            fallback_recv(socket, buffers, meta)
        };
        Poll::Ready(res)
    }

    pub async fn send(&self, transmits: &[Transmit]) -> Result<usize> {
        let mut i = 0;
        while i < transmits.len() {
            i += poll_fn(|cx| self.poll_send(cx, &transmits[i..])).await?;
        }
        Ok(i)
    }

    pub async fn recv(
        &self,
        buffers: &mut [IoSliceMut<'_>],
        meta: &mut [RecvMeta],
    ) -> Result<usize> {
        poll_fn(|cx| self.poll_recv(cx, buffers, meta)).await
    }
}

fn fallback_send(socket: &std::net::UdpSocket, transmits: &[Transmit]) -> Result<usize> {
    let mut sent = 0;
    for transmit in transmits {
        match socket.send_to(&transmit.contents, &transmit.destination) {
            Ok(_) => {
                sent += 1;
            }
            Err(_) if sent != 0 => {
                // We need to report that some packets were sent in this case, so we rely on
                // errors being either harmlessly transient (in the case of WouldBlock) or
                // recurring on the next call.
                return Ok(sent);
            }
            Err(e) => {
                return Err(e);
            }
        }
    }
    Ok(sent)
}

fn fallback_recv(
    socket: &std::net::UdpSocket,
    buffers: &mut [IoSliceMut<'_>],
    meta: &mut [RecvMeta],
) -> Result<usize> {
    let (len, source) = socket.recv_from(&mut buffers[0])?;
    meta[0] = RecvMeta {
        source,
        len,
        ecn: None,
        dst_ip: None,
    };
    Ok(1)
}