tcp_handler/protocols/
compress_encrypt.rs

1//! Compression and encryption protocol.
2//!
3//! Recommended to use this protocol.
4//!
5//! # Example
6//! ```rust
7//! use anyhow::Result;
8//! use bytes::{Buf, BufMut, BytesMut};
9//! use tcp_handler::protocols::compress_encrypt::*;
10//! use tokio::net::{TcpListener, TcpStream};
11//! use variable_len_reader::{VariableReader, VariableWriter};
12//!
13//! #[tokio::main]
14//! async fn main() -> Result<()> {
15//!     let server = TcpListener::bind("localhost:0").await?;
16//!     let mut client = TcpStream::connect(server.local_addr()?).await?;
17//!     let (mut server, _) = server.accept().await?;
18//!
19//!     let c_init = client_init(&mut client, "test", "0").await;
20//!     let s_init = server_init(&mut server, "test", |v| v == "0").await;
21//!     let (s_cipher, protocol_version, client_version) = server_start(&mut server, "test", "0", s_init).await?;
22//!     let c_cipher = client_start(&mut client, c_init).await?;
23//!     # let _ = protocol_version;
24//!     # let _ = client_version;
25//!
26//!     let mut writer = BytesMut::new().writer();
27//!     writer.write_string("hello server.")?;
28//!     let mut bytes = writer.into_inner();
29//!     send(&mut client, &mut bytes, &c_cipher).await?;
30//!
31//!     let mut reader = recv(&mut server, &s_cipher).await?.reader();
32//!     let message = reader.read_string()?;
33//!     assert_eq!("hello server.", message);
34//!
35//!     let mut writer = BytesMut::new().writer();
36//!     writer.write_string("hello client.")?;
37//!     let mut bytes = writer.into_inner();
38//!     send(&mut server, &mut bytes, &s_cipher).await?;
39//!
40//!     let mut reader = recv(&mut client, &c_cipher).await?.reader();
41//!     let message = reader.read_string()?;
42//!     assert_eq!("hello client.", message);
43//!
44//!     Ok(())
45//! }
46//! ```
47//!
48//! The send protocol:
49//! ```text
50//!         ┌────┬────────┬────────────┐ (It may not be in contiguous memory.)
51//! in  --> │ ** │ ****** │ ********** │
52//!         └────┴────────┴────────────┘
53//!           └─────┐
54//!          +Nonce │
55//!           │     │─ Chain
56//!           v     v
57//!         ┌─────┬────┬────────┬────────────┐ (Zero copy. Not in contiguous memory.)
58//!         │ *** │ ** │ ****** │ ********** │
59//!         └─────┴────┴────────┴────────────┘
60//!           │
61//!           │─ DeflateEncoder
62//!           v
63//!         ┌─────────────────────┐ (Compressed bytes. In contiguous memory.)
64//!         │ ******************* │
65//!         └─────────────────────┘
66//!           │
67//!           │─ Encrypt in-place
68//!           v
69//!         ┌─────────────────────┐ (Compressed and encrypted bytes.)
70//! out <-- │ ******************* │
71//!         └─────────────────────┘
72//! ```
73//! The recv process:
74//! ```text
75//!         ┌─────────────────────┐ (Packet data.)
76//! in  --> │ ******************* │
77//!         └─────────────────────┘
78//!           │
79//!           │─ Decrypt in-place
80//!           v
81//!         ┌─────────────────────┐ (Decrypted bytes.)
82//!         │ ******************* │
83//!         └─────────────────────┘
84//!           │
85//!           │─ DeflateEncoder
86//!           v
87//!         ┌─────┬──────────────────┐ (Decrypted and decompressed bytes.)
88//!         │ *** │ **************** │
89//!         └─────┴──────────────────┘
90//!           │     │
91//!          -Nonce │
92//! out <--  ───────┘
93//! ```
94
95use bytes::{Buf, BufMut, BytesMut};
96use flate2::write::{DeflateDecoder, DeflateEncoder};
97use tokio::io::{AsyncRead, AsyncWrite};
98use tokio::task::block_in_place;
99use variable_len_reader::{AsyncVariableReader, AsyncVariableWriter};
100use variable_len_reader::helper::{AsyncReaderHelper, AsyncWriterHelper};
101use crate::config::get_compression;
102use crate::protocols::common::*;
103
104/// Init the client side in tcp-handler compress_encrypt protocol.
105///
106/// Must be used in conjunction with [`client_start`].
107///
108/// # Runtime
109/// Due to call [`block_in_place`] internally,
110/// this function cannot be called in a `current_thread` runtime.
111///
112/// # Arguments
113///  * `stream` - The tcp stream or `WriteHalf`.
114///  * `identifier` - The identifier of your application.
115///  * `version` - Current version of your application.
116///
117/// # Example
118/// ```rust,no_run
119/// use anyhow::Result;
120/// use tcp_handler::protocols::compress_encrypt::{client_init, client_start};
121/// use tokio::net::TcpStream;
122///
123/// #[tokio::main]
124/// async fn main() -> Result<()> {
125///     let mut client = TcpStream::connect("localhost:25564").await?;
126///     let c_init = client_init(&mut client, "test", "0").await;
127///     let cipher = client_start(&mut client, c_init).await?;
128///     // Now the client is ready to use.
129///     # let _ = cipher;
130///     Ok(())
131/// }
132/// ```
133pub async fn client_init<W: AsyncWrite + Unpin>(stream: &mut W, identifier: &str, version: &str) -> Result<rsa::RsaPrivateKey, StarterError> {
134    let (key, n, e) = block_in_place(|| generate_rsa_private())?;
135    write_head(stream, ProtocolVariant::CompressEncryption, identifier, version).await?;
136    AsyncWriterHelper(stream).help_write_u8_vec(&n).await?;
137    AsyncWriterHelper(stream).help_write_u8_vec(&e).await?;
138    flush(stream).await?;
139    Ok(key)
140}
141
142/// Init the server side in tcp-handler compress_encrypt protocol.
143///
144/// Must be used in conjunction with [`server_start`].
145///
146/// # Runtime
147/// Due to call [`block_in_place`] internally,
148/// this function cannot be called in a `current_thread` runtime.
149///
150/// # Arguments
151///  * `stream` - The tcp stream or `ReadHalf`.
152///  * `identifier` - The identifier of your application.
153///  * `version` - A prediction to determine whether the client version is allowed.
154///
155/// # Example
156/// ```rust,no_run
157/// use anyhow::Result;
158/// use tcp_handler::protocols::compress_encrypt::{server_init, server_start};
159/// use tokio::net::TcpListener;
160///
161/// #[tokio::main]
162/// async fn main() -> Result<()> {
163///     let server = TcpListener::bind("localhost:25564").await?;
164///     let (mut server, _) = server.accept().await?;
165///     let s_init = server_init(&mut server, "test", |v| v == "0").await;
166///     let (cipher, protocol_version, client_version) = server_start(&mut server, "test", "0", s_init).await?;
167///     // Now the server is ready to use.
168///     # let _ = cipher;
169///     # let _ = protocol_version;
170///     # let _ = client_version;
171///     Ok(())
172/// }
173/// ```
174pub async fn server_init<R: AsyncRead + Unpin, P: FnOnce(&str) -> bool>(stream: &mut R, identifier: &str, version: P) -> Result<((u16, String), rsa::RsaPublicKey), StarterError> {
175    let versions = read_head(stream, ProtocolVariant::CompressEncryption, identifier, version).await?;
176    let n = AsyncReaderHelper(stream).help_read_u8_vec().await?;
177    let e = AsyncReaderHelper(stream).help_read_u8_vec().await?;
178    let key = block_in_place(move || compose_rsa_public(n, e))?;
179    Ok((versions, key))
180}
181
182/// Make sure the client side is ready to use in tcp-handler compress_encrypt protocol.
183///
184/// Must be used in conjunction with [`client_init`].
185///
186/// # Runtime
187/// Due to call [`block_in_place`] internally,
188/// this function cannot be called in a `current_thread` runtime.
189///
190/// # Arguments
191///  * `stream` - The tcp stream or `ReadHalf`.
192///  * `last` - The return value of [`client_init`].
193///
194/// # Example
195/// ```rust,no_run
196/// use anyhow::Result;
197/// use tcp_handler::protocols::compress_encrypt::{client_init, client_start};
198/// use tokio::net::TcpStream;
199///
200/// #[tokio::main]
201/// async fn main() -> Result<()> {
202///     let mut client = TcpStream::connect("localhost:25564").await?;
203///     let c_init = client_init(&mut client, "test", "0").await;
204///     let cipher = client_start(&mut client, c_init).await?;
205///     // Now the client is ready to use.
206///     # let _ = cipher;
207///     Ok(())
208/// }
209/// ```
210pub async fn client_start<R: AsyncRead + Unpin>(stream: &mut R, last: Result<rsa::RsaPrivateKey, StarterError>) -> Result<Cipher, StarterError> {
211    let rsa = read_last(stream, last).await?;
212    let encrypted_aes = AsyncReaderHelper(stream).help_read_u8_vec().await?;
213    let mut nonce = [0; 12];
214    stream.read_more(&mut nonce).await?;
215    let cipher = block_in_place(move || {
216        use aes_gcm::aead::KeyInit;
217        let aes = rsa.decrypt(rsa::Oaep::new::<rsa::sha2::Sha512>(), &encrypted_aes)?;
218        let cipher = aes_gcm::Aes256Gcm::new_from_slice(&aes).unwrap();
219        Ok::<_, StarterError>((cipher, aes_gcm::Nonce::from(nonce)))
220    })?;
221    Ok(Cipher::new(cipher))
222}
223
224/// Make sure the server side is ready to use in tcp-handler compress_encrypt protocol.
225///
226/// Must be used in conjunction with [`server_init`].
227///
228/// # Runtime
229/// Due to call [`block_in_place`] internally,
230/// this function cannot be called in a `current_thread` runtime.
231///
232/// # Arguments
233///  * `stream` - The tcp stream or `WriteHalf`.
234///  * `identifier` - The returned application identifier.
235/// (Should be same with the para in [`server_init`].)
236///  * `version` - The returned recommended application version.
237/// (Should be passed the prediction in [`server_init`].)
238///  * `last` - The return value of [`server_init`].
239///
240/// # Example
241/// ```rust,no_run
242/// use anyhow::Result;
243/// use tcp_handler::protocols::compress_encrypt::{server_init, server_start};
244/// use tokio::net::TcpListener;
245///
246/// #[tokio::main]
247/// async fn main() -> Result<()> {
248///     let server = TcpListener::bind("localhost:25564").await?;
249///     let (mut server, _) = server.accept().await?;
250///     let s_init = server_init(&mut server, "test", |v| v == "0").await;
251///     let (cipher, protocol_version, client_version) = server_start(&mut server, "test", "0", s_init).await?;
252///     // Now the server is ready to use.
253///     # let _ = cipher;
254///     # let _ = protocol_version;
255///     # let _ = client_version;
256///     Ok(())
257/// }
258/// ```
259pub async fn server_start<W: AsyncWrite + Unpin>(stream: &mut W, identifier: &str, version: &str, last: Result<((u16, String), rsa::RsaPublicKey), StarterError>) -> Result<(Cipher, u16, String), StarterError> {
260    let ((va, vb), rsa) = write_last(stream, ProtocolVariant::CompressEncryption, identifier, version, last).await?;
261    let (cipher, nonce, encrypted_aes) = block_in_place(move || {
262        use aes_gcm::aead::{KeyInit, AeadCore};
263        let aes = aes_gcm::Aes256Gcm::generate_key(&mut rand::thread_rng());
264        let nonce = aes_gcm::Aes256Gcm::generate_nonce(&mut rand::thread_rng());
265        debug_assert_eq!(12, nonce.len());
266        let encrypted_aes = rsa.encrypt(&mut rand::thread_rng(), rsa::oaep::Oaep::new::<rsa::sha2::Sha512>(), &aes)?;
267        let cipher = aes_gcm::Aes256Gcm::new(&aes);
268        Ok::<_, StarterError>((cipher, nonce, encrypted_aes))
269    })?;
270    AsyncWriterHelper(stream).help_write_u8_vec(&encrypted_aes).await?;
271    stream.write_more(&nonce).await?;
272    flush(stream).await?;
273    Ok((Cipher::new((cipher, nonce)), va, vb))
274}
275
276/// Send the message in tcp-handler compress_encrypt protocol.
277///
278/// # Runtime
279/// Due to call [`block_in_place`] internally,
280/// this function cannot be called in a `current_thread` runtime.
281///
282/// # Arguments
283///  * `stream` - The tcp stream or `WriteHalf`.
284///  * `message` - The message to send.
285///  * `cipher` - The cipher returned from [`server_start`] or [`client_start`].
286///
287/// # Example
288/// ```rust,no_run
289/// # use anyhow::Result;
290/// # use bytes::{BufMut, BytesMut};
291/// # use tcp_handler::protocols::compress_encrypt::{client_init, client_start};
292/// use tcp_handler::protocols::compress_encrypt::send;
293/// # use tokio::net::TcpStream;
294/// # use variable_len_reader::VariableWriter;
295///
296/// # #[tokio::main]
297/// # async fn main() -> Result<()> {
298/// #     let mut client = TcpStream::connect("localhost:25564").await?;
299/// #     let c_init = client_init(&mut client, "test", "0").await;
300/// #     let cipher = client_start(&mut client, c_init).await?;
301/// let mut writer = BytesMut::new().writer();
302/// writer.write_string("hello server.")?;
303/// send(&mut client, &mut writer.into_inner(), &cipher).await?;
304/// #     Ok(())
305/// # }
306/// ```
307pub async fn send<W: AsyncWrite + Unpin, B: Buf>(stream: &mut W, message: &mut B, cipher: &Cipher) -> Result<(), PacketError> {
308    let level = get_compression();
309    let mut bytes = block_in_place(|| {
310        use aes_gcm::aead::{AeadCore, AeadMutInPlace};
311        use variable_len_reader::VariableWritable;
312        let new_nonce = aes_gcm::Aes256Gcm::generate_nonce(&mut rand::thread_rng());
313        debug_assert_eq!(12, new_nonce.len());
314        let mut encoder = DeflateEncoder::new(BytesMut::new().writer(), level);
315        encoder.write_more(&new_nonce)?;
316        encoder.write_more_buf(message)?;
317        let mut bytes = encoder.finish()?.into_inner();
318        let ((mut cipher, nonce), lock) = Cipher::get(cipher)?;
319        cipher.encrypt_in_place(&nonce, &[], &mut bytes)?;
320        Cipher::reset(lock, (cipher, new_nonce));
321        Ok::<_, PacketError>(bytes)
322    })?;
323    write_packet(stream, &mut bytes).await?;
324    flush(stream).await?;
325    Ok(())
326}
327
328/// Recv the message in tcp-handler compress_encrypt protocol.
329///
330/// # Runtime
331/// Due to call [`block_in_place`] internally,
332/// this function cannot be called in a `current_thread` runtime.
333///
334/// # Arguments
335///  * `stream` - The tcp stream or `ReadHalf`.
336///  * `cipher` - The cipher returned from [`server_start`] or [`client_start`].
337///
338/// # Example
339/// ```rust,no_run
340/// # use anyhow::Result;
341/// # use bytes::Buf;
342/// # use tcp_handler::protocols::compress_encrypt::{server_init, server_start};
343/// use tcp_handler::protocols::compress_encrypt::recv;
344/// # use tokio::net::TcpListener;
345/// # use variable_len_reader::VariableReader;
346///
347/// # #[tokio::main]
348/// # async fn main() -> Result<()> {
349/// #     let server = TcpListener::bind("localhost:25564").await?;
350/// #     let (mut server, _) = server.accept().await?;
351/// #     let s_init = server_init(&mut server, "test", |v| v == "0").await;
352/// #     let (cipher, _, _) = server_start(&mut server, "test", "0", s_init).await?;
353/// let mut reader = recv(&mut server, &cipher).await?.reader();
354/// let message = reader.read_string()?;
355/// #     let _ = message;
356/// #     Ok(())
357/// # }
358/// ```
359pub async fn recv<R: AsyncRead + Unpin>(stream: &mut R, cipher: &Cipher) -> Result<BytesMut, PacketError> {
360    let mut buffer = read_packet(stream).await?;
361    let message = block_in_place(move || {
362        use aes_gcm::aead::AeadMutInPlace;
363        use variable_len_reader::{VariableReadable, VariableWritable};
364        let ((mut cipher, nonce), lock) = Cipher::get(cipher)?;
365        cipher.decrypt_in_place(&nonce, &[], &mut buffer)?;
366        let mut decoder = DeflateDecoder::new(BytesMut::new().writer());
367        decoder.write_more_buf(&mut buffer)?;
368        let mut reader = decoder.finish()?.into_inner().reader();
369        let mut new_nonce = [0; 12];
370        reader.read_more(&mut new_nonce)?;
371        let new_nonce = aes_gcm::Nonce::from(new_nonce);
372        Cipher::reset(lock, (cipher, new_nonce));
373        Ok::<_, PacketError>(reader.into_inner())
374    })?;
375    Ok(message)
376}
377
378#[cfg(test)]
379mod tests {
380    use anyhow::Result;
381    use variable_len_reader::{VariableReader, VariableWriter};
382    use crate::protocols::common::tests::create;
383    use crate::protocols::compress_encrypt::*;
384
385    #[tokio::test(flavor = "multi_thread")]
386    async fn connect() -> Result<()> {
387        let (mut client, mut server) = create().await?;
388        let c = client_init(&mut client, "a", "1").await;
389        let s = server_init(&mut server, "a", |v| v == "1").await;
390        let (s_cipher, _, _) = server_start(&mut server, "a", "1", s).await?;
391        let c_cipher = client_start(&mut client, c).await?;
392for _ in 0..10 {
393        let mut writer = BytesMut::new().writer();
394        writer.write_string("hello server in encrypt.")?;
395        send(&mut client, &mut writer.into_inner(), &c_cipher).await?;
396
397        let mut reader = recv(&mut server, &s_cipher).await?.reader();
398        let message = reader.read_string()?;
399        assert_eq!("hello server in encrypt.", message);
400
401        let mut writer = BytesMut::new().writer();
402        writer.write_string("hello client in encrypt.")?;
403        send(&mut server, &mut writer.into_inner(), &s_cipher).await?;
404
405        let mut reader = recv(&mut client, &c_cipher).await?.reader();
406        let message = reader.read_string()?;
407        assert_eq!("hello client in encrypt.", message);
408}
409        Ok(())
410    }
411}