winmsg_executor/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod util;
4
5use std::{
6    any::Any,
7    cell::Cell,
8    future::Future,
9    mem::{ManuallyDrop, MaybeUninit},
10    panic::{self, AssertUnwindSafe},
11    pin::{pin, Pin},
12    ptr::{self, NonNull},
13    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
14};
15
16use async_task::Runnable;
17use util::{Window, WindowType};
18use windows_sys::Win32::UI::WindowsAndMessaging::*;
19
20use crate::util::MsgFilterHook;
21
22const MSG_ID_WAKE: u32 = WM_USER;
23
24thread_local! {
25    static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
26    static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
27        if msg.msg == MSG_ID_WAKE {
28            let runnable = unsafe {
29                let runnable_ptr = NonNull::new_unchecked(msg.lparam as *mut _);
30                Runnable::<()>::from_raw(runnable_ptr)
31            };
32            if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
33                PANIC_PAYLOAD.set(Some(panic_payload));
34            }
35            Some(0)
36        } else {
37            None
38        }
39    })
40    .unwrap();
41}
42
43/// An owned permission to join on a task (await its termination).
44///
45/// If a `JoinHandle` is dropped, then its task continues running in the
46/// background and its return value is lost.
47pub struct JoinHandle<T> {
48    task: ManuallyDrop<async_task::Task<T>>,
49}
50
51// Keep the task running when dropped.
52impl<T> Drop for JoinHandle<T> {
53    fn drop(&mut self) {
54        let task = unsafe { ManuallyDrop::take(&mut self.task) };
55        task.detach();
56    }
57}
58
59impl<T> Future for JoinHandle<T> {
60    type Output = T;
61
62    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63        pin!(&mut *self.task).poll(cx)
64    }
65}
66
67unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
68    let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
69
70    // SAFETY: The `future` does not need to be `Send` because the thread that
71    // receives the runnable is our own, meaning the runnable is also dropped
72    // on the original thread.
73    let (runnable, task) = unsafe {
74        async_task::spawn_unchecked(future, move |runnable: Runnable| {
75            PostMessageA(hwnd, MSG_ID_WAKE, 0, runnable.into_raw().as_ptr() as _);
76        })
77    };
78
79    // Trigger initial poll.
80    runnable.schedule();
81
82    JoinHandle {
83        task: ManuallyDrop::new(task),
84    }
85}
86
87/// Spawns a new future on the current thread.
88///
89/// This function may be used to spawn tasks when the message loop is not
90/// running. The provided future will start running once the message loop
91/// is entered with [`block_on`] or [`MessageLoop::run`].
92pub fn spawn_local<T>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
93    // SAFETY: future is `'static`
94    unsafe { spawn_unchecked_lifetime(future) }
95}
96
97/// Runs a future to completion on the calling thread's message loop.
98///
99/// This runs the provided future on the current thread, blocking until it is
100/// complete. Also runs any tasks [`spawn`]ed from the same thread. Note that
101/// any spawned tasks will be suspended after `block_on` returns. Calling
102/// `block_on` again will resume previously spawned tasks.
103///
104/// # Panics
105///
106/// Panics when quitting out of the message loop without the future being
107/// ready. This can happen when the future or any spawned task calls the
108/// `PostQuitMessage()` WinAPI function.
109pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
110    let msg_loop = &MessageLoop::new();
111
112    // Wrap the future so it quits the message loop when finished.
113    // SAFETY: All borrowed variables outlive the task itself because we only
114    // return from this function after the task has finished.
115    let task = unsafe {
116        spawn_unchecked_lifetime(async move {
117            let result = future.await;
118            msg_loop.quit();
119            result
120        })
121    };
122
123    msg_loop.run_loop(|_| FilterResult::Forward);
124
125    poll_ready(task).expect("received unexpected quit message")
126}
127
128fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
129    // TODO: wait for https://github.com/rust-lang/rust/issues/98286 to land.
130    const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
131        |_| RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE),
132        |_| (),
133        |_| (),
134        |_| (),
135    );
136    let noop_waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
137    let future = pin!(future);
138    if let Poll::Ready(result) = future.poll(&mut Context::from_waker(&noop_waker)) {
139        Ok(result)
140    } else {
141        Err(())
142    }
143}
144
145/// Return value of the filter closure passed to [`MessageLoop::run`].
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum FilterResult {
148    /// The message is forwarded to the window procedure.
149    Forward,
150
151    /// The message is dropped and not forwarded to the window procedure.
152    Drop,
153}
154
155/// Abstract representation of a message loop.
156///
157/// Not directly constructible, use [`MessageLoop::run`] to create a message
158/// loop. The message loop struct is used to control the message loop behavior
159/// by passing it as an argument to the filter closure of [`MessageLoop::run`].
160pub struct MessageLoop {
161    quit: Cell<bool>,
162}
163
164impl MessageLoop {
165    fn new() -> Self {
166        Self {
167            quit: Cell::new(false),
168        }
169    }
170
171    fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
172        while !self.quit.get() {
173            unsafe {
174                let mut msg = MaybeUninit::uninit();
175                if GetMessageA(msg.as_mut_ptr(), ptr::null_mut(), 0, 0) == 0 {
176                    return;
177                }
178                let msg = msg.assume_init();
179
180                if filter(&msg) == FilterResult::Forward {
181                    TranslateMessage(&msg);
182                    DispatchMessageA(&msg);
183                }
184                if let Some(panic_payload) = PANIC_PAYLOAD.take() {
185                    panic::resume_unwind(panic_payload)
186                }
187            }
188        }
189    }
190
191    /// Runs the message loop with a filter closure to inspect and drop messages
192    /// before they are dispatched to their respective window procedure.
193    ///
194    /// Use the [`FilterResult`] return value to control how the message is
195    /// handled. The first argument to the filter closure is the [`MessageLoop`]
196    /// struct itself, which can be used to quit out of the message loop.
197    ///
198    /// Like [`block_on`], this function runs any tasks [`spawn`]ed from the
199    /// same thread. Any spawned tasks will be suspended when `run_message_loop`
200    /// returns.
201    /// Be careful not to drop messages not belonging to a window you
202    /// control or you might risk suspending a task indefinitely when dropping
203    /// its wake message.
204    ///
205    /// `run_message_loop` installs a [`WH_MSGFILTER`] hook to allow inspections
206    /// of messages while modal windows are open.
207    ///
208    /// # Panics and Reentrancy
209    ///
210    /// Panics when called from within another `run_message_loop` filter closure.
211    ///
212    /// A call to [`block_on()`] from within the filter closure creates a nested
213    /// message loop which causes the filter closure to be reentered when a modal
214    /// window is open.
215    ///
216    /// [`WH_MSGFILTER`]: (https://learn.microsoft.com/en-us/windows/win32/winmsg/about-hooks#wh_msgfilter-and-wh_sysmsgfilter)
217    pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
218        let msg_loop = MessageLoop::new();
219
220        // Any modal window (i.e. a right-click menu) blocks the main message loop
221        // and dispatches messages internally. To keep the executor running use a
222        // hook to get access to modal windows' internal message loop.
223        // SAFETY: The Drop implementation of MsgFilterHook unregisters the hook,
224        // ensuring that dispatchers will not be called after the end of the scope.
225        let _hook = unsafe {
226            MsgFilterHook::register(|msg| {
227                panic::catch_unwind(AssertUnwindSafe(|| {
228                    let filter_result = filter(&msg_loop, msg);
229                    // When `quit()` is called, it has no real effect because we
230                    // are running in a modal loop. Post a quit message to exit
231                    // the modal message loop to store the panic payload.
232                    if msg_loop.quit.get() {
233                        PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
234                    }
235                    filter_result == FilterResult::Drop
236                }))
237                .unwrap_or_else(|payload| {
238                    PANIC_PAYLOAD.with(|panic_payload| {
239                        panic_payload.set(Some(payload));
240                    });
241                    // Also exit the modal loop ASAP when a panic occurs.
242                    PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
243                    false
244                })
245            })
246        };
247        msg_loop.run_loop(|msg| filter(&msg_loop, msg));
248    }
249
250    /// Quits the message loop as soon as possible.
251    pub fn quit(&self) {
252        self.quit.set(true);
253    }
254
255    /// Quits the message loop when there are no more messages to process.
256    pub fn quit_when_idle(&self) {
257        unsafe { PostQuitMessage(0) };
258    }
259}
260
261#[cfg(test)]
262mod test {
263    use std::{ffi::CStr, future::poll_fn};
264
265    use windows_sys::Win32::Foundation::HWND;
266
267    use super::*;
268
269    fn post_thread_message(msg: u32) {
270        unsafe { PostMessageA(ptr::null_mut(), msg, 0, 0) };
271    }
272
273    #[test]
274    #[should_panic]
275    fn panic_in_dispatcher() {
276        post_thread_message(WM_USER);
277        MessageLoop::run(|_, _| panic!());
278    }
279
280    #[test]
281    fn message_loop_quit() {
282        for i in 0..10 {
283            post_thread_message(WM_USER + i);
284        }
285        MessageLoop::run(|msg_loop, msg| {
286            // This is the only message we observe because we quit the
287            // loop right after it is received.
288            assert_eq!(msg.message, WM_USER);
289            msg_loop.quit();
290            FilterResult::Drop
291        });
292    }
293
294    #[test]
295    fn message_loop_quit_when_idle() {
296        for i in 0..10 {
297            post_thread_message(WM_USER + i);
298        }
299        let expected_msg = Cell::new(0);
300        MessageLoop::run(|msg_loop, msg| {
301            assert_eq!(msg.message, WM_USER + expected_msg.get());
302            expected_msg.set(expected_msg.get() + 1);
303            msg_loop.quit_when_idle();
304            FilterResult::Drop
305        });
306        assert_eq!(expected_msg.get(), 10);
307    }
308
309    #[test]
310    fn nested_block_on() {
311        let count: Cell<usize> = Cell::new(0);
312
313        block_on(async {
314            assert_eq!(count.get(), 0);
315            count.set(count.get() + 1);
316
317            block_on(async {
318                assert_eq!(count.get(), 1);
319                count.set(count.get() + 1);
320            });
321
322            assert_eq!(count.get(), 2);
323            count.set(count.get() + 1);
324        });
325
326        assert_eq!(count.get(), 3);
327    }
328
329    #[test]
330    #[should_panic]
331    fn nested_message_loop() {
332        post_thread_message(WM_USER);
333        MessageLoop::run(|_, _| {
334            MessageLoop::run(|_, _| FilterResult::Drop);
335            FilterResult::Drop
336        });
337    }
338
339    async fn yield_now() {
340        let mut yielded = false;
341        poll_fn(|cx| {
342            if yielded {
343                Poll::Ready(())
344            } else {
345                yielded = true;
346                cx.waker().wake_by_ref();
347                Poll::Pending
348            }
349        })
350        .await;
351    }
352
353    #[test]
354    fn nested_message_loop_block_on() {
355        let inner_executed = Cell::new(false);
356
357        post_thread_message(WM_USER);
358        MessageLoop::run(|msg_loop, _| {
359            block_on(async {
360                inner_executed.set(true);
361            });
362            msg_loop.quit();
363            FilterResult::Forward
364        });
365
366        assert!(inner_executed.get());
367    }
368
369    #[test]
370    fn nested_message_loop_block_on_quit() {
371        post_thread_message(WM_USER);
372        MessageLoop::run(|msg_loop, _| {
373            block_on(async {
374                msg_loop.quit();
375            });
376            FilterResult::Forward
377        });
378    }
379
380    fn window_by_name(name: &CStr) -> HWND {
381        unsafe { FindWindowA(ptr::null_mut(), name.as_ptr() as _) }
382    }
383
384    #[test]
385    fn running_spawned_with_modal_dialog() {
386        // The window name must be unique for each test because cargo runs tests
387        // in parallel and we do not want to close the window of another test.
388        let window_name = c"running_spawned_with_modal_dialog";
389
390        let task = spawn_local(async {
391            // Wait for modal window to be open.
392            while window_by_name(window_name).is_null() {
393                yield_now().await;
394            }
395
396            // Do some async work with modal dialog open.
397            for _ in 0..10 {
398                yield_now().await;
399            }
400
401            // Close the modal window.
402            unsafe {
403                SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0);
404            }
405        });
406
407        block_on(async {
408            unsafe {
409                MessageBoxA(
410                    ptr::null_mut(),
411                    ptr::null_mut(),
412                    window_name.as_ptr() as _,
413                    0,
414                );
415            }
416            task.await;
417        });
418    }
419
420    // This test does not actually expect the library to panic.
421    // The panic is rather an convenient way to signal if the filter closure is
422    // reentered (which is the expected behaviour).
423    #[test]
424    #[should_panic]
425    fn reenter_filter_closure_panic() {
426        // The window name must be unique for each test because cargo runs tests
427        // in parallel and we do not want to close the window of another test.
428        let window_name = c"reenter_filter_closure";
429
430        post_thread_message(WM_USER);
431
432        let running_filter_closure = Cell::new(false);
433        MessageLoop::run(|_, msg| {
434            assert!(
435                !running_filter_closure.replace(true),
436                "Filter closure reentered"
437            );
438
439            if msg.hwnd.is_null() && msg.message == WM_USER {
440                unsafe {
441                    MessageBoxA(
442                        ptr::null_mut(),
443                        ptr::null_mut(),
444                        window_name.as_ptr() as _,
445                        0,
446                    );
447                }
448            }
449
450            running_filter_closure.set(false);
451            FilterResult::Forward
452        });
453    }
454
455    #[test]
456    fn reenter_filter_closure_quit() {
457        // The window name must be unique for each test because cargo runs tests
458        // in parallel and we do not want to close the window of another test.
459        let window_name = c"reenter_filter_closure";
460
461        post_thread_message(WM_USER);
462
463        let running_filter_closure = Cell::new(false);
464        MessageLoop::run(|msg_loop, msg| {
465            if running_filter_closure.replace(true) {
466                msg_loop.quit();
467            }
468
469            if msg.hwnd.is_null() && msg.message == WM_USER {
470                unsafe {
471                    MessageBoxA(
472                        ptr::null_mut(),
473                        ptr::null_mut(),
474                        window_name.as_ptr() as _,
475                        0,
476                    );
477                }
478            }
479
480            running_filter_closure.set(false);
481            FilterResult::Forward
482        });
483    }
484
485    #[test]
486    fn message_loop_with_modal_dialog() {
487        // The window name must be unique for each test because cargo runs tests
488        // in parallel and we do not want to close the window of another test.
489        let window_name = c"message_loop_with_modal_dialog";
490
491        spawn_local(async {
492            unsafe {
493                MessageBoxA(
494                    ptr::null_mut(),
495                    ptr::null_mut(),
496                    window_name.as_ptr() as _,
497                    0,
498                );
499            }
500        });
501
502        spawn_local(async {
503            // Check if modal window is actually open.
504            assert!(!window_by_name(window_name).is_null());
505
506            for i in 0..10 {
507                post_thread_message(WM_USER + i);
508                yield_now().await;
509            }
510
511            // Close modal window again.
512            unsafe { SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0) };
513        });
514
515        let expected_msg = Cell::new(0);
516        MessageLoop::run(|msg_loop, msg| {
517            if msg.hwnd.is_null() && msg.message >= WM_USER {
518                assert_eq!(msg.message, WM_USER + expected_msg.get());
519                expected_msg.set(expected_msg.get() + 1);
520                msg_loop.quit_when_idle();
521                FilterResult::Drop
522            } else {
523                FilterResult::Forward
524            }
525        });
526        assert_eq!(expected_msg.get(), 10);
527    }
528
529    #[test]
530    fn reenter_filter_closure_quit_when_idle() {
531        // The window name must be unique for each test because cargo runs tests
532        // in parallel and we do not want to close the window of another test.
533        let window_name = c"reenter_filter_closure";
534
535        post_thread_message(WM_USER);
536
537        let running_filter_closure = Cell::new(false);
538        MessageLoop::run(|msg_loop, msg| {
539            if running_filter_closure.replace(true) {
540                msg_loop.quit_when_idle();
541            }
542
543            if msg.hwnd.is_null() && msg.message == WM_USER {
544                unsafe {
545                    MessageBoxA(
546                        ptr::null_mut(),
547                        ptr::null_mut(),
548                        window_name.as_ptr() as _,
549                        0,
550                    );
551                }
552            }
553
554            running_filter_closure.set(false);
555            FilterResult::Forward
556        });
557    }
558}