1use std::collections::VecDeque;
29use std::io;
30use std::net::SocketAddr;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicBool, Ordering};
33use std::task::{Context, Poll, Waker};
34
35use bytes::Bytes;
36use parking_lot::Mutex;
37use tokio::net::UdpSocket;
38
39pub const DEFAULT_INBOUND_CAPACITY: usize = 256;
43
44pub struct VirtualUdpSocket {
56 physical: Arc<UdpSocket>,
57 inbound: Mutex<Inbound>,
58 closed: AtomicBool,
59}
60
61struct Inbound {
62 queue: VecDeque<(SocketAddr, Bytes)>,
63 waker: Option<Waker>,
64 capacity: usize,
65}
66
67impl std::fmt::Debug for VirtualUdpSocket {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("VirtualUdpSocket")
70 .field("closed", &self.closed.load(Ordering::Relaxed))
71 .finish_non_exhaustive()
72 }
73}
74
75impl VirtualUdpSocket {
76 #[must_use]
79 pub fn new(physical: Arc<UdpSocket>) -> Arc<Self> {
80 Self::new_with_capacity(physical, DEFAULT_INBOUND_CAPACITY)
81 }
82
83 #[must_use]
87 pub fn new_with_capacity(physical: Arc<UdpSocket>, capacity: usize) -> Arc<Self> {
88 Arc::new(Self {
89 physical,
90 inbound: Mutex::new(Inbound { queue: VecDeque::new(), waker: None, capacity }),
91 closed: AtomicBool::new(false),
92 })
93 }
94
95 pub fn local_addr(&self) -> io::Result<SocketAddr> {
101 self.physical.local_addr()
102 }
103
104 pub fn enqueue_inbound(&self, peer: SocketAddr, datagram: Bytes) {
111 if self.closed.load(Ordering::Relaxed) {
112 tracing::warn!(?peer, "virtual udp socket closed; dropping inbound datagram");
113 return;
114 }
115 let mut inbound = self.inbound.lock();
116 if inbound.queue.len() >= inbound.capacity {
117 tracing::warn!(?peer, "virtual udp socket inbound queue full; dropping datagram");
118 return;
119 }
120 inbound.queue.push_back((peer, datagram));
121 if let Some(w) = inbound.waker.take() {
122 w.wake();
123 }
124 }
125
126 pub fn try_dequeue(&self) -> Option<(SocketAddr, Bytes)> {
129 self.inbound.lock().queue.pop_front()
130 }
131
132 pub fn poll_dequeue(&self, cx: &mut Context<'_>) -> Poll<Option<(SocketAddr, Bytes)>> {
143 let mut inbound = self.inbound.lock();
144 if let Some(item) = inbound.queue.pop_front() {
145 return Poll::Ready(Some(item));
146 }
147 if self.closed.load(Ordering::Relaxed) {
148 return Poll::Ready(None);
149 }
150 inbound.waker = Some(cx.waker().clone());
151 Poll::Pending
152 }
153
154 pub fn try_send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
163 self.physical.try_send_to(buf, target)
164 }
165
166 pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173 self.physical.poll_send_ready(cx)
174 }
175
176 #[must_use]
180 pub fn physical(&self) -> &Arc<UdpSocket> {
181 &self.physical
182 }
183
184 pub fn close(&self) {
190 let already = self.closed.swap(true, Ordering::Relaxed);
191 if !already {
192 if let Some(w) = self.inbound.lock().waker.take() {
195 w.wake();
196 }
197 }
198 }
199
200 #[must_use]
202 pub fn is_closed(&self) -> bool {
203 self.closed.load(Ordering::Relaxed)
204 }
205
206 #[must_use]
208 pub fn inbound_len(&self) -> usize {
209 self.inbound.lock().queue.len()
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use std::future::poll_fn;
216 use std::net::Ipv4Addr;
217
218 use bytes::Bytes;
219 use tokio::net::UdpSocket;
220
221 use super::*;
222
223 async fn bound() -> (Arc<UdpSocket>, SocketAddr) {
224 let s = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await.expect("bind");
225 let a = s.local_addr().expect("local_addr");
226 (Arc::new(s), a)
227 }
228
229 #[tokio::test]
230 async fn try_dequeue_returns_none_when_empty() {
231 let (phys, _) = bound().await;
232 let v = VirtualUdpSocket::new(phys);
233 assert!(v.try_dequeue().is_none());
234 }
235
236 #[tokio::test]
237 async fn enqueue_then_dequeue_roundtrip() {
238 let (phys, _) = bound().await;
239 let v = VirtualUdpSocket::new(phys);
240 let peer: SocketAddr = "192.0.2.1:443".parse().unwrap();
241 v.enqueue_inbound(peer, Bytes::from_static(b"hello"));
242 v.enqueue_inbound(peer, Bytes::from_static(b"world"));
243 let (p1, d1) = v.try_dequeue().unwrap();
244 assert_eq!(p1, peer);
245 assert_eq!(&*d1, b"hello");
246 let (_, d2) = v.try_dequeue().unwrap();
247 assert_eq!(&*d2, b"world");
248 assert!(v.try_dequeue().is_none());
249 }
250
251 #[tokio::test]
252 async fn poll_dequeue_pending_then_woken_on_enqueue() {
253 let (phys, _) = bound().await;
254 let v = VirtualUdpSocket::new(phys);
255 let peer: SocketAddr = "192.0.2.2:443".parse().unwrap();
256 let v_for_task = Arc::clone(&v);
257 let waker_task = tokio::spawn(async move { poll_fn(|cx| v_for_task.poll_dequeue(cx)).await });
258 tokio::task::yield_now().await;
260 v.enqueue_inbound(peer, Bytes::from_static(b"X"));
261 let (got_peer, got_data) = waker_task.await.unwrap().expect("dequeue ok");
262 assert_eq!(got_peer, peer);
263 assert_eq!(&*got_data, b"X");
264 }
265
266 #[tokio::test]
267 async fn full_queue_drops_overflow() {
268 let (phys, _) = bound().await;
269 let v = VirtualUdpSocket::new_with_capacity(phys, 2);
270 let peer: SocketAddr = "192.0.2.3:443".parse().unwrap();
271 v.enqueue_inbound(peer, Bytes::from_static(&[1]));
272 v.enqueue_inbound(peer, Bytes::from_static(&[2]));
273 v.enqueue_inbound(peer, Bytes::from_static(&[3]));
275 assert_eq!(v.inbound_len(), 2);
276 assert_eq!(&*v.try_dequeue().unwrap().1, &[1]);
277 assert_eq!(&*v.try_dequeue().unwrap().1, &[2]);
278 assert!(v.try_dequeue().is_none());
279 }
280
281 #[tokio::test]
282 async fn close_drops_subsequent_enqueues_and_yields_none_after_drain() {
283 let (phys, _) = bound().await;
284 let v = VirtualUdpSocket::new(phys);
285 let peer: SocketAddr = "192.0.2.4:443".parse().unwrap();
286 v.enqueue_inbound(peer, Bytes::from_static(b"A"));
287 v.close();
288 assert_eq!(&*v.try_dequeue().unwrap().1, b"A");
290 v.enqueue_inbound(peer, Bytes::from_static(b"B"));
292 assert!(v.try_dequeue().is_none());
293 let r = poll_fn(|cx| v.poll_dequeue(cx)).await;
295 assert!(r.is_none());
296 }
297
298 #[tokio::test]
299 async fn try_send_to_proxies_physical() {
300 let (phys_a, addr_a) = bound().await;
301 let (phys_b, addr_b) = bound().await;
302 let v = VirtualUdpSocket::new(phys_a);
303 poll_fn(|cx| v.poll_send_ready(cx)).await.expect("send_ready");
306 let n = v.try_send_to(b"PING", addr_b).expect("send");
307 assert_eq!(n, 4);
308 let mut buf = [0u8; 16];
309 let (got, from) = phys_b.recv_from(&mut buf).await.expect("recv");
310 assert_eq!(&buf[..got], b"PING");
311 assert_eq!(from, addr_a);
312 }
313
314 #[tokio::test]
315 async fn local_addr_matches_physical() {
316 let (phys, addr) = bound().await;
317 let v = VirtualUdpSocket::new(phys);
318 assert_eq!(v.local_addr().unwrap(), addr);
319 }
320}