Skip to main content

tfserver/server/
server.rs

1use crate::log_macros::{tf_debug, tf_error, tf_info, tf_warn};
2use crate::server::server_router::TfServerRouter;
3use crate::structures::s_type;
4use crate::structures::s_type::ServerErrorEn::InternalError;
5use crate::structures::s_type::{PacketMeta, ServerErrorEn};
6use std::fmt;
7use std::net::SocketAddr;
8use std::ops::Deref;
9use std::sync::Arc;
10
11use tokio::sync::{Mutex, Notify, RwLock};
12
13use crate::codec::codec_trait::TfCodec;
14use crate::server::handler::Handler;
15use crate::structures::traffic_proc::TrafficProcessorHolder;
16use crate::structures::transport::Transport;
17use futures_util::SinkExt;
18use tokio::io;
19use tokio::io::AsyncWriteExt;
20use tokio::net::{TcpListener, TcpStream};
21use tokio::sync::mpsc::{Receiver, Sender};
22use tokio::task::JoinHandle;
23use tokio_rustls::TlsAcceptor;
24use tokio_rustls::rustls::ServerConfig;
25use tokio_util::bytes::{Bytes, BytesMut};
26use tokio_util::codec::Framed;
27
28/// The request channel, used to move out tcp stream out of server control.
29///
30/// When the stream is moved, the server does not own it anymore.
31///
32/// If there is a need to return the stream, only reconnect is available.
33pub type RequestChannel<C> = (
34    Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
35    Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
36);
37
38#[derive(Clone)]
39pub enum ServerMode {
40    /// Plain TCP or TLS
41    Tcp,
42    /// WebSocket upgrade over plain TCP or TLS
43    WebSocket,
44}
45
46/// Base binary TCP server.
47///
48/// `C` is the codec used to encode/decode data.
49///
50/// Recommended default codec is `LengthDelimitedCodec` from the server codec module.
51pub struct TfServer<C>
52where
53    C: TfCodec,
54{
55    router: Arc<TfServerRouter<C>>,
56    socket: Arc<TcpListener>,
57    shutdown_sig: Arc<Notify>,
58    processor: Option<TrafficProcessorHolder<C>>,
59    codec: C,
60    config: Option<ServerConfig>,
61    mode: ServerMode,
62}
63
64impl<C> TfServer<C>
65where
66    C: TfCodec,
67{
68    /// Creates a new server instance bound to `bind_address`.
69    ///
70    /// Returns an error if the address cannot be bound.
71    pub async fn new(
72        bind_address: String,
73        router: Arc<TfServerRouter<C>>,
74        processor: Option<TrafficProcessorHolder<C>>,
75        codec: C,
76        config: Option<ServerConfig>,
77        mode: ServerMode,
78    ) -> Result<Self, io::Error> {
79        let socket = TcpListener::bind(&bind_address).await.map_err(|e| {
80            tf_error!("Failed to bind to {}: {}", bind_address, e);
81            e
82        })?;
83        tf_info!("Server bound to {}", bind_address);
84        Ok(Self {
85            router,
86            socket: Arc::new(socket),
87            shutdown_sig: Arc::new(Notify::new()),
88            processor,
89            codec,
90            config,
91            mode,
92        })
93    }
94
95    /// Start the task for handling connections.
96    ///
97    /// Returns the join handle for the acceptor task.
98    pub async fn start(&mut self) -> JoinHandle<()> {
99        let (listener, router, shutdown_sig) = (
100            self.socket.clone(),
101            self.router.clone(),
102            self.shutdown_sig.clone(),
103        );
104        let mut processor = if let Some(proc) = self.processor.take() {
105            proc
106        } else {
107            TrafficProcessorHolder::new()
108        };
109        let codec = self.codec.clone();
110        let config = self.config.clone();
111        let mode = self.mode.clone();
112
113        tokio::spawn(async move {
114            loop {
115                tokio::select! {
116                    res = listener.accept() => {
117                        match res {
118                            Ok((stream, addr)) => {
119                                tf_debug!("Accepted connection from {}", addr);
120                                let _ = stream.set_nodelay(true);
121                                let codec = codec.clone();
122                                let mode = mode.clone();
123                                let transport = Self::initial_accept(stream, config.clone(), codec, &mode).await;
124
125                                if let Some(mut transport) = transport {
126                                    if processor.initial_connect(&mut transport.0).await {
127                                        let mut framed = Framed::new(transport.0, transport.1);
128                                        if processor.initial_framed_connect(&mut framed).await {
129                                            let router = router.clone();
130                                            let prc_clone = processor.clone();
131                                            tokio::spawn(async move {
132                                                Self::handle_connection(addr, framed, router.as_ref(), prc_clone).await;
133                                            });
134                                        } else {
135                                            tf_warn!("Framed processor rejected connection from {}", addr);
136                                        }
137                                    } else {
138                                        tf_warn!("Processor rejected connection from {}", addr);
139                                        let _ = transport.0.shutdown().await;
140                                    }
141                                }
142                            }
143                            Err(e) => {
144                                tf_warn!("Accept error: {}", e);
145                            }
146                        }
147                    }
148                    _ = shutdown_sig.notified() => {
149                        tf_info!("Server shutting down");
150                        break;
151                    }
152                }
153            }
154        })
155    }
156
157    async fn initial_accept(
158        stream: TcpStream,
159        config: Option<ServerConfig>,
160        mut codec_setup: C,
161        mode: &ServerMode,
162    ) -> Option<(Transport, C)> {
163        let transport = match &config {
164            None => Transport::plain(stream),
165            Some(cfg) => {
166                let acceptor = TlsAcceptor::from(Arc::new(cfg.clone()));
167                match acceptor.accept(stream).await {
168                    Ok(tls) => Transport::tls_server(tls),
169                    Err(e) => {
170                        tf_warn!("TLS handshake failed: {}", e);
171                        return None;
172                    }
173                }
174            }
175        };
176
177        let mut transport = match mode {
178            ServerMode::Tcp => transport,
179            ServerMode::WebSocket => match Transport::accept_websocket(transport).await {
180                Ok(ws_stream) => ws_stream,
181                Err(e) => {
182                    tf_warn!("WebSocket handshake failed: {}", e);
183                    return None;
184                }
185            },
186        };
187
188        if !codec_setup.initial_setup(&mut transport).await {
189            tf_warn!("Codec initial_setup rejected connection");
190            return None;
191        }
192
193        Some((transport, codec_setup))
194    }
195
196    /// Signals the acceptor task to stop.
197    pub fn send_stop(&self) {
198        self.shutdown_sig.notify_waiters();
199    }
200
201    async fn handle_connection(
202        addr: SocketAddr,
203        mut stream: Framed<Transport, C>,
204        router: &TfServerRouter<C>,
205        mut processor: TrafficProcessorHolder<C>,
206    ) {
207        use futures_util::SinkExt;
208        let move_sig = tokio::sync::oneshot::channel::<Arc<RwLock<dyn Handler<Codec = C>>>>();
209        let mut move_sig = (Some(move_sig.0), move_sig.1);
210        loop {
211            let meta_data: Result<Option<BytesMut>, bool> =
212                Self::receive_message(addr, &mut stream, &mut processor).await;
213            if meta_data.is_err() {
214                if meta_data.unwrap_err() {
215                    stream.close().await.unwrap_or_else(|e| {
216                        tf_warn!("Error closing stream for {}: {}", addr, e);
217                    });
218                    return;
219                }
220                continue;
221            }
222
223            let meta_data = meta_data.unwrap();
224            if meta_data.is_none() {
225                continue;
226            }
227            let meta_data = meta_data.unwrap();
228            let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
229                Ok(meta) => meta.has_payload,
230                Err(e) => {
231                    tf_warn!("Failed to deserialize PacketMeta from {}: {}", addr, e);
232                    false
233                }
234            };
235
236            let mut payload: BytesMut = BytesMut::new();
237            if has_payload {
238                let payload_res =
239                    Self::receive_message(addr, &mut stream, &mut processor).await;
240                if payload_res.is_err() {
241                    if payload_res.unwrap_err() {
242                        stream.close().await.unwrap_or_else(|e| {
243                            tf_warn!("Error closing stream for {}: {}", addr, e);
244                        });
245                        return;
246                    }
247                    continue;
248                }
249                let payload_opt = payload_res.unwrap();
250                if payload_opt.is_none() {
251                    let _ = stream.close().await;
252                    return;
253                }
254                payload = payload_opt.unwrap();
255            }
256            let res = router
257                .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
258                .await;
259
260            let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
261            let res = Self::send_message(&mut stream, message, &mut processor).await;
262
263            if let Ok(requester) = move_sig.1.try_recv() {
264                requester
265                    .write()
266                    .await
267                    .accept_stream(addr, (stream, processor.clone()))
268                    .await;
269                return;
270            }
271
272            if let Err(e) = res {
273                tf_warn!("Send error for {}: {}", addr, e);
274                let _ = stream.close();
275                return;
276            }
277        }
278    }
279
280    async fn send_message(
281        stream: &mut Framed<Transport, C>,
282        message: Vec<u8>,
283        processor: &mut TrafficProcessorHolder<C>,
284    ) -> Result<(), io::Error> {
285        let message = Bytes::from(processor.post_process_traffic(message).await);
286        stream.send(message).await
287    }
288
289    async fn receive_message(
290        addr: SocketAddr,
291        stream: &mut Framed<Transport, C>,
292        processor: &mut TrafficProcessorHolder<C>,
293    ) -> Result<Option<BytesMut>, bool> {
294        use futures_util::StreamExt;
295        match stream.next().await {
296            Some(data) => match data {
297                Ok(mut data) => {
298                    data = processor.pre_process_traffic(data).await;
299                    Ok(Some(data))
300                }
301                Err(e) => match e.kind() {
302                    std::io::ErrorKind::ConnectionReset
303                    | std::io::ErrorKind::ConnectionAborted
304                    | std::io::ErrorKind::BrokenPipe
305                    | std::io::ErrorKind::UnexpectedEof => {
306                        tf_debug!("Client {} disconnected", addr);
307                        Err(true)
308                    }
309                    std::io::ErrorKind::InvalidData => {
310                        tf_warn!("Frame exceeded maximum size from {}: {}", addr, e);
311                        Err(false)
312                    }
313                    _ => {
314                        tf_warn!("IO error reading frame from {}: {}", addr, e);
315                        Err(false)
316                    }
317                },
318            },
319            None => Err(true),
320        }
321    }
322}
323
324impl fmt::Display for ServerErrorEn {
325    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326        match self {
327            ServerErrorEn::MalformedMetaInfo(Some(msg)) => {
328                write!(f, "Malformed meta info: {}", msg)
329            }
330            ServerErrorEn::MalformedMetaInfo(None) => write!(f, "Malformed meta info!"),
331            ServerErrorEn::NoSuchHandler(Some(msg)) => write!(f, "No such handler: {}", msg),
332            ServerErrorEn::NoSuchHandler(None) => write!(f, "No such handler!"),
333            InternalError(Some(data)) => {
334                write!(
335                    f,
336                    "{}",
337                    String::from_utf8(data.clone())
338                        .unwrap_or_else(|_| "Internal server error!".to_owned())
339                )
340            }
341            InternalError(None) => {
342                write!(f, "Internal server error!")
343            }
344            ServerErrorEn::PayloadLost => {
345                write!(f, "Payload lost!")
346            }
347        }
348    }
349}
350
351impl std::error::Error for ServerErrorEn {}