1use crate::log_macros::{tf_debug, tf_error, tf_info, tf_warn};
2use crate::server::server_router::TfServerRouter;
3use crate::structures::s_type;
4use crate::structures::s_type::{PacketMeta};
5use std::net::SocketAddr;
6use std::ops::Deref;
7use std::sync::Arc;
8
9use tokio::sync::{Mutex, Notify, RwLock};
10
11use crate::codec::codec_trait::TfCodec;
12use crate::server::handler::Handler;
13use crate::structures::traffic_proc::TrafficProcessorHolder;
14use crate::structures::transport::Transport;
15use futures_util::SinkExt;
16use tokio::io;
17use tokio::io::AsyncWriteExt;
18use tokio::net::{TcpListener, TcpStream};
19use tokio::sync::mpsc::{Receiver, Sender};
20use tokio::task::JoinHandle;
21use tokio_rustls::TlsAcceptor;
22use tokio_rustls::rustls::ServerConfig;
23use tokio_util::bytes::{Bytes, BytesMut};
24use tokio_util::codec::Framed;
25
26pub type RequestChannel<C> = (
32 Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
33 Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
34);
35
36#[derive(Clone)]
37pub enum ServerMode {
38 Tcp,
40 WebSocket,
42}
43
44pub struct TfServer<C>
50where
51 C: TfCodec,
52{
53 router: Arc<TfServerRouter<C>>,
54 socket: Arc<TcpListener>,
55 shutdown_sig: Arc<Notify>,
56 processor: Option<TrafficProcessorHolder<C>>,
57 codec: C,
58 config: Option<ServerConfig>,
59 mode: ServerMode,
60}
61
62impl<C> TfServer<C>
63where
64 C: TfCodec,
65{
66 pub async fn new(
70 bind_address: String,
71 router: Arc<TfServerRouter<C>>,
72 processor: Option<TrafficProcessorHolder<C>>,
73 codec: C,
74 config: Option<ServerConfig>,
75 mode: ServerMode,
76 ) -> Result<Self, io::Error> {
77 let socket = TcpListener::bind(&bind_address).await.map_err(|e| {
78 tf_error!("Failed to bind to {}: {}", bind_address, e);
79 e
80 })?;
81 tf_info!("Server bound to {}", bind_address);
82 Ok(Self {
83 router,
84 socket: Arc::new(socket),
85 shutdown_sig: Arc::new(Notify::new()),
86 processor,
87 codec,
88 config,
89 mode,
90 })
91 }
92
93 pub async fn start(&mut self) -> JoinHandle<()> {
97 let (listener, router, shutdown_sig) = (
98 self.socket.clone(),
99 self.router.clone(),
100 self.shutdown_sig.clone(),
101 );
102 let mut processor = if let Some(proc) = self.processor.take() {
103 proc
104 } else {
105 TrafficProcessorHolder::new()
106 };
107 let codec = self.codec.clone();
108 let config = self.config.clone();
109 let mode = self.mode.clone();
110
111 tokio::spawn(async move {
112 loop {
113 tokio::select! {
114 res = listener.accept() => {
115 match res {
116 Ok((stream, addr)) => {
117 tf_debug!("Accepted connection from {}", addr);
118 let _ = stream.set_nodelay(true);
119 let codec = codec.clone();
120 let mode = mode.clone();
121 let transport = Self::initial_accept(stream, config.clone(), codec, &mode).await;
122
123 if let Some(mut transport) = transport {
124 if processor.initial_connect(&mut transport.0).await {
125 let mut framed = Framed::new(transport.0, transport.1);
126 if processor.initial_framed_connect(&mut framed).await {
127 let router = router.clone();
128 let prc_clone = processor.clone();
129 tokio::spawn(async move {
130 Self::handle_connection(addr, framed, router.as_ref(), prc_clone).await;
131 });
132 } else {
133 tf_warn!("Framed processor rejected connection from {}", addr);
134 }
135 } else {
136 tf_warn!("Processor rejected connection from {}", addr);
137 let _ = transport.0.shutdown().await;
138 }
139 }
140 }
141 Err(e) => {
142 tf_warn!("Accept error: {}", e);
143 }
144 }
145 }
146 _ = shutdown_sig.notified() => {
147 tf_info!("Server shutting down");
148 break;
149 }
150 }
151 }
152 })
153 }
154
155 async fn initial_accept(
156 stream: TcpStream,
157 config: Option<ServerConfig>,
158 mut codec_setup: C,
159 mode: &ServerMode,
160 ) -> Option<(Transport, C)> {
161 let transport = match &config {
162 None => Transport::plain(stream),
163 Some(cfg) => {
164 let acceptor = TlsAcceptor::from(Arc::new(cfg.clone()));
165 match acceptor.accept(stream).await {
166 Ok(tls) => Transport::tls_server(tls),
167 Err(e) => {
168 tf_warn!("TLS handshake failed: {}", e);
169 return None;
170 }
171 }
172 }
173 };
174
175 let mut transport = match mode {
176 ServerMode::Tcp => transport,
177 ServerMode::WebSocket => match Transport::accept_websocket(transport).await {
178 Ok(ws_stream) => ws_stream,
179 Err(e) => {
180 tf_warn!("WebSocket handshake failed: {}", e);
181 return None;
182 }
183 },
184 };
185
186 if !codec_setup.initial_setup(&mut transport).await {
187 tf_warn!("Codec initial_setup rejected connection");
188 return None;
189 }
190
191 Some((transport, codec_setup))
192 }
193
194 pub fn send_stop(&self) {
196 self.shutdown_sig.notify_waiters();
197 }
198
199 async fn handle_connection(
200 addr: SocketAddr,
201 mut stream: Framed<Transport, C>,
202 router: &TfServerRouter<C>,
203 mut processor: TrafficProcessorHolder<C>,
204 ) {
205 use futures_util::SinkExt;
206 let move_sig = tokio::sync::oneshot::channel::<Arc<RwLock<dyn Handler<Codec = C>>>>();
207 let mut move_sig = (Some(move_sig.0), move_sig.1);
208 loop {
209 let meta_data: Result<Option<BytesMut>, bool> =
210 Self::receive_message(addr, &mut stream, &mut processor).await;
211 if meta_data.is_err() {
212 if meta_data.unwrap_err() {
213 stream.close().await.unwrap_or_else(|e| {
214 tf_warn!("Error closing stream for {}: {}", addr, e);
215 });
216 return;
217 }
218 continue;
219 }
220
221 let meta_data = meta_data.unwrap();
222 if meta_data.is_none() {
223 continue;
224 }
225 let meta_data = meta_data.unwrap();
226 let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
227 Ok(meta) => meta.has_payload,
228 Err(e) => {
229 tf_warn!("Failed to deserialize PacketMeta from {}: {}", addr, e);
230 false
231 }
232 };
233
234 let mut payload: BytesMut = BytesMut::new();
235 if has_payload {
236 let payload_res =
237 Self::receive_message(addr, &mut stream, &mut processor).await;
238 if payload_res.is_err() {
239 if payload_res.unwrap_err() {
240 stream.close().await.unwrap_or_else(|e| {
241 tf_warn!("Error closing stream for {}: {}", addr, e);
242 });
243 return;
244 }
245 continue;
246 }
247 let payload_opt = payload_res.unwrap();
248 if payload_opt.is_none() {
249 let _ = stream.close().await;
250 return;
251 }
252 payload = payload_opt.unwrap();
253 }
254 let res = router
255 .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
256 .await;
257
258 let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
259 let res = Self::send_message(&mut stream, message, &mut processor).await;
260
261 if let Ok(requester) = move_sig.1.try_recv() {
262 requester
263 .write()
264 .await
265 .accept_stream(addr, (stream, processor.clone()))
266 .await;
267 return;
268 }
269
270 if let Err(e) = res {
271 tf_warn!("Send error for {}: {}", addr, e);
272 let _ = stream.close();
273 return;
274 }
275 }
276 }
277
278 async fn send_message(
279 stream: &mut Framed<Transport, C>,
280 message: Vec<u8>,
281 processor: &mut TrafficProcessorHolder<C>,
282 ) -> Result<(), io::Error> {
283 let message = Bytes::from(processor.post_process_traffic(message).await);
284 stream.send(message).await
285 }
286
287 async fn receive_message(
288 addr: SocketAddr,
289 stream: &mut Framed<Transport, C>,
290 processor: &mut TrafficProcessorHolder<C>,
291 ) -> Result<Option<BytesMut>, bool> {
292 use futures_util::StreamExt;
293 match stream.next().await {
294 Some(data) => match data {
295 Ok(mut data) => {
296 data = processor.pre_process_traffic(data).await;
297 Ok(Some(data))
298 }
299 Err(e) => match e.kind() {
300 std::io::ErrorKind::ConnectionReset
301 | std::io::ErrorKind::ConnectionAborted
302 | std::io::ErrorKind::BrokenPipe
303 | std::io::ErrorKind::UnexpectedEof => {
304 tf_debug!("Client {} disconnected", addr);
305 Err(true)
306 }
307 std::io::ErrorKind::InvalidData => {
308 tf_warn!("Frame exceeded maximum size from {}: {}", addr, e);
309 Err(false)
310 }
311 _ => {
312 tf_warn!("IO error reading frame from {}: {}", addr, e);
313 Err(false)
314 }
315 },
316 },
317 None => Err(true),
318 }
319 }
320}