1#![allow(dead_code)]
2
3use std::{
4 cell::{Ref, RefCell, RefMut},
5 net::SocketAddr,
6 rc::{Rc, Weak},
7 time::Duration,
8};
9
10use bytes::Bytes;
11use tokio::sync::{oneshot::Receiver, Notify};
12use tokio_uring::net::UdpSocket;
13
14use crate::{
15 reorder_buffer::ReorderBuffer,
16 utp_packet::{get_microseconds, Packet, PacketHeader, PacketType, HEADER_SIZE},
17};
18
19#[derive(Debug)]
20pub(crate) enum ConnectionState {
21 Idle,
22 SynReceived,
23 SynSent {
24 connect_notifier: tokio::sync::oneshot::Sender<()>,
25 },
26 Connected,
27 FinSent,
28}
29
30impl PartialEq for ConnectionState {
31 fn eq(&self, other: &Self) -> bool {
32 core::mem::discriminant(self) == core::mem::discriminant(other)
33 }
34}
35
36impl Eq for ConnectionState {}
37
38#[derive(Debug)]
40pub(crate) struct StreamState {
41 pub(crate) connection_state: ConnectionState,
43 pub(crate) seq_nr: u16,
45 pub(crate) ack_nr: u16,
48 pub(crate) conn_id_recv: u16,
50 pub(crate) conn_id_send: u16,
52 pub(crate) cur_window: u32,
54 pub(crate) max_window: u32,
56 pub(crate) their_advertised_window: u32,
57 pub(crate) reply_micro: u32,
61 pub(crate) eof_pkt: Option<u16>,
63 pub(crate) incoming_buffer: ReorderBuffer,
65 pub(crate) outgoing_buffer: ReorderBuffer,
67 pub(crate) receive_buf: Box<[u8]>,
73 receive_buf_cursor: usize,
74
75 shutdown_signal: Option<tokio::sync::oneshot::Sender<()>>,
76}
77
78impl StreamState {
79 fn syn_header(&mut self) -> (PacketHeader, Receiver<()>) {
80 let (tx, rc) = tokio::sync::oneshot::channel();
81 self.connection_state = ConnectionState::SynSent {
83 connect_notifier: tx,
84 };
85
86 let header = PacketHeader {
87 seq_nr: self.seq_nr,
88 ack_nr: 0,
89 conn_id: self.conn_id_recv,
90 packet_type: PacketType::Syn,
91 timestamp_microseconds: get_microseconds() as u32,
92 timestamp_difference_microseconds: self.reply_micro,
93 wnd_size: 0,
95 extension: 0,
96 };
97 (header, rc)
98 }
99
100 fn ack(&self) -> PacketHeader {
101 let timestamp_microseconds = get_microseconds();
103 PacketHeader {
104 seq_nr: self.seq_nr,
105 ack_nr: self.ack_nr,
106 conn_id: self.conn_id_send,
107 packet_type: PacketType::State,
108 timestamp_microseconds: timestamp_microseconds as u32,
109 timestamp_difference_microseconds: self.reply_micro,
110 wnd_size: self.our_advertised_window(),
111 extension: 0,
112 }
113 }
114
115 fn data(&mut self) -> PacketHeader {
116 let timestamp_microseconds = get_microseconds();
118 self.seq_nr += 1;
119 PacketHeader {
120 seq_nr: self.seq_nr,
121 ack_nr: self.ack_nr,
122 conn_id: self.conn_id_send,
123 packet_type: PacketType::Data,
124 timestamp_microseconds: timestamp_microseconds as u32,
125 timestamp_difference_microseconds: self.reply_micro,
126 wnd_size: self.our_advertised_window(),
127 extension: 0,
128 }
129 }
130
131 fn try_consume(&mut self, data: &[u8]) -> bool {
132 if data.len() <= (self.receive_buf.len() - self.receive_buf_cursor) {
134 let cursor = self.receive_buf_cursor;
135 self.receive_buf[cursor..cursor + data.len()].copy_from_slice(data);
138 self.receive_buf_cursor += data.len();
139 true
140 } else {
141 log::warn!("Receive buf full, packet dropped");
142 false
143 }
144 }
145
146 #[inline(always)]
147 pub(crate) fn our_advertised_window(&self) -> u32 {
148 let wnd_size = (self.receive_buf.len() - self.receive_buf_cursor) as i32
149 - self.incoming_buffer.size() as i32;
150 std::cmp::max(wnd_size, 0) as u32
151 }
152
153 #[inline(always)]
154 fn stream_window_size(&self) -> u32 {
155 std::cmp::min(self.max_window, self.their_advertised_window)
156 }
157}
158
159#[derive(Clone)]
161pub struct UtpStream {
162 inner: Rc<RefCell<StreamState>>,
163 addr: SocketAddr,
165 weak_socket: Weak<UdpSocket>,
166 data_available: Rc<Notify>,
170}
171
172pub(crate) struct WeakUtpStream {
176 inner: Weak<RefCell<StreamState>>,
177 addr: SocketAddr,
179 weak_socket: Weak<UdpSocket>,
180 data_available: Rc<Notify>,
184}
185
186impl WeakUtpStream {
187 pub(crate) fn try_upgrade(&self) -> Option<UtpStream> {
188 self.inner.upgrade().map(|inner| UtpStream {
189 inner,
190 addr: self.addr,
191 weak_socket: self.weak_socket.clone(),
192 data_available: self.data_available.clone(),
193 })
194 }
195}
196
197impl From<UtpStream> for WeakUtpStream {
198 fn from(stream: UtpStream) -> Self {
199 WeakUtpStream {
200 inner: Rc::downgrade(&stream.inner),
201 addr: stream.addr,
202 weak_socket: stream.weak_socket.clone(),
204 data_available: stream.data_available.clone(),
205 }
206 }
207}
208
209impl std::fmt::Debug for UtpStream {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 f.debug_struct("UtpStream")
212 .field("state", &self.inner)
213 .field("addr", &self.addr)
214 .field("data_available", &self.data_available)
215 .finish()
216 }
217}
218
219const MTU: u32 = 1250;
220
221impl UtpStream {
222 pub(crate) fn new(conn_id: u16, addr: SocketAddr, weak_socket: Weak<UdpSocket>) -> Self {
223 let (shutdown_signal, mut shutdown_receiver) = tokio::sync::oneshot::channel();
224 let stream = UtpStream {
225 inner: Rc::new(RefCell::new(StreamState {
226 connection_state: ConnectionState::Idle,
227 seq_nr: rand::random::<u16>(),
229 conn_id_recv: conn_id,
230 cur_window: 0,
231 max_window: MTU,
232 ack_nr: 0,
233 conn_id_send: conn_id + 1,
234 reply_micro: 0,
235 eof_pkt: None,
236 their_advertised_window: MTU,
238 incoming_buffer: ReorderBuffer::new(256),
239 outgoing_buffer: ReorderBuffer::new(256),
240 receive_buf: vec![0; 1024 * 1024].into_boxed_slice(),
241 receive_buf_cursor: 0,
242 shutdown_signal: Some(shutdown_signal),
243 })),
244 weak_socket,
245 data_available: Rc::new(Notify::new()),
246 addr,
247 };
248
249 let stream_clone = stream.clone();
250 tokio_uring::spawn(async move {
252 let mut tick_interval = tokio::time::interval(Duration::from_millis(250));
253 tick_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
254 loop {
255 tokio::select! {
256 _ = tick_interval.tick() => {
257 if let Err(err) = stream_clone.flush_outbuf().await {
258 log::error!("Error: {err}, shutting down stream send loop");
259 break;
260 }
261 },
262 _ = &mut shutdown_receiver => {
263 log::info!("Shutting down stream send loop");
264 break;
265 },
266 }
267 }
268 });
269 stream
270 }
271
272 pub(crate) fn new_incoming(
273 seq_nr: u16,
274 conn_id: u16,
275 addr: SocketAddr,
276 weak_socket: Weak<UdpSocket>,
277 ) -> Self {
278 let (shutdown_signal, mut shutdown_receiver) = tokio::sync::oneshot::channel();
279 let stream = UtpStream {
280 inner: Rc::new(RefCell::new(StreamState {
281 connection_state: ConnectionState::SynReceived,
282 seq_nr: rand::random::<u16>(),
284 conn_id_recv: conn_id + 1,
285 cur_window: 0,
286 max_window: MTU,
287 ack_nr: seq_nr - 1,
289 conn_id_send: conn_id,
290 reply_micro: 0,
291 eof_pkt: None,
292 their_advertised_window: MTU,
294 incoming_buffer: ReorderBuffer::new(256),
295 outgoing_buffer: ReorderBuffer::new(256),
296 receive_buf: vec![0; 1024 * 1024].into_boxed_slice(),
297 receive_buf_cursor: 0,
298 shutdown_signal: Some(shutdown_signal),
299 })),
300 weak_socket,
301 data_available: Rc::new(Notify::new()),
302 addr,
303 };
304
305 let stream_clone = stream.clone();
306 tokio_uring::spawn(async move {
308 let mut tick_interval = tokio::time::interval(Duration::from_millis(250));
309 tick_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
310 loop {
311 tokio::select! {
312 _ = tick_interval.tick() => {
313 if let Err(err) = stream_clone.flush_outbuf().await {
314 log::error!("Error: {err}, shutting down stream send loop");
315 break;
316 }
317 },
318 _ = &mut shutdown_receiver => {
319 log::info!("Shutting down stream send loop");
320 break;
321 },
322 }
323 }
324 });
325
326 stream
327 }
328
329 pub async fn connect(&self) -> anyhow::Result<()> {
331 let (header, rc) = { self.state_mut().syn_header() };
333
334 log::debug!("Sending SYN");
335 self.send_packet(
336 Packet {
337 header,
338 data: Bytes::new(),
339 },
340 true,
341 )
342 .await?;
343 rc.await?;
344 Ok(())
345 }
346
347 #[cfg(test)]
348 pub async fn send_syn(&self) -> anyhow::Result<()> {
349 let (header, _rc) = { self.state_mut().syn_header() };
351
352 log::debug!("Sending SYN");
353 self.send_packet(
354 Packet {
355 header,
356 data: Bytes::new(),
357 },
358 true,
359 )
360 .await?;
361 Ok(())
362 }
363
364 async fn flush_outbuf(&self) -> anyhow::Result<()> {
365 let packets = {
369 let state = self.state();
370 if state.connection_state != ConnectionState::Connected
371 && !matches!(state.connection_state, ConnectionState::SynSent { .. })
373 {
374 log::debug!("Not yet connected, holding on to outgoing buffer");
376 return Ok(());
377 }
378 let packets: Vec<Packet> = state.outgoing_buffer.iter().cloned().collect();
379 log::debug!("Flushing outgoing buffer len: {}", packets.len());
380 packets
383 };
384 if let Some(socket) = self.weak_socket.upgrade() {
385 for packet in packets.into_iter() {
386 {
387 let state = self.state();
388 if state.cur_window + packet.size() > state.stream_window_size() {
389 log::warn!("Window to small to send packet, skipping");
390 continue;
391 }
392 }
393 let mut packet_bytes = vec![0; HEADER_SIZE as usize + packet.data.len()];
394 packet_bytes[..HEADER_SIZE as usize].copy_from_slice(&packet.header.to_bytes());
395 packet_bytes[HEADER_SIZE as usize..].copy_from_slice(&packet.data);
396 let bytes_sent = packet_bytes.len();
397 log::debug!(
398 "Sending {:?} bytes: {} to addr: {}",
399 packet.header.packet_type,
400 bytes_sent,
401 self.addr,
402 );
403 let (result, _buf) = socket.send_to(packet_bytes, self.addr).await;
405 let _ = result?;
406 let mut state = self.state_mut();
407 debug_assert!(bytes_sent < u32::MAX as usize);
409 state.cur_window += packet.data.len() as u32;
410 }
411 } else {
412 anyhow::bail!("Failed to send packets, socket dropped");
413 }
414 Ok(())
415 }
416
417 async fn send_packet(&self, packet: Packet, only_once: bool) -> anyhow::Result<()> {
418 let seq_nr = packet.header.seq_nr;
419 self.state_mut().outgoing_buffer.insert(packet);
420 self.flush_outbuf().await?;
421 if only_once {
423 self.state_mut().outgoing_buffer.remove(seq_nr);
424 }
425 Ok(())
426 }
427
428 async fn ack_packet(&self, seq_nr: u16) -> anyhow::Result<()> {
429 if let Some(socket) = self.weak_socket.upgrade() {
430 let ack_header = {
437 let mut state = self.state_mut();
438 state.ack_nr = seq_nr;
439 state.ack()
440 };
441 let packet_bytes = ack_header.to_bytes();
442 log::debug!(
443 "Sending Ack bytes: {} to addr: {}",
444 packet_bytes.len(),
445 self.addr,
446 );
447 let (result, _buf) = socket.send_to(packet_bytes, self.addr).await;
449 let _ = result?;
450 } else {
451 anyhow::bail!("Failed to ack packets, socket dropped");
452 }
453 Ok(())
454 }
455
456 pub(crate) async fn process_incoming(&self, packet: Packet) -> anyhow::Result<()> {
460 let packet_header = packet.header;
461
462 let conn_id = if packet_header.packet_type == PacketType::Syn {
466 packet_header.conn_id + 1
467 } else {
468 packet_header.conn_id
469 };
470 if self.state().conn_id_recv != conn_id && packet_header.packet_type != PacketType::Syn {
471 anyhow::bail!(
472 "Received invalid packet connection id: {}, expected: {}",
473 packet_header.conn_id,
474 self.state().conn_id_recv
475 )
476 }
477
478 let dist_from_expected = {
479 let mut state = self.state_mut();
480 if state.seq_nr < packet_header.ack_nr {
484 log::warn!("Incoming ack_nr was invalid, packet acked has never been sent");
487 return Ok(());
488 }
489
490 let their_delay = if packet_header.timestamp_microseconds == 0 {
493 0
496 } else {
497 let time = get_microseconds();
498 time - packet_header.timestamp_microseconds as u64
499 };
500 state.reply_micro = their_delay as u32;
501 state.their_advertised_window = packet_header.wnd_size;
502
503 if packet.header.packet_type == PacketType::State {
504 0
508 } else {
509 debug_assert!(state.ack_nr != 0);
511 packet_header.seq_nr as i32 - state.ack_nr as i32 - 1
515 }
516 };
517
518 match dist_from_expected.cmp(&0) {
519 std::cmp::Ordering::Less => {
520 log::info!("Got packet already acked: {:?}", packet.header.packet_type);
521 Ok(())
522 }
523 std::cmp::Ordering::Equal => {
524 let mut data_available = packet.header.packet_type == PacketType::Data;
527 self.handle_inorder_packet(packet).await?;
528
529 let mut seq_nr = packet_header.seq_nr;
530 let get_next = |seq_nr: u16| self.state_mut().incoming_buffer.remove(seq_nr);
532 while let Some(packet) = get_next(seq_nr) {
533 data_available |= packet.header.packet_type == PacketType::Data;
534 self.handle_inorder_packet(packet).await?;
535 seq_nr += 1;
536 }
537 if data_available {
538 self.data_available.notify_waiters();
539 }
540 Ok(())
541 }
542 std::cmp::Ordering::Greater => {
543 log::debug!("Got out of order packet");
544 let mut state = self.state_mut();
545 if packet.data.len() <= state.our_advertised_window() as usize {
546 state.incoming_buffer.insert(packet);
547 } else {
548 log::warn!("Stream window not respected, packet dropped");
549 }
550 Ok(())
552 }
553 }
554 }
555
556 pub async fn read(&self, buffer: &mut [u8]) -> usize {
560 loop {
564 let data_available = { self.state().receive_buf_cursor };
565 if data_available == 0 {
568 self.data_available.notified().await;
569 } else {
570 break;
571 }
572 }
573
574 let mut state = self.state_mut();
575 if buffer.len() <= state.receive_buf_cursor {
576 let len = buffer.len();
577 buffer[..].copy_from_slice(&state.receive_buf[..len]);
578 state.receive_buf.copy_within(len.., 0);
579 state.receive_buf_cursor -= len;
580 buffer.len()
581 } else {
582 let data_read = state.receive_buf_cursor;
583 buffer[0..state.receive_buf_cursor]
584 .copy_from_slice(&state.receive_buf[..state.receive_buf_cursor]);
585 state.receive_buf_cursor = 0;
586 data_read
587 }
588 }
589
590 pub async fn write(&self, data: Vec<u8>) -> anyhow::Result<()> {
591 if (data.len() as i32 - HEADER_SIZE) > MTU as i32 {
592 log::warn!("Fragmentation is not supported yet");
593 Ok(())
594 } else {
595 let packet = {
596 let mut state = self.state_mut();
597 let header = state.data();
598 Packet {
599 header,
600 data: data.into(),
601 }
602 };
603 self.send_packet(packet, false).await
604 }
605 }
606
607 async fn handle_inorder_packet(&self, packet: Packet) -> anyhow::Result<()> {
608 let conn_state = std::mem::replace(
609 &mut self.state_mut().connection_state,
610 ConnectionState::Idle,
611 );
612 match (packet.header.packet_type, conn_state) {
613 (PacketType::State, conn_state) => {
615 let mut state = self.state_mut();
616 state.ack_nr = packet.header.seq_nr;
617
618 if let ConnectionState::SynSent { connect_notifier } = conn_state {
619 state.connection_state = ConnectionState::Connected;
620 if connect_notifier.send(()).is_err() {
621 log::warn!("Connect notify receiver dropped");
622 }
623 log::debug!("SYN_ACK");
625 } else {
626 if let Some(pkt) = state.outgoing_buffer.remove(packet.header.ack_nr) {
627 state.cur_window -= pkt.data.len() as u32;
629 } else {
630 log::error!("Recevied ack for packet not inside the outgoing_buffer");
631 }
632 state.connection_state = conn_state;
634 }
635 }
636 (PacketType::Data, ConnectionState::Connected) => {
637 let was_consumed = self.state_mut().try_consume(&packet.data);
638 if was_consumed {
639 self.ack_packet(packet.header.seq_nr).await?;
640 }
641 self.state_mut().connection_state = ConnectionState::Connected;
643 }
644 (PacketType::Fin, conn_state) => {
645 let mut state = self.state_mut();
646 log::trace!("Received FIN: {}", self.addr);
647 state.eof_pkt = Some(packet.header.seq_nr);
648 log::info!("Connection closed: {}", self.addr);
649
650 state.connection_state = conn_state;
654 }
655 (PacketType::Syn, ConnectionState::SynReceived) => {
656 log::debug!("Acking received SYN");
660 self.ack_packet(packet.header.seq_nr).await?;
661 self.state_mut().connection_state = ConnectionState::SynReceived;
663 }
664 (PacketType::Data, ConnectionState::SynReceived) => {
665 let was_consumed = {
670 let mut state = self.state_mut();
671 if state.try_consume(&packet.data) {
672 log::info!("Incoming connection established!");
674 state.connection_state = ConnectionState::Connected;
675 true
676 } else {
677 false
678 }
679 };
680 if was_consumed {
681 self.ack_packet(packet.header.seq_nr).await?;
682 } else {
683 anyhow::bail!(
684 "Initial data packet doesn't fit receive buffer, stream is misconfigured"
685 );
686 }
687 }
688 (p_type, conn_state) => {
689 let mut state = self.state_mut();
690 log::error!("Unhandled packet type!: {:?}", p_type);
691 state.connection_state = conn_state;
693 }
694 }
695 Ok(())
696 }
697
698 pub(crate) fn state_mut(&self) -> RefMut<'_, StreamState> {
699 self.inner.borrow_mut()
700 }
701
702 pub(crate) fn state(&self) -> Ref<'_, StreamState> {
703 self.inner.borrow()
704 }
705}
706
707impl Drop for UtpStream {
708 fn drop(&mut self) {
709 if Rc::strong_count(&self.inner) == 2 {
711 self.state_mut()
714 .shutdown_signal
715 .take()
716 .unwrap()
717 .send(())
718 .unwrap();
719 }
720 }
721}