Skip to main content

tfserver/server/
server.rs

1use crate::log_macros::{tf_debug, tf_error, tf_info, tf_warn};
2use crate::server::server_router::TfServerRouter;
3use crate::structures::s_type;
4use crate::structures::s_type::{PacketMeta};
5use std::net::SocketAddr;
6use std::ops::Deref;
7use std::sync::Arc;
8
9use tokio::sync::{Mutex, Notify, RwLock};
10
11use crate::codec::codec_trait::TfCodec;
12use crate::server::handler::Handler;
13use crate::structures::traffic_proc::TrafficProcessorHolder;
14use crate::structures::transport::Transport;
15use futures_util::SinkExt;
16use tokio::io;
17use tokio::io::AsyncWriteExt;
18use tokio::net::{TcpListener, TcpStream};
19use tokio::sync::mpsc::{Receiver, Sender};
20use tokio::task::JoinHandle;
21use tokio_rustls::TlsAcceptor;
22use tokio_rustls::rustls::ServerConfig;
23use tokio_util::bytes::{Bytes, BytesMut};
24use tokio_util::codec::Framed;
25
26/// The request channel, used to move out tcp stream out of server control.
27///
28/// When the stream is moved, the server does not own it anymore.
29///
30/// If there is a need to return the stream, only reconnect is available.
31pub type RequestChannel<C> = (
32    Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
33    Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
34);
35
36#[derive(Clone)]
37pub enum ServerMode {
38    /// Plain TCP or TLS
39    Tcp,
40    /// WebSocket upgrade over plain TCP or TLS
41    WebSocket,
42}
43
44/// Base binary TCP server.
45///
46/// `C` is the codec used to encode/decode data.
47///
48/// Recommended default codec is `LengthDelimitedCodec` from the server codec module.
49pub struct TfServer<C>
50where
51    C: TfCodec,
52{
53    router: Arc<TfServerRouter<C>>,
54    socket: Arc<TcpListener>,
55    shutdown_sig: Arc<Notify>,
56    processor: Option<TrafficProcessorHolder<C>>,
57    codec: C,
58    config: Option<ServerConfig>,
59    mode: ServerMode,
60}
61
62impl<C> TfServer<C>
63where
64    C: TfCodec,
65{
66    /// Creates a new server instance bound to `bind_address`.
67    ///
68    /// Returns an error if the address cannot be bound.
69    pub async fn new(
70        bind_address: String,
71        router: Arc<TfServerRouter<C>>,
72        processor: Option<TrafficProcessorHolder<C>>,
73        codec: C,
74        config: Option<ServerConfig>,
75        mode: ServerMode,
76    ) -> Result<Self, io::Error> {
77        let socket = TcpListener::bind(&bind_address).await.map_err(|e| {
78            tf_error!("Failed to bind to {}: {}", bind_address, e);
79            e
80        })?;
81        tf_info!("Server bound to {}", bind_address);
82        Ok(Self {
83            router,
84            socket: Arc::new(socket),
85            shutdown_sig: Arc::new(Notify::new()),
86            processor,
87            codec,
88            config,
89            mode,
90        })
91    }
92
93    /// Start the task for handling connections.
94    ///
95    /// Returns the join handle for the acceptor task.
96    pub async fn start(&mut self) -> JoinHandle<()> {
97        let (listener, router, shutdown_sig) = (
98            self.socket.clone(),
99            self.router.clone(),
100            self.shutdown_sig.clone(),
101        );
102        let mut processor = if let Some(proc) = self.processor.take() {
103            proc
104        } else {
105            TrafficProcessorHolder::new()
106        };
107        let codec = self.codec.clone();
108        let config = self.config.clone();
109        let mode = self.mode.clone();
110
111        tokio::spawn(async move {
112            loop {
113                tokio::select! {
114                    res = listener.accept() => {
115                        match res {
116                            Ok((stream, addr)) => {
117                                tf_debug!("Accepted connection from {}", addr);
118                                let _ = stream.set_nodelay(true);
119                                let codec = codec.clone();
120                                let mode = mode.clone();
121                                let transport = Self::initial_accept(stream, config.clone(), codec, &mode).await;
122
123                                if let Some(mut transport) = transport {
124                                    if processor.initial_connect(&mut transport.0).await {
125                                        let mut framed = Framed::new(transport.0, transport.1);
126                                        if processor.initial_framed_connect(&mut framed).await {
127                                            let router = router.clone();
128                                            let prc_clone = processor.clone();
129                                            tokio::spawn(async move {
130                                                Self::handle_connection(addr, framed, router.as_ref(), prc_clone).await;
131                                            });
132                                        } else {
133                                            tf_warn!("Framed processor rejected connection from {}", addr);
134                                        }
135                                    } else {
136                                        tf_warn!("Processor rejected connection from {}", addr);
137                                        let _ = transport.0.shutdown().await;
138                                    }
139                                }
140                            }
141                            Err(e) => {
142                                tf_warn!("Accept error: {}", e);
143                            }
144                        }
145                    }
146                    _ = shutdown_sig.notified() => {
147                        tf_info!("Server shutting down");
148                        break;
149                    }
150                }
151            }
152        })
153    }
154
155    async fn initial_accept(
156        stream: TcpStream,
157        config: Option<ServerConfig>,
158        mut codec_setup: C,
159        mode: &ServerMode,
160    ) -> Option<(Transport, C)> {
161        let transport = match &config {
162            None => Transport::plain(stream),
163            Some(cfg) => {
164                let acceptor = TlsAcceptor::from(Arc::new(cfg.clone()));
165                match acceptor.accept(stream).await {
166                    Ok(tls) => Transport::tls_server(tls),
167                    Err(e) => {
168                        tf_warn!("TLS handshake failed: {}", e);
169                        return None;
170                    }
171                }
172            }
173        };
174
175        let mut transport = match mode {
176            ServerMode::Tcp => transport,
177            ServerMode::WebSocket => match Transport::accept_websocket(transport).await {
178                Ok(ws_stream) => ws_stream,
179                Err(e) => {
180                    tf_warn!("WebSocket handshake failed: {}", e);
181                    return None;
182                }
183            },
184        };
185
186        if !codec_setup.initial_setup(&mut transport).await {
187            tf_warn!("Codec initial_setup rejected connection");
188            return None;
189        }
190
191        Some((transport, codec_setup))
192    }
193
194    /// Signals the acceptor task to stop.
195    pub fn send_stop(&self) {
196        self.shutdown_sig.notify_waiters();
197    }
198
199    async fn handle_connection(
200        addr: SocketAddr,
201        mut stream: Framed<Transport, C>,
202        router: &TfServerRouter<C>,
203        mut processor: TrafficProcessorHolder<C>,
204    ) {
205        use futures_util::SinkExt;
206        let move_sig = tokio::sync::oneshot::channel::<Arc<RwLock<dyn Handler<Codec = C>>>>();
207        let mut move_sig = (Some(move_sig.0), move_sig.1);
208        loop {
209            let meta_data: Result<Option<BytesMut>, bool> =
210                Self::receive_message(addr, &mut stream, &mut processor).await;
211            if meta_data.is_err() {
212                if meta_data.unwrap_err() {
213                    stream.close().await.unwrap_or_else(|e| {
214                        tf_warn!("Error closing stream for {}: {}", addr, e);
215                    });
216                    return;
217                }
218                continue;
219            }
220
221            let meta_data = meta_data.unwrap();
222            if meta_data.is_none() {
223                continue;
224            }
225            let meta_data = meta_data.unwrap();
226            let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
227                Ok(meta) => meta.has_payload,
228                Err(e) => {
229                    tf_warn!("Failed to deserialize PacketMeta from {}: {}", addr, e);
230                    false
231                }
232            };
233
234            let mut payload: BytesMut = BytesMut::new();
235            if has_payload {
236                let payload_res =
237                    Self::receive_message(addr, &mut stream, &mut processor).await;
238                if payload_res.is_err() {
239                    if payload_res.unwrap_err() {
240                        stream.close().await.unwrap_or_else(|e| {
241                            tf_warn!("Error closing stream for {}: {}", addr, e);
242                        });
243                        return;
244                    }
245                    continue;
246                }
247                let payload_opt = payload_res.unwrap();
248                if payload_opt.is_none() {
249                    let _ = stream.close().await;
250                    return;
251                }
252                payload = payload_opt.unwrap();
253            }
254            let res = router
255                .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
256                .await;
257
258            let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
259            let res = Self::send_message(&mut stream, message, &mut processor).await;
260
261            if let Ok(requester) = move_sig.1.try_recv() {
262                requester
263                    .write()
264                    .await
265                    .accept_stream(addr, (stream, processor.clone()))
266                    .await;
267                return;
268            }
269
270            if let Err(e) = res {
271                tf_warn!("Send error for {}: {}", addr, e);
272                let _ = stream.close();
273                return;
274            }
275        }
276    }
277
278    async fn send_message(
279        stream: &mut Framed<Transport, C>,
280        message: Vec<u8>,
281        processor: &mut TrafficProcessorHolder<C>,
282    ) -> Result<(), io::Error> {
283        let message = Bytes::from(processor.post_process_traffic(message).await);
284        stream.send(message).await
285    }
286
287    async fn receive_message(
288        addr: SocketAddr,
289        stream: &mut Framed<Transport, C>,
290        processor: &mut TrafficProcessorHolder<C>,
291    ) -> Result<Option<BytesMut>, bool> {
292        use futures_util::StreamExt;
293        match stream.next().await {
294            Some(data) => match data {
295                Ok(mut data) => {
296                    data = processor.pre_process_traffic(data).await;
297                    Ok(Some(data))
298                }
299                Err(e) => match e.kind() {
300                    std::io::ErrorKind::ConnectionReset
301                    | std::io::ErrorKind::ConnectionAborted
302                    | std::io::ErrorKind::BrokenPipe
303                    | std::io::ErrorKind::UnexpectedEof => {
304                        tf_debug!("Client {} disconnected", addr);
305                        Err(true)
306                    }
307                    std::io::ErrorKind::InvalidData => {
308                        tf_warn!("Frame exceeded maximum size from {}: {}", addr, e);
309                        Err(false)
310                    }
311                    _ => {
312                        tf_warn!("IO error reading frame from {}: {}", addr, e);
313                        Err(false)
314                    }
315                },
316            },
317            None => Err(true),
318        }
319    }
320}