1use tokio::{
16 io::{AsyncRead, AsyncWriteExt},
17 net::TcpStream,
18};
19use tokio_stream::StreamExt;
20use tokio_util::{
21 bytes::BytesMut,
22 codec::{FramedRead, LengthDelimitedCodec},
23};
24
25pub const PROTOCOL_VERSION: u32 = 2;
26
27pub type FramedReader<T> = FramedRead<T, LengthDelimitedCodec>;
28
29#[derive(Clone, bincode::Encode, bincode::Decode)]
30pub enum ClientServerPacket {
31 Ping,
34 ProtocolVersion(u32),
36 PubKey(Vec<u8>),
38 ClientId(u64),
40 Challenge(Vec<u8>),
42 ChallengeResponse(Vec<u8>),
44}
45
46impl ClientServerPacket {
47 pub fn into_vec(self) -> Result<Vec<u8>, bincode::error::EncodeError> {
48 bincode::encode_to_vec(self, bincode::config::standard())
49 }
50
51 pub fn from_slice(data: &[u8]) -> Result<Self, bincode::error::DecodeError> {
52 bincode::decode_from_slice(&data, bincode::config::standard()).map(|(packet, _)| packet)
53 }
54}
55
56#[derive(Clone, Debug)]
57pub enum TaggedPacket {
58 Data { client_id: u64, data: Vec<u8> },
59 Failure { client_id: u64, error: String },
60 Kick { client_id: u64 },
61 Reconnection { client_id: u64 },
62}
63
64impl TaggedPacket {
65 pub fn client_id(&self) -> u64 {
66 match self {
67 TaggedPacket::Data { client_id, .. } => *client_id,
68 TaggedPacket::Failure { client_id, .. } => *client_id,
69 TaggedPacket::Kick { client_id } => *client_id,
70 TaggedPacket::Reconnection { client_id } => *client_id,
71 }
72 }
73
74 pub fn into_vec(self) -> Vec<u8> {
75 let mut buf = Vec::new();
76 match self {
77 TaggedPacket::Data { data, client_id } => {
78 buf.extend_from_slice(&client_id.to_le_bytes());
79 buf.push(0x00);
80 buf.extend_from_slice(&data);
81 }
82 TaggedPacket::Failure { error, client_id } => {
83 buf.extend_from_slice(&client_id.to_le_bytes());
84 buf.push(0x01);
85 buf.extend_from_slice(error.as_bytes());
86 }
87 TaggedPacket::Kick { client_id } => {
88 buf.extend_from_slice(&client_id.to_le_bytes());
89 buf.push(0x02);
90 }
91 TaggedPacket::Reconnection { client_id } => {
92 buf.extend_from_slice(&client_id.to_le_bytes());
93 buf.push(0x03);
94 }
95 }
96 buf
97 }
98}
99
100pub fn configure_performance_tcp_socket(stream: &mut TcpStream) -> std::io::Result<()> {
102 stream.set_nodelay(true)?;
103 stream.set_linger(Some(std::time::Duration::from_secs(5)))?;
104 Ok(())
105}
106
107pub fn new_framed_reader<T: AsyncRead + Unpin>(stream: T) -> FramedReader<T> {
108 LengthDelimitedCodec::builder()
109 .length_field_type::<u32>()
111 .little_endian()
112 .new_read(stream)
113}
114
115pub async fn recv_size_prefixed<T: AsyncRead + Unpin>(
119 read: &mut FramedReader<T>,
120) -> anyhow::Result<BytesMut> {
121 Ok(read
122 .next()
123 .await
124 .ok_or_else(|| anyhow::format_err!("Connection closed or Eof"))??)
125}
126
127pub async fn send_size_prefixed<T: AsyncWriteExt + Unpin>(
128 stream: &mut T,
129 message: &[u8],
130) -> anyhow::Result<()> {
131 let size = message.len() as u32;
132 let size_bytes = size.to_le_bytes();
133 let mut combined_message = Vec::with_capacity(4 + message.len());
134 combined_message.extend_from_slice(&size_bytes);
135 combined_message.extend_from_slice(message);
136 stream.write_all(&combined_message).await?;
137 Ok(())
138}
139
140pub async fn recv_tagged_packet<T: AsyncRead + Unpin>(
147 read: &mut FramedReader<T>,
148) -> anyhow::Result<TaggedPacket> {
149 let buffer = recv_size_prefixed(read).await?;
150 if buffer.len() < 8 {
151 return Err(anyhow::format_err!("Packet too small"));
152 }
153 let client_id = u64::from_le_bytes(buffer[0..8].try_into().unwrap());
154 let buf: &[u8] = buffer[8..].into();
155
156 match buf[0] {
157 0x00 => {
158 Ok(TaggedPacket::Data {
160 client_id,
161 data: buf[1..].into(),
162 })
163 }
164 0x01 => {
165 let error = String::from_utf8_lossy(&buf[1..]).to_string();
167 Ok(TaggedPacket::Failure { client_id, error })
168 }
169 0x02 => {
170 Ok(TaggedPacket::Kick { client_id })
172 }
173 0x03 => {
174 Ok(TaggedPacket::Reconnection { client_id })
176 }
177 _ => {
178 return Err(anyhow::format_err!("Unknown packet type"));
179 }
180 }
181}
182
183pub async fn send_tagged_packet<T: AsyncWriteExt + Unpin>(
189 stream: &mut T,
190 packet: TaggedPacket,
191) -> anyhow::Result<()> {
192 let data = packet.into_vec();
193 let size = data.len() as u32;
194 let size_bytes = size.to_le_bytes();
195
196 let mut combined_message = Vec::with_capacity(size as usize + 4);
197 combined_message.extend_from_slice(&size_bytes);
198 combined_message.extend_from_slice(&data);
199
200 stream.write_all(&combined_message).await?;
201 Ok(())
202}