tcp_server/
network.rs

1//! Some network utility functions.
2
3use std::io::ErrorKind;
4use std::net::SocketAddr;
5use std::time::Duration;
6use log::{debug, error, info, trace};
7use tcp_handler::bytes::{Buf, BufMut, BytesMut};
8use tcp_handler::common::{AesCipher, PacketError, StarterError};
9use tcp_handler::compress_encrypt::{server_init, server_start};
10use tcp_handler::flate2::Compression;
11use tcp_handler::variable_len_reader::{VariableReader, VariableWriter};
12use thiserror::Error;
13use tokio::signal::ctrl_c;
14use tokio::time::timeout;
15use tokio::net::{TcpListener, TcpStream};
16use tokio::{select, spawn};
17use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
18use tokio_util::sync::CancellationToken;
19use tokio_util::task::TaskTracker;
20use crate::config::{get_addr, get_connect_sec, get_idle_sec};
21use crate::handler_base::IOStream;
22use crate::Server;
23
24/// Error in send/recv message.
25#[derive(Error, Debug)]
26pub enum NetworkError {
27    /// Sending/receiving timeout. See [`tcp_server::config::get_idle_sec`].
28    #[error("Network timeout: {} after {1} sec.", match .0 { 1 => "Sending", 2 => "Receiving", _ => "Connecting" })]
29    Timeout(u8, u64),
30
31    /// During init protocol. From [`tcp_handler`][crate::tcp_handler].
32    #[error("During io packet: {0:?}")]
33    StarterError(#[from] StarterError),
34
35    /// During io packet. From [`tcp_handler`][crate::tcp_handler].
36    #[error("During io packet: {0:?}")]
37    PacketError(#[from] PacketError),
38
39    /// During read/write data from [`bytes`][crate::bytes].
40    #[error("During read/write data: {0:?}")]
41    BufError(#[from] std::io::Error),
42
43    /// Broken cipher. This is a fatal error.
44    ///
45    /// When another error returned during send/recv, the stream is broken because no [`AesCipher`] received.
46    /// In order not to panic, the stream marks as broken and this error is returned.
47    #[error("Broken client.")]
48    BrokenCipher(),
49}
50
51#[inline]
52pub(crate) async fn send<W: AsyncWriteExt + Unpin + Send, B: Buf + Send>(stream: &mut W, message: &mut B, cipher: AesCipher, level: Compression) -> Result<AesCipher, NetworkError> {
53    let idle = get_idle_sec();
54    timeout(Duration::from_secs(idle), tcp_handler::compress_encrypt::send(stream, message, cipher, level)).await
55        .map_err(|_| NetworkError::Timeout(1, idle))?.map_err(|e| e.into())
56}
57
58#[inline]
59pub(crate) async fn recv<R: AsyncReadExt + Unpin + Send>(stream: &mut R, cipher: AesCipher) -> Result<(BytesMut, AesCipher), NetworkError> {
60    let idle = get_idle_sec();
61    timeout(Duration::from_secs(idle), tcp_handler::compress_encrypt::recv(stream, cipher)).await
62        .map_err(|_| NetworkError::Timeout(2, idle))?.map_err(|e| e.into())
63}
64
65pub(super) async fn start_server<S: Server + Sync + ?Sized>(s: &'static S) -> std::io::Result<()> {
66    let cancel_token = CancellationToken::new();
67    let canceller = cancel_token.clone();
68    spawn(async move {
69        if let Err(e) = ctrl_c().await {
70            error!("Failed to listen for shutdown signal: {}", e);
71        } else {
72            canceller.cancel();
73        }
74    });
75    let server = TcpListener::bind(get_addr()).await?;
76    info!("Listening on {}.", server.local_addr()?);
77    let tasks = TaskTracker::new();
78    select! {
79        _ = cancel_token.cancelled() => {
80            info!("Shutting down the server gracefully...");
81        }
82        _ = async { loop {
83            let (client, address) = match server.accept().await {
84                Ok(pair) => pair,
85                Err(e) => {
86                    error!("Failed to accept connection: {}", e);
87                    continue;
88                }
89            };
90            let canceller = cancel_token.clone();
91            tasks.spawn(async move {
92                trace!("TCP stream connected from {}.", address);
93                if let Err(e) = handle_client(s, client, address, canceller).await {
94                    error!("Failed to handle connection. address: {}, err: {}", address, e);
95                }
96                trace!("TCP stream disconnected from {}.", address);
97            });
98        } } => {}
99    }
100    tasks.close();
101    tasks.wait().await;
102    Ok(())
103}
104
105async fn handle_client<S: Server + Sync + ?Sized>(server: &S, client: TcpStream, address: SocketAddr, cancel_token: CancellationToken) -> Result<(), NetworkError> {
106    let (receiver, sender)= client.into_split();
107    let mut receiver = BufReader::new(receiver);
108    let mut sender = BufWriter::new(sender);
109    let mut version = None;
110    let connect = get_connect_sec();
111    let cipher = match select! {
112        _ = cancel_token.cancelled() => { return Ok(()); },
113        c = timeout(Duration::from_secs(connect), async {
114            let init = server_init(&mut receiver, server.get_identifier(), |v| {
115                version = Some(v.to_string());
116                server.check_version(v)
117            }).await;
118            server_start(&mut sender, init).await
119        }) => c.map_err(|_| NetworkError::Timeout(3, connect))?,
120    } { Ok(cipher) => cipher, Err(e) => {
121        if let StarterError::IO(ref e) = e {
122            if e.kind() == ErrorKind::UnexpectedEof {
123                return Ok(()); // Ignore 'early eof'.
124            }
125        }
126        return Err(e.into());
127    } };
128    let version = version.unwrap();
129    debug!("Client connected from {}. version: {}", address, version);
130    let mut stream = IOStream::new(receiver, sender, cipher, address, version);
131    loop {
132        let receiver = &mut stream.receiver;
133        let sender = &mut stream.sender;
134        let (mut cipher, mut guard) = stream.cipher.get().await?;
135        let mut data = match select! {
136            _ = cancel_token.cancelled() => { return Ok(()); },
137            d = tcp_handler::compress_encrypt::recv(receiver, cipher) => d, // No timeout here.
138        } {
139            Ok((d, c)) => { cipher = c; d.reader() },
140            Err(e) => {
141                if let PacketError::IO(ref e) = e {
142                    if e.kind() == ErrorKind::UnexpectedEof {
143                        return Ok(()); // Ignore 'early eof'.
144                    }
145                }
146                return Err(e.into());
147            }
148        };
149        let func = data.read_string()?;
150        let function = server.get_function(&func);
151        let mut writer = BytesMut::new().writer();
152        writer.write_bool(function.is_some())?;
153        cipher = send(sender, &mut writer.into_inner(), cipher, Compression::fast()).await?;
154        (*guard).replace(cipher);
155        drop(guard);
156        if let Some(function) = function {
157            if let Err(error) = function.handle(&mut stream).await {
158                server.handle_error(&func, error, &mut stream).await?;
159            }
160        }
161    }
162}