1use std::io::Error;
4use bytes::{Buf, BufMut, BytesMut};
5use thiserror::Error;
6use tokio::io::{AsyncRead, AsyncWrite};
7use variable_len_reader::{AsyncVariableReader, AsyncVariableWriter};
8use variable_len_reader::helper::{AsyncReaderHelper, AsyncWriterHelper};
9use crate::config::get_max_packet_size;
10
11#[derive(Error, Debug)]
13pub enum PacketError {
14 #[error("Packet size {0} is larger than the maximum allowed packet size {1}.")]
21 TooLarge(usize, usize),
22
23 #[error("During io bytes.")]
25 IO(#[from] Error),
26
27 #[cfg(feature = "encryption")]
29 #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
30 #[error("During encrypting/decrypting bytes.")]
31 AES(#[from] aes_gcm::aead::Error),
32
33 #[cfg(feature = "encryption")]
38 #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
39 #[error("Broken stream.")]
40 Broken(),
41}
42
43#[derive(Error, Debug)]
45pub enum StarterError {
46 #[error("Invalid stream. MAGIC is not matched.")]
49 InvalidStream(),
50
51 #[error("Incompatible protocol. received protocol: {0:?}")]
55 InvalidProtocol(ProtocolVariant),
56
57 #[error("Invalid identifier. received: {0}")]
62 InvalidIdentifier(String),
63
64 #[error("Invalid version. received: {0}")]
68 InvalidVersion(String),
69
70 #[error("During io bytes.")]
72 IO(#[from] Error),
73
74 #[cfg(feature = "encryption")]
76 #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
77 #[error("During generating/encrypting/decrypting rsa key.")]
78 RSA(#[from] rsa::Error),
79}
80
81
82static MAGIC_BYTES: [u8; 4] = [208, 8, 166, 104];
89
90static MAGIC_VERSION: u16 = 1;
97
98#[derive(Debug, Copy, Clone, Eq, PartialEq)]
100#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
101pub enum ProtocolVariant {
102 Raw,
104 Compression,
106 Encryption,
108 CompressEncryption,
110}
111
112impl From<[bool; 2]> for ProtocolVariant {
113 fn from(value: [bool; 2]) -> Self {
114 match value {
115 [false, false] => ProtocolVariant::Raw,
116 [false, true] => ProtocolVariant::Compression,
117 [true, false] => ProtocolVariant::Encryption,
118 [true, true] => ProtocolVariant::CompressEncryption,
119 }
120 }
121}
122
123impl From<ProtocolVariant> for [bool; 2] {
124 fn from(value: ProtocolVariant) -> Self {
125 match value {
126 ProtocolVariant::Raw => [false, false],
127 ProtocolVariant::Compression => [false, true],
128 ProtocolVariant::Encryption => [true, false],
129 ProtocolVariant::CompressEncryption => [true, true],
130 }
131 }
132}
133
134
135pub(crate) async fn write_head<W: AsyncWrite + Unpin>(stream: &mut W, protocol: ProtocolVariant, identifier: &str, version: &str) -> Result<(), StarterError> {
148 stream.write_more(&MAGIC_BYTES).await?;
149 stream.write_u16_raw_be(MAGIC_VERSION).await?;
150 stream.write_bools_2(protocol.into()).await?;
151 AsyncWriterHelper(stream).help_write_string(identifier).await?;
152 AsyncWriterHelper(stream).help_write_string(version).await?;
153 Ok(())
154}
155
156pub(crate) async fn read_head<R: AsyncRead + Unpin, P: FnOnce(&str) -> bool>(stream: &mut R, protocol: ProtocolVariant, identifier: &str, version: P) -> Result<(u16, String), StarterError> {
159 let mut magic = [0; 4];
160 stream.read_more(&mut magic).await?;
161 if magic != MAGIC_BYTES { return Err(StarterError::InvalidStream()); }
162 let protocol_version = stream.read_u16_raw_be().await?;
163 if protocol_version != MAGIC_VERSION { return Err(StarterError::InvalidStream()); }
164 let protocol_read = stream.read_bools_2().await?.into();
165 if protocol_read != protocol { return Err(StarterError::InvalidProtocol(protocol_read)); }
166 let identifier_read = AsyncReaderHelper(stream).help_read_string().await?;
167 if identifier_read != identifier { return Err(StarterError::InvalidIdentifier(identifier_read)); }
168 let version_read = AsyncReaderHelper(stream).help_read_string().await?;
169 if !version(&version_read) { return Err(StarterError::InvalidVersion(version_read)); }
170 Ok((protocol_version, version_read))
171}
172
173pub(crate) async fn write_last<W: AsyncWrite + Unpin, E>(stream: &mut W, protocol: ProtocolVariant, identifier: &str, version: &str, last: Result<E, StarterError>) -> Result<E, StarterError> {
183 match last {
184 Err(e) => {
185 match &e {
186 StarterError::InvalidProtocol(_) => {
187 stream.write_bools_2([false, false]).await?;
188 stream.write_bools_2(protocol.into()).await?;
189 }
190 StarterError::InvalidIdentifier(_) => {
191 stream.write_bools_2([false, true]).await?;
192 AsyncWriterHelper(stream).help_write_string(identifier).await?;
193 }
194 StarterError::InvalidVersion(_) => {
195 stream.write_bools_2([true, false]).await?;
196 AsyncWriterHelper(stream).help_write_string(version).await?;
197 }
198 _ => {}
199 }
200 return Err(e);
201 },
202 Ok(k) => {
203 stream.write_bools_2([true, true]).await?;
204 Ok(k)
205 }
206 }
207}
208
209pub(crate) async fn read_last<R: AsyncRead + Unpin, E>(stream: &mut R, last: Result<E, StarterError>) -> Result<E, StarterError> {
212 let extra = last?;
213 match stream.read_bools_2().await? {
214 [true, true] => Ok(extra),
215 [false, false] => Err(StarterError::InvalidProtocol(stream.read_bools_2().await?.into())),
216 [false, true] => Err(StarterError::InvalidIdentifier(AsyncReaderHelper(stream).help_read_string().await?)),
217 [true, false] => Err(StarterError::InvalidVersion(AsyncReaderHelper(stream).help_read_string().await?)),
218 }
219}
220
221
222#[inline]
223fn check_bytes_len(len: usize) -> Result<(), PacketError> {
224 let config = get_max_packet_size();
225 if len > config { Err(PacketError::TooLarge(len, config)) } else { Ok(()) }
226}
227
228pub(crate) async fn write_packet<W: AsyncWrite + Unpin, B: Buf>(stream: &mut W, bytes: &mut B) -> Result<(), PacketError> {
237 check_bytes_len(bytes.remaining())?;
238 stream.write_usize_varint_ap(bytes.remaining()).await?;
239 stream.write_more_buf(bytes).await?;
240 Ok(())
241}
242
243pub(crate) async fn read_packet<R: AsyncRead + Unpin>(stream: &mut R) -> Result<BytesMut, PacketError> {
245 let len = stream.read_usize_varint_ap().await?;
246 check_bytes_len(len)?;
247 let mut buf = BytesMut::with_capacity(len).limit(len);
248 stream.read_more_buf(&mut buf).await?;
249 Ok(buf.into_inner())
250}
251
252#[inline]
254pub(crate) async fn flush<W: AsyncWrite + Unpin>(stream: &mut W) -> Result<(), std::io::Error> {
255 #[cfg(feature = "auto_flush")] {
256 use tokio::io::AsyncWriteExt;
257 stream.flush().await
258 }
259 #[cfg(not(feature = "auto_flush"))] {
260 let _ = stream;
261 Ok(())
262 }
263}
264
265
266#[cfg(feature = "encryption")]
267pub(crate) fn generate_rsa_private() -> Result<(rsa::RsaPrivateKey, Vec<u8>, Vec<u8>), StarterError> {
268 use rsa::traits::PublicKeyParts;
269 let key = rsa::RsaPrivateKey::new(&mut rand::thread_rng(), 2048)?;
270 let n = key.n().to_bytes_le();
271 let e = key.e().to_bytes_le();
272 Ok((key, n, e))
273}
274
275#[cfg(feature = "encryption")]
276pub(crate) fn compose_rsa_public(n: Vec<u8>, e: Vec<u8>) -> Result<rsa::RsaPublicKey, StarterError> {
277 let n = rsa::BigUint::from_bytes_le(&n);
278 let e = rsa::BigUint::from_bytes_le(&e);
279 Ok(rsa::RsaPublicKey::new(n, e)?)
280}
281
282#[cfg(feature = "encryption")]
285pub(crate) type InnerAesCipher = (aes_gcm::Aes256Gcm, aes_gcm::Nonce<aes_gcm::aead::consts::U12>);
286
287#[cfg(feature = "encryption")]
289#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
290pub struct Cipher {
291 cipher: std::sync::Mutex<Option<InnerAesCipher>>,
292}
293
294#[cfg(feature = "encryption")]
295impl std::fmt::Debug for Cipher {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 f.debug_struct("Cipher")
298 .field("cipher", &self.cipher.try_lock()
299 .map_or_else(|_| "<locked>",
300 |inner| if (*inner).is_some() { "<unlocked>" } else { "<broken>" }))
301 .finish()
302 }
303}
304
305#[cfg(feature = "encryption")]
306impl Cipher {
307 #[inline]
308 pub(crate) fn new(cipher: InnerAesCipher) -> Self {
309 Self {
310 cipher: std::sync::Mutex::new(Some(cipher))
311 }
312 }
313
314 #[inline]
315 pub(crate) fn get(&self) -> Result<(InnerAesCipher, std::sync::MutexGuard<Option<InnerAesCipher>>), PacketError> {
316 let mut guard = self.cipher.lock().unwrap();
317 let cipher = (*guard).take().ok_or(PacketError::Broken())?;
318 Ok((cipher, guard))
319 }
320
321 #[inline]
322 pub(crate) fn reset(mut guard: std::sync::MutexGuard<Option<InnerAesCipher>>, cipher: InnerAesCipher) {
323 (*guard).replace(cipher);
324 }
325}
326
327
328#[cfg(test)]
329pub(super) mod tests {
330 use anyhow::Result;
331 use bytes::{Buf, Bytes};
332 use tokio::io::{AsyncRead, AsyncWrite, duplex};
333 use crate::protocols::common::{read_packet, write_packet};
334
335 pub(crate) async fn create() -> Result<(impl AsyncRead + AsyncWrite + Unpin, impl AsyncRead + AsyncWrite + Unpin)> {
336 let (client, server) = duplex(1024);
337 Ok((client, server))
338 }
339
340 #[tokio::test]
341 async fn packet() -> Result<()> {
342 let (mut client, mut server) = create().await?;
343
344 let source = &[1, 2, 3, 4, 5];
345 write_packet(&mut client, &mut Bytes::from_static(source)).await?;
346 let res = read_packet(&mut server).await?;
347 assert_eq!(source, res.chunk());
348
349 Ok(())
350 }
351}