1use std::collections::HashMap;
2use std::io;
3use std::net::SocketAddr;
4use std::sync::{Arc, RwLock};
5use std::time::Duration;
6
7use delay_map::HashMapDelay;
8use futures::StreamExt;
9use rand::{thread_rng, Rng};
10use tokio::net::UdpSocket;
11use tokio::sync::mpsc::UnboundedSender;
12use tokio::sync::{mpsc, oneshot};
13
14use crate::cid::ConnectionId;
15use crate::conn::ConnectionConfig;
16use crate::event::{SocketEvent, StreamEvent};
17use crate::packet::{Packet, PacketBuilder, PacketType};
18use crate::peer::{ConnectionPeer, Peer};
19use crate::stream::UtpStream;
20use crate::udp::AsyncUdpSocket;
21
22type ConnChannel = UnboundedSender<StreamEvent>;
23
24struct Accept<P: ConnectionPeer> {
25 stream: oneshot::Sender<io::Result<UtpStream<P>>>,
26 config: ConnectionConfig,
27}
28
29struct AcceptWithCidPeer<P: ConnectionPeer> {
30 cid: ConnectionId<P::Id>,
31 peer: Peer<P>,
32 accept: Accept<P>,
33}
34
35const MAX_UDP_PAYLOAD_SIZE: usize = u16::MAX as usize;
36const CID_GENERATION_TRY_WARNING_COUNT: usize = 10;
37
38const AWAITING_CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
45
46pub struct UtpSocket<P: ConnectionPeer> {
47 conns: Arc<RwLock<HashMap<ConnectionId<P::Id>, ConnChannel>>>,
48 accepts: UnboundedSender<Accept<P>>,
49 accepts_with_cid: UnboundedSender<AcceptWithCidPeer<P>>,
50 socket_events: UnboundedSender<SocketEvent<P>>,
51}
52
53impl UtpSocket<SocketAddr> {
54 pub async fn bind(addr: SocketAddr) -> io::Result<Self> {
55 let socket = UdpSocket::bind(addr).await?;
56 let socket = Self::with_socket(socket);
57 Ok(socket)
58 }
59}
60
61impl<P> UtpSocket<P>
62where
63 P: ConnectionPeer<Id: Unpin> + Unpin + 'static,
64{
65 pub fn with_socket<S>(mut socket: S) -> Self
66 where
67 S: AsyncUdpSocket<P> + 'static,
68 {
69 let conns = HashMap::new();
70 let conns = Arc::new(RwLock::new(conns));
71
72 let mut awaiting: HashMapDelay<ConnectionId<P::Id>, AcceptWithCidPeer<P>> =
73 HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT);
74
75 let mut incoming_conns: HashMapDelay<ConnectionId<P::Id>, (Peer<P>, Packet)> =
76 HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT);
77
78 let (socket_event_tx, mut socket_event_rx) = mpsc::unbounded_channel();
79 let (accepts_tx, mut accepts_rx) = mpsc::unbounded_channel();
80 let (accepts_with_cid_tx, mut accepts_with_cid_rx) = mpsc::unbounded_channel();
81
82 let utp = Self {
83 conns: Arc::clone(&conns),
84 accepts: accepts_tx,
85 accepts_with_cid: accepts_with_cid_tx,
86 socket_events: socket_event_tx.clone(),
87 };
88
89 tokio::spawn(async move {
90 let mut buf = [0; MAX_UDP_PAYLOAD_SIZE];
91 loop {
92 tokio::select! {
93 biased;
94 Ok((n, mut peer)) = socket.recv_from(&mut buf) => {
95 let peer_id = peer.id();
96 let packet = match Packet::decode(&buf[..n]) {
97 Ok(pkt) => pkt,
98 Err(..) => {
99 tracing::warn!(?peer, "unable to decode uTP packet");
100 continue;
101 }
102 };
103
104 let peer_init_cid = cid_from_packet::<P>(&packet, peer_id, IdType::SendIdPeerInitiated);
105 let we_init_cid = cid_from_packet::<P>(&packet, peer_id, IdType::SendIdWeInitiated);
106 let acc_cid = cid_from_packet::<P>(&packet, peer_id, IdType::RecvId);
107 let mut conns = conns.write().unwrap();
108 let conn = conns
109 .get(&acc_cid)
110 .or_else(|| conns.get(&we_init_cid))
111 .or_else(|| conns.get(&peer_init_cid));
112 match conn {
113 Some(conn) => {
114 let _ = conn.send(StreamEvent::Incoming(packet));
115 }
116 None => {
117 if std::matches!(packet.packet_type(), PacketType::Syn) {
118 let cid = acc_cid;
119
120 if let Some(accept_with_cid) = awaiting.remove(&cid) {
124 peer.consolidate(accept_with_cid.peer);
125
126 let (connected_tx, connected_rx) = oneshot::channel();
127 let (events_tx, events_rx) = mpsc::unbounded_channel();
128
129 conns.insert(cid.clone(), events_tx);
130
131 let stream = UtpStream::new(
132 cid,
133 peer,
134 accept_with_cid.accept.config,
135 Some(packet),
136 socket_event_tx.clone(),
137 events_rx,
138 connected_tx
139 );
140
141 tokio::spawn(async move {
142 Self::await_connected(stream, accept_with_cid.accept.stream, connected_rx).await
143 });
144 } else {
145 incoming_conns.insert(cid, (peer, packet));
146 }
147 } else {
148 tracing::debug!(
149 cid = %packet.conn_id(),
150 packet = ?packet.packet_type(),
151 seq = %packet.seq_num(),
152 ack = %packet.ack_num(),
153 peer_init_cid = ?peer_init_cid,
154 we_init_cid = ?we_init_cid,
155 acc_cid = ?acc_cid,
156 "received uTP packet for non-existing conn"
157 );
158 if packet.packet_type() != PacketType::Reset {
160 let random_seq_num = thread_rng().gen_range(0..=65535);
162 let reset_packet =
163 PacketBuilder::new(PacketType::Reset, packet.conn_id(), crate::time::now_micros(), 100_000, random_seq_num)
164 .build();
165 let event = SocketEvent::Outgoing((reset_packet, peer));
166 if socket_event_tx.send(event).is_err() {
167 tracing::warn!("Cannot transmit reset packet: socket closed channel");
168 return;
169 }
170 }
171 }
172 },
173 }
174 }
175 Some(accept_with_cid) = accepts_with_cid_rx.recv() => {
176 let Some((mut peer, syn)) = incoming_conns.remove(&accept_with_cid.cid) else {
177 awaiting.insert(accept_with_cid.cid.clone(), accept_with_cid);
178 continue;
179 };
180 peer.consolidate(accept_with_cid.peer);
181 Self::select_accept_helper(accept_with_cid.cid, peer, syn, conns.clone(), accept_with_cid.accept, socket_event_tx.clone());
182 }
183 Some(accept) = accepts_rx.recv(), if !incoming_conns.is_empty() => {
184 let cid = incoming_conns.keys().next().expect("at least one incoming connection");
185 let cid = cid.clone();
186 let (peer, packet) = incoming_conns.remove(&cid).expect("to delete incoming connection");
187 Self::select_accept_helper(cid, peer, packet, conns.clone(), accept, socket_event_tx.clone());
188 }
189 Some(event) = socket_event_rx.recv() => {
190 match event {
191 SocketEvent::Outgoing((packet, dst)) => {
192 let encoded = packet.encode();
193 if let Err(err) = socket.send_to(&encoded, &dst).await {
194 tracing::debug!(
195 %err,
196 cid = %packet.conn_id(),
197 packet = ?packet.packet_type(),
198 seq = %packet.seq_num(),
199 ack = %packet.ack_num(),
200 "unable to send uTP packet over socket"
201 );
202 }
203 }
204 SocketEvent::Shutdown(cid) => {
205 tracing::debug!(%cid.send, %cid.recv, "uTP conn shutdown");
206 conns.write().unwrap().remove(&cid);
207 }
208 }
209 }
210 Some(Ok((cid, accept_with_cid))) = awaiting.next() => {
211 tracing::debug!(%cid.send, %cid.recv, "accept_with_cid timed out");
214 let _ = accept_with_cid.accept
215 .stream
216 .send(Err(io::Error::from(io::ErrorKind::TimedOut)));
217 }
218 Some(Ok((cid, _packet))) = incoming_conns.next() => {
219 tracing::debug!(%cid.send, %cid.recv, "inbound connection timed out");
222 }
223 }
224 }
225 });
226
227 utp
228 }
229
230 fn generate_cid(
232 &self,
233 peer_id: P::Id,
234 is_initiator: bool,
235 event_tx: Option<UnboundedSender<StreamEvent>>,
236 ) -> ConnectionId<P::Id> {
237 let mut cid = ConnectionId {
238 send: 0,
239 recv: 0,
240 peer_id,
241 };
242 let mut generation_attempt_count = 0;
243 loop {
244 if generation_attempt_count > CID_GENERATION_TRY_WARNING_COUNT {
245 tracing::error!("cid() tried to generate a cid {generation_attempt_count} times")
246 }
247 let recv: u16 = rand::random();
248 let send = if is_initiator {
249 recv.wrapping_add(1)
250 } else {
251 recv.wrapping_sub(1)
252 };
253 cid.send = send;
254 cid.recv = recv;
255
256 if !self.conns.read().unwrap().contains_key(&cid) {
257 if let Some(event_tx) = event_tx {
258 self.conns.write().unwrap().insert(cid.clone(), event_tx);
259 }
260 return cid;
261 }
262 generation_attempt_count += 1;
263 }
264 }
265
266 pub fn cid(&self, peer_id: P::Id, is_initiator: bool) -> ConnectionId<P::Id> {
267 self.generate_cid(peer_id, is_initiator, None)
268 }
269
270 pub fn num_connections(&self) -> usize {
272 self.conns.read().unwrap().len()
273 }
274
275 pub async fn accept(&self, config: ConnectionConfig) -> io::Result<UtpStream<P>> {
278 let (stream_tx, stream_rx) = oneshot::channel();
279 let accept = Accept {
280 stream: stream_tx,
281 config,
282 };
283 self.accepts
284 .send(accept)
285 .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
286 match stream_rx.await {
287 Ok(stream) => Ok(stream?),
288 Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)),
289 }
290 }
291
292 pub async fn accept_with_cid(
295 &self,
296 cid: ConnectionId<P::Id>,
297 peer: Peer<P>,
298 config: ConnectionConfig,
299 ) -> io::Result<UtpStream<P>> {
300 let (stream_tx, stream_rx) = oneshot::channel();
301 let accept = AcceptWithCidPeer {
302 cid,
303 peer,
304 accept: Accept {
305 stream: stream_tx,
306 config,
307 },
308 };
309 self.accepts_with_cid
310 .send(accept)
311 .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
312 match stream_rx.await {
313 Ok(stream) => Ok(stream?),
314 Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)),
315 }
316 }
317
318 pub async fn connect(
319 &self,
320 peer: Peer<P>,
321 config: ConnectionConfig,
322 ) -> io::Result<UtpStream<P>> {
323 let (connected_tx, connected_rx) = oneshot::channel();
324 let (events_tx, events_rx) = mpsc::unbounded_channel();
325 let cid = self.generate_cid(peer.id().clone(), true, Some(events_tx));
326
327 let stream = UtpStream::new(
328 cid,
329 peer,
330 config,
331 None,
332 self.socket_events.clone(),
333 events_rx,
334 connected_tx,
335 );
336
337 match connected_rx.await {
338 Ok(Ok(..)) => Ok(stream),
339 Ok(Err(err)) => Err(err),
340 Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)),
341 }
342 }
343
344 pub async fn connect_with_cid(
345 &self,
346 cid: ConnectionId<P::Id>,
347 peer: Peer<P>,
348 config: ConnectionConfig,
349 ) -> io::Result<UtpStream<P>> {
350 if self.conns.read().unwrap().contains_key(&cid) {
351 return Err(io::Error::new(
352 io::ErrorKind::Other,
353 "connection ID unavailable".to_string(),
354 ));
355 }
356
357 let (connected_tx, connected_rx) = oneshot::channel();
358 let (events_tx, events_rx) = mpsc::unbounded_channel();
359
360 {
361 self.conns.write().unwrap().insert(cid.clone(), events_tx);
362 }
363
364 let stream = UtpStream::new(
365 cid.clone(),
366 peer,
367 config,
368 None,
369 self.socket_events.clone(),
370 events_rx,
371 connected_tx,
372 );
373
374 match connected_rx.await {
375 Ok(Ok(..)) => Ok(stream),
376 Ok(Err(err)) => {
377 tracing::error!(%err, "failed to open connection with {cid:?}");
378 Err(err)
379 }
380 Err(err) => {
381 tracing::error!(%err, "failed to open connection with {cid:?}");
382 Err(io::Error::from(io::ErrorKind::TimedOut))
383 }
384 }
385 }
386
387 async fn await_connected(
388 stream: UtpStream<P>,
389 callback: oneshot::Sender<io::Result<UtpStream<P>>>,
390 connected: oneshot::Receiver<io::Result<()>>,
391 ) {
392 match connected.await {
393 Ok(Ok(..)) => {
394 let _ = callback.send(Ok(stream));
395 }
396 Ok(Err(err)) => {
397 let _ = callback.send(Err(err));
398 }
399 Err(..) => {
400 let _ = callback.send(Err(io::Error::from(io::ErrorKind::ConnectionAborted)));
401 }
402 }
403 }
404
405 fn select_accept_helper(
406 cid: ConnectionId<P::Id>,
407 peer: Peer<P>,
408 syn: Packet,
409 conns: Arc<RwLock<HashMap<ConnectionId<P::Id>, ConnChannel>>>,
410 accept: Accept<P>,
411 socket_event_tx: UnboundedSender<SocketEvent<P>>,
412 ) {
413 if conns.read().unwrap().contains_key(&cid) {
414 let _ = accept.stream.send(Err(io::Error::new(
415 io::ErrorKind::Other,
416 "connection ID unavailable".to_string(),
417 )));
418 return;
419 }
420
421 let (connected_tx, connected_rx) = oneshot::channel();
422 let (events_tx, events_rx) = mpsc::unbounded_channel();
423
424 {
425 conns.write().unwrap().insert(cid.clone(), events_tx);
426 }
427
428 let stream = UtpStream::new(
429 cid,
430 peer,
431 accept.config,
432 Some(syn),
433 socket_event_tx,
434 events_rx,
435 connected_tx,
436 );
437
438 tokio::spawn(
439 async move { Self::await_connected(stream, accept.stream, connected_rx).await },
440 );
441 }
442}
443
444#[derive(Copy, Clone, Debug)]
445enum IdType {
446 RecvId,
447 SendIdWeInitiated,
448 SendIdPeerInitiated,
449}
450
451fn cid_from_packet<P: ConnectionPeer>(
452 packet: &Packet,
453 peer_id: &P::Id,
454 id_type: IdType,
455) -> ConnectionId<P::Id> {
456 let peer_id = peer_id.clone();
457 match id_type {
458 IdType::RecvId => {
459 let (send, recv) = match packet.packet_type() {
460 PacketType::Syn => (packet.conn_id(), packet.conn_id().wrapping_add(1)),
461 PacketType::State | PacketType::Data | PacketType::Fin | PacketType::Reset => {
462 (packet.conn_id().wrapping_sub(1), packet.conn_id())
463 }
464 };
465 ConnectionId {
466 send,
467 recv,
468 peer_id,
469 }
470 }
471 IdType::SendIdWeInitiated => {
472 let (send, recv) = (packet.conn_id().wrapping_add(1), packet.conn_id());
473 ConnectionId {
474 send,
475 recv,
476 peer_id,
477 }
478 }
479 IdType::SendIdPeerInitiated => {
480 let (send, recv) = (packet.conn_id(), packet.conn_id().wrapping_sub(1));
481 ConnectionId {
482 send,
483 recv,
484 peer_id,
485 }
486 }
487 }
488}
489
490impl<P: ConnectionPeer> Drop for UtpSocket<P> {
491 fn drop(&mut self) {
492 for conn in self.conns.read().unwrap().values() {
493 let _ = conn.send(StreamEvent::Shutdown);
494 }
495 }
496}