wind_core/
udp.rs

1use std::{
2	fmt::Debug,
3	future::Future,
4	io::{IoSliceMut, Result as IoResult},
5	net::{IpAddr, Ipv6Addr, SocketAddr},
6	pin::Pin,
7	sync::Arc,
8	task::{Context, Poll, ready},
9};
10
11use bytes::Bytes;
12use futures::future::poll_fn;
13#[cfg(feature = "quic")]
14pub use quinn::UdpPoller;
15pub use quinn_udp::{EcnCodepoint, RecvMeta as QuinnRecvMeta, Transmit, UdpSocketState};
16// Re-export quinn-udp's RecvMeta directly
17// pub use quinn_udp::RecvMeta;
18use tokio::io::Interest;
19
20use crate::types::TargetAddr;
21
22#[cfg(not(feature = "quic"))]
23pub trait UdpPoller: Send + Sync + Debug + 'static {
24	fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>>;
25}
26
27/// Metadata for a single buffer filled with bytes received from the network
28///
29/// This is our custom version of RecvMeta that includes destination information
30/// for better packet routing support.
31#[derive(Debug, Clone)]
32pub struct RecvMeta {
33	/// The source address of the datagram(s) contained in the buffer
34	pub addr:        SocketAddr,
35	/// The number of bytes the associated buffer has
36	pub len:         usize,
37	/// The size of a single datagram in the associated buffer
38	///
39	/// When GRO (Generic Receive Offload) is used this indicates the size of a
40	/// single datagram inside the buffer. If the buffer is larger, that is if
41	/// [`len`] is greater then this value, then the individual datagrams
42	/// contained have their boundaries at `stride` increments from the start.
43	/// The last datagram could be smaller than `stride`.
44	pub stride:      usize,
45	/// The Explicit Congestion Notification bits for the datagram(s) in the
46	/// buffer
47	pub ecn:         Option<EcnCodepoint>,
48	/// The destination IP address which was encoded in this datagram
49	///
50	/// Populated on platforms: Windows, Linux, Android (API level > 25),
51	/// FreeBSD, OpenBSD, NetBSD, macOS, and iOS.
52	pub dst_ip:      Option<IpAddr>,
53	/// The destination address that this packet is intended for
54	/// This is our custom field for better packet routing
55	pub destination: Option<TargetAddr>,
56}
57
58impl Default for RecvMeta {
59	/// Constructs a value with arbitrary fields, intended to be overwritten
60	fn default() -> Self {
61		Self {
62			addr:        SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
63			len:         0,
64			stride:      0,
65			ecn:         None,
66			dst_ip:      None,
67			destination: None,
68		}
69	}
70}
71
72impl From<QuinnRecvMeta> for RecvMeta {
73	fn from(meta: QuinnRecvMeta) -> Self {
74		Self {
75			addr:        meta.addr,
76			len:         meta.len,
77			stride:      meta.stride,
78			ecn:         meta.ecn,
79			dst_ip:      meta.dst_ip,
80			destination: None,
81		}
82	}
83}
84
85#[derive(Debug, Clone)]
86pub struct UdpPacket {
87	pub source:  Option<TargetAddr>,
88	pub target:  TargetAddr,
89	pub payload: Bytes,
90}
91
92// TODO impl quinn::AsyncUdpSocket for AbstractUdpSocket
93
94pub trait AbstractUdpSocket: Send + Sync {
95	/// Required methods
96	/// Creates a UDP socket I/O poller.
97	fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>>;
98
99	/// Tries to send a UDP datagram to the specified destination.
100	fn try_send(&self, transmit: &Transmit) -> IoResult<()>;
101
102	/// Poll to receive a UDP datagram.
103	fn poll_recv(&self, cx: &mut Context, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> Poll<IoResult<usize>>;
104
105	/// Returns the local socket address.
106	fn local_addr(&self) -> IoResult<SocketAddr>;
107
108	/// Maximum number of segments that can be transmitted in one call.
109	fn max_transmit_segments(&self) -> usize {
110		1
111	}
112
113	/// Maximum number of segments that can be received in one call.
114	fn max_receive_segments(&self) -> usize {
115		1
116	}
117
118	/// Returns whether the socket may fragment packets.
119	fn may_fragment(&self) -> bool {
120		true
121	}
122
123	/// Supplied methods
124	/// Receive a UDP datagram.
125	/// `meta` is the returned metadata for each buffer in `bufs`.
126	fn recv(&self, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> impl Future<Output = IoResult<usize>> + Send {
127		poll_fn(|cx| self.poll_recv(cx, bufs, meta))
128	}
129
130	/// Sends data on the socket to the given address.
131	fn poll_send(&self, _cx: &mut Context<'_>, buf: &[u8], target: SocketAddr) -> Poll<IoResult<usize>> {
132		let transmit = Transmit {
133			destination:  target,
134			contents:     buf,
135			ecn:          None,
136			segment_size: None,
137			src_ip:       None,
138		};
139		match self.try_send(&transmit) {
140			Ok(_) => Poll::Ready(Ok(buf.len())),
141			Err(e) => Poll::Ready(Err(e)),
142		}
143	}
144
145	/// Sends data on the socket to the given address.
146	fn send<'a>(&'a self, buf: &'a [u8], target: SocketAddr) -> impl Future<Output = IoResult<usize>> + Send + 'a {
147		poll_fn(move |cx| self.poll_send(cx, buf, target))
148	}
149}
150
151#[derive(Debug)]
152pub struct TokioUdpSocket {
153	io:    tokio::net::UdpSocket,
154	inner: UdpSocketState,
155}
156impl TokioUdpSocket {
157	pub fn new(sock: std::net::UdpSocket) -> std::io::Result<Self> {
158		Ok(Self {
159			inner: UdpSocketState::new((&sock).into())?,
160			io:    tokio::net::UdpSocket::from_std(sock)?,
161		})
162	}
163}
164impl AbstractUdpSocket for TokioUdpSocket {
165	fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>> {
166		Box::pin(UdpPollHelper::new(move || {
167			let socket = self.clone();
168			async move { socket.io.writable().await }
169		}))
170	}
171
172	fn try_send(&self, transmit: &Transmit) -> std::io::Result<()> {
173		self.io
174			.try_io(Interest::WRITABLE, || self.inner.send((&self.io).into(), transmit))
175	}
176
177	fn poll_recv(
178		&self,
179		cx: &mut Context,
180		bufs: &mut [std::io::IoSliceMut<'_>],
181		meta: &mut [RecvMeta],
182	) -> Poll<std::io::Result<usize>> {
183		loop {
184			ready!(self.io.poll_recv_ready(cx))?;
185			// First, receive into quinn's RecvMeta
186			let mut quinn_meta = vec![QuinnRecvMeta::default(); meta.len()];
187			if let Ok(res) = self.io.try_io(Interest::READABLE, || {
188				self.inner.recv((&self.io).into(), bufs, &mut quinn_meta)
189			}) {
190				// Convert quinn's RecvMeta to our RecvMeta
191				for (i, qmeta) in quinn_meta.iter().enumerate().take(res) {
192					if i < meta.len() {
193						meta[i] = RecvMeta::from(*qmeta);
194					}
195				}
196				return Poll::Ready(Ok(res));
197			}
198		}
199	}
200
201	fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
202		self.io.local_addr()
203	}
204
205	fn may_fragment(&self) -> bool {
206		self.inner.may_fragment()
207	}
208
209	fn max_transmit_segments(&self) -> usize {
210		self.inner.max_gso_segments()
211	}
212
213	fn max_receive_segments(&self) -> usize {
214		self.inner.gro_segments()
215	}
216}
217
218pin_project_lite::pin_project! {
219	pub struct UdpPollHelper<MakeFut, Fut> {
220		make_fut: MakeFut,
221		#[pin]
222		fut: Option<Fut>,
223	}
224}
225
226impl<MakeFut, Fut> UdpPollHelper<MakeFut, Fut> {
227	pub fn new(make_fut: MakeFut) -> Self {
228		Self { make_fut, fut: None }
229	}
230}
231
232impl<MakeFut, Fut> UdpPoller for UdpPollHelper<MakeFut, Fut>
233where
234	MakeFut: Fn() -> Fut + Send + Sync + 'static,
235	Fut: Future<Output = std::io::Result<()>> + Send + Sync + 'static,
236{
237	fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
238		let mut this = self.project();
239		if this.fut.is_none() {
240			this.fut.set(Some((this.make_fut)()));
241		}
242		let result = this.fut.as_mut().as_pin_mut().unwrap().poll(cx);
243		if result.is_ready() {
244			this.fut.set(None);
245		}
246		result
247	}
248}
249
250impl<MakeFut, Fut> Debug for UdpPollHelper<MakeFut, Fut> {
251	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252		f.debug_struct("UdpPollHelper").finish_non_exhaustive()
253	}
254}