tcp_server/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3#![forbid(unsafe_code)]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5
6pub mod config;
7pub mod network;
8mod mutable_cipher;
9pub mod handler_base;
10
11pub extern crate anyhow;
12pub extern crate async_trait;
13pub extern crate tokio;
14pub extern crate tcp_handler;
15
16use anyhow::Error;
17use async_trait::async_trait;
18use log::error;
19use tokio::io::{AsyncReadExt, AsyncWriteExt};
20use crate::handler_base::{FuncHandler, IOStream};
21use crate::network::{NetworkError, start_server};
22
23/// The basic trait for a server.
24/// # Example
25/// ```rust,no_run
26/// use tcp_server::{func_handler, Server};
27/// use tcp_server::anyhow::anyhow;
28/// use tcp_server::handler_base::FuncHandler;
29/// use tcp_server::tcp_handler::bytes::{Buf, BufMut, BytesMut};
30/// use tcp_server::tcp_handler::variable_len_reader::{VariableReader, VariableWriter};
31/// use tcp_server::tokio::io::{AsyncReadExt, AsyncWriteExt};
32///
33/// struct MyServer;
34///
35/// impl Server for MyServer {
36///     fn get_identifier(&self) -> &'static str {
37///         "MyTcpApplication"
38///     }
39///
40///     fn check_version(&self, version: &str) -> bool {
41///         version == env!("CARGO_PKG_VERSION")
42///     }
43///
44///     fn get_function<R, W>(&self, func: &str) -> Option<Box<dyn FuncHandler<R, W>>>
45///         where R: AsyncReadExt + Unpin + Send, W: AsyncWriteExt + Unpin + Send {
46///         match func {
47///             "hello" => Some(Box::new(HelloHandler)),
48///             _ => None,
49///         }
50///     }
51/// }
52///
53/// func_handler!(HelloHandler, |stream| {
54///     let mut reader = stream.recv().await?.reader();
55///     if "hello server" != reader.read_string()? {
56///         return Err(anyhow!("Invalid message."));
57///     }
58///     let mut writer = BytesMut::new().writer();
59///     writer.write_string("hello client")?;
60///     stream.send(&mut writer.into_inner()).await?;
61///     Ok(())
62/// });
63///
64/// #[tokio::main]
65/// async fn main() {
66///     MyServer.start().await.unwrap();
67/// }
68/// ```
69#[async_trait]
70pub trait Server {
71    /// Get the identifier of your application.
72    /// # Note
73    /// This should be a const.
74    fn get_identifier(&self) -> &'static str;
75
76    /// Check the version of the client.
77    /// You can reject the client if the version is not supported.
78    /// # Note
79    /// This should be no side effect.
80    /// # Example
81    /// ```rust,ignore
82    /// version == env!("CARGO_PKG_VERSION")
83    /// ```
84    fn check_version(&self, version: &str) -> bool;
85
86    /// Return the function handler. See [`Server`] for example.
87    /// # Note
88    /// This should be no side effect.
89    fn get_function<R, W>(&self, func: &str) -> Option<Box<dyn FuncHandler<R, W>>>
90        where R: AsyncReadExt + Unpin + Send, W: AsyncWriteExt + Unpin + Send;
91
92    /// Handle the error which returned by the [`FuncHandler`].
93    /// Default only print the error message.
94    async fn handle_error<R, W>(&self, func: &str, error: Error, _stream: &mut IOStream<R, W>) -> Result<(), NetworkError>
95        where R: AsyncReadExt + Unpin + Send, W: AsyncWriteExt + Unpin + Send {
96        error!("Failed to handle in {}: {}", func, error);
97        Ok(())
98    }
99
100    /// Start the server. This method will block the caller thread.
101    ///
102    /// It **only** will return an error if the server cannot start.
103    /// If you want to handle errors returned by [`FuncHandler`], you should override [`Server::handle_error`].
104    async fn start(&'static self) -> std::io::Result<()> {
105        start_server(self).await
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use std::time::Duration;
112    use anyhow::Result;
113    use env_logger::Target;
114    use tcp_client::client_base::ClientBase;
115    use tcp_client::config::{ClientConfig, set_config as set_client_config};
116    use tcp_client::{client_factory, ClientFactory};
117    use tcp_client::network::NetworkError;
118    use tcp_handler::bytes::{Buf, BufMut, BytesMut};
119    use tcp_handler::variable_len_reader::{VariableReader, VariableWriter};
120    use tokio::io::{AsyncReadExt, AsyncWriteExt};
121    use tokio::spawn;
122    use tokio::time::sleep;
123    use crate::{func_handler, Server};
124    use crate::config::{ServerConfig, set_config as set_server_config};
125    use crate::handler_base::FuncHandler;
126
127    struct TestServer;
128
129    impl Server for TestServer {
130        fn get_identifier(&self) -> &'static str {
131            "tester"
132        }
133
134        fn check_version(&self, version: &str) -> bool {
135            version == env!("CARGO_PKG_VERSION")
136        }
137
138        fn get_function<R, W>(&self, func: &str) -> Option<Box<dyn FuncHandler<R, W>>>
139            where R: AsyncReadExt + Unpin + Send, W: AsyncWriteExt + Unpin + Send {
140            if func == "test" {
141                func_handler!(TestHandler, |stream| {
142                    let mut reader = stream.recv().await?.reader();
143                    assert_eq!("hello server", reader.read_string()?);
144                    let mut writer = BytesMut::new().writer();
145                    writer.write_string("hello client")?;
146                    stream.send(&mut writer.into_inner()).await?;
147                    Ok(())
148                });
149                assert_eq!(std::mem::size_of::<TestHandler>(), 0);
150                Some(Box::new(TestHandler))
151            } else {
152                None
153            }
154        }
155    }
156
157    client_factory!(TestClientFactory, TestClient, "tester");
158
159    impl TestClient {
160        async fn test_method(&mut self) -> Result<(), NetworkError> {
161            self.check_func("test").await?;
162            let mut writer = BytesMut::new().writer();
163            writer.write_string("hello server")?;
164            let mut reader = self.send_recv(&mut writer.into_inner()).await?.reader();
165            assert_eq!("hello client", reader.read_string()?);
166            Ok(())
167        }
168    }
169
170    #[tokio::test]
171    async fn test() -> Result<()> {
172        env_logger::builder().parse_filters("trace").target(Target::Stderr).try_init()?;
173        set_server_config(ServerConfig {
174            addr: "localhost:25565".to_string(),
175            connect_sec: 10,
176            idle_sec: 3,
177        });
178        set_client_config(ClientConfig {
179            connect_sec: 10,
180            idle_sec: 3,
181        });
182
183        let server = spawn(TestServer.start());
184        let mut client = TestClientFactory.connect("localhost:25565").await?;
185        client.test_method().await?;
186        drop(client);
187
188        sleep(Duration::from_millis(100)).await; // Waiting for all log printing to complete.
189        server.abort();
190        Ok(())
191    }
192}