1use std::{
2 fmt::Debug,
3 future::Future,
4 io::{IoSliceMut, Result as IoResult},
5 net::{IpAddr, Ipv6Addr, SocketAddr},
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll, ready},
9};
10
11use bytes::Bytes;
12use futures::future::poll_fn;
13#[cfg(feature = "quic")]
14pub use quinn::UdpPoller;
15pub use quinn_udp::{EcnCodepoint, RecvMeta as QuinnRecvMeta, Transmit, UdpSocketState};
16use tokio::io::Interest;
19
20use crate::types::TargetAddr;
21
22#[cfg(not(feature = "quic"))]
23pub trait UdpPoller: Send + Sync + Debug + 'static {
24 fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>>;
25}
26
27#[derive(Debug, Clone)]
32pub struct RecvMeta {
33 pub addr: SocketAddr,
35 pub len: usize,
37 pub stride: usize,
45 pub ecn: Option<EcnCodepoint>,
48 pub dst_ip: Option<IpAddr>,
53 pub destination: Option<TargetAddr>,
56}
57
58impl Default for RecvMeta {
59 fn default() -> Self {
61 Self {
62 addr: SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
63 len: 0,
64 stride: 0,
65 ecn: None,
66 dst_ip: None,
67 destination: None,
68 }
69 }
70}
71
72impl From<QuinnRecvMeta> for RecvMeta {
73 fn from(meta: QuinnRecvMeta) -> Self {
74 Self {
75 addr: meta.addr,
76 len: meta.len,
77 stride: meta.stride,
78 ecn: meta.ecn,
79 dst_ip: meta.dst_ip,
80 destination: None,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
86pub struct UdpPacket {
87 pub source: Option<TargetAddr>,
88 pub target: TargetAddr,
89 pub payload: Bytes,
90}
91
92pub trait AbstractUdpSocket: Send + Sync {
95 fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>>;
98
99 fn try_send(&self, transmit: &Transmit) -> IoResult<()>;
101
102 fn poll_recv(&self, cx: &mut Context, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> Poll<IoResult<usize>>;
104
105 fn local_addr(&self) -> IoResult<SocketAddr>;
107
108 fn max_transmit_segments(&self) -> usize {
110 1
111 }
112
113 fn max_receive_segments(&self) -> usize {
115 1
116 }
117
118 fn may_fragment(&self) -> bool {
120 true
121 }
122
123 fn recv(&self, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> impl Future<Output = IoResult<usize>> + Send {
127 poll_fn(|cx| self.poll_recv(cx, bufs, meta))
128 }
129
130 fn poll_send(&self, _cx: &mut Context<'_>, buf: &[u8], target: SocketAddr) -> Poll<IoResult<usize>> {
132 let transmit = Transmit {
133 destination: target,
134 contents: buf,
135 ecn: None,
136 segment_size: None,
137 src_ip: None,
138 };
139 match self.try_send(&transmit) {
140 Ok(_) => Poll::Ready(Ok(buf.len())),
141 Err(e) => Poll::Ready(Err(e)),
142 }
143 }
144
145 fn send<'a>(&'a self, buf: &'a [u8], target: SocketAddr) -> impl Future<Output = IoResult<usize>> + Send + 'a {
147 poll_fn(move |cx| self.poll_send(cx, buf, target))
148 }
149}
150
151#[derive(Debug)]
152pub struct TokioUdpSocket {
153 io: tokio::net::UdpSocket,
154 inner: UdpSocketState,
155}
156impl TokioUdpSocket {
157 pub fn new(sock: std::net::UdpSocket) -> std::io::Result<Self> {
158 Ok(Self {
159 inner: UdpSocketState::new((&sock).into())?,
160 io: tokio::net::UdpSocket::from_std(sock)?,
161 })
162 }
163}
164impl AbstractUdpSocket for TokioUdpSocket {
165 fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>> {
166 Box::pin(UdpPollHelper::new(move || {
167 let socket = self.clone();
168 async move { socket.io.writable().await }
169 }))
170 }
171
172 fn try_send(&self, transmit: &Transmit) -> std::io::Result<()> {
173 self.io
174 .try_io(Interest::WRITABLE, || self.inner.send((&self.io).into(), transmit))
175 }
176
177 fn poll_recv(
178 &self,
179 cx: &mut Context,
180 bufs: &mut [std::io::IoSliceMut<'_>],
181 meta: &mut [RecvMeta],
182 ) -> Poll<std::io::Result<usize>> {
183 loop {
184 ready!(self.io.poll_recv_ready(cx))?;
185 let mut quinn_meta = vec![QuinnRecvMeta::default(); meta.len()];
187 if let Ok(res) = self.io.try_io(Interest::READABLE, || {
188 self.inner.recv((&self.io).into(), bufs, &mut quinn_meta)
189 }) {
190 for (i, qmeta) in quinn_meta.iter().enumerate().take(res) {
192 if i < meta.len() {
193 meta[i] = RecvMeta::from(*qmeta);
194 }
195 }
196 return Poll::Ready(Ok(res));
197 }
198 }
199 }
200
201 fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
202 self.io.local_addr()
203 }
204
205 fn may_fragment(&self) -> bool {
206 self.inner.may_fragment()
207 }
208
209 fn max_transmit_segments(&self) -> usize {
210 self.inner.max_gso_segments()
211 }
212
213 fn max_receive_segments(&self) -> usize {
214 self.inner.gro_segments()
215 }
216}
217
218pin_project_lite::pin_project! {
219 pub struct UdpPollHelper<MakeFut, Fut> {
220 make_fut: MakeFut,
221 #[pin]
222 fut: Option<Fut>,
223 }
224}
225
226impl<MakeFut, Fut> UdpPollHelper<MakeFut, Fut> {
227 pub fn new(make_fut: MakeFut) -> Self {
228 Self { make_fut, fut: None }
229 }
230}
231
232impl<MakeFut, Fut> UdpPoller for UdpPollHelper<MakeFut, Fut>
233where
234 MakeFut: Fn() -> Fut + Send + Sync + 'static,
235 Fut: Future<Output = std::io::Result<()>> + Send + Sync + 'static,
236{
237 fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
238 let mut this = self.project();
239 if this.fut.is_none() {
240 this.fut.set(Some((this.make_fut)()));
241 }
242 let result = this.fut.as_mut().as_pin_mut().unwrap().poll(cx);
243 if result.is_ready() {
244 this.fut.set(None);
245 }
246 result
247 }
248}
249
250impl<MakeFut, Fut> Debug for UdpPollHelper<MakeFut, Fut> {
251 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 f.debug_struct("UdpPollHelper").finish_non_exhaustive()
253 }
254}