1#[macro_use]
24extern crate lazy_static;
25
26#[cfg(feature = "futures")]
27mod futures;
28#[cfg(feature = "futures")]
29pub use self::futures::*;
30mod windows;
31use crossbeam_channel;
32use std::collections::{HashMap, LinkedList};
33use std::sync::{Arc, Mutex};
34use std::thread;
35use std::{error, fmt, process};
36use winapi::shared::minwindef::{BOOL, DWORD, FALSE, LPARAM, LRESULT, TRUE, UINT, WPARAM};
37use winapi::shared::windef::HWND;
38use winapi::um::wincon::{CTRL_BREAK_EVENT, CTRL_CLOSE_EVENT, CTRL_C_EVENT};
39use winapi::um::winuser::{DefWindowProcW, WM_CLOSE, WM_QUIT};
40
41pub fn trap<RT: Sized>(
59 signals: &'static [Signal],
60 handler: impl Fn(Signal) + Send + Sync + 'static,
61 body: impl FnOnce() -> RT,
62) -> Result<RT, Error> {
63 let _trap_guard = Trap::new(signals, Arc::new(handler))?;
64 Ok(body())
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum Signal {
72 CtrlC,
79
80 CtrlBreak,
84
85 CloseConsole,
89
90 CloseWindow,
95}
96
97impl Signal {
98 fn from_console_ctrl_event(event: DWORD) -> Option<Self> {
99 match event {
100 CTRL_C_EVENT => Some(Signal::CtrlC),
101 CTRL_BREAK_EVENT => Some(Signal::CtrlBreak),
102 CTRL_CLOSE_EVENT => Some(Signal::CloseConsole),
103 _ => None,
104 }
105 }
106
107 fn from_window_message(msg: UINT, wparam: WPARAM, _lparam: LPARAM) -> Option<Self> {
108 if msg == WM_CLOSE {
109 Some(Signal::CloseWindow)
110 } else if msg == *WM_CONSOLE_CTRL {
111 Signal::from_console_ctrl_event(wparam as DWORD)
112 } else {
113 None
114 }
115 }
116}
117
118#[derive(Debug)]
121pub enum Error {
122 SetConsoleCtrlHandler(DWORD),
127
128 CreateWindow(DWORD),
136}
137
138impl error::Error for Error {}
139
140impl fmt::Display for Error {
141 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142 match self {
143 Error::SetConsoleCtrlHandler(code) => write!(
144 f,
145 "Error setting console control handler: {}",
146 windows::format_error(*code).unwrap()
147 ),
148 Error::CreateWindow(code) => write!(
149 f,
150 "Error creating Window: {}",
151 windows::format_error(*code).unwrap()
152 ),
153 }
154 }
155}
156
157lazy_static! {
158 static ref WM_CONSOLE_CTRL: UINT =
159 windows::register_window_message("WINSIG_WM_CONSOLE_CTRL").unwrap();
160 static ref TRAP_STACK: Mutex<TrapStack> = Mutex::new(TrapStack::new());
161 static ref TRAP_OWNER_THREAD_ID: thread::ThreadId = thread::current().id();
162}
163
164struct Trap {
165 signals: &'static [Signal],
166 _phantom: std::marker::PhantomData<std::rc::Rc<u8>>,
167}
168
169impl Trap {
170 fn new(
171 signals: &'static [Signal],
172 handler: Arc<dyn Fn(Signal) + Send + Sync + 'static>,
173 ) -> Result<Self, Error> {
174 assert_eq!(*TRAP_OWNER_THREAD_ID, thread::current().id());
175 let mut trap_stack = TRAP_STACK.lock().unwrap();
176 trap_stack.push_trap(signals, handler)?;
177 Ok(Trap {
178 signals,
179 _phantom: std::marker::PhantomData,
180 })
181 }
182}
183
184impl Drop for Trap {
185 fn drop(&mut self) {
186 let mut trap_stack = TRAP_STACK.lock().unwrap();
187 trap_stack.pop_trap(self.signals.as_ref());
188 }
189}
190
191type TrapCallbacks = HashMap<Signal, LinkedList<Arc<dyn Fn(Signal) + Send + Sync + 'static>>>;
192
193struct TrapStack {
194 num_traps: usize,
195 trap_thread_data: Option<TrapThreadData>,
196 callbacks: TrapCallbacks,
197}
198
199impl TrapStack {
200 fn new() -> TrapStack {
201 TrapStack {
202 num_traps: 0,
203 trap_thread_data: None,
204 callbacks: HashMap::new(),
205 }
206 }
207
208 fn increment_trap_count(&mut self) -> Result<(), Error> {
209 self.num_traps += 1;
210 if self.num_traps == 1 {
211 self.trap_thread_data = Some(TrapThreadData::new()?);
213 }
214 Ok(())
215 }
216
217 fn decrement_trap_count(&mut self) {
218 self.num_traps -= 1;
219 if self.num_traps == 0 {
220 self.trap_thread_data = None;
222 }
223 }
224
225 fn push_trap(
226 &mut self,
227 signals: &[Signal],
228 handler: Arc<dyn Fn(Signal) + Send + Sync + 'static>,
229 ) -> Result<(), Error> {
230 self.increment_trap_count()?;
231 for signal in signals.iter() {
232 self.callbacks
233 .entry(*signal)
234 .or_insert_with(LinkedList::new)
235 .push_back(handler.clone());
236 }
237 Ok(())
238 }
239
240 fn pop_trap(&mut self, signals: &[Signal]) {
241 self.decrement_trap_count();
242 for signal in signals.iter() {
243 let callbacks = self.callbacks.get_mut(signal).unwrap();
244 callbacks.pop_back().unwrap();
245 if callbacks.is_empty() {
246 self.callbacks.remove(signal);
247 }
248 }
249 }
250
251 fn has_handler_for(&self, signal: Signal) -> bool {
252 self.callbacks.contains_key(&signal)
253 }
254
255 fn exit_if_only_window(&self) {
256 if let Some(ref trap_thread_data) = self.trap_thread_data {
257 struct EnumWindowsData {
260 hwnd: HWND,
261 process_id: DWORD,
262 }
263 let enum_windows_data = EnumWindowsData {
264 hwnd: trap_thread_data.window_handle.hwnd,
265 process_id: process::id(),
266 };
267 unsafe extern "system" fn enum_windows_proc(hwnd: HWND, lparam: LPARAM) -> BOOL {
268 let enum_windows_data = &*(lparam as *const EnumWindowsData);
269 if enum_windows_data.hwnd == hwnd {
270 TRUE
271 } else {
272 let (_, process_id) = windows::get_window_thread_process_id(hwnd);
273 if enum_windows_data.process_id == process_id {
274 FALSE
275 } else {
276 TRUE
277 }
278 }
279 }
280 if !windows::enum_windows(
283 enum_windows_proc,
284 (&enum_windows_data as *const EnumWindowsData) as LPARAM,
285 ) {
286 process::exit(0);
287 }
288 } else {
289 unreachable!();
290 }
291 }
292}
293
294struct TrapThreadData {
295 thread: Option<thread::JoinHandle<()>>,
296 thread_id: DWORD,
297 window_handle: windows::WindowHandle,
298}
299
300impl TrapThreadData {
301 fn new() -> Result<TrapThreadData, Error> {
302 windows::set_console_ctrl_handler(console_ctrl_handler, true)
304 .map_err(Error::SetConsoleCtrlHandler)?;
305
306 let (s, r) = crossbeam_channel::bounded(2);
308 let thread = Some(thread::spawn(move || {
309 s.send(windows::get_current_thread_id() as usize).unwrap();
310 let mut window = windows::Window::new(window_proc).unwrap();
311 s.send(window.hwnd as usize).unwrap();
312 window
313 .run_event_loop(|&msg| {
314 if let Some(signal) =
315 Signal::from_window_message(msg.message, msg.wParam, msg.lParam)
316 {
317 let trap_stack = TRAP_STACK.lock().unwrap();
318 if let Some(callback_list) = trap_stack.callbacks.get(&signal) {
319 callback_list.back().unwrap()(signal);
320 } else if msg.message == WM_CLOSE {
321 trap_stack.exit_if_only_window();
323 }
324 }
325 })
326 .unwrap();
327 }));
328 let thread_id = r.recv().unwrap() as DWORD;
329 let hwnd = r.recv().unwrap() as HWND;
330 Ok(TrapThreadData {
331 thread,
332 thread_id,
333 window_handle: windows::WindowHandle { hwnd },
334 })
335 }
336
337 fn enqueue_ctrl_event(&self, event: DWORD) -> Result<(), DWORD> {
338 windows::post_message(self.window_handle, *WM_CONSOLE_CTRL, event as WPARAM, 0)
339 }
340}
341
342impl Drop for TrapThreadData {
343 fn drop(&mut self) {
344 windows::set_console_ctrl_handler(console_ctrl_handler, false).unwrap();
345 windows::post_thread_message(self.thread_id, WM_QUIT, 0, 0).unwrap();
346 self.thread.take().unwrap().join().unwrap();
347 }
348}
349
350unsafe extern "system" fn console_ctrl_handler(event: DWORD) -> BOOL {
351 match Signal::from_console_ctrl_event(event) {
352 Some(signal) => {
353 let trap_stack = TRAP_STACK.lock().unwrap();
354 if trap_stack.has_handler_for(signal) {
355 match trap_stack.trap_thread_data {
358 Some(ref trap_thread_data) => {
359 match trap_thread_data.enqueue_ctrl_event(event) {
360 Ok(_) => TRUE,
361 Err(_) => FALSE,
362 }
363 }
364 None => FALSE,
365 }
366 } else {
367 FALSE
368 }
369 }
370 None => FALSE,
371 }
372}
373
374unsafe extern "system" fn window_proc(
375 hwnd: HWND,
376 msg: UINT,
377 wparam: WPARAM,
378 lparam: LPARAM,
379) -> LRESULT {
380 if msg == WM_CLOSE || msg == *WM_CONSOLE_CTRL {
388 0
389 } else {
390 DefWindowProcW(hwnd, msg, wparam, lparam)
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 static_assertions::assert_not_impl_any!(Trap: Send, Sync);
399 #[test]
400 fn test_nested_traps() {
401 trap(
402 &[Signal::CtrlC, Signal::CloseWindow],
403 |_| {},
404 || {
405 println!("Trap 1");
406 trap(
407 &[Signal::CtrlC, Signal::CtrlBreak],
408 |_| {},
409 || {
410 println!("Trap 2");
411 },
412 )
413 .unwrap();
414 },
415 )
416 .unwrap();
417 }
418
419 #[test]
420 fn test_trap_exit_and_reenter() {
421 trap(
422 &[Signal::CtrlC],
423 |_| {},
424 || {
425 println!("Trap 1");
426 },
427 )
428 .unwrap();
429 trap(
430 &[Signal::CtrlC],
431 |_| {},
432 || {
433 println!("Trap 2");
434 },
435 )
436 .unwrap();
437 }
438}