#![feature(optin_builtin_traits)]
#[macro_use]
extern crate lazy_static;
#[cfg(feature = "futures")]
mod futures;
#[cfg(feature = "futures")]
pub use self::futures::*;
mod windows;
use crossbeam_channel;
use std::collections::{HashMap, LinkedList};
use std::sync::{Arc, Mutex};
use std::thread;
use std::{error, fmt, process};
use winapi::shared::minwindef::{BOOL, DWORD, FALSE, LPARAM, LRESULT, TRUE, UINT, WPARAM};
use winapi::shared::windef::HWND;
use winapi::um::wincon::{CTRL_BREAK_EVENT, CTRL_CLOSE_EVENT, CTRL_C_EVENT};
use winapi::um::winuser::{DefWindowProcW, WM_CLOSE, WM_QUIT};
pub fn trap<RT: Sized>(
signals: &'static [Signal],
handler: impl Fn(Signal) + Send + Sync + 'static,
body: impl FnOnce() -> RT,
) -> Result<RT, Error> {
let _trap_guard = Trap::new(signals, Arc::new(handler))?;
Ok(body())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Signal {
CtrlC,
CtrlBreak,
CloseConsole,
CloseWindow,
}
impl Signal {
fn from_console_ctrl_event(event: DWORD) -> Option<Self> {
match event {
CTRL_C_EVENT => Some(Signal::CtrlC),
CTRL_BREAK_EVENT => Some(Signal::CtrlBreak),
CTRL_CLOSE_EVENT => Some(Signal::CloseConsole),
_ => None,
}
}
fn from_window_message(msg: UINT, wparam: WPARAM, _lparam: LPARAM) -> Option<Self> {
if msg == WM_CLOSE {
Some(Signal::CloseWindow)
} else if msg == *WM_CONSOLE_CTRL {
Signal::from_console_ctrl_event(wparam as DWORD)
} else {
None
}
}
}
#[derive(Debug)]
pub enum Error {
SetConsoleCtrlHandler(DWORD),
CreateWindow(DWORD),
}
impl error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::SetConsoleCtrlHandler(code) => write!(
f,
"Error setting console control handler: {}",
windows::format_error(*code).unwrap()
),
Error::CreateWindow(code) => write!(
f,
"Error creating Window: {}",
windows::format_error(*code).unwrap()
),
}
}
}
lazy_static! {
static ref WM_CONSOLE_CTRL: UINT =
windows::register_window_message("WINSIG_WM_CONSOLE_CTRL").unwrap();
static ref TRAP_STACK: Mutex<TrapStack> = Mutex::new(TrapStack::new());
static ref TRAP_OWNER_THREAD_ID: thread::ThreadId = thread::current().id();
}
struct Trap {
signals: &'static [Signal],
}
impl Trap {
fn new(
signals: &'static [Signal],
handler: Arc<dyn Fn(Signal) + Send + Sync + 'static>,
) -> Result<Self, Error> {
assert_eq!(*TRAP_OWNER_THREAD_ID, thread::current().id());
let mut trap_stack = TRAP_STACK.lock().unwrap();
trap_stack.push_trap(signals, handler)?;
Ok(Trap { signals })
}
}
impl Drop for Trap {
fn drop(&mut self) {
let mut trap_stack = TRAP_STACK.lock().unwrap();
trap_stack.pop_trap(self.signals.as_ref());
}
}
impl !Send for Trap {}
impl !Sync for Trap {}
type TrapCallbacks = HashMap<Signal, LinkedList<Arc<dyn Fn(Signal) + Send + Sync + 'static>>>;
struct TrapStack {
num_traps: usize,
trap_thread_data: Option<TrapThreadData>,
callbacks: TrapCallbacks,
}
impl TrapStack {
fn new() -> TrapStack {
TrapStack {
num_traps: 0,
trap_thread_data: None,
callbacks: HashMap::new(),
}
}
fn increment_trap_count(&mut self) -> Result<(), Error> {
self.num_traps += 1;
if self.num_traps == 1 {
self.trap_thread_data = Some(TrapThreadData::new()?);
}
Ok(())
}
fn decrement_trap_count(&mut self) {
self.num_traps -= 1;
if self.num_traps == 0 {
self.trap_thread_data = None;
}
}
fn push_trap(
&mut self,
signals: &[Signal],
handler: Arc<dyn Fn(Signal) + Send + Sync + 'static>,
) -> Result<(), Error> {
self.increment_trap_count()?;
for signal in signals.iter() {
self.callbacks
.entry(*signal)
.or_insert_with(LinkedList::new)
.push_back(handler.clone());
}
Ok(())
}
fn pop_trap(&mut self, signals: &[Signal]) {
self.decrement_trap_count();
for signal in signals.iter() {
let callbacks = self.callbacks.get_mut(signal).unwrap();
callbacks.pop_back().unwrap();
if callbacks.is_empty() {
self.callbacks.remove(signal);
}
}
}
fn has_handler_for(&self, signal: Signal) -> bool {
self.callbacks.contains_key(&signal)
}
fn exit_if_only_window(&self) {
if let Some(ref trap_thread_data) = self.trap_thread_data {
struct EnumWindowsData {
hwnd: HWND,
process_id: DWORD,
};
let enum_windows_data = EnumWindowsData {
hwnd: trap_thread_data.window_handle.hwnd,
process_id: process::id(),
};
unsafe extern "system" fn enum_windows_proc(hwnd: HWND, lparam: LPARAM) -> BOOL {
let enum_windows_data = &*(lparam as *const EnumWindowsData);
if enum_windows_data.hwnd == hwnd {
TRUE
} else {
let (_, process_id) = windows::get_window_thread_process_id(hwnd);
if enum_windows_data.process_id == process_id {
FALSE
} else {
TRUE
}
}
}
if !windows::enum_windows(
enum_windows_proc,
(&enum_windows_data as *const EnumWindowsData) as LPARAM,
) {
process::exit(0);
}
} else {
unreachable!();
}
}
}
struct TrapThreadData {
thread: Option<thread::JoinHandle<()>>,
thread_id: DWORD,
window_handle: windows::WindowHandle,
}
impl TrapThreadData {
fn new() -> Result<TrapThreadData, Error> {
windows::set_console_ctrl_handler(console_ctrl_handler, true)
.map_err(Error::SetConsoleCtrlHandler)?;
let (s, r) = crossbeam_channel::bounded(2);
let thread = Some(thread::spawn(move || {
s.send(windows::get_current_thread_id() as usize).unwrap();
let mut window = windows::Window::new(window_proc).unwrap();
s.send(window.hwnd as usize).unwrap();
window
.run_event_loop(|&msg| {
if let Some(signal) =
Signal::from_window_message(msg.message, msg.wParam, msg.lParam)
{
let trap_stack = TRAP_STACK.lock().unwrap();
if let Some(callback_list) = trap_stack.callbacks.get(&signal) {
callback_list.back().unwrap()(signal);
} else if msg.message == WM_CLOSE {
trap_stack.exit_if_only_window();
}
}
})
.unwrap();
}));
let thread_id = r.recv().unwrap() as DWORD;
let hwnd = r.recv().unwrap() as HWND;
Ok(TrapThreadData {
thread,
thread_id,
window_handle: windows::WindowHandle { hwnd },
})
}
fn enqueue_ctrl_event(&self, event: DWORD) -> Result<(), DWORD> {
windows::post_message(self.window_handle, *WM_CONSOLE_CTRL, event as WPARAM, 0)
}
}
impl Drop for TrapThreadData {
fn drop(&mut self) {
windows::set_console_ctrl_handler(console_ctrl_handler, false).unwrap();
windows::post_thread_message(self.thread_id, WM_QUIT, 0, 0).unwrap();
self.thread.take().unwrap().join().unwrap();
}
}
unsafe extern "system" fn console_ctrl_handler(event: DWORD) -> BOOL {
match Signal::from_console_ctrl_event(event) {
Some(signal) => {
let trap_stack = TRAP_STACK.lock().unwrap();
if trap_stack.has_handler_for(signal) {
match trap_stack.trap_thread_data {
Some(ref trap_thread_data) => {
match trap_thread_data.enqueue_ctrl_event(event) {
Ok(_) => TRUE,
Err(_) => FALSE,
}
}
None => FALSE,
}
} else {
FALSE
}
}
None => FALSE,
}
}
unsafe extern "system" fn window_proc(
hwnd: HWND,
msg: UINT,
wparam: WPARAM,
lparam: LPARAM,
) -> LRESULT {
if msg == WM_CLOSE || msg == *WM_CONSOLE_CTRL {
0
} else {
DefWindowProcW(hwnd, msg, wparam, lparam)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nested_traps() {
trap(
&[Signal::CtrlC, Signal::CloseWindow],
|_| {},
|| {
println!("Trap 1");
trap(
&[Signal::CtrlC, Signal::CtrlBreak],
|_| {},
|| {
println!("Trap 2");
},
)
.unwrap();
},
)
.unwrap();
}
#[test]
fn test_trap_exit_and_reenter() {
trap(
&[Signal::CtrlC],
|_| {},
|| {
println!("Trap 1");
},
)
.unwrap();
trap(
&[Signal::CtrlC],
|_| {},
|| {
println!("Trap 2");
},
)
.unwrap();
}
}