tcp_handler/protocols/
common.rs

1//! Common utilities for protocols.
2
3use std::io::Error;
4use bytes::{Buf, BufMut, BytesMut};
5use thiserror::Error;
6use tokio::io::{AsyncRead, AsyncWrite};
7use variable_len_reader::{AsyncVariableReader, AsyncVariableWriter};
8use variable_len_reader::helper::{AsyncReaderHelper, AsyncWriterHelper};
9use crate::config::get_max_packet_size;
10
11/// Error when send/recv packets.
12#[derive(Error, Debug)]
13pub enum PacketError {
14    /// The packet size is larger than the maximum allowed packet size.
15    /// This is due to you sending too much data at once,
16    /// resulting in triggering memory safety limit.
17    ///
18    /// You can reduce the size of data packet sent each time.
19    /// Or you can change the maximum packet size by call [tcp_handler::config::set_config].
20    #[error("Packet size {0} is larger than the maximum allowed packet size {1}.")]
21    TooLarge(usize, usize),
22
23    /// During io bytes.
24    #[error("During io bytes.")]
25    IO(#[from] Error),
26
27    /// During encrypting/decrypting bytes.
28    #[cfg(feature = "encryption")]
29    #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
30    #[error("During encrypting/decrypting bytes.")]
31    AES(#[from] aes_gcm::aead::Error),
32
33    /// Broken stream cipher. This is a fatal error.
34    ///
35    /// When another error returned during send/recv, the stream is broken because no [`Cipher`] received.
36    /// In order not to panic, marks this stream as broken and returns this error.
37    #[cfg(feature = "encryption")]
38    #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
39    #[error("Broken stream.")]
40    Broken(),
41}
42
43/// Error when init/start protocol.
44#[derive(Error, Debug)]
45pub enum StarterError {
46    /// [`MAGIC_BYTES`] isn't matched. Or the [`MAGIC_VERSION`] is no longer supported.
47    /// Please confirm that you are connected to the correct address.
48    #[error("Invalid stream. MAGIC is not matched.")]
49    InvalidStream(),
50
51    /// Incompatible tcp-handler protocol.
52    /// The param came from the other side.
53    /// Please check whether you use the same protocol between client and server.
54    #[error("Incompatible protocol. received protocol: {0:?}")]
55    InvalidProtocol(ProtocolVariant),
56
57    /// Invalid application identifier.
58    /// The param came from the other side.
59    /// Please confirm that you are connected to the correct application,
60    /// or that there are no spelling errors in the server and client identifiers.
61    #[error("Invalid identifier. received: {0}")]
62    InvalidIdentifier(String),
63
64    /// Invalid application version.
65    /// The param came from the other side.
66    /// This is usually caused by the low version of the client application.
67    #[error("Invalid version. received: {0}")]
68    InvalidVersion(String),
69
70    /// During io bytes.
71    #[error("During io bytes.")]
72    IO(#[from] Error),
73
74    /// During generating/encrypting/decrypting rsa key.
75    #[cfg(feature = "encryption")]
76    #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
77    #[error("During generating/encrypting/decrypting rsa key.")]
78    RSA(#[from] rsa::Error),
79}
80
81
82/// The MAGIC is generated in `j-shell` environment:
83/// ```java
84/// var r = new Random("tcp-handler".hashCode());
85/// r.nextInt(0, 255); r.nextInt(0, 255);
86/// r.nextInt(0, 255); r.nextInt(0, 255);
87/// ```
88static MAGIC_BYTES: [u8; 4] = [208, 8, 166, 104];
89
90/// The version of the tcp-handler protocol.
91///
92/// | crate version | protocol version |
93/// |---------------|------------------|
94/// | \>=0.6.0      | 1                |
95/// | <0.6.0        | 0                |
96static MAGIC_VERSION: u16 = 1;
97
98/// The variants of the protocol.
99#[derive(Debug, Copy, Clone, Eq, PartialEq)]
100#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
101pub enum ProtocolVariant {
102    /// See [crate::raw].
103    Raw,
104    /// See [crate::compress].
105    Compression,
106    /// See [crate::encrypt].
107    Encryption,
108    /// See [crate::compress_encrypt].
109    CompressEncryption,
110}
111
112impl From<[bool; 2]> for ProtocolVariant {
113    fn from(value: [bool; 2]) -> Self {
114        match value {
115            [false, false] => ProtocolVariant::Raw,
116            [false, true] => ProtocolVariant::Compression,
117            [true, false] => ProtocolVariant::Encryption,
118            [true, true] => ProtocolVariant::CompressEncryption,
119        }
120    }
121}
122
123impl From<ProtocolVariant> for [bool; 2] {
124    fn from(value: ProtocolVariant) -> Self {
125        match value {
126            ProtocolVariant::Raw => [false, false],
127            ProtocolVariant::Compression => [false, true],
128            ProtocolVariant::Encryption => [true, false],
129            ProtocolVariant::CompressEncryption => [true, true],
130        }
131    }
132}
133
134
135/// In client side.
136/// ```text
137///   ┌─ Magic bytes
138///   │     ┌─ Magic version
139///   │     │    ┌─ Protocol variant
140///   │     │    │    ┌─ Application identifier
141///   │     │    │    │       ┌─ Application version
142///   v     v    v    v       v
143/// ┌─────┬────┬────┬───────┬───────┐
144/// │ *** │ ** │ ** │ ***** │ ***** │
145/// └─────┴────┴────┴───────┴───────┘
146/// ```
147pub(crate) async fn write_head<W: AsyncWrite + Unpin>(stream: &mut W, protocol: ProtocolVariant, identifier: &str, version: &str) -> Result<(), StarterError> {
148    stream.write_more(&MAGIC_BYTES).await?;
149    stream.write_u16_raw_be(MAGIC_VERSION).await?;
150    stream.write_bools_2(protocol.into()).await?;
151    AsyncWriterHelper(stream).help_write_string(identifier).await?;
152    AsyncWriterHelper(stream).help_write_string(version).await?;
153    Ok(())
154}
155
156/// In server side.
157/// See [`write_head`].
158pub(crate) async fn read_head<R: AsyncRead + Unpin, P: FnOnce(&str) -> bool>(stream: &mut R, protocol: ProtocolVariant, identifier: &str, version: P) -> Result<(u16, String), StarterError> {
159    let mut magic = [0; 4];
160    stream.read_more(&mut magic).await?;
161    if magic != MAGIC_BYTES { return Err(StarterError::InvalidStream()); }
162    let protocol_version = stream.read_u16_raw_be().await?;
163    if protocol_version != MAGIC_VERSION { return Err(StarterError::InvalidStream()); }
164    let protocol_read = stream.read_bools_2().await?.into();
165    if protocol_read != protocol { return Err(StarterError::InvalidProtocol(protocol_read)); }
166    let identifier_read = AsyncReaderHelper(stream).help_read_string().await?;
167    if identifier_read != identifier { return Err(StarterError::InvalidIdentifier(identifier_read)); }
168    let version_read = AsyncReaderHelper(stream).help_read_string().await?;
169    if !version(&version_read) { return Err(StarterError::InvalidVersion(version_read)); }
170    Ok((protocol_version, version_read))
171}
172
173/// In server side.
174/// ```text
175///   ┌─ State bytes
176///   │   ┌─ Error information.
177///   v   v
178/// ┌───┬───────┐
179/// │ * │ ***** │
180/// └───┴───────┘
181/// ```
182pub(crate) async fn write_last<W: AsyncWrite + Unpin, E>(stream: &mut W, protocol: ProtocolVariant, identifier: &str, version: &str, last: Result<E, StarterError>) -> Result<E, StarterError> {
183    match last {
184        Err(e) => {
185            match &e {
186                StarterError::InvalidProtocol(_) => {
187                    stream.write_bools_2([false, false]).await?;
188                    stream.write_bools_2(protocol.into()).await?;
189                }
190                StarterError::InvalidIdentifier(_) => {
191                    stream.write_bools_2([false, true]).await?;
192                    AsyncWriterHelper(stream).help_write_string(identifier).await?;
193                }
194                StarterError::InvalidVersion(_) => {
195                    stream.write_bools_2([true, false]).await?;
196                    AsyncWriterHelper(stream).help_write_string(version).await?;
197                }
198                _ => {}
199            }
200            return Err(e);
201        },
202        Ok(k) => {
203            stream.write_bools_2([true, true]).await?;
204            Ok(k)
205        }
206    }
207}
208
209/// In client side.
210/// See [`write_last`].
211pub(crate) async fn read_last<R: AsyncRead + Unpin, E>(stream: &mut R, last: Result<E, StarterError>) -> Result<E, StarterError> {
212    let extra = last?;
213    match stream.read_bools_2().await? {
214        [true, true] => Ok(extra),
215        [false, false] => Err(StarterError::InvalidProtocol(stream.read_bools_2().await?.into())),
216        [false, true] => Err(StarterError::InvalidIdentifier(AsyncReaderHelper(stream).help_read_string().await?)),
217        [true, false] => Err(StarterError::InvalidVersion(AsyncReaderHelper(stream).help_read_string().await?)),
218    }
219}
220
221
222#[inline]
223fn check_bytes_len(len: usize) -> Result<(), PacketError> {
224    let config = get_max_packet_size();
225    if len > config { Err(PacketError::TooLarge(len, config)) } else { Ok(()) }
226}
227
228/// ```text
229///   ┌─ Packet length (in varint)
230///   │    ┌─ Packet message
231///   v    v
232/// ┌────┬────────┐
233/// │ ** │ ****** │
234/// └────┴────────┘
235/// ```
236pub(crate) async fn write_packet<W: AsyncWrite + Unpin, B: Buf>(stream: &mut W, bytes: &mut B) -> Result<(), PacketError> {
237    check_bytes_len(bytes.remaining())?;
238    stream.write_usize_varint_ap(bytes.remaining()).await?;
239    stream.write_more_buf(bytes).await?;
240    Ok(())
241}
242
243/// See [`write_packet`].
244pub(crate) async fn read_packet<R: AsyncRead + Unpin>(stream: &mut R) -> Result<BytesMut, PacketError> {
245    let len = stream.read_usize_varint_ap().await?;
246    check_bytes_len(len)?;
247    let mut buf = BytesMut::with_capacity(len).limit(len);
248    stream.read_more_buf(&mut buf).await?;
249    Ok(buf.into_inner())
250}
251
252/// Flush if the `auto_flush` feature is enabled.
253#[inline]
254pub(crate) async fn flush<W: AsyncWrite + Unpin>(stream: &mut W) -> Result<(), std::io::Error> {
255    #[cfg(feature = "auto_flush")] {
256        use tokio::io::AsyncWriteExt;
257        stream.flush().await
258    }
259    #[cfg(not(feature = "auto_flush"))] {
260        let _ = stream;
261        Ok(())
262    }
263}
264
265
266#[cfg(feature = "encryption")]
267pub(crate) fn generate_rsa_private() -> Result<(rsa::RsaPrivateKey, Vec<u8>, Vec<u8>), StarterError> {
268    use rsa::traits::PublicKeyParts;
269    let key = rsa::RsaPrivateKey::new(&mut rand::thread_rng(), 2048)?;
270    let n = key.n().to_bytes_le();
271    let e = key.e().to_bytes_le();
272    Ok((key, n, e))
273}
274
275#[cfg(feature = "encryption")]
276pub(crate) fn compose_rsa_public(n: Vec<u8>, e: Vec<u8>) -> Result<rsa::RsaPublicKey, StarterError> {
277    let n = rsa::BigUint::from_bytes_le(&n);
278    let e = rsa::BigUint::from_bytes_le(&e);
279    Ok(rsa::RsaPublicKey::new(n, e)?)
280}
281
282/// The cipher in encryption mode.
283/// You **must** update this value after each call to the send/recv function.
284#[cfg(feature = "encryption")]
285pub(crate) type InnerAesCipher = (aes_gcm::Aes256Gcm, aes_gcm::Nonce<aes_gcm::aead::consts::U12>);
286
287/// The cipher in encryption mode.
288#[cfg(feature = "encryption")]
289#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
290pub struct Cipher {
291    cipher: std::sync::Mutex<Option<InnerAesCipher>>,
292}
293
294#[cfg(feature = "encryption")]
295impl std::fmt::Debug for Cipher {
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        f.debug_struct("Cipher")
298            .field("cipher", &self.cipher.try_lock()
299                .map_or_else(|_| "<locked>",
300                             |inner| if (*inner).is_some() { "<unlocked>" } else { "<broken>" }))
301            .finish()
302    }
303}
304
305#[cfg(feature = "encryption")]
306impl Cipher {
307    #[inline]
308    pub(crate) fn new(cipher: InnerAesCipher) -> Self {
309        Self {
310            cipher: std::sync::Mutex::new(Some(cipher))
311        }
312    }
313
314    #[inline]
315    pub(crate) fn get(&self) -> Result<(InnerAesCipher, std::sync::MutexGuard<Option<InnerAesCipher>>), PacketError> {
316        let mut guard = self.cipher.lock().unwrap();
317        let cipher = (*guard).take().ok_or(PacketError::Broken())?;
318        Ok((cipher, guard))
319    }
320
321    #[inline]
322    pub(crate) fn reset(mut guard: std::sync::MutexGuard<Option<InnerAesCipher>>, cipher: InnerAesCipher) {
323        (*guard).replace(cipher);
324    }
325}
326
327
328#[cfg(test)]
329pub(super) mod tests {
330    use anyhow::Result;
331    use bytes::{Buf, Bytes};
332    use tokio::io::{AsyncRead, AsyncWrite, duplex};
333    use crate::protocols::common::{read_packet, write_packet};
334
335    pub(crate) async fn create() -> Result<(impl AsyncRead + AsyncWrite + Unpin, impl AsyncRead + AsyncWrite + Unpin)> {
336        let (client, server) = duplex(1024);
337        Ok((client, server))
338    }
339
340    #[tokio::test]
341    async fn packet() -> Result<()> {
342        let (mut client, mut server) = create().await?;
343
344        let source = &[1, 2, 3, 4, 5];
345        write_packet(&mut client, &mut Bytes::from_static(source)).await?;
346        let res = read_packet(&mut server).await?;
347        assert_eq!(source, res.chunk());
348
349        Ok(())
350    }
351}