1use std::{
2 cell::RefCell,
3 collections::{hash_map::Entry, HashMap},
4 net::SocketAddr,
5 rc::Rc,
6};
7
8use anyhow::Context;
9use bytes::Bytes;
10use tokio_uring::net::UdpSocket;
11
12use crate::{
13 utp_packet::{Packet, PacketHeader, PacketType, HEADER_SIZE},
14 utp_stream::{UtpStream, WeakUtpStream},
15};
16
17pub struct UtpSocket {
22 socket: Rc<UdpSocket>,
23 shutdown_signal: Option<tokio::sync::oneshot::Sender<()>>,
24 accept_chan: Rc<RefCell<Option<tokio::sync::oneshot::Sender<UtpStream>>>>,
25 streams: Rc<RefCell<HashMap<StreamKey, WeakUtpStream>>>,
26}
27
28#[derive(PartialEq, Eq, Debug, Copy, Clone, Hash)]
29struct StreamKey {
30 conn_id: u16,
31 addr: SocketAddr,
32}
33
34impl UtpSocket {
37 pub async fn bind(bind_addr: SocketAddr) -> anyhow::Result<Self> {
38 let socket = Rc::new(UdpSocket::bind(bind_addr).await?);
39 let net_loop_socket = socket.clone();
40
41 let (shutdown_signal, mut shutdown_receiver) = tokio::sync::oneshot::channel();
42 let utp_socket = UtpSocket {
44 socket,
45 shutdown_signal: Some(shutdown_signal),
46 accept_chan: Default::default(),
47 streams: Default::default(),
48 };
49
50 let streams_clone = utp_socket.streams.clone();
51 let accept_chan = utp_socket.accept_chan.clone();
52 tokio_uring::spawn(async move {
54 let mut recv_buf = vec![0; 1024 * 1024];
56 loop {
57 tokio::select! {
60 buf = process_incomming(&net_loop_socket, &streams_clone, &accept_chan, std::mem::take(&mut recv_buf)) => {
61 recv_buf = buf;
62 }
63 _ = &mut shutdown_receiver => {
64 log::info!("Shutting down network loop");
65 break;
67 }
68 }
69 }
70 });
71
72 Ok(utp_socket)
73 }
74
75 pub async fn connect(&self, addr: SocketAddr) -> anyhow::Result<UtpStream> {
76 let mut stream_key = StreamKey {
77 conn_id: rand::random(),
78 addr,
79 };
80
81 while self.streams.borrow().contains_key(&stream_key) {
82 log::debug!("Stream with same conn_id and addr already exists, regenerating conn_id");
83 stream_key = StreamKey {
84 conn_id: rand::random::<u16>(),
85 addr,
86 }
87 }
88
89 let stream = UtpStream::new(stream_key.conn_id, addr, Rc::downgrade(&self.socket));
90 self.streams
91 .borrow_mut()
92 .insert(stream_key, stream.clone().into());
93
94 stream.connect().await?;
95
96 Ok(stream)
97 }
98
99 pub async fn accept(&self) -> anyhow::Result<UtpStream> {
100 let (tx, rc) = tokio::sync::oneshot::channel();
101 {
102 let mut chan = self.accept_chan.borrow_mut();
103 *chan = Some(tx);
104 }
105 rc.await.context("Net loop exited")
106 }
107}
108
109impl Drop for UtpSocket {
110 fn drop(&mut self) {
111 println!("dropping");
112 self.shutdown_signal.take().unwrap().send(()).unwrap();
113 }
114}
115
116async fn process_incomming(
127 socket: &Rc<UdpSocket>,
128 connections: &Rc<RefCell<HashMap<StreamKey, WeakUtpStream>>>,
129 accept_chan: &Rc<RefCell<Option<tokio::sync::oneshot::Sender<UtpStream>>>>,
130 recv_buf: Vec<u8>,
131) -> Vec<u8> {
132 let (result, buf) = socket.recv_from(recv_buf).await;
133 match result {
134 Ok((recv, addr)) => {
135 log::info!("Received {recv} from {addr}");
136 match PacketHeader::try_from(&buf[..recv]) {
137 Ok(packet_header) => {
138 let key = StreamKey {
139 conn_id: packet_header.conn_id,
140 addr,
141 };
142
143 let packet = Packet {
144 header: packet_header,
145 data: Bytes::copy_from_slice(&buf[HEADER_SIZE as usize..recv]),
146 };
147
148 let maybe_stream = { connections.borrow_mut().remove(&key) };
149 if let Some(weak_stream) = maybe_stream {
150 if let Some(stream) = weak_stream.try_upgrade() {
151 match stream.process_incoming(packet).await {
152 Ok(()) => {
153 connections.borrow_mut().insert(key, stream.into());
154 }
155 Err(err) => {
156 log::error!("Error: Failed processing incoming packet: {err}");
157 }
158 }
159 }
160 } else if packet_header.packet_type == PacketType::Syn {
161 let maybe_chan = { accept_chan.borrow_mut().take() };
162 if let Some(chan) = maybe_chan {
163 let stream = UtpStream::new_incoming(
164 packet_header.seq_nr,
165 packet_header.conn_id,
166 addr,
167 Rc::downgrade(socket),
168 );
169 let stream_key = StreamKey {
170 conn_id: packet_header.conn_id + 1,
173 addr,
174 };
175 {
178 let mut connections = connections.borrow_mut();
179 let entry = connections.entry(stream_key);
180 match entry {
181 Entry::Occupied(_) => {
182 log::warn!("Connection with id: {} already exists. Dropping connection",
183 packet_header.conn_id + 1
184 );
185 return buf;
186 }
187 Entry::Vacant(entry) => {
188 log::info!("New incoming connection!");
189 entry.insert(stream.clone().into());
190 }
191 }
192 }
193
194 if let Err(err) = stream.process_incoming(packet).await {
196 log::error!("Error accepting connection: {err}");
197 *accept_chan.borrow_mut() = Some(chan);
201 connections.borrow_mut().remove(&stream_key);
202 } else {
203 chan.send(stream).unwrap();
204 }
205 }
206 } else {
207 log::warn!("Connection not established prior");
208 }
209 }
210 Err(err) => log::error!("Error parsing packet: {err}"),
211 }
212 }
213 Err(err) => log::error!("Failed to receive on utp socket: {err}"),
214 }
215 buf
216}
217
218