1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
pub mod configuration;

pub extern crate async_trait;
pub extern crate tcp_handler;
#[cfg(feature = "serde")]
pub extern crate serde;

use std::net::SocketAddr;
use std::ops::Add;
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use log::{debug, error, info, trace};
use tcp_handler::bytes::{Buf, Bytes, BytesMut};
use tcp_handler::common::{AesCipher, PacketError};
use tcp_handler::compress_encrypt::{server_init, server_start};
use tcp_handler::flate2::Compression;
use tcp_handler::variable_len_reader::asynchronous::AsyncVariableWritable;
use tcp_handler::variable_len_reader::VariableReadable;
use tokio::signal::ctrl_c;
use tokio::time::{Instant, sleep};
use tokio::net::{TcpListener, TcpStream};
use tokio::{select, spawn};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use crate::configuration::{get_addr, get_connect_sec, get_idle_sec};

#[async_trait]
pub trait FuncHandler<R, W>: Send where R: AsyncReadExt + Unpin + Send, W: AsyncWriteExt + Unpin + Send {
    async fn handle(&self, receiver: &mut R, sender: &mut W, cipher: AesCipher) -> Result<AesCipher>;
}

#[async_trait]
pub trait Server {
    fn check_version(&self, version: &str) -> bool;

    fn get_function<R, W>(&self, func: &str) -> Option<Box<dyn FuncHandler<R, W>>>
        where R: AsyncReadExt + Unpin + Send, W: AsyncWriteExt + Unpin + Send;

    async fn start(&'static self) -> Result<()> {
        let cancel_token = CancellationToken::new();
        let canceller = cancel_token.clone();
        spawn(async move {
            if let Err(e) = ctrl_c().await {
                error!("Failed to listen for shutdown signal: {}", e);
            } else {
                canceller.cancel();
            }
        });
        let server = TcpListener::bind(get_addr()).await?;
        info!("Listening on {}.", server.local_addr()?);
        let tasks = TaskTracker::new();
        select! {
            _ = cancel_token.cancelled() => {
                info!("Shutting down the server gracefully...");
            }
            _ = async { loop {
                let (client, address) = match server.accept().await {
                    Ok(pair) => pair,
                    Err(e) => {
                        error!("Failed to accept connection: {}", e);
                        continue;
                    }
                };
                let canceller = cancel_token.clone();
                tasks.spawn(async move {
                    trace!("TCP stream connected from {}.", address);
                    if let Err(e) = handle(self, client, address, canceller).await {
                        error!("Failed to handle connection. address: {}, err: {}", address, e);
                    }
                    trace!("TCP stream disconnected from {}.", address);
                });
            } } => {}
        }
        tasks.close();
        tasks.wait().await;
        Ok(())
    }
}

#[inline]
pub async fn send<W: AsyncWriteExt + Unpin + Send>(stream: &mut W, message: &Bytes, cipher: AesCipher, level: Compression) -> std::result::Result<AesCipher, PacketError> {
    tcp_handler::compress_encrypt::send(stream, message, cipher, level).await
}

#[inline]
pub async fn recv<R: AsyncReadExt + Unpin + Send>(stream: &mut R, cipher: AesCipher, time: Duration) -> Option<std::result::Result<(BytesMut, AesCipher), PacketError>> {
    select! {
        c = tcp_handler::compress_encrypt::recv(stream, cipher) => Some(c),
        _ = sleep(time) => None,
    }
}

async fn handle(server: &'static (impl Server + ?Sized), client: TcpStream, address: SocketAddr, cancel_token: CancellationToken) -> Result<()> {
    let (mut receiver, mut sender)= client.into_split();
    let mut version = None;
    let connect_sec = get_connect_sec();
    let mut cipher = match select! {
        _ = cancel_token.cancelled() => { Err(()) },
        _ = sleep(Duration::from_secs(connect_sec)) => {
            debug!("Connection timeout: {}, {} secs.", address, connect_sec);
            Err(())
        },
        c = async {
            let init = server_init(&mut receiver, &"Wlist-server", |v| {
                version = Some(v.to_string());
                server.check_version(v)
            }).await;
            server_start(&mut sender, init).await.map_err(|e| {
                trace!("Error connection client. address: {}, err: {:?}", address, e)
            })
        } => c,
    } { Ok(c) => c, Err(_) => return Ok(()), };
    let version = version.unwrap();
    debug!("Client connected from {}. version: {}", address, version);
    let mut last_time = Instant::now();
    loop {
        let idle_sec = get_idle_sec();
        let mut data = select! {
            _ = cancel_token.cancelled() => { return Ok(()); },
            d = recv(&mut receiver, cipher, Instant::now().duration_since(last_time).add(Duration::from_secs(idle_sec))) => match d {
                Some(d) => match d {
                    Ok((d, c)) => { cipher = c; d.reader() },
                    Err(e) => { trace!("Error receiving data. address: {}, err: {:?}", address, e); return Ok(()); }
                },
                None => {
                    debug!("Read timeout: {}. duration: {} secs.", address, idle_sec);
                    return Ok(());
                }
            },
        };
        if let Some(func) = server.get_function(&data.read_string()?) {
            sender.write_bool(true).await?;
            sender.flush().await?;
            cipher = func.handle(&mut receiver, &mut sender, cipher).await?;
            last_time = Instant::now();
        } else {
            sender.write_bool(false).await?;
            sender.flush().await?;
        }
    }
}


#[cfg(test)]
mod tests {
    use anyhow::Result;

    #[tokio::test]
    async fn test() -> Result<()> {
        env_logger::builder().is_test(true).try_init()?;

        Ok(())
    }
}