tcp_handler/protocols/
encrypt.rs

1//! Encryption protocol. Without compression.
2//!
3//! With encryption, you can keep the data safe from being intercepted by others.
4//!
5//! # Example
6//! ```rust
7//! use anyhow::Result;
8//! use bytes::{Buf, BufMut, BytesMut};
9//! use tcp_handler::protocols::encrypt::*;
10//! use tokio::net::{TcpListener, TcpStream};
11//! use variable_len_reader::{VariableReader, VariableWriter};
12//!
13//! #[tokio::main]
14//! async fn main() -> Result<()> {
15//!     let server = TcpListener::bind("localhost:0").await?;
16//!     let mut client = TcpStream::connect(server.local_addr()?).await?;
17//!     let (mut server, _) = server.accept().await?;
18//!
19//!     let c_init = client_init(&mut client, "test", "0").await;
20//!     let s_init = server_init(&mut server, "test", |v| v == "0").await;
21//!     let (s_cipher, protocol_version, client_version) = server_start(&mut server, "test", "0", s_init).await?;
22//!     let c_cipher = client_start(&mut client, c_init).await?;
23//!     # let _ = protocol_version;
24//!     # let _ = client_version;
25//!
26//!     let mut writer = BytesMut::new().writer();
27//!     writer.write_string("hello server.")?;
28//!     let mut bytes = writer.into_inner();
29//!     send(&mut client, &mut bytes, &c_cipher).await?;
30//!
31//!     let mut reader = recv(&mut server, &s_cipher).await?.reader();
32//!     let message = reader.read_string()?;
33//!     assert_eq!("hello server.", message);
34//!
35//!     let mut writer = BytesMut::new().writer();
36//!     writer.write_string("hello client.")?;
37//!     let mut bytes = writer.into_inner();
38//!     send(&mut server, &mut bytes, &s_cipher).await?;
39//!
40//!     let mut reader = recv(&mut client, &c_cipher).await?.reader();
41//!     let message = reader.read_string()?;
42//!     assert_eq!("hello client.", message);
43//!
44//!     Ok(())
45//! }
46//! ```
47//!
48//! The send process:
49//! ```text
50//!         ┌─────┬────────┬────────────┐ (It may not be in contiguous memory.)
51//! in  --> │ *** │ ****** │ ********** │
52//!         └─────┴────────┴────────────┘
53//!           └─────┐
54//!          +Nonce │
55//!           │     │─ Copy once.
56//!           v     v
57//!         ┌─────┬─────────────────────┐ (In contiguous memory.)
58//!         │ *** │ ******************* │
59//!         └─────┴─────────────────────┘
60//!           │
61//!           │─ Encrypt in-place
62//!           v
63//!         ┌────────────────────────┐ (Encrypted bytes.)
64//! out <-- │ ********************** │
65//!         └────────────────────────┘
66//! ```
67//! The recv process:
68//! ```text
69//!         ┌────────────────────────┐ (Packet data.)
70//! in  --> │ ********************** │
71//!         └────────────────────────┘
72//!           │
73//!           │─ Decrypt in-place
74//!           v
75//!         ┌─────┬─────────────────────┐ (Decrypted bytes.)
76//!         │ *** │ ******************* │
77//!         └─────┴─────────────────────┘
78//!           │     │
79//!          -Nonce │
80//! out <--  ───────┘
81//! ```
82
83use bytes::{Buf, BufMut, BytesMut};
84use tokio::io::{AsyncRead, AsyncWrite};
85use tokio::task::block_in_place;
86use variable_len_reader::{AsyncVariableReader, AsyncVariableWriter};
87use variable_len_reader::helper::{AsyncReaderHelper, AsyncWriterHelper};
88use crate::protocols::common::*;
89
90/// Init the client side in tcp-handler encrypt protocol.
91///
92/// Must be used in conjunction with [`client_start`].
93///
94/// # Runtime
95/// Due to call [`block_in_place`] internally,
96/// this function cannot be called in a `current_thread` runtime.
97///
98/// # Arguments
99///  * `stream` - The tcp stream or `WriteHalf`.
100///  * `identifier` - The identifier of your application.
101///  * `version` - Current version of your application.
102///
103/// # Example
104/// ```rust,no_run
105/// use anyhow::Result;
106/// use tcp_handler::protocols::encrypt::{client_init, client_start};
107/// use tokio::net::TcpStream;
108///
109/// #[tokio::main]
110/// async fn main() -> Result<()> {
111///     let mut client = TcpStream::connect("localhost:25564").await?;
112///     let c_init = client_init(&mut client, "test", "0").await;
113///     let cipher = client_start(&mut client, c_init).await?;
114///     // Now the client is ready to use.
115///     # let _ = cipher;
116///     Ok(())
117/// }
118/// ```
119pub async fn client_init<W: AsyncWrite + Unpin>(stream: &mut W, identifier: &str, version: &str) -> Result<rsa::RsaPrivateKey, StarterError> {
120    let (key, n, e) = block_in_place(|| generate_rsa_private())?;
121    write_head(stream, ProtocolVariant::Encryption, identifier, version).await?;
122    AsyncWriterHelper(stream).help_write_u8_vec(&n).await?;
123    AsyncWriterHelper(stream).help_write_u8_vec(&e).await?;
124    flush(stream).await?;
125    Ok(key)
126}
127
128/// Init the server side in tcp-handler encrypt protocol.
129///
130/// Must be used in conjunction with [`server_start`].
131///
132/// # Runtime
133/// Due to call [`block_in_place`] internally,
134/// this function cannot be called in a `current_thread` runtime.
135///
136/// # Arguments
137///  * `stream` - The tcp stream or `ReadHalf`.
138///  * `identifier` - The identifier of your application.
139///  * `version` - A prediction to determine whether the client version is allowed.
140///
141/// # Example
142/// ```rust,no_run
143/// use anyhow::Result;
144/// use tcp_handler::protocols::encrypt::{server_init, server_start};
145/// use tokio::net::TcpListener;
146///
147/// #[tokio::main]
148/// async fn main() -> Result<()> {
149///     let server = TcpListener::bind("localhost:25564").await?;
150///     let (mut server, _) = server.accept().await?;
151///     let s_init = server_init(&mut server, "test", |v| v == "0").await;
152///     let (cipher, protocol_version, client_version) = server_start(&mut server, "test", "0", s_init).await?;
153///     // Now the server is ready to use.
154///     # let _ = cipher;
155///     # let _ = protocol_version;
156///     # let _ = client_version;
157///     Ok(())
158/// }
159/// ```
160pub async fn server_init<R: AsyncRead + Unpin, P: FnOnce(&str) -> bool>(stream: &mut R, identifier: &str, version: P) -> Result<((u16, String), rsa::RsaPublicKey), StarterError> {
161    let versions = read_head(stream, ProtocolVariant::Encryption, identifier, version).await?;
162    let n = AsyncReaderHelper(stream).help_read_u8_vec().await?;
163    let e = AsyncReaderHelper(stream).help_read_u8_vec().await?;
164    let key = block_in_place(move || compose_rsa_public(n, e))?;
165    Ok((versions, key))
166}
167
168/// Make sure the client side is ready to use in tcp-handler encrypt protocol.
169///
170/// Must be used in conjunction with [`client_init`].
171///
172/// # Runtime
173/// Due to call [`block_in_place`] internally,
174/// this function cannot be called in a `current_thread` runtime.
175///
176/// # Arguments
177///  * `stream` - The tcp stream or `ReadHalf`.
178///  * `last` - The return value of [`client_init`].
179///
180/// # Example
181/// ```rust,no_run
182/// use anyhow::Result;
183/// use tcp_handler::protocols::encrypt::{client_init, client_start};
184/// use tokio::net::TcpStream;
185///
186/// #[tokio::main]
187/// async fn main() -> Result<()> {
188///     let mut client = TcpStream::connect("localhost:25564").await?;
189///     let c_init = client_init(&mut client, "test", "0").await;
190///     let cipher = client_start(&mut client, c_init).await?;
191///     // Now the client and cipher are ready to use.
192///     # let _ = cipher;
193///     Ok(())
194/// }
195/// ```
196pub async fn client_start<R: AsyncRead + Unpin>(stream: &mut R, last: Result<rsa::RsaPrivateKey, StarterError>) -> Result<Cipher, StarterError> {
197    let rsa = read_last(stream, last).await?;
198    let encrypted_aes = AsyncReaderHelper(stream).help_read_u8_vec().await?;
199    let mut nonce = [0; 12];
200    stream.read_more(&mut nonce).await?;
201    let cipher = block_in_place(move || {
202        use aes_gcm::aead::KeyInit;
203        let aes = rsa.decrypt(rsa::Oaep::new::<rsa::sha2::Sha512>(), &encrypted_aes)?;
204        let cipher = aes_gcm::Aes256Gcm::new_from_slice(&aes).unwrap();
205        Ok::<_, StarterError>((cipher, aes_gcm::Nonce::from(nonce)))
206    })?;
207    Ok(Cipher::new(cipher))
208}
209
210/// Make sure the server side is ready to use in tcp-handler encrypt protocol.
211///
212/// Must be used in conjunction with [`server_init`].
213///
214/// # Runtime
215/// Due to call [`block_in_place`] internally,
216/// this function cannot be called in a `current_thread` runtime.
217///
218/// # Arguments
219///  * `stream` - The tcp stream or `WriteHalf`.
220///  * `identifier` - The returned application identifier.
221/// (Should be same with the para in [`server_init`].)
222///  * `version` - The returned recommended application version.
223/// (Should be passed the prediction in [`server_init`].)
224///  * `last` - The return value of [`server_init`].
225///
226/// # Example
227/// ```rust,no_run
228/// use anyhow::Result;
229/// use tcp_handler::protocols::encrypt::{server_init, server_start};
230/// use tokio::net::TcpListener;
231///
232/// #[tokio::main]
233/// async fn main() -> Result<()> {
234///     let server = TcpListener::bind("localhost:25564").await?;
235///     let (mut server, _) = server.accept().await?;
236///     let s_init = server_init(&mut server, "test", |v| v == "0").await;
237///     let (cipher, protocol_version, client_version) = server_start(&mut server, "test", "0", s_init).await?;
238///     // Now the server is ready to use.
239///     # let _ = cipher;
240///     # let _ = protocol_version;
241///     # let _ = client_version;
242///     Ok(())
243/// }
244/// ```
245pub async fn server_start<W: AsyncWrite + Unpin>(stream: &mut W, identifier: &str, version: &str, last: Result<((u16, String), rsa::RsaPublicKey), StarterError>) -> Result<(Cipher, u16, String), StarterError> {
246    let ((va, vb), rsa) = write_last(stream, ProtocolVariant::Encryption, identifier, version, last).await?;
247    let (cipher, nonce, encrypted_aes) = block_in_place(move || {
248        use aes_gcm::aead::{KeyInit, AeadCore};
249        let aes = aes_gcm::Aes256Gcm::generate_key(&mut rand::thread_rng());
250        let nonce = aes_gcm::Aes256Gcm::generate_nonce(&mut rand::thread_rng());
251        debug_assert_eq!(12, nonce.len());
252        let encrypted_aes = rsa.encrypt(&mut rand::thread_rng(), rsa::oaep::Oaep::new::<rsa::sha2::Sha512>(), &aes)?;
253        let cipher = aes_gcm::Aes256Gcm::new(&aes);
254        Ok::<_, StarterError>((cipher, nonce, encrypted_aes))
255    })?;
256    AsyncWriterHelper(stream).help_write_u8_vec(&encrypted_aes).await?;
257    stream.write_more(&nonce).await?;
258    flush(stream).await?;
259    Ok((Cipher::new((cipher, nonce)), va, vb))
260}
261
262/// Send the message in tcp-handler encrypt protocol.
263///
264/// # Runtime
265/// Due to call [`block_in_place`] internally,
266/// this function cannot be called in a `current_thread` runtime.
267///
268/// # Arguments
269///  * `stream` - The tcp stream or `WriteHalf`.
270///  * `message` - The message to send.
271///  * `cipher` - The cipher returned from [`server_start`] or [`client_start`].
272///
273/// # Example
274/// ```rust,no_run
275/// # use anyhow::Result;
276/// # use bytes::{BufMut, BytesMut};
277/// # use tcp_handler::protocols::encrypt::{client_init, client_start};
278/// use tcp_handler::protocols::encrypt::send;
279/// # use tokio::net::TcpStream;
280/// # use variable_len_reader::VariableWriter;
281///
282/// # #[tokio::main]
283/// # async fn main() -> Result<()> {
284/// #     let mut client = TcpStream::connect("localhost:25564").await?;
285/// #     let c_init = client_init(&mut client, "test", "0").await;
286/// #     let cipher = client_start(&mut client, c_init).await?;
287/// let mut writer = BytesMut::new().writer();
288/// writer.write_string("hello server.")?;
289/// send(&mut client, &mut writer.into_inner(), &cipher).await?;
290/// #     Ok(())
291/// # }
292/// ```
293pub async fn send<W: AsyncWrite + Unpin, B: Buf>(stream: &mut W, message: &mut B, cipher: &Cipher) -> Result<(), PacketError> {
294    let mut bytes = block_in_place(|| {
295        use aes_gcm::aead::{AeadCore, AeadMutInPlace};
296        use aes_gcm::aes::cipher::Unsigned;
297        use variable_len_reader::VariableWritable;
298        let new_nonce = aes_gcm::Aes256Gcm::generate_nonce(&mut rand::thread_rng());
299        debug_assert_eq!(12, new_nonce.len());
300        let mut bytes = BytesMut::with_capacity(12 + message.remaining() + aes_gcm::aead::consts::U12::to_usize());
301        bytes.extend_from_slice(&new_nonce);
302        let mut writer = bytes.writer();
303        writer.write_more_buf(message)?;
304        let mut bytes = writer.into_inner();
305        let ((mut cipher, nonce), lock) = Cipher::get(cipher)?;
306        cipher.encrypt_in_place(&nonce, &[], &mut bytes)?;
307        Cipher::reset(lock, (cipher, new_nonce));
308        Ok::<_, PacketError>(bytes)
309    })?;
310    write_packet(stream, &mut bytes).await?;
311    flush(stream).await?;
312    Ok(())
313}
314
315/// Recv the message in tcp-handler encrypt protocol.
316///
317/// # Runtime
318/// Due to call [`block_in_place`] internally,
319/// this function cannot be called in a `current_thread` runtime.
320///
321/// # Arguments
322///  * `stream` - The tcp stream or `ReadHalf`.
323///  * `cipher` - The cipher returned from [`server_start`] or [`client_start`].
324///
325/// # Example
326/// ```rust,no_run
327/// # use anyhow::Result;
328/// # use bytes::Buf;
329/// # use tcp_handler::protocols::encrypt::{server_init, server_start};
330/// use tcp_handler::protocols::encrypt::recv;
331/// # use tokio::net::TcpListener;
332/// # use variable_len_reader::VariableReader;
333///
334/// # #[tokio::main]
335/// # async fn main() -> Result<()> {
336/// #     let server = TcpListener::bind("localhost:25564").await?;
337/// #     let (mut server, _) = server.accept().await?;
338/// #     let s_init = server_init(&mut server, "test", |v| v == "0").await;
339/// #     let (cipher, _, _) = server_start(&mut server, "test", "0", s_init).await?;
340/// let mut reader = recv(&mut server, &cipher).await?.reader();
341/// let message = reader.read_string()?;
342/// #     let _ = message;
343/// #     Ok(())
344/// # }
345/// ```
346pub async fn recv<R: AsyncRead + Unpin>(stream: &mut R, cipher: &Cipher) -> Result<BytesMut, PacketError> {
347    let mut bytes = read_packet(stream).await?;
348    let message = block_in_place(move || {
349        use aes_gcm::aead::AeadMutInPlace;
350        use variable_len_reader::VariableReadable;
351        let ((mut cipher, nonce), lock) = Cipher::get(cipher)?;
352        cipher.decrypt_in_place(&nonce, &[], &mut bytes)?;
353        let mut reader = bytes.reader();
354        let mut new_nonce = [0; 12];
355        reader.read_more(&mut new_nonce)?;
356        let new_nonce = aes_gcm::Nonce::from(new_nonce);
357        Cipher::reset(lock, (cipher, new_nonce));
358        Ok::<_, PacketError>(reader.into_inner())
359    })?;
360    Ok(message)
361}
362
363#[cfg(test)]
364mod tests {
365    use anyhow::Result;
366    use variable_len_reader::{VariableReader, VariableWriter};
367    use crate::protocols::common::tests::create;
368    use crate::protocols::encrypt::*;
369
370    #[tokio::test(flavor = "multi_thread")]
371    async fn connect() -> Result<()> {
372        let (mut client, mut server) = create().await?;
373        let c = client_init(&mut client, "a", "1").await;
374        let s = server_init(&mut server, "a", |v| v == "1").await;
375        let (s_cipher, _, _) = server_start(&mut server, "a", "1", s).await?;
376        let c_cipher = client_start(&mut client, c).await?;
377for _ in 0..10 {
378        let mut writer = BytesMut::new().writer();
379        writer.write_string("hello server in encrypt.")?;
380        send(&mut client, &mut writer.into_inner(), &c_cipher).await?;
381
382        let mut reader = recv(&mut server, &s_cipher).await?.reader();
383        let message = reader.read_string()?;
384        assert_eq!("hello server in encrypt.", message);
385
386        let mut writer = BytesMut::new().writer();
387        writer.write_string("hello client in encrypt.")?;
388        send(&mut server, &mut writer.into_inner(), &s_cipher).await?;
389
390        let mut reader = recv(&mut client, &c_cipher).await?.reader();
391        let message = reader.read_string()?;
392        assert_eq!("hello client in encrypt.", message);
393}
394        Ok(())
395    }
396}