winmsg_executor/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod util;
4
5use std::{
6    any::Any,
7    cell::Cell,
8    future::Future,
9    mem::{ManuallyDrop, MaybeUninit},
10    panic::{self, AssertUnwindSafe},
11    pin::{pin, Pin},
12    ptr::{self, NonNull},
13    task::{Context, Poll, Waker},
14};
15
16use async_task::Runnable;
17use util::{Window, WindowType};
18use windows_sys::Win32::UI::WindowsAndMessaging::*;
19
20use crate::util::MsgFilterHook;
21
22const MSG_ID_WAKE: u32 = WM_USER;
23
24thread_local! {
25    static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
26    static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
27        if msg.msg == MSG_ID_WAKE {
28            let runnable = unsafe {
29                let runnable_ptr = NonNull::new_unchecked(msg.lparam as *mut _);
30                Runnable::<()>::from_raw(runnable_ptr)
31            };
32            if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
33                PANIC_PAYLOAD.set(Some(panic_payload));
34            }
35            Some(0)
36        } else {
37            None
38        }
39    })
40    .unwrap();
41}
42
43/// An owned permission to join on a task (await its termination).
44///
45/// If a `JoinHandle` is dropped, then its task continues running in the
46/// background and its return value is lost.
47pub struct JoinHandle<T> {
48    task: ManuallyDrop<async_task::Task<T>>,
49}
50
51// Keep the task running when dropped.
52impl<T> Drop for JoinHandle<T> {
53    fn drop(&mut self) {
54        let task = unsafe { ManuallyDrop::take(&mut self.task) };
55        task.detach();
56    }
57}
58
59impl<T> Future for JoinHandle<T> {
60    type Output = T;
61
62    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63        pin!(&mut *self.task).poll(cx)
64    }
65}
66
67unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
68    let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
69
70    // SAFETY: The `future` does not need to be `Send` because the thread that
71    // receives the runnable is our own, meaning the runnable is also dropped
72    // on the original thread.
73    let (runnable, task) = unsafe {
74        async_task::spawn_unchecked(future, move |runnable: Runnable| {
75            PostMessageA(hwnd, MSG_ID_WAKE, 0, runnable.into_raw().as_ptr() as _);
76        })
77    };
78
79    // Trigger initial poll.
80    runnable.schedule();
81
82    JoinHandle {
83        task: ManuallyDrop::new(task),
84    }
85}
86
87/// Spawns a new future on the current thread.
88///
89/// This function may be used to spawn tasks when the message loop is not
90/// running. The provided future will start running once the message loop
91/// is entered with [`block_on`] or [`MessageLoop::run`].
92pub fn spawn_local<T>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
93    // SAFETY: future is `'static`
94    unsafe { spawn_unchecked_lifetime(future) }
95}
96
97/// Runs a future to completion on the calling thread's message loop.
98///
99/// This runs the provided future on the current thread, blocking until it is
100/// complete. Also runs any tasks [`spawn`]ed from the same thread. Note that
101/// any spawned tasks will be suspended after `block_on` returns. Calling
102/// `block_on` again will resume previously spawned tasks.
103///
104/// # Panics
105///
106/// Panics when quitting out of the message loop without the future being
107/// ready. This can happen when the future or any spawned task calls the
108/// `PostQuitMessage()` WinAPI function.
109pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
110    let msg_loop = &MessageLoop::new();
111
112    // Wrap the future so it quits the message loop when finished.
113    // SAFETY: All borrowed variables outlive the task itself because we only
114    // return from this function after the task has finished.
115    let task = unsafe {
116        spawn_unchecked_lifetime(async move {
117            let result = future.await;
118            msg_loop.quit();
119            result
120        })
121    };
122
123    msg_loop.run_loop(|_| FilterResult::Forward);
124
125    poll_ready(task).expect("received unexpected quit message")
126}
127
128fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
129    let future = pin!(future);
130    match future.poll(&mut Context::from_waker(Waker::noop())) {
131        Poll::Ready(result) => Ok(result),
132        Poll::Pending => Err(()),
133    }
134}
135
136/// Return value of the filter closure passed to [`MessageLoop::run`].
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum FilterResult {
139    /// The message is forwarded to the window procedure.
140    Forward,
141
142    /// The message is dropped and not forwarded to the window procedure.
143    Drop,
144}
145
146/// Abstract representation of a message loop.
147///
148/// Not directly constructible, use [`MessageLoop::run`] to create a message
149/// loop. The message loop struct is used to control the message loop behavior
150/// by passing it as an argument to the filter closure of [`MessageLoop::run`].
151pub struct MessageLoop {
152    quit: Cell<bool>,
153}
154
155impl MessageLoop {
156    fn new() -> Self {
157        Self {
158            quit: Cell::new(false),
159        }
160    }
161
162    fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
163        let executor_hwnd = EXECUTOR_WINDOW.with(|ew| ew.hwnd());
164
165        while !self.quit.get() {
166            unsafe {
167                let mut msg = MaybeUninit::uninit();
168                if GetMessageA(msg.as_mut_ptr(), ptr::null_mut(), 0, 0) == 0 {
169                    return;
170                }
171                let msg = msg.assume_init();
172
173                // Do not allow the filter to drop our wake messages.
174                let is_wake_message = msg.hwnd == executor_hwnd && msg.message == MSG_ID_WAKE;
175                if is_wake_message || filter(&msg) == FilterResult::Forward {
176                    TranslateMessage(&msg);
177                    DispatchMessageA(&msg);
178                }
179
180                if let Some(panic_payload) = PANIC_PAYLOAD.take() {
181                    panic::resume_unwind(panic_payload)
182                }
183            }
184        }
185    }
186
187    /// Runs the message loop with a filter closure to inspect and drop messages
188    /// before they are dispatched to their respective window procedure.
189    ///
190    /// Use the [`FilterResult`] return value to control how the message is
191    /// handled. The first argument to the filter closure is the [`MessageLoop`]
192    /// struct itself, which can be used to quit out of the message loop.
193    ///
194    /// Like [`block_on`], this function runs any tasks [`spawn`]ed from the
195    /// same thread. Any spawned tasks will be suspended when `run_message_loop`
196    /// returns.
197    /// Be careful not to drop messages not belonging to a window you
198    /// control or you might risk suspending a task indefinitely when dropping
199    /// its wake message.
200    ///
201    /// `run_message_loop` installs a [`WH_MSGFILTER`] hook to allow inspections
202    /// of messages while modal windows are open.
203    ///
204    /// # Panics and Reentrancy
205    ///
206    /// Panics when called from within another `run_message_loop` filter closure.
207    ///
208    /// A call to [`block_on()`] from within the filter closure creates a nested
209    /// message loop which causes the filter closure to be reentered when a modal
210    /// window is open.
211    ///
212    /// [`WH_MSGFILTER`]: (https://learn.microsoft.com/en-us/windows/win32/winmsg/about-hooks#wh_msgfilter-and-wh_sysmsgfilter)
213    pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
214        let msg_loop = MessageLoop::new();
215
216        // Any modal window (i.e. a right-click menu) blocks the main message loop
217        // and dispatches messages internally. To keep the executor running use a
218        // hook to get access to modal windows' internal message loop.
219        // SAFETY: The Drop implementation of MsgFilterHook unregisters the hook,
220        // ensuring that dispatchers will not be called after the end of the scope.
221        let _hook = unsafe {
222            MsgFilterHook::register(|msg| {
223                panic::catch_unwind(AssertUnwindSafe(|| {
224                    let filter_result = filter(&msg_loop, msg);
225                    // When `quit()` is called, it has no real effect because we
226                    // are running in a modal loop. Post a quit message to exit
227                    // the modal message loop to store the panic payload.
228                    if msg_loop.quit.get() {
229                        PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
230                    }
231                    filter_result == FilterResult::Drop
232                }))
233                .unwrap_or_else(|payload| {
234                    PANIC_PAYLOAD.with(|panic_payload| {
235                        panic_payload.set(Some(payload));
236                    });
237                    // Also exit the modal loop ASAP when a panic occurs.
238                    PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
239                    false
240                })
241            })
242        };
243        msg_loop.run_loop(|msg| filter(&msg_loop, msg));
244    }
245
246    /// Quits the message loop as soon as possible.
247    pub fn quit(&self) {
248        self.quit.set(true);
249    }
250
251    /// Quits the message loop when there are no more messages to process.
252    pub fn quit_when_idle(&self) {
253        unsafe { PostQuitMessage(0) };
254    }
255}
256
257#[cfg(test)]
258mod test {
259    use std::{ffi::CStr, future::poll_fn};
260
261    use windows_sys::Win32::Foundation::HWND;
262
263    use super::*;
264
265    fn post_thread_message(msg: u32) {
266        unsafe { PostMessageA(ptr::null_mut(), msg, 0, 0) };
267    }
268
269    #[test]
270    #[should_panic]
271    fn panic_in_dispatcher() {
272        post_thread_message(WM_USER);
273        MessageLoop::run(|_, _| panic!());
274    }
275
276    #[test]
277    fn message_loop_quit() {
278        for i in 0..10 {
279            post_thread_message(WM_USER + i);
280        }
281        MessageLoop::run(|msg_loop, msg| {
282            // This is the only message we observe because we quit the
283            // loop right after it is received.
284            assert_eq!(msg.message, WM_USER);
285            msg_loop.quit();
286            FilterResult::Drop
287        });
288    }
289
290    #[test]
291    fn message_loop_quit_when_idle() {
292        for i in 0..10 {
293            post_thread_message(WM_USER + i);
294        }
295        let expected_msg = Cell::new(0);
296        MessageLoop::run(|msg_loop, msg| {
297            assert_eq!(msg.message, WM_USER + expected_msg.get());
298            expected_msg.set(expected_msg.get() + 1);
299            msg_loop.quit_when_idle();
300            FilterResult::Drop
301        });
302        assert_eq!(expected_msg.get(), 10);
303    }
304
305    #[test]
306    fn nested_block_on() {
307        let count: Cell<usize> = Cell::new(0);
308
309        block_on(async {
310            assert_eq!(count.get(), 0);
311            count.set(count.get() + 1);
312
313            block_on(async {
314                assert_eq!(count.get(), 1);
315                count.set(count.get() + 1);
316            });
317
318            assert_eq!(count.get(), 2);
319            count.set(count.get() + 1);
320        });
321
322        assert_eq!(count.get(), 3);
323    }
324
325    #[test]
326    #[should_panic]
327    fn nested_message_loop() {
328        post_thread_message(WM_USER);
329        MessageLoop::run(|_, _| {
330            MessageLoop::run(|_, _| FilterResult::Drop);
331            FilterResult::Drop
332        });
333    }
334
335    async fn yield_now() {
336        let mut yielded = false;
337        poll_fn(|cx| {
338            if yielded {
339                Poll::Ready(())
340            } else {
341                yielded = true;
342                cx.waker().wake_by_ref();
343                Poll::Pending
344            }
345        })
346        .await;
347    }
348
349    #[test]
350    fn nested_message_loop_block_on() {
351        let inner_executed = Cell::new(false);
352
353        post_thread_message(WM_USER);
354        MessageLoop::run(|msg_loop, _| {
355            block_on(async {
356                inner_executed.set(true);
357            });
358            msg_loop.quit();
359            FilterResult::Forward
360        });
361
362        assert!(inner_executed.get());
363    }
364
365    #[test]
366    fn nested_message_loop_block_on_quit() {
367        post_thread_message(WM_USER);
368        MessageLoop::run(|msg_loop, _| {
369            block_on(async {
370                msg_loop.quit();
371            });
372            FilterResult::Forward
373        });
374    }
375
376    fn window_by_name(name: &CStr) -> HWND {
377        unsafe { FindWindowA(ptr::null_mut(), name.as_ptr() as _) }
378    }
379
380    #[test]
381    fn running_spawned_with_modal_dialog() {
382        // The window name must be unique for each test because cargo runs tests
383        // in parallel and we do not want to close the window of another test.
384        let window_name = c"running_spawned_with_modal_dialog";
385
386        let task = spawn_local(async {
387            // Wait for modal window to be open.
388            while window_by_name(window_name).is_null() {
389                yield_now().await;
390            }
391
392            // Do some async work with modal dialog open.
393            for _ in 0..10 {
394                yield_now().await;
395            }
396
397            // Close the modal window.
398            unsafe {
399                SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0);
400            }
401        });
402
403        block_on(async {
404            unsafe {
405                MessageBoxA(
406                    ptr::null_mut(),
407                    ptr::null_mut(),
408                    window_name.as_ptr() as _,
409                    0,
410                );
411            }
412            task.await;
413        });
414    }
415
416    // This test does not actually expect the library to panic.
417    // The panic is rather an convenient way to signal if the filter closure is
418    // reentered (which is the expected behaviour).
419    #[test]
420    #[should_panic]
421    fn reenter_filter_closure_panic() {
422        // The window name must be unique for each test because cargo runs tests
423        // in parallel and we do not want to close the window of another test.
424        let window_name = c"reenter_filter_closure";
425
426        post_thread_message(WM_USER);
427
428        let running_filter_closure = Cell::new(false);
429        MessageLoop::run(|_, msg| {
430            assert!(
431                !running_filter_closure.replace(true),
432                "Filter closure reentered"
433            );
434
435            if msg.hwnd.is_null() && msg.message == WM_USER {
436                unsafe {
437                    MessageBoxA(
438                        ptr::null_mut(),
439                        ptr::null_mut(),
440                        window_name.as_ptr() as _,
441                        0,
442                    );
443                }
444            }
445
446            running_filter_closure.set(false);
447            FilterResult::Forward
448        });
449    }
450
451    #[test]
452    fn reenter_filter_closure_quit() {
453        // The window name must be unique for each test because cargo runs tests
454        // in parallel and we do not want to close the window of another test.
455        let window_name = c"reenter_filter_closure";
456
457        post_thread_message(WM_USER);
458
459        let running_filter_closure = Cell::new(false);
460        MessageLoop::run(|msg_loop, msg| {
461            if running_filter_closure.replace(true) {
462                msg_loop.quit();
463            }
464
465            if msg.hwnd.is_null() && msg.message == WM_USER {
466                unsafe {
467                    MessageBoxA(
468                        ptr::null_mut(),
469                        ptr::null_mut(),
470                        window_name.as_ptr() as _,
471                        0,
472                    );
473                }
474            }
475
476            running_filter_closure.set(false);
477            FilterResult::Forward
478        });
479    }
480
481    #[test]
482    fn message_loop_with_modal_dialog() {
483        // The window name must be unique for each test because cargo runs tests
484        // in parallel and we do not want to close the window of another test.
485        let window_name = c"message_loop_with_modal_dialog";
486
487        spawn_local(async {
488            unsafe {
489                MessageBoxA(
490                    ptr::null_mut(),
491                    ptr::null_mut(),
492                    window_name.as_ptr() as _,
493                    0,
494                );
495            }
496        });
497
498        spawn_local(async {
499            // Check if modal window is actually open.
500            assert!(!window_by_name(window_name).is_null());
501
502            for i in 0..10 {
503                post_thread_message(WM_USER + i);
504                yield_now().await;
505            }
506
507            // Close modal window again.
508            unsafe { SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0) };
509        });
510
511        let expected_msg = Cell::new(0);
512        MessageLoop::run(|msg_loop, msg| {
513            if msg.hwnd.is_null() && msg.message >= WM_USER {
514                assert_eq!(msg.message, WM_USER + expected_msg.get());
515                expected_msg.set(expected_msg.get() + 1);
516                msg_loop.quit_when_idle();
517                FilterResult::Drop
518            } else {
519                FilterResult::Forward
520            }
521        });
522        assert_eq!(expected_msg.get(), 10);
523    }
524
525    #[test]
526    fn reenter_filter_closure_quit_when_idle() {
527        // The window name must be unique for each test because cargo runs tests
528        // in parallel and we do not want to close the window of another test.
529        let window_name = c"reenter_filter_closure";
530
531        post_thread_message(WM_USER);
532
533        let running_filter_closure = Cell::new(false);
534        MessageLoop::run(|msg_loop, msg| {
535            if running_filter_closure.replace(true) {
536                msg_loop.quit_when_idle();
537            }
538
539            if msg.hwnd.is_null() && msg.message == WM_USER {
540                unsafe {
541                    MessageBoxA(
542                        ptr::null_mut(),
543                        ptr::null_mut(),
544                        window_name.as_ptr() as _,
545                        0,
546                    );
547                }
548            }
549
550            running_filter_closure.set(false);
551            FilterResult::Forward
552        });
553    }
554
555    #[test]
556    fn disallow_wake_message_filtering() {
557        let msg_loop = MessageLoop::new();
558        let msg_loop = Box::leak(Box::new(msg_loop));
559
560        // `MSG_ID_WAKE` message for the custom should be filtered by the run loop filter below.
561        let custom_wnd = Window::new(WindowType::MessageOnly, (), |_, msg| {
562            assert_ne!(msg.msg, MSG_ID_WAKE);
563            None
564        })
565        .unwrap();
566        unsafe {
567            PostMessageA(custom_wnd.hwnd(), MSG_ID_WAKE, 0, 0);
568        }
569
570        // Spawn a task to ensure that the executor window also has a wake message,
571        // which must not be filtered.
572        spawn_local(async {
573            yield_now().await;
574            yield_now().await;
575            yield_now().await;
576            msg_loop.quit();
577        });
578
579        msg_loop.run_loop(|msg| {
580            // This test is to ensure that this callback is not even called for internal wake messages.
581            if msg.message == MSG_ID_WAKE {
582                assert_ne!(msg.hwnd, EXECUTOR_WINDOW.with(|ew| ew.hwnd()));
583                FilterResult::Drop
584            } else {
585                FilterResult::Forward
586            }
587        });
588    }
589}