1use crate::proto::{RecvMeta, SocketType, Transmit, UdpCapabilities};
2use async_io::Async;
3use futures_lite::future::poll_fn;
4use std::io::{IoSliceMut, Result};
5use std::net::SocketAddr;
6use std::task::{Context, Poll};
7
8#[cfg(unix)]
9use crate::unix as platform;
10#[cfg(not(unix))]
11use fallback as platform;
12
13#[derive(Debug)]
14pub struct UdpSocket {
15 inner: Async<std::net::UdpSocket>,
16 ty: SocketType,
17}
18
19impl UdpSocket {
20 pub fn capabilities() -> Result<UdpCapabilities> {
21 Ok(UdpCapabilities {
22 max_gso_segments: platform::max_gso_segments()?,
23 })
24 }
25
26 pub fn bind(addr: SocketAddr) -> Result<Self> {
27 let socket = std::net::UdpSocket::bind(addr)?;
28 let ty = platform::init(&socket)?;
29 Ok(Self {
30 inner: Async::new(socket)?,
31 ty,
32 })
33 }
34
35 pub fn socket_type(&self) -> SocketType {
36 self.ty
37 }
38
39 pub fn local_addr(&self) -> Result<SocketAddr> {
40 self.inner.get_ref().local_addr()
41 }
42
43 pub fn ttl(&self) -> Result<u8> {
44 let ttl = self.inner.get_ref().ttl()?;
45 Ok(ttl as u8)
46 }
47
48 pub fn set_ttl(&self, ttl: u8) -> Result<()> {
49 self.inner.get_ref().set_ttl(ttl as u32)
50 }
51
52 pub fn poll_send(&self, cx: &mut Context, transmits: &[Transmit]) -> Poll<Result<usize>> {
53 match self.inner.poll_writable(cx) {
54 Poll::Ready(Ok(())) => {}
55 Poll::Pending => return Poll::Pending,
56 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
57 }
58 let socket = self.inner.get_ref();
59 match platform::send(socket, transmits) {
60 Ok(len) => Poll::Ready(Ok(len)),
61 Err(err) => Poll::Ready(Err(err)),
62 }
63 }
64
65 pub fn poll_recv(
66 &self,
67 cx: &mut Context,
68 buffers: &mut [IoSliceMut<'_>],
69 meta: &mut [RecvMeta],
70 ) -> Poll<Result<usize>> {
71 match self.inner.poll_readable(cx) {
72 Poll::Ready(Ok(())) => {}
73 Poll::Pending => return Poll::Pending,
74 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
75 }
76 let socket = self.inner.get_ref();
77 Poll::Ready(platform::recv(socket, buffers, meta))
78 }
79
80 pub async fn send(&self, transmits: &[Transmit]) -> Result<usize> {
81 let mut i = 0;
82 while i < transmits.len() {
83 i += poll_fn(|cx| self.poll_send(cx, &transmits[i..])).await?;
84 }
85 Ok(i)
86 }
87
88 pub async fn recv(
89 &self,
90 buffers: &mut [IoSliceMut<'_>],
91 meta: &mut [RecvMeta],
92 ) -> Result<usize> {
93 poll_fn(|cx| self.poll_recv(cx, buffers, meta)).await
94 }
95}
96
97#[cfg(not(unix))]
98mod fallback {
99 use super::*;
100
101 pub fn max_gso_segments() -> Result<usize> {
102 Ok(1)
103 }
104
105 pub fn init(socket: &std::net::UdpSocket) -> Result<SocketType> {
106 Ok(if socket.local_addr()?.is_ipv4() {
107 SocketType::Ipv4
108 } else {
109 SocketType::Ipv6Only
110 })
111 }
112
113 pub fn send(socket: &std::net::UdpSocket, transmits: &[Transmit]) -> Result<usize> {
114 let mut sent = 0;
115 for transmit in transmits {
116 match socket.send_to(&transmit.contents, &transmit.destination) {
117 Ok(_) => {
118 sent += 1;
119 }
120 Err(_) if sent != 0 => {
121 return Ok(sent);
125 }
126 Err(e) => {
127 return Err(e);
128 }
129 }
130 }
131 Ok(sent)
132 }
133
134 pub fn recv(
135 socket: &std::net::UdpSocket,
136 buffers: &mut [IoSliceMut<'_>],
137 meta: &mut [RecvMeta],
138 ) -> Result<usize> {
139 let (len, source) = socket.recv_from(&mut buffers[0])?;
140 meta[0] = RecvMeta {
141 source,
142 len,
143 ecn: None,
144 dst_ip: None,
145 };
146 Ok(1)
147 }
148}