signal_handler/handler/
impl_tokio.rs

1use std::{collections::HashMap, panic};
2
3use tokio::{spawn, sync::mpsc::unbounded_channel};
4
5use crate::{
6    callback::{Callback, CallbackInfo, CallbackType},
7    handler::{builder::Builder, HandleError, Handler},
8    register::RegisterType,
9};
10
11//
12impl Handler {
13    pub async fn handle_async_with_tokio(self) -> Result<(), HandleError> {
14        let Builder {
15            callbacks,
16            registers,
17        } = self.builder;
18
19        //
20        //
21        //
22        let (register_tx, mut register_rx) = unbounded_channel::<RegisterType>();
23
24        let _sig_id_map = registers
25            .register(register_tx)
26            .map_err(HandleError::RegisterFailed)?;
27
28        //
29        //
30        //
31        let mut initialized_cb = None;
32        let mut wait_for_stop_cb = None;
33
34        let mut callback_tx_map = HashMap::new();
35        let mut callback_join_handle_map = HashMap::new();
36
37        for (tp, cb) in callbacks.into_inner() {
38            match tp {
39                CallbackType::Initialized => {
40                    initialized_cb = Some(cb);
41                    continue;
42                }
43                CallbackType::ReloadConfig => {}
44                CallbackType::WaitForStop => {
45                    wait_for_stop_cb = Some(cb);
46                    continue;
47                }
48                CallbackType::PrintStats => {}
49            }
50
51            let (tx, mut rx) = unbounded_channel::<CallbackInfo>();
52
53            let join_handle = spawn(async move {
54                let mut latest_finish_time = None;
55
56                #[allow(clippy::while_let_loop)]
57                loop {
58                    match rx.recv().await {
59                        Some(info) => {
60                            if let Some(latest_finish_time) = latest_finish_time {
61                                if latest_finish_time > *info.time() {
62                                    continue;
63                                }
64                            }
65
66                            match &cb {
67                                Callback::Sync(cb) => cb(info),
68                                Callback::Async(cb) => cb(info).await,
69                            }
70
71                            latest_finish_time = Some(CallbackInfo::time_now());
72                        }
73                        None => {
74                            break;
75                        }
76                    }
77                }
78            });
79
80            callback_tx_map.insert(tp, tx);
81            callback_join_handle_map.insert(tp, join_handle);
82        }
83
84        //
85        //
86        //
87        if let Some(cb) = initialized_cb {
88            match &cb {
89                Callback::Sync(cb) => cb(CallbackInfo::new()),
90                Callback::Async(cb) => cb(CallbackInfo::new()).await,
91            }
92        }
93
94        //
95        //
96        //
97        loop {
98            match register_rx.recv().await {
99                #[cfg(not(windows))]
100                Some(RegisterType::ReloadConfig) => {
101                    if let Some(tx_callback) = callback_tx_map.get(&CallbackType::ReloadConfig) {
102                        #[allow(clippy::single_match)]
103                        match tx_callback.send(CallbackInfo::new()) {
104                            Ok(_) => {}
105                            Err(_) => {
106                                // Ignore, disconnected
107                            }
108                        }
109                    }
110                    continue;
111                }
112                Some(RegisterType::WaitForStop) => {
113                    if let Some(cb) = wait_for_stop_cb {
114                        match &cb {
115                            Callback::Sync(cb) => cb(CallbackInfo::new()),
116                            Callback::Async(cb) => cb(CallbackInfo::new()).await,
117                        }
118                    }
119
120                    drop(register_rx);
121
122                    break;
123                }
124                #[cfg(not(windows))]
125                Some(RegisterType::PrintStats) => {
126                    if let Some(tx_callback) = callback_tx_map.get(&CallbackType::PrintStats) {
127                        #[allow(clippy::single_match)]
128                        match tx_callback.send(CallbackInfo::new()) {
129                            Ok(_) => {}
130                            Err(_) => {
131                                // Ignore, disconnected
132                            }
133                        }
134                    }
135                    continue;
136                }
137                None => break,
138            }
139        }
140
141        //
142        //
143        //
144        for (_, tx) in callback_tx_map {
145            drop(tx);
146        }
147
148        for (_, join_handle) in callback_join_handle_map {
149            match join_handle.await {
150                Ok(_) => {}
151                Err(err) => {
152                    if let Ok(err) = err.try_into_panic() {
153                        panic::resume_unwind(err);
154                    }
155                }
156            }
157        }
158
159        Ok(())
160    }
161}