use std::io::Error;
use bytes::{Buf, BufMut, BytesMut};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use variable_len_reader::{AsyncVariableReader, AsyncVariableWriter};
use variable_len_reader::helper::{AsyncReaderHelper, AsyncWriterHelper};
use crate::config::get_max_packet_size;
#[derive(Error, Debug)]
pub enum PacketError {
#[error("Packet size {0} is larger than the maximum allowed packet size {1}.")]
TooLarge(usize, usize),
#[error("During io bytes.")]
IO(#[from] Error),
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
#[error("During encrypting/decrypting bytes.")]
AES(#[from] aes_gcm::aead::Error),
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
#[error("Broken stream.")]
Broken(),
}
#[derive(Error, Debug)]
pub enum StarterError {
#[error("Invalid stream. MAGIC is not matched.")]
InvalidStream(),
#[error("Incompatible protocol. received protocol: {0:?}")]
InvalidProtocol(ProtocolVariant),
#[error("Invalid identifier. received: {0}")]
InvalidIdentifier(String),
#[error("Invalid version. received: {0}")]
InvalidVersion(String),
#[error("During io bytes.")]
IO(#[from] Error),
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
#[error("During generating/encrypting/decrypting rsa key.")]
RSA(#[from] rsa::Error),
}
static MAGIC_BYTES: [u8; 4] = [208, 8, 166, 104];
static MAGIC_VERSION: u16 = 1;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ProtocolVariant {
Raw,
Compression,
Encryption,
CompressEncryption,
}
impl From<[bool; 2]> for ProtocolVariant {
fn from(value: [bool; 2]) -> Self {
match value {
[false, false] => ProtocolVariant::Raw,
[false, true] => ProtocolVariant::Compression,
[true, false] => ProtocolVariant::Encryption,
[true, true] => ProtocolVariant::CompressEncryption,
}
}
}
impl From<ProtocolVariant> for [bool; 2] {
fn from(value: ProtocolVariant) -> Self {
match value {
ProtocolVariant::Raw => [false, false],
ProtocolVariant::Compression => [false, true],
ProtocolVariant::Encryption => [true, false],
ProtocolVariant::CompressEncryption => [true, true],
}
}
}
pub(crate) async fn write_head<W: AsyncWrite + Unpin>(stream: &mut W, protocol: ProtocolVariant, identifier: &str, version: &str) -> Result<(), StarterError> {
stream.write_more(&MAGIC_BYTES).await?;
stream.write_u16_raw_be(MAGIC_VERSION).await?;
stream.write_bools_2(protocol.into()).await?;
AsyncWriterHelper(stream).help_write_string(identifier).await?;
AsyncWriterHelper(stream).help_write_string(version).await?;
Ok(())
}
pub(crate) async fn read_head<R: AsyncRead + Unpin, P: FnOnce(&str) -> bool>(stream: &mut R, protocol: ProtocolVariant, identifier: &str, version: P) -> Result<(u16, String), StarterError> {
let mut magic = [0; 4];
stream.read_more(&mut magic).await?;
if magic != MAGIC_BYTES { return Err(StarterError::InvalidStream()); }
let protocol_version = stream.read_u16_raw_be().await?;
if protocol_version != MAGIC_VERSION { return Err(StarterError::InvalidStream()); }
let protocol_read = stream.read_bools_2().await?.into();
if protocol_read != protocol { return Err(StarterError::InvalidProtocol(protocol_read)); }
let identifier_read = AsyncReaderHelper(stream).help_read_string().await?;
if identifier_read != identifier { return Err(StarterError::InvalidIdentifier(identifier_read)); }
let version_read = AsyncReaderHelper(stream).help_read_string().await?;
if !version(&version_read) { return Err(StarterError::InvalidVersion(version_read)); }
Ok((protocol_version, version_read))
}
pub(crate) async fn write_last<W: AsyncWrite + Unpin, E>(stream: &mut W, protocol: ProtocolVariant, identifier: &str, version: &str, last: Result<E, StarterError>) -> Result<E, StarterError> {
match last {
Err(e) => {
match &e {
StarterError::InvalidProtocol(_) => {
stream.write_bools_2([false, false]).await?;
stream.write_bools_2(protocol.into()).await?;
}
StarterError::InvalidIdentifier(_) => {
stream.write_bools_2([false, true]).await?;
AsyncWriterHelper(stream).help_write_string(identifier).await?;
}
StarterError::InvalidVersion(_) => {
stream.write_bools_2([true, false]).await?;
AsyncWriterHelper(stream).help_write_string(version).await?;
}
_ => {}
}
return Err(e);
},
Ok(k) => {
stream.write_bools_2([true, true]).await?;
Ok(k)
}
}
}
pub(crate) async fn read_last<R: AsyncRead + Unpin, E>(stream: &mut R, last: Result<E, StarterError>) -> Result<E, StarterError> {
let extra = last?;
match stream.read_bools_2().await? {
[true, true] => Ok(extra),
[false, false] => Err(StarterError::InvalidProtocol(stream.read_bools_2().await?.into())),
[false, true] => Err(StarterError::InvalidIdentifier(AsyncReaderHelper(stream).help_read_string().await?)),
[true, false] => Err(StarterError::InvalidVersion(AsyncReaderHelper(stream).help_read_string().await?)),
}
}
#[inline]
fn check_bytes_len(len: usize) -> Result<(), PacketError> {
let config = get_max_packet_size();
if len > config { Err(PacketError::TooLarge(len, config)) } else { Ok(()) }
}
pub(crate) async fn write_packet<W: AsyncWrite + Unpin, B: Buf>(stream: &mut W, bytes: &mut B) -> Result<(), PacketError> {
check_bytes_len(bytes.remaining())?;
stream.write_usize_varint_ap(bytes.remaining()).await?;
stream.write_more_buf(bytes).await?;
Ok(())
}
pub(crate) async fn read_packet<R: AsyncRead + Unpin>(stream: &mut R) -> Result<BytesMut, PacketError> {
let len = stream.read_usize_varint_ap().await?;
check_bytes_len(len)?;
let mut buf = BytesMut::with_capacity(len).limit(len);
stream.read_more_buf(&mut buf).await?;
Ok(buf.into_inner())
}
#[cfg(feature = "encryption")]
pub(crate) fn generate_rsa_private() -> Result<(rsa::RsaPrivateKey, Vec<u8>, Vec<u8>), StarterError> {
use rsa::traits::PublicKeyParts;
let key = rsa::RsaPrivateKey::new(&mut rand::thread_rng(), 2048)?;
let n = key.n().to_bytes_le();
let e = key.e().to_bytes_le();
Ok((key, n, e))
}
#[cfg(feature = "encryption")]
pub(crate) fn compose_rsa_public(n: Vec<u8>, e: Vec<u8>) -> Result<rsa::RsaPublicKey, StarterError> {
let n = rsa::BigUint::from_bytes_le(&n);
let e = rsa::BigUint::from_bytes_le(&e);
Ok(rsa::RsaPublicKey::new(n, e)?)
}
#[cfg(feature = "encryption")]
pub(crate) type InnerAesCipher = (aes_gcm::Aes256Gcm, aes_gcm::Nonce<aes_gcm::aead::consts::U12>);
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub struct Cipher {
cipher: std::sync::Mutex<Option<InnerAesCipher>>,
}
#[cfg(feature = "encryption")]
impl std::fmt::Debug for Cipher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Cipher")
.field("cipher", &self.cipher.try_lock()
.map_or_else(|_| "<locked>",
|inner| if (*inner).is_some() { "<unlocked>" } else { "<broken>" }))
.finish()
}
}
#[cfg(feature = "encryption")]
impl Cipher {
#[inline]
pub(crate) fn new(cipher: InnerAesCipher) -> Self {
Self {
cipher: std::sync::Mutex::new(Some(cipher))
}
}
#[inline]
pub(crate) fn get(&self) -> Result<(InnerAesCipher, std::sync::MutexGuard<Option<InnerAesCipher>>), PacketError> {
let mut guard = self.cipher.lock().unwrap();
let cipher = (*guard).take().ok_or(PacketError::Broken())?;
Ok((cipher, guard))
}
#[inline]
pub(crate) fn reset(mut guard: std::sync::MutexGuard<Option<InnerAesCipher>>, cipher: InnerAesCipher) {
(*guard).replace(cipher);
}
}
#[cfg(test)]
pub(crate) mod tests {
use anyhow::Result;
use bytes::{Buf, Bytes};
use tokio::io::{AsyncRead, AsyncWrite, duplex};
use crate::common::{read_packet, write_packet};
pub(crate) async fn create() -> Result<(impl AsyncRead + AsyncWrite + Unpin, impl AsyncRead + AsyncWrite + Unpin)> {
let (client, server) = duplex(1024);
Ok((client, server))
}
#[tokio::test]
async fn packet() -> Result<()> {
let (mut client, mut server) = create().await?;
let source = &[1, 2, 3, 4, 5];
write_packet(&mut client, &mut Bytes::from_static(source)).await?;
let res = read_packet(&mut server).await?;
assert_eq!(source, res.chunk());
Ok(())
}
}