Skip to main content

tfserver/server/
server.rs

1use crate::server::server_router::TfServerRouter;
2use crate::structures::s_type;
3use crate::structures::s_type::ServerErrorEn::InternalError;
4use crate::structures::s_type::{PacketMeta, ServerErrorEn};
5use std::fmt;
6use std::net::SocketAddr;
7use std::ops::Deref;
8use std::sync::Arc;
9
10use tokio::sync::{Mutex, Notify, RwLock};
11
12use crate::codec::codec_trait::TfCodec;
13use crate::server::handler::Handler;
14use crate::structures::traffic_proc::TrafficProcessorHolder;
15use crate::structures::transport::Transport;
16use futures_util::SinkExt;
17use tokio::io;
18use tokio::io::AsyncWriteExt;
19use tokio::net::{TcpListener, TcpStream};
20use tokio::sync::mpsc::{Receiver, Sender};
21use tokio::task::JoinHandle;
22use tokio_rustls::TlsAcceptor;
23use tokio_rustls::rustls::ServerConfig;
24use tokio_util::bytes::{Bytes, BytesMut};
25use tokio_util::codec::Framed;
26
27///The request channel, used to move out tcp stream out of server control.
28///
29///When the stream is moved, the server does not owns it anymore.
30///
31///If is there need to return stream, only reconnect is available.
32pub type RequestChannel<C> = (
33    Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
34    Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
35);
36
37#[derive(Clone)]
38pub enum ServerMode {
39    /// Plain TCP or TLS
40    Tcp,
41    /// WebSocket upgrade over plain TCP or TLS
42    WebSocket,
43}
44
45///Base binary tcp server.
46///
47/// 'C' is you codec, that you want to use to encode/decode data.
48///
49///Recommended default codec is LengthDelimitedCodec, from the server codec module.
50
51pub struct TfServer<C>
52where
53    C: TfCodec,
54{
55    router: Arc<TfServerRouter<C>>,
56    socket: Arc<TcpListener>,
57    shutdown_sig: Arc<Notify>,
58    processor: Option<TrafficProcessorHolder<C>>,
59    codec: C,
60    config: Option<ServerConfig>,
61    mode: ServerMode,
62}
63
64impl<C> TfServer<C>
65where
66    C: TfCodec,
67{
68    ///Creates a new instance of a server.
69    ///
70    /// 'bind_address' is a target address to bind current server. E.g: 0.0.0.0:8080
71    /// 'router' setted up router with handlers. Must be called commit_routes before using.
72    /// 'processor' Custom traffic processor, used for all streams.
73    /// 'codec' basically codec used for every stream with it's own instance, when the codec is applied to stream, first call is clone, the second call is initial_setup.
74    /// 'config' optional config for tls connection, when None the tls is not using, when some all connections are passed behind tls.
75    pub async fn new(
76        bind_address: String,
77        router: Arc<TfServerRouter<C>>,
78        processor: Option<TrafficProcessorHolder<C>>,
79        codec: C,
80        config: Option<ServerConfig>,
81        mode: ServerMode,
82    ) -> Self {
83        Self {
84            router,
85            socket: Arc::new(
86                TcpListener::bind(&bind_address)
87                    .await
88                    .expect("Failed to bind to address"),
89            ),
90            shutdown_sig: Arc::new(Notify::new()),
91            processor,
92            codec,
93            config,
94            mode,
95        }
96    }
97
98    ///Start the task for handling connections.
99    ///
100    ///Return the join handle, of this task.
101    pub async fn start(&mut self) -> JoinHandle<()> {
102        let (listener, router, shutdown_sig) = (
103            self.socket.clone(),
104            self.router.clone(),
105            self.shutdown_sig.clone(),
106        );
107        let mut processor = if let Some(proc) = self.processor.take() {
108            proc
109        } else {
110            TrafficProcessorHolder::new()
111        };
112        let codec = self.codec.clone();
113        let config = self.config.clone();
114        let mode = self.mode.clone(); // ← new
115
116        tokio::spawn(async move {
117            loop {
118                tokio::select! {
119                    res = listener.accept() => {
120                        if let Ok((stream, addr)) = res {
121                            let _ = stream.set_nodelay(true);
122                            let codec = codec.clone();
123                            let mode = mode.clone();    
124                            let transport = Self::initial_accept(stream, config.clone(), codec, &mode).await;
125
126                            if let Some(mut transport) = transport {
127                                if processor.initial_connect(&mut transport.0).await {
128                                    let mut framed = Framed::new(transport.0, transport.1);
129                                    if processor.initial_framed_connect(&mut framed).await {
130                                        let router = router.clone();
131                                        let prc_clone = processor.clone();
132                                        tokio::spawn(async move {
133                                            Self::handle_connection(addr, framed, router.as_ref(), prc_clone).await;
134                                        });
135                                    }
136                                } else {
137                                    let _ = transport.0.shutdown().await;
138                                }
139                            }
140                        }
141                    }
142                    _ = shutdown_sig.notified() => break,
143                }
144            }
145        })
146    }
147
148    ///Initial accept called for every connection, on connected event.
149    async fn initial_accept(
150        stream: TcpStream,
151        config: Option<ServerConfig>,
152        mut codec_setup: C,
153        mode: &ServerMode,
154    ) -> Option<(Transport, C)> {
155        let transport = match &config {
156            None => Transport::plain(stream),
157            Some(cfg) => {
158                let acceptor = TlsAcceptor::from(Arc::new(cfg.clone()));
159                match acceptor.accept(stream).await {
160                    Ok(tls) => Transport::tls_server(tls),
161                    Err(_) => return None,
162                }
163            }
164        };
165
166        let mut transport = match mode {
167            ServerMode::Tcp => transport,
168            ServerMode::WebSocket => match Transport::accept_websocket(transport).await {
169                Ok(ws_stream) => ws_stream,
170                Err(e) => {
171                    eprintln!("WebSocket handshake failed: {e}");
172                    return None;
173                }
174            },
175        };
176
177        if !codec_setup.initial_setup(&mut transport).await {
178            return None;
179        }
180
181        Some((transport, codec_setup))
182    }
183    ///Stops the acceptor task.
184    pub fn send_stop(&self) {
185        self.shutdown_sig.notify_waiters();
186    }
187
188    ///Main function for every connection
189    async fn handle_connection(
190        addr: SocketAddr,
191        mut stream: Framed<Transport, C>,
192        router: &TfServerRouter<C>,
193        mut processor: TrafficProcessorHolder<C>,
194    ) {
195        use futures_util::SinkExt;
196        let move_sig = tokio::sync::oneshot::channel::<Arc<RwLock<dyn Handler<Codec = C>>>>();
197        let mut move_sig = (Some(move_sig.0), move_sig.1);
198        loop {
199            let meta_data: Result<Option<BytesMut>, bool> =
200                Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
201            if meta_data.is_err() {
202                if meta_data.unwrap_err() {
203                    stream.close().await.unwrap();
204                    return;
205                }
206                continue;
207            }
208
209            let meta_data = meta_data.unwrap();
210            if meta_data.is_none() {
211                continue;
212            }
213            let meta_data = meta_data.unwrap();
214            let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
215                Ok(meta) => meta.has_payload,
216                Err(_) => false,
217            };
218
219            let mut payload: BytesMut = BytesMut::new();
220            if has_payload {
221                let payload_res =
222                    Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
223                if payload_res.is_err() {
224                    if payload_res.unwrap_err() {
225                        stream.close().await.unwrap();
226                        return;
227                    }
228                    continue;
229                }
230                let payload_opt = payload_res.unwrap();
231                if payload_opt.is_none() {
232                    let _ = stream.close().await;
233                    return;
234                }
235                payload = payload_opt.unwrap();
236            }
237            let res = router
238                .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
239                .await;
240
241            let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
242            let res = Self::send_message(&mut stream, message, &mut processor).await;
243
244            if let Ok(requester) = move_sig.1.try_recv() {
245                requester
246                    .write()
247                    .await
248                    .accept_stream(addr, (stream, processor.clone()))
249                    .await;
250                return;
251            }
252
253            match res {
254                Err(_) => {
255                    let _ = stream.close();
256                    return;
257                }
258                _ => {}
259            }
260        }
261    }
262    async fn send_message(
263        stream: &mut Framed<Transport, C>,
264        message: Vec<u8>,
265        processor: &mut TrafficProcessorHolder<C>,
266    ) -> Result<(), io::Error> {
267        let message = Bytes::from(processor.post_process_traffic(message).await);
268        stream.send(message).await
269    }
270
271    async fn receive_message(
272        _: SocketAddr,
273        stream: &mut Framed<Transport, C>,
274        processor: &mut TrafficProcessorHolder<C>,
275    ) -> Result<Option<BytesMut>, bool> {
276        use futures_util::StreamExt;
277        match stream.next().await {
278            Some(data) => match data {
279                Ok(mut data) => {
280                    data = processor.pre_process_traffic(data).await;
281                    return Ok(Some(data));
282                }
283                Err(e) => {
284                    // This is where codec-level decoding errors happen
285                    match e.kind() {
286                        // IO errors usually mean the connection is broken
287                        std::io::ErrorKind::ConnectionReset
288                        | std::io::ErrorKind::ConnectionAborted
289                        | std::io::ErrorKind::BrokenPipe
290                        | std::io::ErrorKind::UnexpectedEof => {
291                            println!("Client disconnected");
292                            return Err(true);
293                        }
294
295                        // Frame too large (if you set max_frame_length)
296                        std::io::ErrorKind::InvalidData => {
297                            eprintln!("Frame exceeded maximum size: {e}");
298                            return Err(false);
299                        }
300
301                        // Other IO errors
302                        _ => {
303                            eprintln!("IO error while reading frame: {e}");
304                            return Err(false);
305                        }
306                    }
307                }
308            },
309            None => {
310                return Err(true);
311            }
312        }
313    }
314}
315
316// Custom Error Display
317impl fmt::Display for ServerErrorEn {
318    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319        match self {
320            ServerErrorEn::MalformedMetaInfo(Some(msg)) => {
321                write!(f, "Malformed meta info: {}", msg)
322            }
323            ServerErrorEn::MalformedMetaInfo(None) => write!(f, "Malformed meta info!"),
324            ServerErrorEn::NoSuchHandler(Some(msg)) => write!(f, "No such handler: {}", msg),
325            ServerErrorEn::NoSuchHandler(None) => write!(f, "No such handler!"),
326            InternalError(Some(data)) => {
327                write!(
328                    f,
329                    "{}",
330                    String::from_utf8(data.clone())
331                        .unwrap_or_else(|_| "Internal server error!".to_owned())
332                )
333            }
334            InternalError(None) => {
335                write!(f, "Internal server error!")
336            }
337            ServerErrorEn::PayloadLost => {
338                write!(f, "Payload lost!")
339            }
340        }
341    }
342}
343
344impl std::error::Error for ServerErrorEn {}