1use crate::log_macros::{tf_debug, tf_error, tf_info, tf_warn};
2use crate::server::server_router::TfServerRouter;
3use crate::structures::s_type;
4use crate::structures::s_type::ServerErrorEn::InternalError;
5use crate::structures::s_type::{PacketMeta, ServerErrorEn};
6use std::fmt;
7use std::net::SocketAddr;
8use std::ops::Deref;
9use std::sync::Arc;
10
11use tokio::sync::{Mutex, Notify, RwLock};
12
13use crate::codec::codec_trait::TfCodec;
14use crate::server::handler::Handler;
15use crate::structures::traffic_proc::TrafficProcessorHolder;
16use crate::structures::transport::Transport;
17use futures_util::SinkExt;
18use tokio::io;
19use tokio::io::AsyncWriteExt;
20use tokio::net::{TcpListener, TcpStream};
21use tokio::sync::mpsc::{Receiver, Sender};
22use tokio::task::JoinHandle;
23use tokio_rustls::TlsAcceptor;
24use tokio_rustls::rustls::ServerConfig;
25use tokio_util::bytes::{Bytes, BytesMut};
26use tokio_util::codec::Framed;
27
28pub type RequestChannel<C> = (
34 Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
35 Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
36);
37
38#[derive(Clone)]
39pub enum ServerMode {
40 Tcp,
42 WebSocket,
44}
45
46pub struct TfServer<C>
52where
53 C: TfCodec,
54{
55 router: Arc<TfServerRouter<C>>,
56 socket: Arc<TcpListener>,
57 shutdown_sig: Arc<Notify>,
58 processor: Option<TrafficProcessorHolder<C>>,
59 codec: C,
60 config: Option<ServerConfig>,
61 mode: ServerMode,
62}
63
64impl<C> TfServer<C>
65where
66 C: TfCodec,
67{
68 pub async fn new(
72 bind_address: String,
73 router: Arc<TfServerRouter<C>>,
74 processor: Option<TrafficProcessorHolder<C>>,
75 codec: C,
76 config: Option<ServerConfig>,
77 mode: ServerMode,
78 ) -> Result<Self, io::Error> {
79 let socket = TcpListener::bind(&bind_address).await.map_err(|e| {
80 tf_error!("Failed to bind to {}: {}", bind_address, e);
81 e
82 })?;
83 tf_info!("Server bound to {}", bind_address);
84 Ok(Self {
85 router,
86 socket: Arc::new(socket),
87 shutdown_sig: Arc::new(Notify::new()),
88 processor,
89 codec,
90 config,
91 mode,
92 })
93 }
94
95 pub async fn start(&mut self) -> JoinHandle<()> {
99 let (listener, router, shutdown_sig) = (
100 self.socket.clone(),
101 self.router.clone(),
102 self.shutdown_sig.clone(),
103 );
104 let mut processor = if let Some(proc) = self.processor.take() {
105 proc
106 } else {
107 TrafficProcessorHolder::new()
108 };
109 let codec = self.codec.clone();
110 let config = self.config.clone();
111 let mode = self.mode.clone();
112
113 tokio::spawn(async move {
114 loop {
115 tokio::select! {
116 res = listener.accept() => {
117 match res {
118 Ok((stream, addr)) => {
119 tf_debug!("Accepted connection from {}", addr);
120 let _ = stream.set_nodelay(true);
121 let codec = codec.clone();
122 let mode = mode.clone();
123 let transport = Self::initial_accept(stream, config.clone(), codec, &mode).await;
124
125 if let Some(mut transport) = transport {
126 if processor.initial_connect(&mut transport.0).await {
127 let mut framed = Framed::new(transport.0, transport.1);
128 if processor.initial_framed_connect(&mut framed).await {
129 let router = router.clone();
130 let prc_clone = processor.clone();
131 tokio::spawn(async move {
132 Self::handle_connection(addr, framed, router.as_ref(), prc_clone).await;
133 });
134 } else {
135 tf_warn!("Framed processor rejected connection from {}", addr);
136 }
137 } else {
138 tf_warn!("Processor rejected connection from {}", addr);
139 let _ = transport.0.shutdown().await;
140 }
141 }
142 }
143 Err(e) => {
144 tf_warn!("Accept error: {}", e);
145 }
146 }
147 }
148 _ = shutdown_sig.notified() => {
149 tf_info!("Server shutting down");
150 break;
151 }
152 }
153 }
154 })
155 }
156
157 async fn initial_accept(
158 stream: TcpStream,
159 config: Option<ServerConfig>,
160 mut codec_setup: C,
161 mode: &ServerMode,
162 ) -> Option<(Transport, C)> {
163 let transport = match &config {
164 None => Transport::plain(stream),
165 Some(cfg) => {
166 let acceptor = TlsAcceptor::from(Arc::new(cfg.clone()));
167 match acceptor.accept(stream).await {
168 Ok(tls) => Transport::tls_server(tls),
169 Err(e) => {
170 tf_warn!("TLS handshake failed: {}", e);
171 return None;
172 }
173 }
174 }
175 };
176
177 let mut transport = match mode {
178 ServerMode::Tcp => transport,
179 ServerMode::WebSocket => match Transport::accept_websocket(transport).await {
180 Ok(ws_stream) => ws_stream,
181 Err(e) => {
182 tf_warn!("WebSocket handshake failed: {}", e);
183 return None;
184 }
185 },
186 };
187
188 if !codec_setup.initial_setup(&mut transport).await {
189 tf_warn!("Codec initial_setup rejected connection");
190 return None;
191 }
192
193 Some((transport, codec_setup))
194 }
195
196 pub fn send_stop(&self) {
198 self.shutdown_sig.notify_waiters();
199 }
200
201 async fn handle_connection(
202 addr: SocketAddr,
203 mut stream: Framed<Transport, C>,
204 router: &TfServerRouter<C>,
205 mut processor: TrafficProcessorHolder<C>,
206 ) {
207 use futures_util::SinkExt;
208 let move_sig = tokio::sync::oneshot::channel::<Arc<RwLock<dyn Handler<Codec = C>>>>();
209 let mut move_sig = (Some(move_sig.0), move_sig.1);
210 loop {
211 let meta_data: Result<Option<BytesMut>, bool> =
212 Self::receive_message(addr, &mut stream, &mut processor).await;
213 if meta_data.is_err() {
214 if meta_data.unwrap_err() {
215 stream.close().await.unwrap_or_else(|e| {
216 tf_warn!("Error closing stream for {}: {}", addr, e);
217 });
218 return;
219 }
220 continue;
221 }
222
223 let meta_data = meta_data.unwrap();
224 if meta_data.is_none() {
225 continue;
226 }
227 let meta_data = meta_data.unwrap();
228 let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
229 Ok(meta) => meta.has_payload,
230 Err(e) => {
231 tf_warn!("Failed to deserialize PacketMeta from {}: {}", addr, e);
232 false
233 }
234 };
235
236 let mut payload: BytesMut = BytesMut::new();
237 if has_payload {
238 let payload_res =
239 Self::receive_message(addr, &mut stream, &mut processor).await;
240 if payload_res.is_err() {
241 if payload_res.unwrap_err() {
242 stream.close().await.unwrap_or_else(|e| {
243 tf_warn!("Error closing stream for {}: {}", addr, e);
244 });
245 return;
246 }
247 continue;
248 }
249 let payload_opt = payload_res.unwrap();
250 if payload_opt.is_none() {
251 let _ = stream.close().await;
252 return;
253 }
254 payload = payload_opt.unwrap();
255 }
256 let res = router
257 .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
258 .await;
259
260 let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
261 let res = Self::send_message(&mut stream, message, &mut processor).await;
262
263 if let Ok(requester) = move_sig.1.try_recv() {
264 requester
265 .write()
266 .await
267 .accept_stream(addr, (stream, processor.clone()))
268 .await;
269 return;
270 }
271
272 if let Err(e) = res {
273 tf_warn!("Send error for {}: {}", addr, e);
274 let _ = stream.close();
275 return;
276 }
277 }
278 }
279
280 async fn send_message(
281 stream: &mut Framed<Transport, C>,
282 message: Vec<u8>,
283 processor: &mut TrafficProcessorHolder<C>,
284 ) -> Result<(), io::Error> {
285 let message = Bytes::from(processor.post_process_traffic(message).await);
286 stream.send(message).await
287 }
288
289 async fn receive_message(
290 addr: SocketAddr,
291 stream: &mut Framed<Transport, C>,
292 processor: &mut TrafficProcessorHolder<C>,
293 ) -> Result<Option<BytesMut>, bool> {
294 use futures_util::StreamExt;
295 match stream.next().await {
296 Some(data) => match data {
297 Ok(mut data) => {
298 data = processor.pre_process_traffic(data).await;
299 Ok(Some(data))
300 }
301 Err(e) => match e.kind() {
302 std::io::ErrorKind::ConnectionReset
303 | std::io::ErrorKind::ConnectionAborted
304 | std::io::ErrorKind::BrokenPipe
305 | std::io::ErrorKind::UnexpectedEof => {
306 tf_debug!("Client {} disconnected", addr);
307 Err(true)
308 }
309 std::io::ErrorKind::InvalidData => {
310 tf_warn!("Frame exceeded maximum size from {}: {}", addr, e);
311 Err(false)
312 }
313 _ => {
314 tf_warn!("IO error reading frame from {}: {}", addr, e);
315 Err(false)
316 }
317 },
318 },
319 None => Err(true),
320 }
321 }
322}
323
324impl fmt::Display for ServerErrorEn {
325 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326 match self {
327 ServerErrorEn::MalformedMetaInfo(Some(msg)) => {
328 write!(f, "Malformed meta info: {}", msg)
329 }
330 ServerErrorEn::MalformedMetaInfo(None) => write!(f, "Malformed meta info!"),
331 ServerErrorEn::NoSuchHandler(Some(msg)) => write!(f, "No such handler: {}", msg),
332 ServerErrorEn::NoSuchHandler(None) => write!(f, "No such handler!"),
333 InternalError(Some(data)) => {
334 write!(
335 f,
336 "{}",
337 String::from_utf8(data.clone())
338 .unwrap_or_else(|_| "Internal server error!".to_owned())
339 )
340 }
341 InternalError(None) => {
342 write!(f, "Internal server error!")
343 }
344 ServerErrorEn::PayloadLost => {
345 write!(f, "Payload lost!")
346 }
347 }
348 }
349}
350
351impl std::error::Error for ServerErrorEn {}