utp_rs/
socket.rs

1use std::collections::HashMap;
2use std::io;
3use std::net::SocketAddr;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use delay_map::HashMapDelay;
8use futures::StreamExt;
9use rand::{thread_rng, Rng};
10use tokio::net::UdpSocket;
11use tokio::sync::mpsc::UnboundedSender;
12use tokio::sync::{mpsc, oneshot};
13
14use crate::cid::ConnectionId;
15use crate::conn::ConnectionConfig;
16use crate::event::{SocketEvent, StreamEvent};
17use crate::packet::{Packet, PacketBuilder, PacketType};
18use crate::peer::{ConnectionPeer, Peer};
19use crate::stream::UtpStream;
20use crate::udp::AsyncUdpSocket;
21
22type ConnChannel = UnboundedSender<StreamEvent>;
23
24struct Accept<P: ConnectionPeer> {
25    stream: oneshot::Sender<io::Result<UtpStream<P>>>,
26    config: ConnectionConfig,
27}
28
29struct AcceptWithCidPeer<P: ConnectionPeer> {
30    cid: ConnectionId<P::Id>,
31    peer: Peer<P>,
32    accept: Accept<P>,
33}
34
35const MAX_UDP_PAYLOAD_SIZE: usize = u16::MAX as usize;
36const CID_GENERATION_TRY_WARNING_COUNT: usize = 10;
37
38/// accept_with_cid() has unique interactions compared to accept()
39/// accept() pulls awaiting requests off a queue, but accept_with_cid() only
40/// takes a connection off if CID matches. Because of this if we are awaiting a CID
41/// eventually we need to timeout the await, or the queue would never stop growing with stale awaits
42/// 20 seconds is arbitrary, after the uTP config refactor is done that can replace this constant.
43/// but thee uTP config refactor is currently very low priority.
44const AWAITING_CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
45
46pub struct UtpSocket<P: ConnectionPeer> {
47    conns: Arc<RwLock<HashMap<ConnectionId<P::Id>, ConnChannel>>>,
48    accepts: UnboundedSender<Accept<P>>,
49    accepts_with_cid: UnboundedSender<AcceptWithCidPeer<P>>,
50    socket_events: UnboundedSender<SocketEvent<P>>,
51}
52
53impl UtpSocket<SocketAddr> {
54    pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
55        let socket = UdpSocket::bind(addr).await?;
56        let socket = Self::with_socket(socket);
57        Ok(socket)
58    }
59}
60
61impl<P> UtpSocket<P>
62where
63    P: ConnectionPeer<Id: Unpin> + Unpin + 'static,
64{
65    pub fn with_socket<S>(mut socket: S) -> Self
66    where
67        S: AsyncUdpSocket<P> + 'static,
68    {
69        let conns = HashMap::new();
70        let conns = Arc::new(RwLock::new(conns));
71
72        let mut awaiting: HashMapDelay<ConnectionId<P::Id>, AcceptWithCidPeer<P>> =
73            HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT);
74
75        let mut incoming_conns: HashMapDelay<ConnectionId<P::Id>, (Peer<P>, Packet)> =
76            HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT);
77
78        let (socket_event_tx, mut socket_event_rx) = mpsc::unbounded_channel();
79        let (accepts_tx, mut accepts_rx) = mpsc::unbounded_channel();
80        let (accepts_with_cid_tx, mut accepts_with_cid_rx) = mpsc::unbounded_channel();
81
82        let utp = Self {
83            conns: Arc::clone(&conns),
84            accepts: accepts_tx,
85            accepts_with_cid: accepts_with_cid_tx,
86            socket_events: socket_event_tx.clone(),
87        };
88
89        tokio::spawn(async move {
90            let mut buf = [0; MAX_UDP_PAYLOAD_SIZE];
91            loop {
92                tokio::select! {
93                    biased;
94                    Ok((n, mut peer)) = socket.recv_from(&mut buf) => {
95                        let peer_id = peer.id();
96                        let packet = match Packet::decode(&buf[..n]) {
97                            Ok(pkt) => pkt,
98                            Err(..) => {
99                                tracing::warn!(?peer, "unable to decode uTP packet");
100                                continue;
101                            }
102                        };
103
104                        let peer_init_cid = cid_from_packet::<P>(&packet, peer_id, IdType::SendIdPeerInitiated);
105                        let we_init_cid = cid_from_packet::<P>(&packet, peer_id, IdType::SendIdWeInitiated);
106                        let acc_cid = cid_from_packet::<P>(&packet, peer_id, IdType::RecvId);
107                        let mut conns = conns.write().unwrap();
108                        let conn = conns
109                            .get(&acc_cid)
110                            .or_else(|| conns.get(&we_init_cid))
111                            .or_else(|| conns.get(&peer_init_cid));
112                        match conn {
113                            Some(conn) => {
114                                let _ = conn.send(StreamEvent::Incoming(packet));
115                            }
116                            None => {
117                                if std::matches!(packet.packet_type(), PacketType::Syn) {
118                                    let cid = acc_cid;
119
120                                    // If there was an awaiting connection with the CID, then
121                                    // create a new stream for that connection. Otherwise, add the
122                                    // connection to the incoming connections.
123                                    if let Some(accept_with_cid) = awaiting.remove(&cid) {
124                                        peer.consolidate(accept_with_cid.peer);
125
126                                        let (connected_tx, connected_rx) = oneshot::channel();
127                                        let (events_tx, events_rx) = mpsc::unbounded_channel();
128
129                                        conns.insert(cid.clone(), events_tx);
130
131                                        let stream = UtpStream::new(
132                                            cid,
133                                            peer,
134                                            accept_with_cid.accept.config,
135                                            Some(packet),
136                                            socket_event_tx.clone(),
137                                            events_rx,
138                                            connected_tx
139                                        );
140
141                                        tokio::spawn(async move {
142                                            Self::await_connected(stream, accept_with_cid.accept.stream, connected_rx).await
143                                        });
144                                    } else {
145                                        incoming_conns.insert(cid, (peer, packet));
146                                    }
147                                } else {
148                                    tracing::debug!(
149                                        cid = %packet.conn_id(),
150                                        packet = ?packet.packet_type(),
151                                        seq = %packet.seq_num(),
152                                        ack = %packet.ack_num(),
153                                        peer_init_cid = ?peer_init_cid,
154                                        we_init_cid = ?we_init_cid,
155                                        acc_cid = ?acc_cid,
156                                        "received uTP packet for non-existing conn"
157                                    );
158                                    // don't send a reset if we are receiving a reset
159                                    if packet.packet_type() != PacketType::Reset {
160                                        // if we get a packet from an unknown source send a reset packet.
161                                        let random_seq_num = thread_rng().gen_range(0..=65535);
162                                        let reset_packet =
163                                            PacketBuilder::new(PacketType::Reset, packet.conn_id(), crate::time::now_micros(), 100_000, random_seq_num)
164                                                .build();
165                                        let event = SocketEvent::Outgoing((reset_packet, peer));
166                                        if socket_event_tx.send(event).is_err() {
167                                            tracing::warn!("Cannot transmit reset packet: socket closed channel");
168                                            return;
169                                        }
170                                    }
171                                }
172                            },
173                        }
174                    }
175                    Some(accept_with_cid) = accepts_with_cid_rx.recv() => {
176                        let Some((mut peer, syn)) = incoming_conns.remove(&accept_with_cid.cid) else {
177                            awaiting.insert(accept_with_cid.cid.clone(), accept_with_cid);
178                            continue;
179                        };
180                        peer.consolidate(accept_with_cid.peer);
181                        Self::select_accept_helper(accept_with_cid.cid, peer, syn, conns.clone(), accept_with_cid.accept, socket_event_tx.clone());
182                    }
183                    Some(accept) = accepts_rx.recv(), if !incoming_conns.is_empty() => {
184                        let cid = incoming_conns.keys().next().expect("at least one incoming connection");
185                        let cid = cid.clone();
186                        let (peer, packet) = incoming_conns.remove(&cid).expect("to delete incoming connection");
187                        Self::select_accept_helper(cid, peer, packet, conns.clone(), accept, socket_event_tx.clone());
188                    }
189                    Some(event) = socket_event_rx.recv() => {
190                        match event {
191                            SocketEvent::Outgoing((packet, dst)) => {
192                                let encoded = packet.encode();
193                                if let Err(err) = socket.send_to(&encoded, &dst).await {
194                                    tracing::debug!(
195                                        %err,
196                                        cid = %packet.conn_id(),
197                                        packet = ?packet.packet_type(),
198                                        seq = %packet.seq_num(),
199                                        ack = %packet.ack_num(),
200                                        "unable to send uTP packet over socket"
201                                    );
202                                }
203                            }
204                            SocketEvent::Shutdown(cid) => {
205                                tracing::debug!(%cid.send, %cid.recv, "uTP conn shutdown");
206                                conns.write().unwrap().remove(&cid);
207                            }
208                        }
209                    }
210                    Some(Ok((cid, accept_with_cid))) = awaiting.next() => {
211                        // accept_with_cid didn't receive an inbound connection within the timeout period
212                        // log it and return a timeout error
213                        tracing::debug!(%cid.send, %cid.recv, "accept_with_cid timed out");
214                        let _ = accept_with_cid.accept
215                            .stream
216                            .send(Err(io::Error::from(io::ErrorKind::TimedOut)));
217                    }
218                    Some(Ok((cid, _packet))) = incoming_conns.next() => {
219                        // didn't handle inbound connection within the timeout period
220                        // log it and return a timeout error
221                        tracing::debug!(%cid.send, %cid.recv, "inbound connection timed out");
222                    }
223                }
224            }
225        });
226
227        utp
228    }
229
230    /// Internal cid generation
231    fn generate_cid(
232        &self,
233        peer_id: P::Id,
234        is_initiator: bool,
235        event_tx: Option<UnboundedSender<StreamEvent>>,
236    ) -> ConnectionId<P::Id> {
237        let mut cid = ConnectionId {
238            send: 0,
239            recv: 0,
240            peer_id,
241        };
242        let mut generation_attempt_count = 0;
243        loop {
244            if generation_attempt_count > CID_GENERATION_TRY_WARNING_COUNT {
245                tracing::error!("cid() tried to generate a cid {generation_attempt_count} times")
246            }
247            let recv: u16 = rand::random();
248            let send = if is_initiator {
249                recv.wrapping_add(1)
250            } else {
251                recv.wrapping_sub(1)
252            };
253            cid.send = send;
254            cid.recv = recv;
255
256            if !self.conns.read().unwrap().contains_key(&cid) {
257                if let Some(event_tx) = event_tx {
258                    self.conns.write().unwrap().insert(cid.clone(), event_tx);
259                }
260                return cid;
261            }
262            generation_attempt_count += 1;
263        }
264    }
265
266    pub fn cid(&self, peer_id: P::Id, is_initiator: bool) -> ConnectionId<P::Id> {
267        self.generate_cid(peer_id, is_initiator, None)
268    }
269
270    /// Returns the number of connections currently open, both inbound and outbound.
271    pub fn num_connections(&self) -> usize {
272        self.conns.read().unwrap().len()
273    }
274
275    /// WARNING: only accept() or accept_with_cid() can be used in an application.
276    /// they aren't compatible to use interchangeably in a program
277    pub async fn accept(&self, config: ConnectionConfig) -> io::Result<UtpStream<P>> {
278        let (stream_tx, stream_rx) = oneshot::channel();
279        let accept = Accept {
280            stream: stream_tx,
281            config,
282        };
283        self.accepts
284            .send(accept)
285            .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
286        match stream_rx.await {
287            Ok(stream) => Ok(stream?),
288            Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)),
289        }
290    }
291
292    /// WARNING: only accept() or accept_with_cid() can be used in an application.
293    /// they aren't compatible to use interchangeably in a program
294    pub async fn accept_with_cid(
295        &self,
296        cid: ConnectionId<P::Id>,
297        peer: Peer<P>,
298        config: ConnectionConfig,
299    ) -> io::Result<UtpStream<P>> {
300        let (stream_tx, stream_rx) = oneshot::channel();
301        let accept = AcceptWithCidPeer {
302            cid,
303            peer,
304            accept: Accept {
305                stream: stream_tx,
306                config,
307            },
308        };
309        self.accepts_with_cid
310            .send(accept)
311            .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
312        match stream_rx.await {
313            Ok(stream) => Ok(stream?),
314            Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)),
315        }
316    }
317
318    pub async fn connect(
319        &self,
320        peer: Peer<P>,
321        config: ConnectionConfig,
322    ) -> io::Result<UtpStream<P>> {
323        let (connected_tx, connected_rx) = oneshot::channel();
324        let (events_tx, events_rx) = mpsc::unbounded_channel();
325        let cid = self.generate_cid(peer.id().clone(), true, Some(events_tx));
326
327        let stream = UtpStream::new(
328            cid,
329            peer,
330            config,
331            None,
332            self.socket_events.clone(),
333            events_rx,
334            connected_tx,
335        );
336
337        match connected_rx.await {
338            Ok(Ok(..)) => Ok(stream),
339            Ok(Err(err)) => Err(err),
340            Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)),
341        }
342    }
343
344    pub async fn connect_with_cid(
345        &self,
346        cid: ConnectionId<P::Id>,
347        peer: Peer<P>,
348        config: ConnectionConfig,
349    ) -> io::Result<UtpStream<P>> {
350        if self.conns.read().unwrap().contains_key(&cid) {
351            return Err(io::Error::new(
352                io::ErrorKind::Other,
353                "connection ID unavailable".to_string(),
354            ));
355        }
356
357        let (connected_tx, connected_rx) = oneshot::channel();
358        let (events_tx, events_rx) = mpsc::unbounded_channel();
359
360        {
361            self.conns.write().unwrap().insert(cid.clone(), events_tx);
362        }
363
364        let stream = UtpStream::new(
365            cid.clone(),
366            peer,
367            config,
368            None,
369            self.socket_events.clone(),
370            events_rx,
371            connected_tx,
372        );
373
374        match connected_rx.await {
375            Ok(Ok(..)) => Ok(stream),
376            Ok(Err(err)) => {
377                tracing::error!(%err, "failed to open connection with {cid:?}");
378                Err(err)
379            }
380            Err(err) => {
381                tracing::error!(%err, "failed to open connection with {cid:?}");
382                Err(io::Error::from(io::ErrorKind::TimedOut))
383            }
384        }
385    }
386
387    async fn await_connected(
388        stream: UtpStream<P>,
389        callback: oneshot::Sender<io::Result<UtpStream<P>>>,
390        connected: oneshot::Receiver<io::Result<()>>,
391    ) {
392        match connected.await {
393            Ok(Ok(..)) => {
394                let _ = callback.send(Ok(stream));
395            }
396            Ok(Err(err)) => {
397                let _ = callback.send(Err(err));
398            }
399            Err(..) => {
400                let _ = callback.send(Err(io::Error::from(io::ErrorKind::ConnectionAborted)));
401            }
402        }
403    }
404
405    fn select_accept_helper(
406        cid: ConnectionId<P::Id>,
407        peer: Peer<P>,
408        syn: Packet,
409        conns: Arc<RwLock<HashMap<ConnectionId<P::Id>, ConnChannel>>>,
410        accept: Accept<P>,
411        socket_event_tx: UnboundedSender<SocketEvent<P>>,
412    ) {
413        if conns.read().unwrap().contains_key(&cid) {
414            let _ = accept.stream.send(Err(io::Error::new(
415                io::ErrorKind::Other,
416                "connection ID unavailable".to_string(),
417            )));
418            return;
419        }
420
421        let (connected_tx, connected_rx) = oneshot::channel();
422        let (events_tx, events_rx) = mpsc::unbounded_channel();
423
424        {
425            conns.write().unwrap().insert(cid.clone(), events_tx);
426        }
427
428        let stream = UtpStream::new(
429            cid,
430            peer,
431            accept.config,
432            Some(syn),
433            socket_event_tx,
434            events_rx,
435            connected_tx,
436        );
437
438        tokio::spawn(
439            async move { Self::await_connected(stream, accept.stream, connected_rx).await },
440        );
441    }
442}
443
444#[derive(Copy, Clone, Debug)]
445enum IdType {
446    RecvId,
447    SendIdWeInitiated,
448    SendIdPeerInitiated,
449}
450
451fn cid_from_packet<P: ConnectionPeer>(
452    packet: &Packet,
453    peer_id: &P::Id,
454    id_type: IdType,
455) -> ConnectionId<P::Id> {
456    let peer_id = peer_id.clone();
457    match id_type {
458        IdType::RecvId => {
459            let (send, recv) = match packet.packet_type() {
460                PacketType::Syn => (packet.conn_id(), packet.conn_id().wrapping_add(1)),
461                PacketType::State | PacketType::Data | PacketType::Fin | PacketType::Reset => {
462                    (packet.conn_id().wrapping_sub(1), packet.conn_id())
463                }
464            };
465            ConnectionId {
466                send,
467                recv,
468                peer_id,
469            }
470        }
471        IdType::SendIdWeInitiated => {
472            let (send, recv) = (packet.conn_id().wrapping_add(1), packet.conn_id());
473            ConnectionId {
474                send,
475                recv,
476                peer_id,
477            }
478        }
479        IdType::SendIdPeerInitiated => {
480            let (send, recv) = (packet.conn_id(), packet.conn_id().wrapping_sub(1));
481            ConnectionId {
482                send,
483                recv,
484                peer_id,
485            }
486        }
487    }
488}
489
490impl<P: ConnectionPeer> Drop for UtpSocket<P> {
491    fn drop(&mut self) {
492        for conn in self.conns.read().unwrap().values() {
493            let _ = conn.send(StreamEvent::Shutdown);
494        }
495    }
496}