1use std::io::ErrorKind;
4use std::net::SocketAddr;
5use std::time::Duration;
6use log::{debug, error, info, trace};
7use tcp_handler::bytes::{Buf, BufMut, BytesMut};
8use tcp_handler::common::{AesCipher, PacketError, StarterError};
9use tcp_handler::compress_encrypt::{server_init, server_start};
10use tcp_handler::flate2::Compression;
11use tcp_handler::variable_len_reader::{VariableReader, VariableWriter};
12use thiserror::Error;
13use tokio::signal::ctrl_c;
14use tokio::time::timeout;
15use tokio::net::{TcpListener, TcpStream};
16use tokio::{select, spawn};
17use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
18use tokio_util::sync::CancellationToken;
19use tokio_util::task::TaskTracker;
20use crate::config::{get_addr, get_connect_sec, get_idle_sec};
21use crate::handler_base::IOStream;
22use crate::Server;
23
24#[derive(Error, Debug)]
26pub enum NetworkError {
27 #[error("Network timeout: {} after {1} sec.", match .0 { 1 => "Sending", 2 => "Receiving", _ => "Connecting" })]
29 Timeout(u8, u64),
30
31 #[error("During io packet: {0:?}")]
33 StarterError(#[from] StarterError),
34
35 #[error("During io packet: {0:?}")]
37 PacketError(#[from] PacketError),
38
39 #[error("During read/write data: {0:?}")]
41 BufError(#[from] std::io::Error),
42
43 #[error("Broken client.")]
48 BrokenCipher(),
49}
50
51#[inline]
52pub(crate) async fn send<W: AsyncWriteExt + Unpin + Send, B: Buf + Send>(stream: &mut W, message: &mut B, cipher: AesCipher, level: Compression) -> Result<AesCipher, NetworkError> {
53 let idle = get_idle_sec();
54 timeout(Duration::from_secs(idle), tcp_handler::compress_encrypt::send(stream, message, cipher, level)).await
55 .map_err(|_| NetworkError::Timeout(1, idle))?.map_err(|e| e.into())
56}
57
58#[inline]
59pub(crate) async fn recv<R: AsyncReadExt + Unpin + Send>(stream: &mut R, cipher: AesCipher) -> Result<(BytesMut, AesCipher), NetworkError> {
60 let idle = get_idle_sec();
61 timeout(Duration::from_secs(idle), tcp_handler::compress_encrypt::recv(stream, cipher)).await
62 .map_err(|_| NetworkError::Timeout(2, idle))?.map_err(|e| e.into())
63}
64
65pub(super) async fn start_server<S: Server + Sync + ?Sized>(s: &'static S) -> std::io::Result<()> {
66 let cancel_token = CancellationToken::new();
67 let canceller = cancel_token.clone();
68 spawn(async move {
69 if let Err(e) = ctrl_c().await {
70 error!("Failed to listen for shutdown signal: {}", e);
71 } else {
72 canceller.cancel();
73 }
74 });
75 let server = TcpListener::bind(get_addr()).await?;
76 info!("Listening on {}.", server.local_addr()?);
77 let tasks = TaskTracker::new();
78 select! {
79 _ = cancel_token.cancelled() => {
80 info!("Shutting down the server gracefully...");
81 }
82 _ = async { loop {
83 let (client, address) = match server.accept().await {
84 Ok(pair) => pair,
85 Err(e) => {
86 error!("Failed to accept connection: {}", e);
87 continue;
88 }
89 };
90 let canceller = cancel_token.clone();
91 tasks.spawn(async move {
92 trace!("TCP stream connected from {}.", address);
93 if let Err(e) = handle_client(s, client, address, canceller).await {
94 error!("Failed to handle connection. address: {}, err: {}", address, e);
95 }
96 trace!("TCP stream disconnected from {}.", address);
97 });
98 } } => {}
99 }
100 tasks.close();
101 tasks.wait().await;
102 Ok(())
103}
104
105async fn handle_client<S: Server + Sync + ?Sized>(server: &S, client: TcpStream, address: SocketAddr, cancel_token: CancellationToken) -> Result<(), NetworkError> {
106 let (receiver, sender)= client.into_split();
107 let mut receiver = BufReader::new(receiver);
108 let mut sender = BufWriter::new(sender);
109 let mut version = None;
110 let connect = get_connect_sec();
111 let cipher = match select! {
112 _ = cancel_token.cancelled() => { return Ok(()); },
113 c = timeout(Duration::from_secs(connect), async {
114 let init = server_init(&mut receiver, server.get_identifier(), |v| {
115 version = Some(v.to_string());
116 server.check_version(v)
117 }).await;
118 server_start(&mut sender, init).await
119 }) => c.map_err(|_| NetworkError::Timeout(3, connect))?,
120 } { Ok(cipher) => cipher, Err(e) => {
121 if let StarterError::IO(ref e) = e {
122 if e.kind() == ErrorKind::UnexpectedEof {
123 return Ok(()); }
125 }
126 return Err(e.into());
127 } };
128 let version = version.unwrap();
129 debug!("Client connected from {}. version: {}", address, version);
130 let mut stream = IOStream::new(receiver, sender, cipher, address, version);
131 loop {
132 let receiver = &mut stream.receiver;
133 let sender = &mut stream.sender;
134 let (mut cipher, mut guard) = stream.cipher.get().await?;
135 let mut data = match select! {
136 _ = cancel_token.cancelled() => { return Ok(()); },
137 d = tcp_handler::compress_encrypt::recv(receiver, cipher) => d, } {
139 Ok((d, c)) => { cipher = c; d.reader() },
140 Err(e) => {
141 if let PacketError::IO(ref e) = e {
142 if e.kind() == ErrorKind::UnexpectedEof {
143 return Ok(()); }
145 }
146 return Err(e.into());
147 }
148 };
149 let func = data.read_string()?;
150 let function = server.get_function(&func);
151 let mut writer = BytesMut::new().writer();
152 writer.write_bool(function.is_some())?;
153 cipher = send(sender, &mut writer.into_inner(), cipher, Compression::fast()).await?;
154 (*guard).replace(cipher);
155 drop(guard);
156 if let Some(function) = function {
157 if let Err(error) = function.handle(&mut stream).await {
158 server.handle_error(&func, error, &mut stream).await?;
159 }
160 }
161 }
162}