utp_socket/
utp_socket.rs

1use std::{
2    cell::RefCell,
3    collections::{hash_map::Entry, HashMap},
4    net::SocketAddr,
5    rc::Rc,
6};
7
8use anyhow::Context;
9use bytes::Bytes;
10use tokio_uring::net::UdpSocket;
11
12use crate::{
13    utp_packet::{Packet, PacketHeader, PacketType, HEADER_SIZE},
14    utp_stream::{UtpStream, WeakUtpStream},
15};
16
17// Conceptually there is a single socket that handles multiple connections
18// The socket context keeps a hashmap of all connections keyed by the socketaddr
19// the network loop listens on data and then finds the relevant connection based on addr
20// and also double checks the connection ids
21pub struct UtpSocket {
22    socket: Rc<UdpSocket>,
23    shutdown_signal: Option<tokio::sync::oneshot::Sender<()>>,
24    accept_chan: Rc<RefCell<Option<tokio::sync::oneshot::Sender<UtpStream>>>>,
25    streams: Rc<RefCell<HashMap<StreamKey, WeakUtpStream>>>,
26}
27
28#[derive(PartialEq, Eq, Debug, Copy, Clone, Hash)]
29struct StreamKey {
30    conn_id: u16,
31    addr: SocketAddr,
32}
33
34// TODO better error handling
35// This is more similar to TcpListener
36impl UtpSocket {
37    pub async fn bind(bind_addr: SocketAddr) -> anyhow::Result<Self> {
38        let socket = Rc::new(UdpSocket::bind(bind_addr).await?);
39        let net_loop_socket = socket.clone();
40
41        let (shutdown_signal, mut shutdown_receiver) = tokio::sync::oneshot::channel();
42        //connections: HashMap<SocketKey, Rc<RefCell<UtpStream>>>,
43        let utp_socket = UtpSocket {
44            socket,
45            shutdown_signal: Some(shutdown_signal),
46            accept_chan: Default::default(),
47            streams: Default::default(),
48        };
49
50        let streams_clone = utp_socket.streams.clone();
51        let accept_chan = utp_socket.accept_chan.clone();
52        // Net loop
53        tokio_uring::spawn(async move {
54            // TODO check how this relates to windows size and opt_rcvbuf
55            let mut recv_buf = vec![0; 1024 * 1024];
56            loop {
57                // Double check if this is cancellation safe
58                // (I don't think it is but shouldn't matter anyways)
59                tokio::select! {
60                    buf = process_incomming(&net_loop_socket, &streams_clone, &accept_chan, std::mem::take(&mut recv_buf)) => {
61                           recv_buf = buf;
62                        }
63                    _ = &mut shutdown_receiver =>  {
64                        log::info!("Shutting down network loop");
65                        // TODO shutdown all streams gracefully
66                        break;
67                    }
68                }
69            }
70        });
71
72        Ok(utp_socket)
73    }
74
75    pub async fn connect(&self, addr: SocketAddr) -> anyhow::Result<UtpStream> {
76        let mut stream_key = StreamKey {
77            conn_id: rand::random(),
78            addr,
79        };
80
81        while self.streams.borrow().contains_key(&stream_key) {
82            log::debug!("Stream with same conn_id and addr already exists, regenerating conn_id");
83            stream_key = StreamKey {
84                conn_id: rand::random::<u16>(),
85                addr,
86            }
87        }
88
89        let stream = UtpStream::new(stream_key.conn_id, addr, Rc::downgrade(&self.socket));
90        self.streams
91            .borrow_mut()
92            .insert(stream_key, stream.clone().into());
93
94        stream.connect().await?;
95
96        Ok(stream)
97    }
98
99    pub async fn accept(&self) -> anyhow::Result<UtpStream> {
100        let (tx, rc) = tokio::sync::oneshot::channel();
101        {
102            let mut chan = self.accept_chan.borrow_mut();
103            *chan = Some(tx);
104        }
105        rc.await.context("Net loop exited")
106    }
107}
108
109impl Drop for UtpSocket {
110    fn drop(&mut self) {
111        println!("dropping");
112        self.shutdown_signal.take().unwrap().send(()).unwrap();
113    }
114}
115
116// Socket reads and parses out a vec of packets per read
117// the packets are then sent to the streams incoming circular packet buffer
118// in the stream specific packet handler they check if it's the expected packet or out of order
119// if it's out of order they then just insert it into the buffer
120// if it's in order they handle it together with all other potential packets that are orderd
121// after it and stored in the incoming buffer
122//
123// outbuffer is written to by the stream and handled in a separate task i think
124// the packets can get stored in that task if they need to be resent?
125// the incoming task could keep a channel of acks received that can be removed from resend buf
126async fn process_incomming(
127    socket: &Rc<UdpSocket>,
128    connections: &Rc<RefCell<HashMap<StreamKey, WeakUtpStream>>>,
129    accept_chan: &Rc<RefCell<Option<tokio::sync::oneshot::Sender<UtpStream>>>>,
130    recv_buf: Vec<u8>,
131) -> Vec<u8> {
132    let (result, buf) = socket.recv_from(recv_buf).await;
133    match result {
134        Ok((recv, addr)) => {
135            log::info!("Received {recv} from {addr}");
136            match PacketHeader::try_from(&buf[..recv]) {
137                Ok(packet_header) => {
138                    let key = StreamKey {
139                        conn_id: packet_header.conn_id,
140                        addr,
141                    };
142
143                    let packet = Packet {
144                        header: packet_header,
145                        data: Bytes::copy_from_slice(&buf[HEADER_SIZE as usize..recv]),
146                    };
147
148                    let maybe_stream = { connections.borrow_mut().remove(&key) };
149                    if let Some(weak_stream) = maybe_stream {
150                        if let Some(stream) = weak_stream.try_upgrade() {
151                            match stream.process_incoming(packet).await {
152                                Ok(()) => {
153                                    connections.borrow_mut().insert(key, stream.into());
154                                }
155                                Err(err) => {
156                                    log::error!("Error: Failed processing incoming packet: {err}");
157                                }
158                            }
159                        }
160                    } else if packet_header.packet_type == PacketType::Syn {
161                        let maybe_chan = { accept_chan.borrow_mut().take() };
162                        if let Some(chan) = maybe_chan {
163                            let stream = UtpStream::new_incoming(
164                                packet_header.seq_nr,
165                                packet_header.conn_id,
166                                addr,
167                                Rc::downgrade(socket),
168                            );
169                            let stream_key = StreamKey {
170                                // Special case for initial stream setup
171                                // Same as stream recv conn id
172                                conn_id: packet_header.conn_id + 1,
173                                addr,
174                            };
175                            // Ensure the connection doesn't conflict with an already
176                            // existing connection before accepting it.
177                            {
178                                let mut connections = connections.borrow_mut();
179                                let entry = connections.entry(stream_key);
180                                match entry {
181                                    Entry::Occupied(_) => {
182                                        log::warn!("Connection with id: {} already exists. Dropping connection",
183                                            packet_header.conn_id + 1
184                                        );
185                                        return buf;
186                                    }
187                                    Entry::Vacant(entry) => {
188                                        log::info!("New incoming connection!");
189                                        entry.insert(stream.clone().into());
190                                    }
191                                }
192                            }
193
194                            // Remove the connection if the inital syn couldn't be processed
195                            if let Err(err) = stream.process_incoming(packet).await {
196                                log::error!("Error accepting connection: {err}");
197                                // If the packet couldn't be processed
198                                // we the accept chan is reset and the stream is removed from
199                                // the connection map
200                                *accept_chan.borrow_mut() = Some(chan);
201                                connections.borrow_mut().remove(&stream_key);
202                            } else {
203                                chan.send(stream).unwrap();
204                            }
205                        }
206                    } else {
207                        log::warn!("Connection not established prior");
208                    }
209                }
210                Err(err) => log::error!("Error parsing packet: {err}"),
211            }
212        }
213        Err(err) => log::error!("Failed to receive on utp socket: {err}"),
214    }
215    buf
216}
217
218// Process outgoing