Skip to main content

tfserver/server/
tcp_server.rs

1use crate::server::server_router::TcpServerRouter;
2use crate::structures::s_type;
3use crate::structures::s_type::ServerErrorEn::InternalError;
4use crate::structures::s_type::{PacketMeta, ServerErrorEn};
5use std::fmt;
6use std::net::SocketAddr;
7use std::ops::Deref;
8use std::sync::Arc;
9
10use tokio::sync::{Mutex, Notify};
11
12use crate::server::handler::Handler;
13use crate::structures::traffic_proc::TrafficProcessorHolder;
14use crate::structures::transport::Transport;
15use futures_util::{SinkExt};
16use tokio::io;
17use tokio::io::AsyncWriteExt;
18use tokio::net::{TcpListener, TcpStream};
19use tokio::sync::mpsc::{Receiver, Sender};
20use tokio::task::JoinHandle;
21use tokio_rustls::TlsAcceptor;
22use tokio_rustls::rustls::ServerConfig;
23use tokio_util::bytes::{Bytes, BytesMut};
24use tokio_util::codec::{Decoder, Encoder, Framed};
25use crate::codec::codec_trait::TfCodec;
26
27///The request channel, used to move out tcp stream out of server control.
28///
29///When the stream is moved, the server does not owns it anymore.
30///
31///If is there need to return stream, only reconnect is available.
32pub type RequestChannel<C>
33 = (
34    Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
35    Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
36);
37
38///Base binary tcp server.
39///
40/// 'C' is you codec, that you want to use to encode/decode data.
41///
42///Recommended default codec is LengthDelimitedCodec, from the server codec module.
43
44pub struct TcpServer<C>
45where
46    C: Encoder<Bytes, Error = io::Error>
47    + Decoder<Item = BytesMut, Error = io::Error>
48    + Send
49    + Sync
50    + Clone
51    + 'static + TfCodec,
52{
53    router: Arc<TcpServerRouter<C>>,
54    socket: Arc<TcpListener>,
55    shutdown_sig: Arc<Notify>,
56    processor: Option<TrafficProcessorHolder<C>>,
57    codec: C,
58    config: Option<ServerConfig>,
59}
60
61impl<C> TcpServer<C>
62where
63    C: Encoder<Bytes, Error = io::Error>
64    + Decoder<Item = BytesMut, Error = io::Error>
65    + Send
66    + Sync
67    + Clone
68    + 'static + TfCodec,
69{
70    ///Creates a new instance of a server.
71    ///
72    /// 'bind_address' is a target address to bind current server. E.g: 0.0.0.0:8080
73    /// 'router' setted up router with handlers. Must be called commit_routes before using.
74    /// 'processor' Custom traffic processor, used for all streams.
75    /// 'codec' basically codec used for every stream with it's own instance, when the codec is applied to stream, first call is clone, the second call is initial_setup.
76    /// 'config' optional config for tls connection, when None the tls is not using, when some all connections are passed behind tls.
77    pub async fn new(
78        bind_address: String,
79        router: Arc<TcpServerRouter<C>>,
80        processor: Option<TrafficProcessorHolder<C>>,
81        codec: C,
82        config: Option<ServerConfig>,
83    ) -> Self {
84        Self {
85            router,
86            socket: Arc::new(
87                TcpListener::bind(&bind_address)
88                    .await
89                    .expect("Failed to bind to address"),
90            ),
91            shutdown_sig: Arc::new(Notify::new()),
92            processor,
93            codec,
94            config,
95        }
96    }
97
98    ///Start the task for handling connections.
99    ///
100    ///Return the join handle, of this task.
101    pub async fn start(&mut self) -> JoinHandle<()>{
102        let (listener, router, shutdown_sig) = {
103            (
104                self.socket.clone(),
105                self.router.clone(),
106                self.shutdown_sig.clone(),
107            )
108        };
109        let mut processor = if let Some(proc) = self.processor.take() {
110            proc
111        } else {
112            TrafficProcessorHolder::new()
113        };
114        let codec = self.codec.clone();
115        let config = self.config.clone();
116        tokio::spawn(async move {
117            loop {
118                tokio::select! {
119                            res = listener.accept() => {
120                            if res.is_ok() {
121                                let stream = res.unwrap();
122                                stream.0.set_nodelay(true).unwrap();
123                                let codec = codec.clone();
124
125
126                                let transport = Self::initial_accept(stream.0, config.clone(), codec).await;
127                                if let Some(mut transport) = transport {
128                                 if processor.initial_connect(&mut transport.0).await {
129                                    let mut framed = Framed::new(transport.0, transport.1);
130
131                                if processor.initial_framed_connect(&mut framed).await {
132                                    let router = router.clone();
133                                    let prc_clone = processor.clone();
134
135                                    tokio::spawn(async move {
136                                        Self::handle_connection(
137                                            stream.1,
138                                            framed,
139                                            router.as_ref(),
140                                            prc_clone,
141                                        )
142                                        .await;
143                                    });
144                                }
145                                } else {
146                                    let _ = transport.0.shutdown().await;
147                                }
148                            }
149
150                        }
151                    }
152                    _ = shutdown_sig.notified() => break,
153                }
154            }
155        })
156    }
157
158    ///Initial accept called for every connection, on connected event.
159    async fn initial_accept(stream: TcpStream, config: Option<ServerConfig>, mut codec_setup: C) -> Option<(Transport, C)> {
160        if config.is_none() {
161            let mut res = Transport::plain(stream);
162            if !codec_setup.initial_setup(&mut res).await {
163                return None;
164            }
165            return Some((res, codec_setup));
166        } else {
167            let cfg = config.unwrap();
168            let acceptor = TlsAcceptor::from(Arc::new(cfg));
169            let res = acceptor.accept(stream).await;
170            if res.is_err() {
171                return None;
172            }
173            let mut res = Transport::tls_server(res.unwrap());
174            if !codec_setup.initial_setup(&mut res).await {
175                return None;
176            }
177            return Some((res, codec_setup));
178        }
179    }
180    ///Stops the acceptor task.
181    pub fn send_stop(&self) {
182        self.shutdown_sig.notify_one();
183    }
184    
185    ///Main function for every connection
186    async fn handle_connection(
187        addr: SocketAddr,
188        mut stream: Framed<Transport, C>,
189        router: &TcpServerRouter<C>,
190        mut processor: TrafficProcessorHolder<C>,
191    ) {
192        use futures_util::SinkExt;
193        let move_sig = tokio::sync::oneshot::channel::<Arc<Mutex<dyn Handler<Codec = C>>>>();
194        let mut move_sig = (Some(move_sig.0), move_sig.1);
195        loop {
196            let meta_data: Result<Option<BytesMut>, bool> =
197                Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
198            if meta_data.is_err() {
199                if meta_data.unwrap_err() {
200                    stream.close().await.unwrap();
201                    return;
202                }
203                continue;
204            }
205
206            let meta_data = meta_data.unwrap();
207            if meta_data.is_none() {
208                continue;
209            }
210            let meta_data = meta_data.unwrap();
211            let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
212                Ok(meta) => meta.has_payload,
213                Err(_) => false,
214            };
215
216            let mut payload: BytesMut = BytesMut::new();
217            if has_payload {
218                let payload_res =
219                    Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
220                if payload_res.is_err() {
221                    if payload_res.unwrap_err() {
222                        stream.close().await.unwrap();
223                        return;
224                    }
225                    continue;
226                }
227                let payload_opt = payload_res.unwrap();
228                if payload_opt.is_none() {
229                    let _ = stream.close().await;
230                    return;
231                }
232                payload = payload_opt.unwrap();
233            }
234            let res = router
235                .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
236                .await;
237
238            let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
239            let res = Self::send_message(&mut stream, message, &mut processor).await;
240
241            if let Ok(requester) = move_sig.1.try_recv() {
242                requester
243                    .lock()
244                    .await
245                    .accept_stream(addr, (stream, processor.clone()))
246                    .await;
247                return;
248            }
249
250            match res {
251                Err(_) => {
252                    let _ = stream.close();
253                    return;
254                }
255                _ => {}
256            }
257        }
258    }
259    async fn send_message(
260        stream: &mut Framed<Transport, C>,
261        message: Vec<u8>,
262        processor: &mut TrafficProcessorHolder<C>,
263    ) -> Result<(), io::Error> {
264        let message = Bytes::from(processor.post_process_traffic(message).await);
265        stream.send(message).await
266    }
267
268    async fn receive_message(
269        _: SocketAddr,
270        stream: &mut Framed<Transport, C>,
271        processor: &mut TrafficProcessorHolder<C>,
272    ) -> Result<Option<BytesMut>, bool> {
273        use futures_util::StreamExt;
274        match stream.next().await {
275            Some(data) => match data {
276                Ok(mut data) => {
277                    data = processor.pre_process_traffic(data).await;
278                    return Ok(Some(data));
279                }
280                Err(e) => {
281                    // This is where codec-level decoding errors happen
282                    match e.kind() {
283                        // IO errors usually mean the connection is broken
284                        std::io::ErrorKind::ConnectionReset
285                        | std::io::ErrorKind::ConnectionAborted
286                        | std::io::ErrorKind::BrokenPipe
287                        | std::io::ErrorKind::UnexpectedEof => {
288                            println!("Client disconnected");
289                            return Err(true);
290                        }
291
292                        // Frame too large (if you set max_frame_length)
293                        std::io::ErrorKind::InvalidData => {
294                            eprintln!("Frame exceeded maximum size: {e}");
295                            return Err(false);
296                        }
297
298                        // Other IO errors
299                        _ => {
300                            eprintln!("IO error while reading frame: {e}");
301                            return Err(false);
302                        }
303                    }
304                }
305            },
306            None => {
307                return Err(true);
308            }
309        }
310    }
311}
312
313// Custom Error Display
314impl fmt::Display for ServerErrorEn {
315    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316        match self {
317            ServerErrorEn::MalformedMetaInfo(Some(msg)) => {
318                write!(f, "Malformed meta info: {}", msg)
319            }
320            ServerErrorEn::MalformedMetaInfo(None) => write!(f, "Malformed meta info!"),
321            ServerErrorEn::NoSuchHandler(Some(msg)) => write!(f, "No such handler: {}", msg),
322            ServerErrorEn::NoSuchHandler(None) => write!(f, "No such handler!"),
323            InternalError(Some(data)) => {
324                write!(
325                    f,
326                    "{}",
327                    String::from_utf8(data.clone())
328                        .unwrap_or_else(|_| "Internal server error!".to_owned())
329                )
330            }
331            InternalError(None) => {
332                write!(f, "Internal server error!")
333            }
334            ServerErrorEn::PayloadLost => {
335                write!(f, "Payload lost!")
336            }
337        }
338    }
339}
340
341impl std::error::Error for ServerErrorEn {}