wintrap/
lib.rs

1//! The `wintrap` crate allows a Windows process to trap one or more abstracted
2//! "signals", running a callback function in a dedicated thread whenever they
3//! are caught while active.
4//!
5//! # Examples
6//!
7//! ```
8//! wintrap::trap(&[wintrap::Signal::CtrlC, wintrap::Signal::CloseWindow], |signal| {
9//!     // handle signal here
10//!     println!("Caught a signal: {:?}", signal);
11//! }, || {
12//!     // do work
13//!     println!("Doing work");
14//! }).unwrap();
15//! ```
16//!
17//! # Caveats
18//!
19//! Please note that it is not possible to correctly trap Ctrl-C signals when
20//! running programs via `cargo run`. You will have to run them directly via
21//! the target directory after building.
22
23#[macro_use]
24extern crate lazy_static;
25
26#[cfg(feature = "futures")]
27mod futures;
28#[cfg(feature = "futures")]
29pub use self::futures::*;
30mod windows;
31use crossbeam_channel;
32use std::collections::{HashMap, LinkedList};
33use std::sync::{Arc, Mutex};
34use std::thread;
35use std::{error, fmt, process};
36use winapi::shared::minwindef::{BOOL, DWORD, FALSE, LPARAM, LRESULT, TRUE, UINT, WPARAM};
37use winapi::shared::windef::HWND;
38use winapi::um::wincon::{CTRL_BREAK_EVENT, CTRL_CLOSE_EVENT, CTRL_C_EVENT};
39use winapi::um::winuser::{DefWindowProcW, WM_CLOSE, WM_QUIT};
40
41/// Associates one or more [Signal]s to an callback function to be executed in
42/// a dedicated thread while `body` is executing. A caveat of its usage is that
43/// *only one thread* is ever able to trap signals throughout the entire
44/// execution of your program. You are free to nest traps freely, however, only
45/// the innermost signal handlers will be executed.
46///
47/// # Arguments
48///
49/// * `signals` - A list of signals to trap during the execution of `body`.
50///
51/// * `handler` - The handler to execute whenever a signal is trapped. These
52/// signals will be trapped and handled in the order that they are received in
53/// a dedicated thread. The handler will *override* the default behavior of the
54/// signal, in which most cases, is to end the process.
55///
56/// * `body` - The code to execute while the trap is active. The return value
57/// will be used as the `Ok` value of the result of the trap call.
58pub fn trap<RT: Sized>(
59    signals: &'static [Signal],
60    handler: impl Fn(Signal) + Send + Sync + 'static,
61    body: impl FnOnce() -> RT,
62) -> Result<RT, Error> {
63    let _trap_guard = Trap::new(signals, Arc::new(handler))?;
64    Ok(body())
65}
66
67/// Represents one of several abstracted "signals" available to Windows
68/// processes. A number of these signals may be associated with a single [trap]
69/// call.
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum Signal {
72    /// `SetConsoleCtrlHandler`-generated `CTRL_C_EVENT`. Equivalent to
73    /// `SIGINT` on Unix. It is typically generated by the user pressing Ctrl+C
74    /// in the console. However, the Restart Manager may also trigger this
75    /// signal; see the
76    /// [MSDN](https://docs.microsoft.com/en-us/windows/desktop/RstMgr/guidelines-for-applications)
77    /// documentation for more details.
78    CtrlC,
79
80    /// `SetConsoleCtrlHandler`-generated `CTRL_BREAK_EVENT`. Roughly analagous
81    /// to `SIGQUIT` on Unix. It is generated by the user pressing Ctrl+Break
82    /// in the console.
83    CtrlBreak,
84
85    /// `SetConsoleCtrlHandler`-generated `CTRL_CLOSE_EVENT`. Roughly analagous
86    /// to `SIGHUP` on Unix. It is generated by the user closing the console
87    /// window.
88    CloseConsole,
89
90    /// A `WM_CLOSE` Window message. Roughly analagous to `SIGTERM` on Unix. It
91    /// is generated by sending WM_CLOSE to the top-level windows in the
92    /// process, which is done by [std::process::Child::kill()] and the Windows
93    /// command line tool `taskkill`, among others.
94    CloseWindow,
95}
96
97impl Signal {
98    fn from_console_ctrl_event(event: DWORD) -> Option<Self> {
99        match event {
100            CTRL_C_EVENT => Some(Signal::CtrlC),
101            CTRL_BREAK_EVENT => Some(Signal::CtrlBreak),
102            CTRL_CLOSE_EVENT => Some(Signal::CloseConsole),
103            _ => None,
104        }
105    }
106
107    fn from_window_message(msg: UINT, wparam: WPARAM, _lparam: LPARAM) -> Option<Self> {
108        if msg == WM_CLOSE {
109            Some(Signal::CloseWindow)
110        } else if msg == *WM_CONSOLE_CTRL {
111            Signal::from_console_ctrl_event(wparam as DWORD)
112        } else {
113            None
114        }
115    }
116}
117
118/// An error that may potentially be generated by [trap]. These errors will
119/// rarely ever be produced, and you can unwrap `Result`s safely in most cases.
120#[derive(Debug)]
121pub enum Error {
122    /// An error setting the console control handler. The DWORD is the Windows
123    /// error code; see the [MSDN
124    /// documentation](https://docs.microsoft.com/en-us/windows/console/setconsolectrlhandler)
125    /// for details.
126    SetConsoleCtrlHandler(DWORD),
127
128    /// An error occurred when creating a window or registering its window
129    /// class. The DWORD is the Windows error code; see the MSDN documentation
130    /// on
131    /// [RegisterClassW](https://docs.microsoft.com/en-us/windows/desktop/api/winuser/nf-winuser-registerclassw)
132    /// and
133    /// [CreateWindowExW](https://docs.microsoft.com/en-us/windows/desktop/api/winuser/nf-winuser-createwindowexw)
134    /// for more details.
135    CreateWindow(DWORD),
136}
137
138impl error::Error for Error {}
139
140impl fmt::Display for Error {
141    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142        match self {
143            Error::SetConsoleCtrlHandler(code) => write!(
144                f,
145                "Error setting console control handler: {}",
146                windows::format_error(*code).unwrap()
147            ),
148            Error::CreateWindow(code) => write!(
149                f,
150                "Error creating Window: {}",
151                windows::format_error(*code).unwrap()
152            ),
153        }
154    }
155}
156
157lazy_static! {
158    static ref WM_CONSOLE_CTRL: UINT =
159        windows::register_window_message("WINSIG_WM_CONSOLE_CTRL").unwrap();
160    static ref TRAP_STACK: Mutex<TrapStack> = Mutex::new(TrapStack::new());
161    static ref TRAP_OWNER_THREAD_ID: thread::ThreadId = thread::current().id();
162}
163
164struct Trap {
165    signals: &'static [Signal],
166    _phantom: std::marker::PhantomData<std::rc::Rc<u8>>,
167}
168
169impl Trap {
170    fn new(
171        signals: &'static [Signal],
172        handler: Arc<dyn Fn(Signal) + Send + Sync + 'static>,
173    ) -> Result<Self, Error> {
174        assert_eq!(*TRAP_OWNER_THREAD_ID, thread::current().id());
175        let mut trap_stack = TRAP_STACK.lock().unwrap();
176        trap_stack.push_trap(signals, handler)?;
177        Ok(Trap {
178            signals,
179            _phantom: std::marker::PhantomData,
180        })
181    }
182}
183
184impl Drop for Trap {
185    fn drop(&mut self) {
186        let mut trap_stack = TRAP_STACK.lock().unwrap();
187        trap_stack.pop_trap(self.signals.as_ref());
188    }
189}
190
191type TrapCallbacks = HashMap<Signal, LinkedList<Arc<dyn Fn(Signal) + Send + Sync + 'static>>>;
192
193struct TrapStack {
194    num_traps: usize,
195    trap_thread_data: Option<TrapThreadData>,
196    callbacks: TrapCallbacks,
197}
198
199impl TrapStack {
200    fn new() -> TrapStack {
201        TrapStack {
202            num_traps: 0,
203            trap_thread_data: None,
204            callbacks: HashMap::new(),
205        }
206    }
207
208    fn increment_trap_count(&mut self) -> Result<(), Error> {
209        self.num_traps += 1;
210        if self.num_traps == 1 {
211            // Initialize the active trap data
212            self.trap_thread_data = Some(TrapThreadData::new()?);
213        }
214        Ok(())
215    }
216
217    fn decrement_trap_count(&mut self) {
218        self.num_traps -= 1;
219        if self.num_traps == 0 {
220            // Drop the active trap data
221            self.trap_thread_data = None;
222        }
223    }
224
225    fn push_trap(
226        &mut self,
227        signals: &[Signal],
228        handler: Arc<dyn Fn(Signal) + Send + Sync + 'static>,
229    ) -> Result<(), Error> {
230        self.increment_trap_count()?;
231        for signal in signals.iter() {
232            self.callbacks
233                .entry(*signal)
234                .or_insert_with(LinkedList::new)
235                .push_back(handler.clone());
236        }
237        Ok(())
238    }
239
240    fn pop_trap(&mut self, signals: &[Signal]) {
241        self.decrement_trap_count();
242        for signal in signals.iter() {
243            let callbacks = self.callbacks.get_mut(signal).unwrap();
244            callbacks.pop_back().unwrap();
245            if callbacks.is_empty() {
246                self.callbacks.remove(signal);
247            }
248        }
249    }
250
251    fn has_handler_for(&self, signal: Signal) -> bool {
252        self.callbacks.contains_key(&signal)
253    }
254
255    fn exit_if_only_window(&self) {
256        if let Some(ref trap_thread_data) = self.trap_thread_data {
257            // If we get a WM_CLOSE event and we don't have a handler for it, AND if
258            // this process does not own any other windows, quit.
259            struct EnumWindowsData {
260                hwnd: HWND,
261                process_id: DWORD,
262            }
263            let enum_windows_data = EnumWindowsData {
264                hwnd: trap_thread_data.window_handle.hwnd,
265                process_id: process::id(),
266            };
267            unsafe extern "system" fn enum_windows_proc(hwnd: HWND, lparam: LPARAM) -> BOOL {
268                let enum_windows_data = &*(lparam as *const EnumWindowsData);
269                if enum_windows_data.hwnd == hwnd {
270                    TRUE
271                } else {
272                    let (_, process_id) = windows::get_window_thread_process_id(hwnd);
273                    if enum_windows_data.process_id == process_id {
274                        FALSE
275                    } else {
276                        TRUE
277                    }
278                }
279            }
280            // If we get through all windows during enumeration, then we didn't
281            // find any other windows that we own.
282            if !windows::enum_windows(
283                enum_windows_proc,
284                (&enum_windows_data as *const EnumWindowsData) as LPARAM,
285            ) {
286                process::exit(0);
287            }
288        } else {
289            unreachable!();
290        }
291    }
292}
293
294struct TrapThreadData {
295    thread: Option<thread::JoinHandle<()>>,
296    thread_id: DWORD,
297    window_handle: windows::WindowHandle,
298}
299
300impl TrapThreadData {
301    fn new() -> Result<TrapThreadData, Error> {
302        // Initialize custom window message, console handler, and thread
303        windows::set_console_ctrl_handler(console_ctrl_handler, true)
304            .map_err(Error::SetConsoleCtrlHandler)?;
305
306        // Window message loop
307        let (s, r) = crossbeam_channel::bounded(2);
308        let thread = Some(thread::spawn(move || {
309            s.send(windows::get_current_thread_id() as usize).unwrap();
310            let mut window = windows::Window::new(window_proc).unwrap();
311            s.send(window.hwnd as usize).unwrap();
312            window
313                .run_event_loop(|&msg| {
314                    if let Some(signal) =
315                        Signal::from_window_message(msg.message, msg.wParam, msg.lParam)
316                    {
317                        let trap_stack = TRAP_STACK.lock().unwrap();
318                        if let Some(callback_list) = trap_stack.callbacks.get(&signal) {
319                            callback_list.back().unwrap()(signal);
320                        } else if msg.message == WM_CLOSE {
321                            // Exit the process if we don't own any other windows.
322                            trap_stack.exit_if_only_window();
323                        }
324                    }
325                })
326                .unwrap();
327        }));
328        let thread_id = r.recv().unwrap() as DWORD;
329        let hwnd = r.recv().unwrap() as HWND;
330        Ok(TrapThreadData {
331            thread,
332            thread_id,
333            window_handle: windows::WindowHandle { hwnd },
334        })
335    }
336
337    fn enqueue_ctrl_event(&self, event: DWORD) -> Result<(), DWORD> {
338        windows::post_message(self.window_handle, *WM_CONSOLE_CTRL, event as WPARAM, 0)
339    }
340}
341
342impl Drop for TrapThreadData {
343    fn drop(&mut self) {
344        windows::set_console_ctrl_handler(console_ctrl_handler, false).unwrap();
345        windows::post_thread_message(self.thread_id, WM_QUIT, 0, 0).unwrap();
346        self.thread.take().unwrap().join().unwrap();
347    }
348}
349
350unsafe extern "system" fn console_ctrl_handler(event: DWORD) -> BOOL {
351    match Signal::from_console_ctrl_event(event) {
352        Some(signal) => {
353            let trap_stack = TRAP_STACK.lock().unwrap();
354            if trap_stack.has_handler_for(signal) {
355                // A handler exists, so queue the signal to be handled in the
356                // window thread
357                match trap_stack.trap_thread_data {
358                    Some(ref trap_thread_data) => {
359                        match trap_thread_data.enqueue_ctrl_event(event) {
360                            Ok(_) => TRUE,
361                            Err(_) => FALSE,
362                        }
363                    }
364                    None => FALSE,
365                }
366            } else {
367                FALSE
368            }
369        }
370        None => FALSE,
371    }
372}
373
374unsafe extern "system" fn window_proc(
375    hwnd: HWND,
376    msg: UINT,
377    wparam: WPARAM,
378    lparam: LPARAM,
379) -> LRESULT {
380    // Don't dare calling any user callbacks over the C function boundry. This
381    // function should just simulate having processed the message by returning
382    // the correct result. The actual processing happens in the callback to
383    // `run_event_loop`.
384
385    // Don't ever run the default handler for WM_CLOSE, as it destroys the
386    // window.
387    if msg == WM_CLOSE || msg == *WM_CONSOLE_CTRL {
388        0
389    } else {
390        DefWindowProcW(hwnd, msg, wparam, lparam)
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    static_assertions::assert_not_impl_any!(Trap: Send, Sync);
399    #[test]
400    fn test_nested_traps() {
401        trap(
402            &[Signal::CtrlC, Signal::CloseWindow],
403            |_| {},
404            || {
405                println!("Trap 1");
406                trap(
407                    &[Signal::CtrlC, Signal::CtrlBreak],
408                    |_| {},
409                    || {
410                        println!("Trap 2");
411                    },
412                )
413                .unwrap();
414            },
415        )
416        .unwrap();
417    }
418
419    #[test]
420    fn test_trap_exit_and_reenter() {
421        trap(
422            &[Signal::CtrlC],
423            |_| {},
424            || {
425                println!("Trap 1");
426            },
427        )
428        .unwrap();
429        trap(
430            &[Signal::CtrlC],
431            |_| {},
432            || {
433                println!("Trap 2");
434            },
435        )
436        .unwrap();
437    }
438}