1use crate::server::server_router::TfServerRouter;
2use crate::structures::s_type;
3use crate::structures::s_type::ServerErrorEn::InternalError;
4use crate::structures::s_type::{PacketMeta, ServerErrorEn};
5use std::fmt;
6use std::net::SocketAddr;
7use std::ops::Deref;
8use std::sync::Arc;
9
10use tokio::sync::{Mutex, Notify, RwLock};
11
12use crate::codec::codec_trait::TfCodec;
13use crate::server::handler::Handler;
14use crate::structures::traffic_proc::TrafficProcessorHolder;
15use crate::structures::transport::Transport;
16use futures_util::SinkExt;
17use tokio::io;
18use tokio::io::AsyncWriteExt;
19use tokio::net::{TcpListener, TcpStream};
20use tokio::sync::mpsc::{Receiver, Sender};
21use tokio::task::JoinHandle;
22use tokio_rustls::TlsAcceptor;
23use tokio_rustls::rustls::ServerConfig;
24use tokio_util::bytes::{Bytes, BytesMut};
25use tokio_util::codec::Framed;
26
27pub type RequestChannel<C> = (
33 Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
34 Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
35);
36
37#[derive(Clone)]
38pub enum ServerMode {
39 Tcp,
41 WebSocket,
43}
44
45pub struct TfServer<C>
52where
53 C: TfCodec,
54{
55 router: Arc<TfServerRouter<C>>,
56 socket: Arc<TcpListener>,
57 shutdown_sig: Arc<Notify>,
58 processor: Option<TrafficProcessorHolder<C>>,
59 codec: C,
60 config: Option<ServerConfig>,
61 mode: ServerMode,
62}
63
64impl<C> TfServer<C>
65where
66 C: TfCodec,
67{
68 pub async fn new(
76 bind_address: String,
77 router: Arc<TfServerRouter<C>>,
78 processor: Option<TrafficProcessorHolder<C>>,
79 codec: C,
80 config: Option<ServerConfig>,
81 mode: ServerMode,
82 ) -> Self {
83 Self {
84 router,
85 socket: Arc::new(
86 TcpListener::bind(&bind_address)
87 .await
88 .expect("Failed to bind to address"),
89 ),
90 shutdown_sig: Arc::new(Notify::new()),
91 processor,
92 codec,
93 config,
94 mode,
95 }
96 }
97
98 pub async fn start(&mut self) -> JoinHandle<()> {
102 let (listener, router, shutdown_sig) = (
103 self.socket.clone(),
104 self.router.clone(),
105 self.shutdown_sig.clone(),
106 );
107 let mut processor = if let Some(proc) = self.processor.take() {
108 proc
109 } else {
110 TrafficProcessorHolder::new()
111 };
112 let codec = self.codec.clone();
113 let config = self.config.clone();
114 let mode = self.mode.clone(); tokio::spawn(async move {
117 loop {
118 tokio::select! {
119 res = listener.accept() => {
120 if let Ok((stream, addr)) = res {
121 let _ = stream.set_nodelay(true);
122 let codec = codec.clone();
123 let mode = mode.clone();
124 let transport = Self::initial_accept(stream, config.clone(), codec, &mode).await;
125
126 if let Some(mut transport) = transport {
127 if processor.initial_connect(&mut transport.0).await {
128 let mut framed = Framed::new(transport.0, transport.1);
129 if processor.initial_framed_connect(&mut framed).await {
130 let router = router.clone();
131 let prc_clone = processor.clone();
132 tokio::spawn(async move {
133 Self::handle_connection(addr, framed, router.as_ref(), prc_clone).await;
134 });
135 }
136 } else {
137 let _ = transport.0.shutdown().await;
138 }
139 }
140 }
141 }
142 _ = shutdown_sig.notified() => break,
143 }
144 }
145 })
146 }
147
148 async fn initial_accept(
150 stream: TcpStream,
151 config: Option<ServerConfig>,
152 mut codec_setup: C,
153 mode: &ServerMode,
154 ) -> Option<(Transport, C)> {
155 let transport = match &config {
156 None => Transport::plain(stream),
157 Some(cfg) => {
158 let acceptor = TlsAcceptor::from(Arc::new(cfg.clone()));
159 match acceptor.accept(stream).await {
160 Ok(tls) => Transport::tls_server(tls),
161 Err(_) => return None,
162 }
163 }
164 };
165
166 let mut transport = match mode {
167 ServerMode::Tcp => transport,
168 ServerMode::WebSocket => match Transport::accept_websocket(transport).await {
169 Ok(ws_stream) => ws_stream,
170 Err(e) => {
171 eprintln!("WebSocket handshake failed: {e}");
172 return None;
173 }
174 },
175 };
176
177 if !codec_setup.initial_setup(&mut transport).await {
178 return None;
179 }
180
181 Some((transport, codec_setup))
182 }
183 pub fn send_stop(&self) {
185 self.shutdown_sig.notify_waiters();
186 }
187
188 async fn handle_connection(
190 addr: SocketAddr,
191 mut stream: Framed<Transport, C>,
192 router: &TfServerRouter<C>,
193 mut processor: TrafficProcessorHolder<C>,
194 ) {
195 use futures_util::SinkExt;
196 let move_sig = tokio::sync::oneshot::channel::<Arc<RwLock<dyn Handler<Codec = C>>>>();
197 let mut move_sig = (Some(move_sig.0), move_sig.1);
198 loop {
199 let meta_data: Result<Option<BytesMut>, bool> =
200 Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
201 if meta_data.is_err() {
202 if meta_data.unwrap_err() {
203 stream.close().await.unwrap();
204 return;
205 }
206 continue;
207 }
208
209 let meta_data = meta_data.unwrap();
210 if meta_data.is_none() {
211 continue;
212 }
213 let meta_data = meta_data.unwrap();
214 let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
215 Ok(meta) => meta.has_payload,
216 Err(_) => false,
217 };
218
219 let mut payload: BytesMut = BytesMut::new();
220 if has_payload {
221 let payload_res =
222 Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
223 if payload_res.is_err() {
224 if payload_res.unwrap_err() {
225 stream.close().await.unwrap();
226 return;
227 }
228 continue;
229 }
230 let payload_opt = payload_res.unwrap();
231 if payload_opt.is_none() {
232 let _ = stream.close().await;
233 return;
234 }
235 payload = payload_opt.unwrap();
236 }
237 let res = router
238 .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
239 .await;
240
241 let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
242 let res = Self::send_message(&mut stream, message, &mut processor).await;
243
244 if let Ok(requester) = move_sig.1.try_recv() {
245 requester
246 .write()
247 .await
248 .accept_stream(addr, (stream, processor.clone()))
249 .await;
250 return;
251 }
252
253 match res {
254 Err(_) => {
255 let _ = stream.close();
256 return;
257 }
258 _ => {}
259 }
260 }
261 }
262 async fn send_message(
263 stream: &mut Framed<Transport, C>,
264 message: Vec<u8>,
265 processor: &mut TrafficProcessorHolder<C>,
266 ) -> Result<(), io::Error> {
267 let message = Bytes::from(processor.post_process_traffic(message).await);
268 stream.send(message).await
269 }
270
271 async fn receive_message(
272 _: SocketAddr,
273 stream: &mut Framed<Transport, C>,
274 processor: &mut TrafficProcessorHolder<C>,
275 ) -> Result<Option<BytesMut>, bool> {
276 use futures_util::StreamExt;
277 match stream.next().await {
278 Some(data) => match data {
279 Ok(mut data) => {
280 data = processor.pre_process_traffic(data).await;
281 return Ok(Some(data));
282 }
283 Err(e) => {
284 match e.kind() {
286 std::io::ErrorKind::ConnectionReset
288 | std::io::ErrorKind::ConnectionAborted
289 | std::io::ErrorKind::BrokenPipe
290 | std::io::ErrorKind::UnexpectedEof => {
291 println!("Client disconnected");
292 return Err(true);
293 }
294
295 std::io::ErrorKind::InvalidData => {
297 eprintln!("Frame exceeded maximum size: {e}");
298 return Err(false);
299 }
300
301 _ => {
303 eprintln!("IO error while reading frame: {e}");
304 return Err(false);
305 }
306 }
307 }
308 },
309 None => {
310 return Err(true);
311 }
312 }
313 }
314}
315
316impl fmt::Display for ServerErrorEn {
318 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319 match self {
320 ServerErrorEn::MalformedMetaInfo(Some(msg)) => {
321 write!(f, "Malformed meta info: {}", msg)
322 }
323 ServerErrorEn::MalformedMetaInfo(None) => write!(f, "Malformed meta info!"),
324 ServerErrorEn::NoSuchHandler(Some(msg)) => write!(f, "No such handler: {}", msg),
325 ServerErrorEn::NoSuchHandler(None) => write!(f, "No such handler!"),
326 InternalError(Some(data)) => {
327 write!(
328 f,
329 "{}",
330 String::from_utf8(data.clone())
331 .unwrap_or_else(|_| "Internal server error!".to_owned())
332 )
333 }
334 InternalError(None) => {
335 write!(f, "Internal server error!")
336 }
337 ServerErrorEn::PayloadLost => {
338 write!(f, "Payload lost!")
339 }
340 }
341 }
342}
343
344impl std::error::Error for ServerErrorEn {}