1use crate::server::server_router::TcpServerRouter;
2use crate::structures::s_type;
3use crate::structures::s_type::ServerErrorEn::InternalError;
4use crate::structures::s_type::{PacketMeta, ServerErrorEn};
5use std::fmt;
6use std::net::SocketAddr;
7use std::ops::Deref;
8use std::sync::Arc;
9
10use tokio::sync::{Mutex, Notify};
11
12use crate::server::handler::Handler;
13use crate::structures::traffic_proc::TrafficProcessorHolder;
14use crate::structures::transport::Transport;
15use futures_util::{SinkExt};
16use tokio::io;
17use tokio::io::AsyncWriteExt;
18use tokio::net::{TcpListener, TcpStream};
19use tokio::sync::mpsc::{Receiver, Sender};
20use tokio::task::JoinHandle;
21use tokio_rustls::TlsAcceptor;
22use tokio_rustls::rustls::ServerConfig;
23use tokio_util::bytes::{Bytes, BytesMut};
24use tokio_util::codec::{Decoder, Encoder, Framed};
25use crate::codec::codec_trait::TfCodec;
26
27pub type RequestChannel<C>
33 = (
34 Sender<Arc<Mutex<dyn Handler<Codec = C>>>>,
35 Receiver<Arc<Mutex<dyn Handler<Codec = C>>>>,
36);
37
38pub struct TcpServer<C>
45where
46 C: Encoder<Bytes, Error = io::Error>
47 + Decoder<Item = BytesMut, Error = io::Error>
48 + Send
49 + Sync
50 + Clone
51 + 'static + TfCodec,
52{
53 router: Arc<TcpServerRouter<C>>,
54 socket: Arc<TcpListener>,
55 shutdown_sig: Arc<Notify>,
56 processor: Option<TrafficProcessorHolder<C>>,
57 codec: C,
58 config: Option<ServerConfig>,
59}
60
61impl<C> TcpServer<C>
62where
63 C: Encoder<Bytes, Error = io::Error>
64 + Decoder<Item = BytesMut, Error = io::Error>
65 + Send
66 + Sync
67 + Clone
68 + 'static + TfCodec,
69{
70 pub async fn new(
78 bind_address: String,
79 router: Arc<TcpServerRouter<C>>,
80 processor: Option<TrafficProcessorHolder<C>>,
81 codec: C,
82 config: Option<ServerConfig>,
83 ) -> Self {
84 Self {
85 router,
86 socket: Arc::new(
87 TcpListener::bind(&bind_address)
88 .await
89 .expect("Failed to bind to address"),
90 ),
91 shutdown_sig: Arc::new(Notify::new()),
92 processor,
93 codec,
94 config,
95 }
96 }
97
98 pub async fn start(&mut self) -> JoinHandle<()>{
102 let (listener, router, shutdown_sig) = {
103 (
104 self.socket.clone(),
105 self.router.clone(),
106 self.shutdown_sig.clone(),
107 )
108 };
109 let mut processor = if let Some(proc) = self.processor.take() {
110 proc
111 } else {
112 TrafficProcessorHolder::new()
113 };
114 let codec = self.codec.clone();
115 let config = self.config.clone();
116 tokio::spawn(async move {
117 loop {
118 tokio::select! {
119 res = listener.accept() => {
120 if res.is_ok() {
121 let stream = res.unwrap();
122 stream.0.set_nodelay(true).unwrap();
123 let codec = codec.clone();
124
125
126 let transport = Self::initial_accept(stream.0, config.clone(), codec).await;
127 if let Some(mut transport) = transport {
128 if processor.initial_connect(&mut transport.0).await {
129 let mut framed = Framed::new(transport.0, transport.1);
130
131 if processor.initial_framed_connect(&mut framed).await {
132 let router = router.clone();
133 let prc_clone = processor.clone();
134
135 tokio::spawn(async move {
136 Self::handle_connection(
137 stream.1,
138 framed,
139 router.as_ref(),
140 prc_clone,
141 )
142 .await;
143 });
144 }
145 } else {
146 let _ = transport.0.shutdown().await;
147 }
148 }
149
150 }
151 }
152 _ = shutdown_sig.notified() => break,
153 }
154 }
155 })
156 }
157
158 async fn initial_accept(stream: TcpStream, config: Option<ServerConfig>, mut codec_setup: C) -> Option<(Transport, C)> {
160 if config.is_none() {
161 let mut res = Transport::plain(stream);
162 if !codec_setup.initial_setup(&mut res).await {
163 return None;
164 }
165 return Some((res, codec_setup));
166 } else {
167 let cfg = config.unwrap();
168 let acceptor = TlsAcceptor::from(Arc::new(cfg));
169 let res = acceptor.accept(stream).await;
170 if res.is_err() {
171 return None;
172 }
173 let mut res = Transport::tls_server(res.unwrap());
174 if !codec_setup.initial_setup(&mut res).await {
175 return None;
176 }
177 return Some((res, codec_setup));
178 }
179 }
180 pub fn send_stop(&self) {
182 self.shutdown_sig.notify_one();
183 }
184
185 async fn handle_connection(
187 addr: SocketAddr,
188 mut stream: Framed<Transport, C>,
189 router: &TcpServerRouter<C>,
190 mut processor: TrafficProcessorHolder<C>,
191 ) {
192 use futures_util::SinkExt;
193 let move_sig = tokio::sync::oneshot::channel::<Arc<Mutex<dyn Handler<Codec = C>>>>();
194 let mut move_sig = (Some(move_sig.0), move_sig.1);
195 loop {
196 let meta_data: Result<Option<BytesMut>, bool> =
197 Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
198 if meta_data.is_err() {
199 if meta_data.unwrap_err() {
200 stream.close().await.unwrap();
201 return;
202 }
203 continue;
204 }
205
206 let meta_data = meta_data.unwrap();
207 if meta_data.is_none() {
208 continue;
209 }
210 let meta_data = meta_data.unwrap();
211 let has_payload = match s_type::from_slice::<PacketMeta>(meta_data.deref()) {
212 Ok(meta) => meta.has_payload,
213 Err(_) => false,
214 };
215
216 let mut payload: BytesMut = BytesMut::new();
217 if has_payload {
218 let payload_res =
219 Self::receive_message(addr.clone(), &mut stream, &mut processor).await;
220 if payload_res.is_err() {
221 if payload_res.unwrap_err() {
222 stream.close().await.unwrap();
223 return;
224 }
225 continue;
226 }
227 let payload_opt = payload_res.unwrap();
228 if payload_opt.is_none() {
229 let _ = stream.close().await;
230 return;
231 }
232 payload = payload_opt.unwrap();
233 }
234 let res = router
235 .serve_packet(meta_data, payload, (addr, &mut move_sig.0))
236 .await;
237
238 let message = res.unwrap_or_else(|err| s_type::to_vec(&err).unwrap());
239 let res = Self::send_message(&mut stream, message, &mut processor).await;
240
241 if let Ok(requester) = move_sig.1.try_recv() {
242 requester
243 .lock()
244 .await
245 .accept_stream(addr, (stream, processor.clone()))
246 .await;
247 return;
248 }
249
250 match res {
251 Err(_) => {
252 let _ = stream.close();
253 return;
254 }
255 _ => {}
256 }
257 }
258 }
259 async fn send_message(
260 stream: &mut Framed<Transport, C>,
261 message: Vec<u8>,
262 processor: &mut TrafficProcessorHolder<C>,
263 ) -> Result<(), io::Error> {
264 let message = Bytes::from(processor.post_process_traffic(message).await);
265 stream.send(message).await
266 }
267
268 async fn receive_message(
269 _: SocketAddr,
270 stream: &mut Framed<Transport, C>,
271 processor: &mut TrafficProcessorHolder<C>,
272 ) -> Result<Option<BytesMut>, bool> {
273 use futures_util::StreamExt;
274 match stream.next().await {
275 Some(data) => match data {
276 Ok(mut data) => {
277 data = processor.pre_process_traffic(data).await;
278 return Ok(Some(data));
279 }
280 Err(e) => {
281 match e.kind() {
283 std::io::ErrorKind::ConnectionReset
285 | std::io::ErrorKind::ConnectionAborted
286 | std::io::ErrorKind::BrokenPipe
287 | std::io::ErrorKind::UnexpectedEof => {
288 println!("Client disconnected");
289 return Err(true);
290 }
291
292 std::io::ErrorKind::InvalidData => {
294 eprintln!("Frame exceeded maximum size: {e}");
295 return Err(false);
296 }
297
298 _ => {
300 eprintln!("IO error while reading frame: {e}");
301 return Err(false);
302 }
303 }
304 }
305 },
306 None => {
307 return Err(true);
308 }
309 }
310 }
311}
312
313impl fmt::Display for ServerErrorEn {
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 match self {
317 ServerErrorEn::MalformedMetaInfo(Some(msg)) => {
318 write!(f, "Malformed meta info: {}", msg)
319 }
320 ServerErrorEn::MalformedMetaInfo(None) => write!(f, "Malformed meta info!"),
321 ServerErrorEn::NoSuchHandler(Some(msg)) => write!(f, "No such handler: {}", msg),
322 ServerErrorEn::NoSuchHandler(None) => write!(f, "No such handler!"),
323 InternalError(Some(data)) => {
324 write!(
325 f,
326 "{}",
327 String::from_utf8(data.clone())
328 .unwrap_or_else(|_| "Internal server error!".to_owned())
329 )
330 }
331 InternalError(None) => {
332 write!(f, "Internal server error!")
333 }
334 ServerErrorEn::PayloadLost => {
335 write!(f, "Payload lost!")
336 }
337 }
338 }
339}
340
341impl std::error::Error for ServerErrorEn {}