udp_socket/
socket.rs

1use crate::proto::{RecvMeta, SocketType, Transmit, UdpCapabilities};
2use async_io::Async;
3use futures_lite::future::poll_fn;
4use std::io::{IoSliceMut, Result};
5use std::net::SocketAddr;
6use std::task::{Context, Poll};
7
8#[cfg(unix)]
9use crate::unix as platform;
10#[cfg(not(unix))]
11use fallback as platform;
12
13#[derive(Debug)]
14pub struct UdpSocket {
15    inner: Async<std::net::UdpSocket>,
16    ty: SocketType,
17}
18
19impl UdpSocket {
20    pub fn capabilities() -> Result<UdpCapabilities> {
21        Ok(UdpCapabilities {
22            max_gso_segments: platform::max_gso_segments()?,
23        })
24    }
25
26    pub fn bind(addr: SocketAddr) -> Result<Self> {
27        let socket = std::net::UdpSocket::bind(addr)?;
28        let ty = platform::init(&socket)?;
29        Ok(Self {
30            inner: Async::new(socket)?,
31            ty,
32        })
33    }
34
35    pub fn socket_type(&self) -> SocketType {
36        self.ty
37    }
38
39    pub fn local_addr(&self) -> Result<SocketAddr> {
40        self.inner.get_ref().local_addr()
41    }
42
43    pub fn ttl(&self) -> Result<u8> {
44        let ttl = self.inner.get_ref().ttl()?;
45        Ok(ttl as u8)
46    }
47
48    pub fn set_ttl(&self, ttl: u8) -> Result<()> {
49        self.inner.get_ref().set_ttl(ttl as u32)
50    }
51
52    pub fn poll_send(&self, cx: &mut Context, transmits: &[Transmit]) -> Poll<Result<usize>> {
53        match self.inner.poll_writable(cx) {
54            Poll::Ready(Ok(())) => {}
55            Poll::Pending => return Poll::Pending,
56            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
57        }
58        let socket = self.inner.get_ref();
59        match platform::send(socket, transmits) {
60            Ok(len) => Poll::Ready(Ok(len)),
61            Err(err) => Poll::Ready(Err(err)),
62        }
63    }
64
65    pub fn poll_recv(
66        &self,
67        cx: &mut Context,
68        buffers: &mut [IoSliceMut<'_>],
69        meta: &mut [RecvMeta],
70    ) -> Poll<Result<usize>> {
71        match self.inner.poll_readable(cx) {
72            Poll::Ready(Ok(())) => {}
73            Poll::Pending => return Poll::Pending,
74            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
75        }
76        let socket = self.inner.get_ref();
77        Poll::Ready(platform::recv(socket, buffers, meta))
78    }
79
80    pub async fn send(&self, transmits: &[Transmit]) -> Result<usize> {
81        let mut i = 0;
82        while i < transmits.len() {
83            i += poll_fn(|cx| self.poll_send(cx, &transmits[i..])).await?;
84        }
85        Ok(i)
86    }
87
88    pub async fn recv(
89        &self,
90        buffers: &mut [IoSliceMut<'_>],
91        meta: &mut [RecvMeta],
92    ) -> Result<usize> {
93        poll_fn(|cx| self.poll_recv(cx, buffers, meta)).await
94    }
95}
96
97#[cfg(not(unix))]
98mod fallback {
99    use super::*;
100
101    pub fn max_gso_segments() -> Result<usize> {
102        Ok(1)
103    }
104
105    pub fn init(socket: &std::net::UdpSocket) -> Result<SocketType> {
106        Ok(if socket.local_addr()?.is_ipv4() {
107            SocketType::Ipv4
108        } else {
109            SocketType::Ipv6Only
110        })
111    }
112
113    pub fn send(socket: &std::net::UdpSocket, transmits: &[Transmit]) -> Result<usize> {
114        let mut sent = 0;
115        for transmit in transmits {
116            match socket.send_to(&transmit.contents, &transmit.destination) {
117                Ok(_) => {
118                    sent += 1;
119                }
120                Err(_) if sent != 0 => {
121                    // We need to report that some packets were sent in this case, so we rely on
122                    // errors being either harmlessly transient (in the case of WouldBlock) or
123                    // recurring on the next call.
124                    return Ok(sent);
125                }
126                Err(e) => {
127                    return Err(e);
128                }
129            }
130        }
131        Ok(sent)
132    }
133
134    pub fn recv(
135        socket: &std::net::UdpSocket,
136        buffers: &mut [IoSliceMut<'_>],
137        meta: &mut [RecvMeta],
138    ) -> Result<usize> {
139        let (len, source) = socket.recv_from(&mut buffers[0])?;
140        meta[0] = RecvMeta {
141            source,
142            len,
143            ecn: None,
144            dst_ip: None,
145        };
146        Ok(1)
147    }
148}