tiny_web/sys/
go.rs

1use std::{
2    collections::BTreeMap,
3    io::ErrorKind,
4    process,
5    sync::{
6        atomic::{AtomicBool, Ordering},
7        Arc,
8    },
9    time::Duration,
10};
11
12use tokio::{
13    io::{AsyncReadExt, AsyncWriteExt},
14    net::{TcpListener, TcpStream},
15    runtime::Builder,
16    sync::{oneshot, Mutex},
17    task::JoinHandle,
18    time,
19};
20
21#[cfg(debug_assertions)]
22use tokio::sync::RwLock;
23
24use super::{
25    action::ActMap,
26    cache::CacheSys,
27    dbs::adapter::DB,
28    html::Html,
29    init::{AcceptAddr, Addr, Config, Init},
30    lang::Lang,
31    log::Log,
32    mail::Mail,
33    worker::{Worker, WorkerData},
34};
35
36/// Server management
37pub(crate) struct Go;
38
39impl Go {
40    /// Run server in tokio runtime
41    pub fn run(init: &Init, func: &impl Fn() -> ActMap) {
42        let runtime = match Builder::new_multi_thread().worker_threads(init.conf.max).enable_all().build() {
43            Ok(r) => r,
44            Err(e) => {
45                Log::stop(1, Some(e.to_string()));
46                return;
47            }
48        };
49
50        // Start tokio runtime
51        runtime.block_on(async move {
52            let stop = Arc::new(AtomicBool::new(false));
53
54            // Start listening to incoming clients
55            if let Some(main) = Go::listen(init, Arc::clone(&stop), func).await {
56                if !main.is_finished() {
57                    Go::listen_rpc(&init.conf, stop, main).await;
58                }
59            };
60        });
61    }
62
63    /// Listens for clients on the bind port
64    ///
65    /// # Return
66    ///
67    /// `None` - The server cannot listen on the bind port
68    /// `Some(JoinHandle)` - Handler for main tokio thread
69    async fn listen(init: &Init, stop: Arc<AtomicBool>, func: &impl Fn() -> ActMap) -> Option<JoinHandle<()>> {
70        // Open bind port
71        let bind = match &init.conf.bind {
72            Addr::SocketAddr(a) => TcpListener::bind(a).await,
73            #[cfg(not(target_family = "windows"))]
74            Addr::Uds(s) => TcpListener::bind(s).await,
75        };
76        let bind = match bind {
77            Ok(i) => i,
78            Err(e) => {
79                Log::stop(500, Some(e.to_string()));
80                return None;
81            }
82        };
83        let root_path = Arc::clone(&init.root_path);
84        let db = Arc::clone(&init.conf.db);
85        let lang = Arc::clone(&init.conf.lang);
86        let bind_accept = Arc::clone(&init.conf.bind_accept);
87        let session_key = Arc::clone(&init.conf.session);
88        let salt = Arc::clone(&init.conf.salt);
89        let engine_data = func();
90
91        let action_index = Arc::clone(&init.conf.action_index);
92        let action_not_found = Arc::clone(&init.conf.action_not_found);
93        let action_err = Arc::clone(&init.conf.action_err);
94
95        let max = db.max;
96        let mut db = DB::new(max, db).await?;
97
98        let signal_stop = if db.in_use() { None } else { Some((Arc::clone(&init.conf.rpc), init.conf.stop_signal)) };
99
100        #[cfg(feature = "https")]
101        let acceptor = match Worker::load_cert(Arc::clone(&root_path)) {
102            Ok(acceptor) => acceptor,
103            Err(e) => {
104                Log::stop(507, Some(e.to_string()));
105                return None;
106            }
107        };
108
109        let main = tokio::spawn(async move {
110            #[cfg(not(debug_assertions))]
111            let lang = Arc::new(Lang::new(Arc::clone(&root_path), &lang, &mut db).await);
112
113            #[cfg(debug_assertions)]
114            let lang = Arc::new(RwLock::new(Lang::new(Arc::clone(&root_path), &lang, &mut db).await));
115
116            #[cfg(not(debug_assertions))]
117            let html = Arc::new(Html::new(Arc::clone(&root_path)).await);
118            #[cfg(debug_assertions)]
119            let html = Arc::new(RwLock::new(Html::new(Arc::clone(&root_path)).await));
120
121            let cache = CacheSys::new().await;
122            let engine = Arc::new(engine_data);
123
124            let db = Arc::new(db);
125            let session_key = Arc::clone(&session_key);
126            let salt = Arc::clone(&salt);
127            let mail = Arc::new(Mutex::new(Mail::new(Arc::clone(&db)).await));
128
129            let action_index = Arc::clone(&action_index);
130            let action_not_found = Arc::clone(&action_not_found);
131            let action_err = Arc::clone(&action_err);
132
133            let signal_stop = match signal_stop {
134                Some((ref rpc, stop)) => Some((Arc::clone(rpc), stop)),
135                None => None,
136            };
137
138            let root_path = Arc::clone(&root_path);
139            #[cfg(feature = "https")]
140            let acceptor = Arc::clone(&acceptor);
141
142            // Started (accepted) threads
143            let handles = Arc::new(Mutex::new(BTreeMap::new()));
144            let mut counter: u64 = 0;
145            loop {
146                let (stream, addr) = match bind.accept().await {
147                    Ok((stream, addr)) => (stream, addr),
148                    Err(e) => {
149                        // Check no critical error
150                        match e.kind() {
151                            ErrorKind::ConnectionRefused
152                            | ErrorKind::ConnectionReset
153                            | ErrorKind::Interrupted
154                            | ErrorKind::TimedOut
155                            | ErrorKind::WouldBlock
156                            | ErrorKind::UnexpectedEof => continue,
157                            _ => {
158                                Log::stop(504, Some(e.to_string()));
159                                break;
160                            }
161                        }
162                    }
163                };
164                // Check stop signal
165                if stop.load(Ordering::Relaxed) {
166                    break;
167                }
168
169                let (tx, rx) = oneshot::channel();
170
171                let lang = Arc::clone(&lang);
172                let html = Arc::clone(&html);
173                let cache = Arc::clone(&cache);
174                let engine = Arc::clone(&engine);
175                let db = Arc::clone(&db);
176                let bind_accept = Arc::clone(&bind_accept);
177                let session_key = Arc::clone(&session_key);
178                let salt = Arc::clone(&salt);
179                let mail = Arc::clone(&mail);
180                let action_index = Arc::clone(&action_index);
181                let action_not_found = Arc::clone(&action_not_found);
182                let action_err = Arc::clone(&action_err);
183                let signal_stop = match signal_stop {
184                    Some((ref rpc, stop)) => Some((Arc::clone(rpc), stop)),
185                    None => None,
186                };
187                let root_path = Arc::clone(&root_path);
188                #[cfg(feature = "https")]
189                let acceptor = Arc::clone(&acceptor);
190
191                let handle = tokio::spawn(async move {
192                    let id = counter;
193                    if let Err(e) = stream.set_nodelay(true) {
194                        Log::warning(506, Some(e.to_string()));
195                        return;
196                    }
197                    // Check accept ip
198                    if let AcceptAddr::IpAddr(ip) = &*bind_accept {
199                        if &addr.ip() != ip {
200                            Log::warning(501, Some(addr.ip().to_string()));
201                            return;
202                        }
203                    }
204
205                    // Starting one main thread from the client connection
206                    let data = WorkerData {
207                        engine,
208                        lang,
209                        html,
210                        cache,
211                        db,
212                        session_key,
213                        salt,
214                        mail,
215                        action_index,
216                        action_not_found,
217                        action_err,
218                        stop: signal_stop,
219                        root: root_path,
220                        #[cfg(any(feature = "http", feature = "https"))]
221                        ip: addr.ip(),
222                        #[cfg(feature = "https")]
223                        acceptor,
224                    };
225                    Worker::run(stream, data).await;
226                    if let Err(i) = tx.send(id) {
227                        Log::error(502, Some(i.to_string()));
228                    }
229                });
230                let handles_clone = Arc::clone(&handles);
231                // Handle the termination of the main thread from the client connection
232                tokio::spawn(async move {
233                    handles_clone.lock().await.insert(counter, handle);
234                    if let Ok(id) = rx.await {
235                        handles_clone.lock().await.remove(&id);
236                    };
237                });
238                counter += 1;
239                // Check stop signal
240                if stop.load(Ordering::Relaxed) {
241                    break;
242                }
243            }
244
245            for (_, handle) in handles.lock().await.iter() {
246                handle.abort()
247            }
248            for (_, handle) in handles.lock().await.iter_mut() {
249                if let Err(e) = handle.await {
250                    if !e.is_cancelled() {
251                        Log::stop(505, Some(e.to_string()));
252                    }
253                }
254            }
255        });
256        Some(main)
257    }
258
259    /// Listens for rcp connection
260    async fn listen_rpc(conf: &Config, stop: Arc<AtomicBool>, main: JoinHandle<()>) {
261        // Open rpc port
262        let rpc = match conf.rpc.as_ref() {
263            Addr::SocketAddr(a) => TcpListener::bind(a).await,
264            #[cfg(not(target_family = "windows"))]
265            Addr::Uds(s) => TcpListener::bind(s).await,
266        };
267        let rpc = match rpc {
268            Ok(i) => i,
269            Err(e) => {
270                Log::stop(202, Some(e.to_string()));
271                return;
272            }
273        };
274        loop {
275            // accept rpc
276            let (mut stream, addr) = match rpc.accept().await {
277                Ok(acpt) => acpt,
278                Err(e) => {
279                    Log::warning(231, Some(e.to_string()));
280                    continue;
281                }
282            };
283            if let AcceptAddr::IpAddr(ip) = conf.rpc_accept {
284                if addr.ip() != ip {
285                    Log::warning(203, Some(addr.ip().to_string()));
286                    continue;
287                }
288            }
289            if let Err(e) = stream.set_nodelay(true) {
290                Log::warning(219, Some(e.to_string()));
291                continue;
292            }
293            // read stop key
294            let signal = stream.read_i64();
295            let signal = match time::timeout(Duration::from_secs(2), signal).await {
296                Ok(signal) => match signal {
297                    Ok(signal) => signal,
298                    Err(e) => {
299                        Log::warning(205, Some(e.to_string()));
300                        continue;
301                    }
302                },
303                Err(_) => {
304                    Log::warning(204, None);
305                    continue;
306                }
307            };
308            if signal == conf.stop_signal {
309                // set stop
310                stop.store(true, Ordering::Relaxed);
311                // push current thread id
312                Log::info(207, None);
313                let pid = process::id() as u64;
314                if let Err(e) = stream.write_u64(pid).await {
315                    Log::warning(215, Some(e.to_string()));
316                }
317                // send stop signal
318                Go::send_stop(&conf.bind).await;
319                // wait all threads stop
320                if let Err(e) = main.await {
321                    Log::stop(220, Some(e.to_string()));
322                }
323                break;
324            } else if signal == conf.status_signal {
325                Log::info(227, None);
326                let pid = process::id() as u64;
327                if let Err(e) = stream.write_u64(pid).await {
328                    Log::warning(215, Some(e.to_string()));
329                } else if let Err(e) = stream.write_all("Working...".as_bytes()).await {
330                    Log::warning(215, Some(e.to_string()));
331                }
332            } else {
333                Log::warning(206, Some(signal.to_string()));
334            }
335        }
336    }
337
338    /// Send stop signal to bind port
339    async fn send_stop(addr: &Addr) {
340        #[allow(clippy::infallible_destructuring_match)]
341        match addr {
342            Addr::SocketAddr(s) => match time::timeout(Duration::from_secs(1), TcpStream::connect(s)).await {
343                Ok(stream) => {
344                    if let Err(e) = stream {
345                        Log::warning(222, Some(e.to_string()));
346                    }
347                }
348                Err(_) => {
349                    Log::warning(221, None);
350                }
351            },
352            #[cfg(not(target_family = "windows"))]
353            Addr::Uds(s) => match time::timeout(Duration::from_secs(1), TcpStream::connect(s)).await {
354                Ok(stream) => {
355                    if let Err(e) = stream {
356                        Log::warning(222, Some(e.to_string()));
357                    }
358                }
359                Err(_) => {
360                    Log::warning(221, None);
361                }
362            },
363        }
364    }
365}