udx/
socket.rs

1use bytes::BytesMut;
2use futures::Future;
3use std::collections::HashMap;
4use std::collections::VecDeque;
5use std::fmt;
6use std::fmt::Debug;
7use std::io;
8use std::io::IoSliceMut;
9use std::mem::MaybeUninit;
10use std::net::IpAddr;
11use std::net::Ipv4Addr;
12use std::net::SocketAddr;
13use std::net::ToSocketAddrs;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::Waker;
17use std::task::{Context, Poll};
18use std::time::Duration;
19use std::time::Instant;
20use tokio::sync::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
21use tokio::time::Sleep;
22use tracing::{debug, trace};
23
24use crate::constants::UDX_HEADER_SIZE;
25use crate::constants::UDX_MTU;
26use crate::mutex::Mutex;
27use crate::packet::{Dgram, Header, IncomingPacket, PacketSet};
28use crate::stream::UdxStream;
29use crate::udp::{BATCH_SIZE, RecvMeta, Transmit, UdpSocket, UdpState};
30
31const MAX_LOOP: usize = 60;
32
33const RECV_QUEUE_MAX_LEN: usize = 1024;
34
35#[derive(Debug)]
36pub(crate) enum EventIncoming {
37    Packet(IncomingPacket),
38}
39
40#[derive(Debug)]
41pub(crate) enum EventOutgoing {
42    Transmit(PacketSet),
43    TransmitDgram(Dgram),
44    // TransmitOne(PacketRef),
45    StreamDropped(u32),
46}
47
48#[derive(Debug)]
49struct StreamHandle {
50    recv_tx: Sender<EventIncoming>,
51}
52
53#[derive(Clone, Debug)]
54pub struct UdxSocket(Arc<Mutex<UdxSocketInner>>);
55
56impl std::ops::Deref for UdxSocket {
57    type Target = Mutex<UdxSocketInner>;
58    fn deref(&self) -> &Self::Target {
59        &self.0
60    }
61}
62
63impl UdxSocket {
64    pub fn bind_rnd() -> io::Result<Self> {
65        Self::bind_port(0)
66    }
67    pub fn bind_port(port: u16) -> io::Result<Self> {
68        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
69        Self::bind(addr)
70    }
71    // TODO FIXME this is not async but requires tokio running. Which will cause a runtime failure.
72    // rm this depndence
73    /// Create a socket on the given `addr`. Note `addr` is a *local* address normally it would
74    /// look like `127.0.0.1:8080` which creates a socket on port `8080`. To connect to any random
75    /// port pass `:0` as the port.
76    pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
77        let inner = UdxSocketInner::bind(addr)?;
78        let socket = Self(Arc::new(Mutex::new(inner)));
79        let driver = SocketDriver(socket.clone());
80        tokio::spawn(async {
81            if let Err(e) = driver.await {
82                tracing::error!("Socket I/O error: {}", e);
83            }
84        });
85        Ok(socket)
86    }
87
88    pub fn local_addr(&self) -> io::Result<SocketAddr> {
89        self.0.lock("UdxSocket::local_addr").socket.local_addr()
90    }
91
92    pub fn create_stream(&self, local_id: u32) -> io::Result<HalfOpenStreamHandle> {
93        self.0.lock("UdxSocket::make_stream").streams.insert(
94            local_id,
95            MaybeOpenStream::HalfOpen(HalfOpenStream {
96                socket: self.clone(),
97                local_id,
98                rx_messages: vec![],
99            }),
100        );
101        Ok(HalfOpenStreamHandle {
102            socket: self.clone(),
103            local_id,
104        })
105    }
106    pub fn connect(
107        &self,
108        dest: SocketAddr,
109        local_id: u32,
110        remote_id: u32,
111    ) -> io::Result<UdxStream> {
112        self.0
113            .lock("UdxSocket::connect")
114            .connect(dest, local_id, remote_id)
115    }
116
117    pub fn stats(&self) -> SocketStats {
118        self.0.lock("UdxSocket::stats").stats.clone()
119    }
120
121    pub fn send(&self, dest: SocketAddr, buf: &[u8]) {
122        let dgram = Dgram::new(dest, buf.to_vec());
123        let ev = EventOutgoing::TransmitDgram(dgram);
124        self.0.lock("UdxSocket::send").send_tx.send(ev).unwrap();
125    }
126
127    pub fn recv(&self) -> RecvFuture {
128        RecvFuture(self.clone())
129    }
130}
131
132pub struct RecvFuture(UdxSocket);
133impl Future for RecvFuture {
134    type Output = io::Result<(SocketAddr, Vec<u8>)>;
135    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136        let mut socket = self.0.lock("UdxSocket::recv");
137        if let Some(dgram) = socket.recv_dgrams.pop_front() {
138            if !socket.recv_dgrams.is_empty() {
139                cx.waker().wake_by_ref();
140            }
141            Poll::Ready(Ok((dgram.dest, dgram.buf)))
142        } else {
143            socket.recv_waker = Some(cx.waker().clone());
144            Poll::Pending
145        }
146    }
147}
148
149impl Drop for UdxSocket {
150    fn drop(&mut self) {
151        // Only the driver is left, shutdown.
152        if Arc::strong_count(&self.0) == 2 {
153            let mut socket = self.0.lock("UdxSocket::drop");
154            socket.has_refs = false;
155            if let Some(waker) = socket.drive_waker.take() {
156                waker.wake();
157            }
158        }
159    }
160}
161
162pub struct SocketDriver(UdxSocket);
163
164impl Future for SocketDriver {
165    type Output = io::Result<()>;
166    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167        let mut socket = self.0.lock("UdxSocket::poll_drive");
168        let mut should_continue = false;
169        should_continue |= socket.poll_recv(cx)?;
170        if let Some(send_overflow_timer) = socket.send_overflow_timer.as_mut() {
171            match send_overflow_timer.as_mut().poll(cx) {
172                Poll::Pending => {}
173                Poll::Ready(()) => {
174                    log::warn!("send overflow timer clear!");
175                    socket.send_overflow_timer = None;
176                    should_continue = true;
177                }
178            }
179        } else {
180            should_continue |= socket.poll_transmit(cx)?;
181        }
182
183        if !should_continue && !socket.has_refs && socket.streams.is_empty() {
184            return Poll::Ready(Ok(()));
185        }
186        if should_continue {
187            drop(socket);
188            cx.waker().wake_by_ref();
189        } else {
190            socket.drive_waker = Some(cx.waker().clone());
191        }
192        Poll::Pending
193    }
194}
195
196#[derive(Debug)]
197pub struct HalfOpenStreamHandle {
198    socket: UdxSocket,
199    local_id: u32,
200}
201
202impl HalfOpenStreamHandle {
203    pub fn connect(self, dest: SocketAddr, remote_id: u32) -> io::Result<UdxStream> {
204        let Some(MaybeOpenStream::HalfOpen(ds)) = self
205            .socket
206            .0
207            .lock("HalfOpenStreamHandle::connect get stream")
208            .streams
209            .remove(&self.local_id)
210        else {
211            todo!()
212        };
213        let (stream, handle) = ds.connect(dest, remote_id)?;
214        for event in ds.rx_messages {
215            if let Err(_packet) = handle.recv_tx.send(event) {
216                // stream dropped?
217                todo!()
218            }
219        }
220        self.socket
221            .0
222            .lock("HalfOpenStreamHandle:: put stream")
223            .streams
224            .insert(self.local_id, MaybeOpenStream::Open(handle));
225        Ok(stream)
226    }
227}
228
229#[derive(Debug)]
230struct HalfOpenStream {
231    socket: UdxSocket,
232    local_id: u32,
233    rx_messages: Vec<EventIncoming>,
234}
235
236impl HalfOpenStream {
237    fn connect(&self, dest: SocketAddr, remote_id: u32) -> io::Result<(UdxStream, StreamHandle)> {
238        let inner = self.socket.0.lock("DisconnectedStream::connect");
239        let (recv_tx, recv_rx) = mpsc::unbounded_channel();
240        let stream = UdxStream::connect(
241            recv_rx,
242            inner.send_tx.clone(),
243            inner.udp_state.clone(),
244            dest,
245            remote_id,
246            self.local_id,
247        );
248        // replay messages
249        let handle = StreamHandle { recv_tx };
250        Ok((stream, handle))
251    }
252}
253
254#[derive(Debug)]
255enum MaybeOpenStream {
256    HalfOpen(HalfOpenStream),
257    Open(StreamHandle),
258}
259
260pub struct UdxSocketInner {
261    socket: UdpSocket,
262    send_rx: Receiver<EventOutgoing>,
263    send_tx: Sender<EventOutgoing>,
264    streams: HashMap<u32, MaybeOpenStream>,
265    outgoing_transmits: VecDeque<Transmit>,
266    outgoing_packet_sets: VecDeque<PacketSet>,
267    recv_buf: Option<Box<[u8]>>,
268    udp_state: Arc<UdpState>,
269    stats: SocketStats,
270    has_refs: bool,
271    drive_waker: Option<Waker>,
272
273    send_overflow_timer: Option<Pin<Box<Sleep>>>,
274    recv_dgrams: VecDeque<Dgram>,
275    recv_waker: Option<Waker>,
276}
277
278impl fmt::Debug for UdxSocketInner {
279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280        f.debug_struct("UdxSocketInner")
281            .field("socket", &self.socket)
282            .field("streams", &self.streams)
283            .field("pending_transmits", &self.outgoing_transmits.len())
284            .field("udp_state", &self.udp_state)
285            .field("stats", &self.stats)
286            .field("has_refs", &self.has_refs)
287            .field("drive_waker", &self.drive_waker)
288            .finish()
289    }
290}
291
292#[derive(Default, Clone, Debug)]
293pub struct SocketStats {
294    tx_transmits: usize,
295    tx_dgrams: usize,
296    tx_bytes: usize,
297    rx_bytes: usize,
298    rx_dgrams: usize,
299    tx_window_start: Option<Instant>,
300    tx_window_bytes: usize,
301    tx_window_dgrams: usize,
302}
303
304impl SocketStats {
305    fn track_tx(&mut self, transmit: &Transmit) {
306        self.tx_bytes += transmit.contents.len();
307        self.tx_transmits += 1;
308        self.tx_dgrams += transmit.num_segments();
309        self.tx_window_bytes += transmit.contents.len();
310        self.tx_window_dgrams += transmit.num_segments();
311        if self.tx_window_start.is_none() {
312            self.tx_window_start = Some(Instant::now());
313        }
314        if self.tx_window_start.as_ref().unwrap().elapsed() > Duration::from_millis(1000) {
315            let elapsed = self
316                .tx_window_start
317                .as_ref()
318                .unwrap()
319                .elapsed()
320                .as_secs_f32();
321            trace!(
322                "{} MB/s {} pps",
323                self.tx_window_bytes as f32 / (1024. * 1024.) / elapsed,
324                self.tx_window_dgrams as f32 / elapsed
325            );
326            self.tx_window_bytes = 0;
327            self.tx_window_dgrams = 0;
328            self.tx_window_start = Some(Instant::now());
329        }
330    }
331}
332
333impl UdxSocketInner {
334    pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
335        let socket = std::net::UdpSocket::bind(addr)?;
336        let socket = UdpSocket::from_std(socket)?;
337        let (send_tx, send_rx) = mpsc::unbounded_channel();
338        let recv_buf = vec![0; UDX_MTU * BATCH_SIZE];
339        Ok(Self {
340            socket,
341            send_rx,
342            send_tx,
343            streams: HashMap::new(),
344            recv_buf: Some(recv_buf.into()),
345            udp_state: Arc::new(UdpState::new()),
346            outgoing_transmits: VecDeque::with_capacity(BATCH_SIZE),
347            outgoing_packet_sets: VecDeque::with_capacity(BATCH_SIZE),
348            stats: SocketStats::default(),
349            has_refs: true,
350            drive_waker: None,
351            send_overflow_timer: None,
352            recv_waker: None,
353            recv_dgrams: VecDeque::new(),
354        })
355    }
356
357    pub fn local_addr(&self) -> io::Result<SocketAddr> {
358        self.socket.local_addr()
359    }
360
361    pub fn connect(
362        &mut self,
363        dest: SocketAddr,
364        local_id: u32,
365        remote_id: u32,
366    ) -> io::Result<UdxStream> {
367        debug!(
368            "UdxSocketInner::connect {} [{}] -> {} [{}])",
369            self.local_addr().unwrap(),
370            local_id,
371            dest,
372            remote_id
373        );
374        let (recv_tx, recv_rx) = mpsc::unbounded_channel();
375        let stream = UdxStream::connect(
376            recv_rx,
377            self.send_tx.clone(),
378            self.udp_state.clone(),
379            dest,
380            remote_id,
381            local_id,
382        );
383        let handle = StreamHandle { recv_tx };
384        self.streams.insert(local_id, MaybeOpenStream::Open(handle));
385        Ok(stream)
386    }
387
388    fn poll_transmit(&mut self, cx: &mut Context<'_>) -> io::Result<bool> {
389        let mut iters = 0;
390        loop {
391            iters += 1;
392            let mut send_rx_pending = false;
393            while self.outgoing_transmits.len() < BATCH_SIZE {
394                match Pin::new(&mut self.send_rx).poll_recv(cx) {
395                    Poll::Pending => {
396                        send_rx_pending = true;
397                        break;
398                    }
399                    Poll::Ready(None) => unreachable!(),
400                    Poll::Ready(Some(event)) => match event {
401                        EventOutgoing::StreamDropped(local_id) => {
402                            let _ = self.streams.remove(&local_id);
403                        }
404                        EventOutgoing::TransmitDgram(dgram) => {
405                            self.outgoing_transmits.push_back(dgram.into_transmit());
406                        }
407                        EventOutgoing::Transmit(packet_set) => {
408                            let transmit = packet_set.to_transmit();
409                            trace!("send {:?}", packet_set);
410                            self.outgoing_transmits.push_back(transmit);
411                            self.outgoing_packet_sets.push_back(packet_set);
412                        }
413                    },
414                }
415            }
416            if self.outgoing_transmits.is_empty() {
417                break Ok(false);
418            }
419
420            match self
421                .socket
422                .poll_send(&self.udp_state, cx, self.outgoing_transmits.as_slices().0)
423            {
424                Poll::Pending => break Ok(false),
425                Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::Interrupted => {
426                    // Send overflow! Scale back write rate.
427                    self.send_overflow_timer =
428                        Some(Box::pin(tokio::time::sleep(Duration::from_millis(20))));
429                    log::warn!("send overflow timer set!");
430                    break Ok(false);
431                }
432                Poll::Ready(Err(err)) => break Err(err),
433                Poll::Ready(Ok(n)) => {
434                    for transmit in self.outgoing_transmits.drain(..n) {
435                        self.stats.track_tx(&transmit);
436                    }
437                    // update packet sent time for data packets.
438                    let n = n.min(self.outgoing_packet_sets.len());
439                    for packet_set in self.outgoing_packet_sets.drain(..n) {
440                        for packet in packet_set.iter_shared() {
441                            packet.time_sent.set_now();
442                        }
443                    }
444                }
445            }
446            if send_rx_pending {
447                break Ok(false);
448            }
449            if iters > 0 {
450                break Ok(true);
451            }
452        }
453    }
454
455    fn poll_recv(&mut self, cx: &mut Context<'_>) -> io::Result<bool> {
456        let mut metas = [RecvMeta::default(); BATCH_SIZE];
457        let mut recv_buf = self.recv_buf.take().unwrap();
458        let mut iovs = unsafe { iovectors_from_buf::<BATCH_SIZE>(&mut recv_buf) };
459
460        // process recv
461        let mut iters = 0;
462        let res = loop {
463            iters += 1;
464            if iters == MAX_LOOP {
465                break Ok(true);
466            }
467            match self.socket.poll_recv(cx, &mut iovs, &mut metas) {
468                Poll::Ready(Ok(msgs)) => {
469                    for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
470                        let data: BytesMut = buf[0..meta.len].into();
471                        if let Err(data) = self.process_packet(data, meta) {
472                            // received invalid header. emit as message on socket.
473                            // TODO: Remove the queue and invoke a poll handler directly?
474                            self.on_recv_dgram(data, meta);
475                        }
476                    }
477                }
478                Poll::Pending => break Ok(false),
479                Poll::Ready(Err(e)) => break Err(e),
480            }
481        };
482        self.recv_buf = Some(recv_buf);
483        res
484    }
485
486    fn on_recv_dgram(&mut self, data: BytesMut, meta: &RecvMeta) {
487        if self.recv_dgrams.len() < RECV_QUEUE_MAX_LEN {
488            self.recv_dgrams
489                .push_back(Dgram::new(meta.addr, data.to_vec()));
490            if let Some(waker) = self.recv_waker.take() {
491                waker.wake()
492            }
493        } else {
494            drop(data)
495        }
496    }
497
498    fn process_packet(&mut self, mut data: BytesMut, meta: &RecvMeta) -> Result<(), BytesMut> {
499        let local_addr = self.local_addr().unwrap();
500        let len = data.len();
501        self.stats.rx_bytes += len;
502        self.stats.rx_dgrams += 1;
503
504        // try to decode the udx header
505        let header = match Header::from_bytes(&data) {
506            Ok(header) => header,
507            Err(_) => return Err(data),
508        };
509        let stream_id = header.stream_id;
510        trace!(
511            to = stream_id,
512            "[{}] recv from :{} typ {} seq {} ack {} len {}",
513            local_addr.port(),
514            meta.addr.port(),
515            header.typ,
516            header.seq,
517            header.ack,
518            len
519        );
520        match self.streams.get_mut(&stream_id) {
521            Some(stream) => {
522                let _ = data.split_to(UDX_HEADER_SIZE);
523                let incoming = IncomingPacket {
524                    header,
525                    buf: data.into(),
526                    read_offset: 0,
527                    from: meta.addr,
528                };
529                let event = EventIncoming::Packet(incoming);
530                match stream {
531                    MaybeOpenStream::Open(handle) => {
532                        if let Err(_p) = handle.recv_tx.send(event) {
533                            // stream was dropped.
534                            // remove stream?
535                            self.streams.remove(&stream_id);
536                        }
537                    }
538                    MaybeOpenStream::HalfOpen(ds) => {
539                        ds.rx_messages.push(event);
540                    }
541                }
542            }
543            None => {
544                // received packet for nonexisting stream.
545                return Err(data);
546            }
547        }
548        Ok(())
549    }
550}
551
552pub enum SocketEvent {
553    UnknownStream,
554}
555
556// Create an array of IO vectors from a buffer.
557// Safety: buf has to be longer than N. You may only read from slices that have been written to.
558// Taken from: quinn/src/endpoint.rs
559unsafe fn iovectors_from_buf<const N: usize>(buf: &mut [u8]) -> [IoSliceMut<'_>; N] {
560    let mut iovs = MaybeUninit::<[IoSliceMut; N]>::uninit();
561    buf.chunks_mut(buf.len() / N)
562        .enumerate()
563        .for_each(|(i, buf)| {
564            unsafe {
565                iovs.as_mut_ptr()
566                    .cast::<IoSliceMut>()
567                    .add(i)
568                    .write(IoSliceMut::new(buf))
569            };
570        });
571    unsafe { iovs.assume_init() }
572}