1use bytes::BytesMut;
2use futures::Future;
3use std::collections::HashMap;
4use std::collections::VecDeque;
5use std::fmt;
6use std::fmt::Debug;
7use std::io;
8use std::io::IoSliceMut;
9use std::mem::MaybeUninit;
10use std::net::IpAddr;
11use std::net::Ipv4Addr;
12use std::net::SocketAddr;
13use std::net::ToSocketAddrs;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::Waker;
17use std::task::{Context, Poll};
18use std::time::Duration;
19use std::time::Instant;
20use tokio::sync::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
21use tokio::time::Sleep;
22use tracing::{debug, trace};
23
24use crate::constants::UDX_HEADER_SIZE;
25use crate::constants::UDX_MTU;
26use crate::mutex::Mutex;
27use crate::packet::{Dgram, Header, IncomingPacket, PacketSet};
28use crate::stream::UdxStream;
29use crate::udp::{BATCH_SIZE, RecvMeta, Transmit, UdpSocket, UdpState};
30
31const MAX_LOOP: usize = 60;
32
33const RECV_QUEUE_MAX_LEN: usize = 1024;
34
35#[derive(Debug)]
36pub(crate) enum EventIncoming {
37 Packet(IncomingPacket),
38}
39
40#[derive(Debug)]
41pub(crate) enum EventOutgoing {
42 Transmit(PacketSet),
43 TransmitDgram(Dgram),
44 StreamDropped(u32),
46}
47
48#[derive(Debug)]
49struct StreamHandle {
50 recv_tx: Sender<EventIncoming>,
51}
52
53#[derive(Clone, Debug)]
54pub struct UdxSocket(Arc<Mutex<UdxSocketInner>>);
55
56impl std::ops::Deref for UdxSocket {
57 type Target = Mutex<UdxSocketInner>;
58 fn deref(&self) -> &Self::Target {
59 &self.0
60 }
61}
62
63impl UdxSocket {
64 pub fn bind_rnd() -> io::Result<Self> {
65 Self::bind_port(0)
66 }
67 pub fn bind_port(port: u16) -> io::Result<Self> {
68 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
69 Self::bind(addr)
70 }
71 pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
77 let inner = UdxSocketInner::bind(addr)?;
78 let socket = Self(Arc::new(Mutex::new(inner)));
79 let driver = SocketDriver(socket.clone());
80 tokio::spawn(async {
81 if let Err(e) = driver.await {
82 tracing::error!("Socket I/O error: {}", e);
83 }
84 });
85 Ok(socket)
86 }
87
88 pub fn local_addr(&self) -> io::Result<SocketAddr> {
89 self.0.lock("UdxSocket::local_addr").socket.local_addr()
90 }
91
92 pub fn create_stream(&self, local_id: u32) -> io::Result<HalfOpenStreamHandle> {
93 self.0.lock("UdxSocket::make_stream").streams.insert(
94 local_id,
95 MaybeOpenStream::HalfOpen(HalfOpenStream {
96 socket: self.clone(),
97 local_id,
98 rx_messages: vec![],
99 }),
100 );
101 Ok(HalfOpenStreamHandle {
102 socket: self.clone(),
103 local_id,
104 })
105 }
106 pub fn connect(
107 &self,
108 dest: SocketAddr,
109 local_id: u32,
110 remote_id: u32,
111 ) -> io::Result<UdxStream> {
112 self.0
113 .lock("UdxSocket::connect")
114 .connect(dest, local_id, remote_id)
115 }
116
117 pub fn stats(&self) -> SocketStats {
118 self.0.lock("UdxSocket::stats").stats.clone()
119 }
120
121 pub fn send(&self, dest: SocketAddr, buf: &[u8]) {
122 let dgram = Dgram::new(dest, buf.to_vec());
123 let ev = EventOutgoing::TransmitDgram(dgram);
124 self.0.lock("UdxSocket::send").send_tx.send(ev).unwrap();
125 }
126
127 pub fn recv(&self) -> RecvFuture {
128 RecvFuture(self.clone())
129 }
130}
131
132pub struct RecvFuture(UdxSocket);
133impl Future for RecvFuture {
134 type Output = io::Result<(SocketAddr, Vec<u8>)>;
135 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136 let mut socket = self.0.lock("UdxSocket::recv");
137 if let Some(dgram) = socket.recv_dgrams.pop_front() {
138 if !socket.recv_dgrams.is_empty() {
139 cx.waker().wake_by_ref();
140 }
141 Poll::Ready(Ok((dgram.dest, dgram.buf)))
142 } else {
143 socket.recv_waker = Some(cx.waker().clone());
144 Poll::Pending
145 }
146 }
147}
148
149impl Drop for UdxSocket {
150 fn drop(&mut self) {
151 if Arc::strong_count(&self.0) == 2 {
153 let mut socket = self.0.lock("UdxSocket::drop");
154 socket.has_refs = false;
155 if let Some(waker) = socket.drive_waker.take() {
156 waker.wake();
157 }
158 }
159 }
160}
161
162pub struct SocketDriver(UdxSocket);
163
164impl Future for SocketDriver {
165 type Output = io::Result<()>;
166 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167 let mut socket = self.0.lock("UdxSocket::poll_drive");
168 let mut should_continue = false;
169 should_continue |= socket.poll_recv(cx)?;
170 if let Some(send_overflow_timer) = socket.send_overflow_timer.as_mut() {
171 match send_overflow_timer.as_mut().poll(cx) {
172 Poll::Pending => {}
173 Poll::Ready(()) => {
174 log::warn!("send overflow timer clear!");
175 socket.send_overflow_timer = None;
176 should_continue = true;
177 }
178 }
179 } else {
180 should_continue |= socket.poll_transmit(cx)?;
181 }
182
183 if !should_continue && !socket.has_refs && socket.streams.is_empty() {
184 return Poll::Ready(Ok(()));
185 }
186 if should_continue {
187 drop(socket);
188 cx.waker().wake_by_ref();
189 } else {
190 socket.drive_waker = Some(cx.waker().clone());
191 }
192 Poll::Pending
193 }
194}
195
196#[derive(Debug)]
197pub struct HalfOpenStreamHandle {
198 socket: UdxSocket,
199 local_id: u32,
200}
201
202impl HalfOpenStreamHandle {
203 pub fn connect(self, dest: SocketAddr, remote_id: u32) -> io::Result<UdxStream> {
204 let Some(MaybeOpenStream::HalfOpen(ds)) = self
205 .socket
206 .0
207 .lock("HalfOpenStreamHandle::connect get stream")
208 .streams
209 .remove(&self.local_id)
210 else {
211 todo!()
212 };
213 let (stream, handle) = ds.connect(dest, remote_id)?;
214 for event in ds.rx_messages {
215 if let Err(_packet) = handle.recv_tx.send(event) {
216 todo!()
218 }
219 }
220 self.socket
221 .0
222 .lock("HalfOpenStreamHandle:: put stream")
223 .streams
224 .insert(self.local_id, MaybeOpenStream::Open(handle));
225 Ok(stream)
226 }
227}
228
229#[derive(Debug)]
230struct HalfOpenStream {
231 socket: UdxSocket,
232 local_id: u32,
233 rx_messages: Vec<EventIncoming>,
234}
235
236impl HalfOpenStream {
237 fn connect(&self, dest: SocketAddr, remote_id: u32) -> io::Result<(UdxStream, StreamHandle)> {
238 let inner = self.socket.0.lock("DisconnectedStream::connect");
239 let (recv_tx, recv_rx) = mpsc::unbounded_channel();
240 let stream = UdxStream::connect(
241 recv_rx,
242 inner.send_tx.clone(),
243 inner.udp_state.clone(),
244 dest,
245 remote_id,
246 self.local_id,
247 );
248 let handle = StreamHandle { recv_tx };
250 Ok((stream, handle))
251 }
252}
253
254#[derive(Debug)]
255enum MaybeOpenStream {
256 HalfOpen(HalfOpenStream),
257 Open(StreamHandle),
258}
259
260pub struct UdxSocketInner {
261 socket: UdpSocket,
262 send_rx: Receiver<EventOutgoing>,
263 send_tx: Sender<EventOutgoing>,
264 streams: HashMap<u32, MaybeOpenStream>,
265 outgoing_transmits: VecDeque<Transmit>,
266 outgoing_packet_sets: VecDeque<PacketSet>,
267 recv_buf: Option<Box<[u8]>>,
268 udp_state: Arc<UdpState>,
269 stats: SocketStats,
270 has_refs: bool,
271 drive_waker: Option<Waker>,
272
273 send_overflow_timer: Option<Pin<Box<Sleep>>>,
274 recv_dgrams: VecDeque<Dgram>,
275 recv_waker: Option<Waker>,
276}
277
278impl fmt::Debug for UdxSocketInner {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 f.debug_struct("UdxSocketInner")
281 .field("socket", &self.socket)
282 .field("streams", &self.streams)
283 .field("pending_transmits", &self.outgoing_transmits.len())
284 .field("udp_state", &self.udp_state)
285 .field("stats", &self.stats)
286 .field("has_refs", &self.has_refs)
287 .field("drive_waker", &self.drive_waker)
288 .finish()
289 }
290}
291
292#[derive(Default, Clone, Debug)]
293pub struct SocketStats {
294 tx_transmits: usize,
295 tx_dgrams: usize,
296 tx_bytes: usize,
297 rx_bytes: usize,
298 rx_dgrams: usize,
299 tx_window_start: Option<Instant>,
300 tx_window_bytes: usize,
301 tx_window_dgrams: usize,
302}
303
304impl SocketStats {
305 fn track_tx(&mut self, transmit: &Transmit) {
306 self.tx_bytes += transmit.contents.len();
307 self.tx_transmits += 1;
308 self.tx_dgrams += transmit.num_segments();
309 self.tx_window_bytes += transmit.contents.len();
310 self.tx_window_dgrams += transmit.num_segments();
311 if self.tx_window_start.is_none() {
312 self.tx_window_start = Some(Instant::now());
313 }
314 if self.tx_window_start.as_ref().unwrap().elapsed() > Duration::from_millis(1000) {
315 let elapsed = self
316 .tx_window_start
317 .as_ref()
318 .unwrap()
319 .elapsed()
320 .as_secs_f32();
321 trace!(
322 "{} MB/s {} pps",
323 self.tx_window_bytes as f32 / (1024. * 1024.) / elapsed,
324 self.tx_window_dgrams as f32 / elapsed
325 );
326 self.tx_window_bytes = 0;
327 self.tx_window_dgrams = 0;
328 self.tx_window_start = Some(Instant::now());
329 }
330 }
331}
332
333impl UdxSocketInner {
334 pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
335 let socket = std::net::UdpSocket::bind(addr)?;
336 let socket = UdpSocket::from_std(socket)?;
337 let (send_tx, send_rx) = mpsc::unbounded_channel();
338 let recv_buf = vec![0; UDX_MTU * BATCH_SIZE];
339 Ok(Self {
340 socket,
341 send_rx,
342 send_tx,
343 streams: HashMap::new(),
344 recv_buf: Some(recv_buf.into()),
345 udp_state: Arc::new(UdpState::new()),
346 outgoing_transmits: VecDeque::with_capacity(BATCH_SIZE),
347 outgoing_packet_sets: VecDeque::with_capacity(BATCH_SIZE),
348 stats: SocketStats::default(),
349 has_refs: true,
350 drive_waker: None,
351 send_overflow_timer: None,
352 recv_waker: None,
353 recv_dgrams: VecDeque::new(),
354 })
355 }
356
357 pub fn local_addr(&self) -> io::Result<SocketAddr> {
358 self.socket.local_addr()
359 }
360
361 pub fn connect(
362 &mut self,
363 dest: SocketAddr,
364 local_id: u32,
365 remote_id: u32,
366 ) -> io::Result<UdxStream> {
367 debug!(
368 "UdxSocketInner::connect {} [{}] -> {} [{}])",
369 self.local_addr().unwrap(),
370 local_id,
371 dest,
372 remote_id
373 );
374 let (recv_tx, recv_rx) = mpsc::unbounded_channel();
375 let stream = UdxStream::connect(
376 recv_rx,
377 self.send_tx.clone(),
378 self.udp_state.clone(),
379 dest,
380 remote_id,
381 local_id,
382 );
383 let handle = StreamHandle { recv_tx };
384 self.streams.insert(local_id, MaybeOpenStream::Open(handle));
385 Ok(stream)
386 }
387
388 fn poll_transmit(&mut self, cx: &mut Context<'_>) -> io::Result<bool> {
389 let mut iters = 0;
390 loop {
391 iters += 1;
392 let mut send_rx_pending = false;
393 while self.outgoing_transmits.len() < BATCH_SIZE {
394 match Pin::new(&mut self.send_rx).poll_recv(cx) {
395 Poll::Pending => {
396 send_rx_pending = true;
397 break;
398 }
399 Poll::Ready(None) => unreachable!(),
400 Poll::Ready(Some(event)) => match event {
401 EventOutgoing::StreamDropped(local_id) => {
402 let _ = self.streams.remove(&local_id);
403 }
404 EventOutgoing::TransmitDgram(dgram) => {
405 self.outgoing_transmits.push_back(dgram.into_transmit());
406 }
407 EventOutgoing::Transmit(packet_set) => {
408 let transmit = packet_set.to_transmit();
409 trace!("send {:?}", packet_set);
410 self.outgoing_transmits.push_back(transmit);
411 self.outgoing_packet_sets.push_back(packet_set);
412 }
413 },
414 }
415 }
416 if self.outgoing_transmits.is_empty() {
417 break Ok(false);
418 }
419
420 match self
421 .socket
422 .poll_send(&self.udp_state, cx, self.outgoing_transmits.as_slices().0)
423 {
424 Poll::Pending => break Ok(false),
425 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::Interrupted => {
426 self.send_overflow_timer =
428 Some(Box::pin(tokio::time::sleep(Duration::from_millis(20))));
429 log::warn!("send overflow timer set!");
430 break Ok(false);
431 }
432 Poll::Ready(Err(err)) => break Err(err),
433 Poll::Ready(Ok(n)) => {
434 for transmit in self.outgoing_transmits.drain(..n) {
435 self.stats.track_tx(&transmit);
436 }
437 let n = n.min(self.outgoing_packet_sets.len());
439 for packet_set in self.outgoing_packet_sets.drain(..n) {
440 for packet in packet_set.iter_shared() {
441 packet.time_sent.set_now();
442 }
443 }
444 }
445 }
446 if send_rx_pending {
447 break Ok(false);
448 }
449 if iters > 0 {
450 break Ok(true);
451 }
452 }
453 }
454
455 fn poll_recv(&mut self, cx: &mut Context<'_>) -> io::Result<bool> {
456 let mut metas = [RecvMeta::default(); BATCH_SIZE];
457 let mut recv_buf = self.recv_buf.take().unwrap();
458 let mut iovs = unsafe { iovectors_from_buf::<BATCH_SIZE>(&mut recv_buf) };
459
460 let mut iters = 0;
462 let res = loop {
463 iters += 1;
464 if iters == MAX_LOOP {
465 break Ok(true);
466 }
467 match self.socket.poll_recv(cx, &mut iovs, &mut metas) {
468 Poll::Ready(Ok(msgs)) => {
469 for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
470 let data: BytesMut = buf[0..meta.len].into();
471 if let Err(data) = self.process_packet(data, meta) {
472 self.on_recv_dgram(data, meta);
475 }
476 }
477 }
478 Poll::Pending => break Ok(false),
479 Poll::Ready(Err(e)) => break Err(e),
480 }
481 };
482 self.recv_buf = Some(recv_buf);
483 res
484 }
485
486 fn on_recv_dgram(&mut self, data: BytesMut, meta: &RecvMeta) {
487 if self.recv_dgrams.len() < RECV_QUEUE_MAX_LEN {
488 self.recv_dgrams
489 .push_back(Dgram::new(meta.addr, data.to_vec()));
490 if let Some(waker) = self.recv_waker.take() {
491 waker.wake()
492 }
493 } else {
494 drop(data)
495 }
496 }
497
498 fn process_packet(&mut self, mut data: BytesMut, meta: &RecvMeta) -> Result<(), BytesMut> {
499 let local_addr = self.local_addr().unwrap();
500 let len = data.len();
501 self.stats.rx_bytes += len;
502 self.stats.rx_dgrams += 1;
503
504 let header = match Header::from_bytes(&data) {
506 Ok(header) => header,
507 Err(_) => return Err(data),
508 };
509 let stream_id = header.stream_id;
510 trace!(
511 to = stream_id,
512 "[{}] recv from :{} typ {} seq {} ack {} len {}",
513 local_addr.port(),
514 meta.addr.port(),
515 header.typ,
516 header.seq,
517 header.ack,
518 len
519 );
520 match self.streams.get_mut(&stream_id) {
521 Some(stream) => {
522 let _ = data.split_to(UDX_HEADER_SIZE);
523 let incoming = IncomingPacket {
524 header,
525 buf: data.into(),
526 read_offset: 0,
527 from: meta.addr,
528 };
529 let event = EventIncoming::Packet(incoming);
530 match stream {
531 MaybeOpenStream::Open(handle) => {
532 if let Err(_p) = handle.recv_tx.send(event) {
533 self.streams.remove(&stream_id);
536 }
537 }
538 MaybeOpenStream::HalfOpen(ds) => {
539 ds.rx_messages.push(event);
540 }
541 }
542 }
543 None => {
544 return Err(data);
546 }
547 }
548 Ok(())
549 }
550}
551
552pub enum SocketEvent {
553 UnknownStream,
554}
555
556unsafe fn iovectors_from_buf<const N: usize>(buf: &mut [u8]) -> [IoSliceMut<'_>; N] {
560 let mut iovs = MaybeUninit::<[IoSliceMut; N]>::uninit();
561 buf.chunks_mut(buf.len() / N)
562 .enumerate()
563 .for_each(|(i, buf)| {
564 unsafe {
565 iovs.as_mut_ptr()
566 .cast::<IoSliceMut>()
567 .add(i)
568 .write(IoSliceMut::new(buf))
569 };
570 });
571 unsafe { iovs.assume_init() }
572}