1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
use std::io::{Error, ErrorKind};
use std::net::SocketAddr;
use std::ops::Add;
use std::time::Duration;
use log::{debug, error, info, trace};
use tcp_handler::bytes::{Buf, Bytes, BytesMut};
use tcp_handler::common::{AesCipher, PacketError};
use tcp_handler::compress_encrypt::{server_init, server_start};
use tcp_handler::flate2::Compression;
use tcp_handler::variable_len_reader::asynchronous::AsyncVariableWritable;
use tcp_handler::variable_len_reader::VariableReadable;
use tokio::signal::ctrl_c;
use tokio::time::{Instant, sleep};
use tokio::net::{TcpListener, TcpStream};
use tokio::{select, spawn};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use crate::configuration::{get_addr, get_connect_sec, get_idle_sec};
use crate::Server;

#[inline]
pub async fn send<W: AsyncWriteExt + Unpin + Send>(stream: &mut W, message: &Bytes, cipher: AesCipher, level: Compression) -> Result<AesCipher, PacketError> {
    tcp_handler::compress_encrypt::send(stream, message, cipher, level).await
}

#[inline]
pub async fn recv<R: AsyncReadExt + Unpin + Send>(stream: &mut R, cipher: AesCipher, timeout: Option<(SocketAddr, Duration)>) -> Result<(BytesMut, AesCipher), PacketError> {
    if let Some((addr, time)) = timeout {
        select! {
            c = tcp_handler::compress_encrypt::recv(stream, cipher) => c,
            _ = sleep(time) => Err(PacketError::IO(Error::new(ErrorKind::TimedOut, format!("Recv timeout: {}. timeout: {:?}", addr, time)))),
        }
    } else {
        tcp_handler::compress_encrypt::recv(stream, cipher).await
    }
}

pub(super) async fn start_server<S: Server + ?Sized + Sync>(s: &'static S, identifier: &'static str) -> anyhow::Result<()> {
    let cancel_token = CancellationToken::new();
    let canceller = cancel_token.clone();
    spawn(async move {
        if let Err(e) = ctrl_c().await {
            error!("Failed to listen for shutdown signal: {}", e);
        } else {
            canceller.cancel();
        }
    });
    let server = TcpListener::bind(get_addr()).await?;
    info!("Listening on {}.", server.local_addr()?);
    let tasks = TaskTracker::new();
    select! {
            _ = cancel_token.cancelled() => {
                info!("Shutting down the server gracefully...");
            }
            _ = async { loop {
                let (client, address) = match server.accept().await {
                    Ok(pair) => pair,
                    Err(e) => {
                        error!("Failed to accept connection: {}", e);
                        continue;
                    }
                };
                let canceller = cancel_token.clone();
                tasks.spawn(async move {
                    trace!("TCP stream connected from {}.", address);
                    if let Err(e) = handle_client(s, client, address, canceller, identifier).await {
                        error!("Failed to handle connection. address: {}, err: {}", address, e);
                    }
                    trace!("TCP stream disconnected from {}.", address);
                });
            } } => {}
        }
    tasks.close();
    tasks.wait().await;
    Ok(())
}

async fn handle_client<S: Server + ?Sized>(server: &'static S, client: TcpStream, address: SocketAddr, cancel_token: CancellationToken, identifier: &str) -> anyhow::Result<()> {
    let (mut receiver, mut sender)= client.into_split();
    let mut version = None;
    let connect_sec = get_connect_sec();
    let mut cipher = match select! {
        _ = cancel_token.cancelled() => { Err(()) },
        _ = sleep(Duration::from_secs(connect_sec)) => {
            debug!("Connection timeout: {}, {} secs.", address, connect_sec);
            Err(())
        },
        c = async {
            let init = server_init(&mut receiver, identifier, |v| {
                version = Some(v.to_string());
                server.check_version(v)
            }).await;
            server_start(&mut sender, init).await.map_err(|e| {
                trace!("Error connection client. address: {}, err: {:?}", address, e)
            })
        } => c,
    } { Ok(c) => c, Err(_) => return Ok(()), };
    let version = version.unwrap();
    debug!("Client connected from {}. version: {}", address, version);
    let mut last_time = Instant::now();
    loop {
        let idle_sec = get_idle_sec();
        let mut data = select! {
            _ = cancel_token.cancelled() => { return Ok(()); },
            d = recv(&mut receiver, cipher, Some((address, Instant::now().duration_since(last_time).add(Duration::from_secs(idle_sec))))) => match d {
                Ok((d, c)) => { cipher = c; d.reader() },
                Err(e) => { trace!("Error receiving data. address: {}, err: {:?}", address, e); return Ok(()); }
            },
        };
        if let Some(func) = server.get_function(&data.read_string()?) {
            sender.write_bool(true).await?;
            sender.flush().await?;
            cipher = func.handle(&mut receiver, &mut sender, cipher, address).await?;
            last_time = Instant::now();
        } else {
            sender.write_bool(false).await?;
            sender.flush().await?;
        }
    }
}