winmsg_executor/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod util;
4
5use std::{
6    any::Any,
7    cell::Cell,
8    future::Future,
9    mem::{ManuallyDrop, MaybeUninit},
10    panic::{self, AssertUnwindSafe},
11    pin::{pin, Pin},
12    ptr::{self, NonNull},
13    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
14};
15
16use async_task::Runnable;
17use util::{Window, WindowType};
18use windows_sys::Win32::UI::WindowsAndMessaging::*;
19
20use crate::util::MsgFilterHook;
21
22const MSG_ID_WAKE: u32 = WM_USER;
23
24thread_local! {
25    static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
26    static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
27        if msg.msg == MSG_ID_WAKE {
28            let runnable = unsafe {
29                let runnable_ptr = NonNull::new_unchecked(msg.lparam as *mut _);
30                Runnable::<()>::from_raw(runnable_ptr)
31            };
32            if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
33                PANIC_PAYLOAD.set(Some(panic_payload));
34            }
35            Some(0)
36        } else {
37            None
38        }
39    })
40    .unwrap();
41}
42
43/// An owned permission to join on a task (await its termination).
44///
45/// If a `JoinHandle` is dropped, then its task continues running in the
46/// background and its return value is lost.
47pub struct JoinHandle<T> {
48    task: ManuallyDrop<async_task::Task<T>>,
49}
50
51// Keep the task running when dropped.
52impl<T> Drop for JoinHandle<T> {
53    fn drop(&mut self) {
54        let task = unsafe { ManuallyDrop::take(&mut self.task) };
55        task.detach();
56    }
57}
58
59impl<T> Future for JoinHandle<T> {
60    type Output = T;
61
62    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63        pin!(&mut *self.task).poll(cx)
64    }
65}
66
67unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
68    let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
69
70    // SAFETY: The `future` does not need to be `Send` because the thread that
71    // receives the runnable is our own, meaning the runniable is also dropped
72    // on original thread.
73    let (runnable, task) = unsafe {
74        async_task::spawn_unchecked(future, move |runnable: Runnable| {
75            PostMessageA(hwnd, MSG_ID_WAKE, 0, runnable.into_raw().as_ptr() as _);
76        })
77    };
78
79    // Trigger initial poll.
80    runnable.schedule();
81
82    JoinHandle {
83        task: ManuallyDrop::new(task),
84    }
85}
86
87/// Spawns a new future on the current thread.
88///
89/// This function may be used to spawn tasks when the message loop is not
90/// running. The provided future will start running once the message loop
91/// is entered with [`block_on`] or [`MessageLoop::run`].
92pub fn spawn_local<T>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
93    // SAFETY: future is `'static`
94    unsafe { spawn_unchecked_lifetime(future) }
95}
96
97/// Runs a future to completion on the calling threads message loop.
98///
99/// This runs the provided future on the current thread, blocking until it is
100/// complete. Also runs any tasks [`spawn`]ed from the same thread. Note that
101/// any spawned tasks will be suspended after `block_on` returns. Calling
102/// `block_on` again will resume previously spawned tasks.
103///
104/// # Panics
105///
106/// Panics when quitting out of the message loop without the future being
107/// ready. This can happen when calling when the future or any spawned task
108/// calls the `PostQuitMessage()` winapi function.
109pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
110    let msg_loop = &MessageLoop::new();
111
112    // Wrap the future so it quits the message loop when finished.
113    // SAFETY: All borrowed variables outlive the task itself because we only
114    // return from this function after the task has finished.
115    let task = unsafe {
116        spawn_unchecked_lifetime(async move {
117            let result = future.await;
118            msg_loop.quit();
119            result
120        })
121    };
122
123    msg_loop.run_loop(|_| FilterResult::Forward);
124
125    poll_ready(task).expect("received unexpected quit message")
126}
127
128fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
129    // TODO: wait for https://github.com/rust-lang/rust/issues/98286 to land.
130    const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
131        |_| RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE),
132        |_| (),
133        |_| (),
134        |_| (),
135    );
136    let noop_waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
137    let future = pin!(future);
138    if let Poll::Ready(result) = future.poll(&mut Context::from_waker(&noop_waker)) {
139        Ok(result)
140    } else {
141        Err(())
142    }
143}
144
145/// Return value of the filter closure passed to [`MessageLoop::run`].
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum FilterResult {
148    /// The message is forwarded to the window procedure.
149    Forward,
150
151    /// The message is dropped and not forwarded to the window procedure.
152    Drop,
153}
154
155/// Abstract representation of a message loop.
156///
157/// Not directly constructible, use [`MessageLoop::run`] to create a message
158/// loop. The message loop struct is used to control the message loop behavior
159/// by passing it as an argument to the filter closure of [`MessageLoop::run`].
160pub struct MessageLoop {
161    quit: Cell<bool>,
162}
163
164impl MessageLoop {
165    fn new() -> Self {
166        Self {
167            quit: Cell::new(false),
168        }
169    }
170
171    fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
172        while !self.quit.get() {
173            unsafe {
174                let mut msg = MaybeUninit::uninit();
175                if GetMessageA(msg.as_mut_ptr(), ptr::null_mut(), 0, 0) == 0 {
176                    return;
177                }
178                let msg = msg.assume_init();
179
180                if filter(&msg) == FilterResult::Forward {
181                    TranslateMessage(&msg);
182                    DispatchMessageA(&msg);
183                }
184                if let Some(panic_payload) = PANIC_PAYLOAD.take() {
185                    panic::resume_unwind(panic_payload)
186                }
187            }
188        }
189    }
190
191    /// Runs the message loop with a filter closure to inspect and drop messages
192    /// before they are dispatched to their respective window procedure.
193    ///
194    /// Use the [`FilterResult`] return value to control how the message is
195    /// handled. The first argument to the filter closure is the [`MessageLoop`]
196    /// struct itself, which can be used to quit out of the message loop.
197    ///
198    /// Like [`block_on`] this function runs any tasks [`spawn`]ed from the same
199    /// thread. Any spawned tasks will be suspended when the `run_message_loop`
200    /// returns. Be careful not to drop messages not belonging to a window you
201    /// control or you might risk suspending a task indefinitely when dropping
202    /// its wake message.
203    ///
204    /// `run_message_loop` installs a [`WH_MSGFILTER`] hook to allow inspections
205    /// of messages while modal windows are open.
206    ///
207    /// # Panics and Reentrancy
208    ///
209    /// Panics when called from within another `run_message_loop` filter closure.
210    /// A call to [`block_on()`] from within the filter closure creates a nested
211    /// message loop which causes the filter closure to be reentered when a modal
212    /// window is open.
213    ///
214    /// [`WH_MSGFILTER`]: (https://learn.microsoft.com/en-us/windows/win32/winmsg/about-hooks#wh_msgfilter-and-wh_sysmsgfilter)
215    pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
216        let msg_loop = MessageLoop::new();
217
218        // Any modal window (i.e. a right-click menu) blocks the main message loop
219        // and dispatches messages internally. To keep the executor running use a
220        // hook to get access to modal windows' internal message loop.
221        // SAFETY: The Drop implementation of MsgFilterHook unregisters the hook,
222        // ensuring that dispatchers will not be called after the end of the scope.
223        let _hook = unsafe {
224            MsgFilterHook::register(|msg| {
225                panic::catch_unwind(AssertUnwindSafe(|| {
226                    let filter_result = filter(&msg_loop, msg);
227                    // When quit() was called it has no real effect because we
228                    // are running in a modal loop. Post a quit message to exit
229                    // the message loop that is not under our control ASAP.
230                    if msg_loop.quit.get() {
231                        PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
232                    }
233                    filter_result == FilterResult::Drop
234                }))
235                .unwrap_or_else(|payload| {
236                    PANIC_PAYLOAD.with(|panic_payload| {
237                        panic_payload.set(Some(payload));
238                    });
239                    // Also exit the modal loop ASAP when a panic occurs.
240                    PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
241                    false
242                })
243            })
244        };
245        msg_loop.run_loop(|msg| filter(&msg_loop, msg));
246    }
247
248    /// Quits the message loop as soon as possible.
249    pub fn quit(&self) {
250        self.quit.set(true);
251    }
252
253    /// Quits the message loop when there are no more messages to process.
254    pub fn quit_when_idle(&self) {
255        unsafe { PostQuitMessage(0) };
256    }
257}
258
259#[cfg(test)]
260mod test {
261    use std::{ffi::CStr, future::poll_fn};
262
263    use windows_sys::Win32::Foundation::HWND;
264
265    use super::*;
266
267    fn post_thread_message(msg: u32) {
268        unsafe { PostMessageA(ptr::null_mut(), msg, 0, 0) };
269    }
270
271    #[test]
272    #[should_panic]
273    fn panic_in_dispatcher() {
274        post_thread_message(WM_USER);
275        MessageLoop::run(|_, _| panic!());
276    }
277
278    #[test]
279    fn message_loop_quit() {
280        for i in 0..10 {
281            post_thread_message(WM_USER + i);
282        }
283        MessageLoop::run(|msg_loop, msg| {
284            // This is the only ever message we observe becasue we quit the
285            // loop right after it is received.
286            assert_eq!(msg.message, WM_USER);
287            msg_loop.quit();
288            FilterResult::Drop
289        });
290    }
291
292    #[test]
293    fn message_loop_quit_when_idle() {
294        for i in 0..10 {
295            post_thread_message(WM_USER + i);
296        }
297        let expected_msg = Cell::new(0);
298        MessageLoop::run(|msg_loop, msg| {
299            assert_eq!(msg.message, WM_USER + expected_msg.get());
300            expected_msg.set(expected_msg.get() + 1);
301            msg_loop.quit_when_idle();
302            FilterResult::Drop
303        });
304        assert_eq!(expected_msg.get(), 10);
305    }
306
307    #[test]
308    fn nested_block_on() {
309        let count: Cell<usize> = Cell::new(0);
310
311        block_on(async {
312            assert_eq!(count.get(), 0);
313            count.set(count.get() + 1);
314
315            block_on(async {
316                assert_eq!(count.get(), 1);
317                count.set(count.get() + 1);
318            });
319
320            assert_eq!(count.get(), 2);
321            count.set(count.get() + 1);
322        });
323
324        assert_eq!(count.get(), 3);
325    }
326
327    #[test]
328    #[should_panic]
329    fn nested_message_loop() {
330        post_thread_message(WM_USER);
331        MessageLoop::run(|_, _| {
332            MessageLoop::run(|_, _| FilterResult::Drop);
333            FilterResult::Drop
334        });
335    }
336
337    async fn yield_now() {
338        let mut yielded = false;
339        poll_fn(|cx| {
340            if yielded {
341                Poll::Ready(())
342            } else {
343                yielded = true;
344                cx.waker().wake_by_ref();
345                Poll::Pending
346            }
347        })
348        .await;
349    }
350
351    #[test]
352    fn nested_message_loop_block_on() {
353        let inner_executed = Cell::new(false);
354
355        post_thread_message(WM_USER);
356        MessageLoop::run(|msg_loop, _| {
357            block_on(async {
358                inner_executed.set(true);
359            });
360            msg_loop.quit();
361            FilterResult::Forward
362        });
363
364        assert!(inner_executed.get());
365    }
366
367    #[test]
368    fn nested_message_loop_block_on_quit() {
369        post_thread_message(WM_USER);
370        MessageLoop::run(|msg_loop, _| {
371            block_on(async {
372                msg_loop.quit();
373            });
374            FilterResult::Forward
375        });
376    }
377
378    fn window_by_name(name: &CStr) -> HWND {
379        unsafe { FindWindowA(ptr::null_mut(), name.as_ptr() as _) }
380    }
381
382    #[test]
383    fn running_spawned_with_modal_dialog() {
384        // The window name must be unique for each test because cargo runs tests
385        // in parallel and we do not want to close the window of another test.
386        let window_name = c"running_spawned_with_modal_dialog";
387
388        let task = spawn_local(async {
389            // Wait for modal window to be open.
390            while window_by_name(window_name).is_null() {
391                yield_now().await;
392            }
393
394            // Do some async work with modal dialog open.
395            for _ in 0..10 {
396                yield_now().await;
397            }
398
399            // Close the modal window.
400            unsafe {
401                SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0);
402            }
403        });
404
405        block_on(async {
406            unsafe {
407                MessageBoxA(
408                    ptr::null_mut(),
409                    ptr::null_mut(),
410                    window_name.as_ptr() as _,
411                    0,
412                );
413            }
414            task.await;
415        });
416    }
417
418    #[test]
419    #[should_panic]
420    fn reenter_filter_closure_panic() {
421        // The window name must be unique for each test because cargo runs tests
422        // in parallel and we do not want to close the window of another test.
423        let window_name = c"reenter_filter_closure";
424
425        post_thread_message(WM_USER);
426
427        let running_filter_closure = Cell::new(false);
428        MessageLoop::run(|_, msg| {
429            assert!(
430                !running_filter_closure.replace(true),
431                "Filter closure reentered"
432            );
433
434            if msg.hwnd.is_null() && msg.message == WM_USER {
435                unsafe {
436                    MessageBoxA(
437                        ptr::null_mut(),
438                        ptr::null_mut(),
439                        window_name.as_ptr() as _,
440                        0,
441                    );
442                }
443            }
444
445            running_filter_closure.set(false);
446            FilterResult::Forward
447        });
448    }
449
450    #[test]
451    fn reenter_filter_closure_quit() {
452        // The window name must be unique for each test because cargo runs tests
453        // in parallel and we do not want to close the window of another test.
454        let window_name = c"reenter_filter_closure";
455
456        post_thread_message(WM_USER);
457
458        let running_filter_closure = Cell::new(false);
459        MessageLoop::run(|msg_loop, msg| {
460            if running_filter_closure.replace(true) {
461                msg_loop.quit();
462            }
463
464            if msg.hwnd.is_null() && msg.message == WM_USER {
465                unsafe {
466                    MessageBoxA(
467                        ptr::null_mut(),
468                        ptr::null_mut(),
469                        window_name.as_ptr() as _,
470                        0,
471                    );
472                }
473            }
474
475            running_filter_closure.set(false);
476            FilterResult::Forward
477        });
478    }
479
480    #[test]
481    fn message_loop_with_modal_dialog() {
482        // The window name must be unique for each test because cargo runs tests
483        // in parallel and we do not want to close the window of another test.
484        let window_name = c"message_loop_with_modal_dialog";
485
486        spawn_local(async {
487            unsafe {
488                MessageBoxA(
489                    ptr::null_mut(),
490                    ptr::null_mut(),
491                    window_name.as_ptr() as _,
492                    0,
493                );
494            }
495        });
496
497        spawn_local(async {
498            // Check if modal window is actually open.
499            assert!(!window_by_name(window_name).is_null());
500
501            for i in 0..10 {
502                post_thread_message(WM_USER + i);
503                yield_now().await;
504            }
505
506            // Close modal window again.
507            unsafe { SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0) };
508        });
509
510        let expected_msg = Cell::new(0);
511        MessageLoop::run(|msg_loop, msg| {
512            if msg.hwnd.is_null() && msg.message >= WM_USER {
513                assert_eq!(msg.message, WM_USER + expected_msg.get());
514                expected_msg.set(expected_msg.get() + 1);
515                msg_loop.quit_when_idle();
516                FilterResult::Drop
517            } else {
518                FilterResult::Forward
519            }
520        });
521        assert_eq!(expected_msg.get(), 10);
522    }
523
524    #[test]
525    fn reenter_filter_closure_quit_when_idle() {
526        // The window name must be unique for each test because cargo runs tests
527        // in parallel and we do not want to close the window of another test.
528        let window_name = c"reenter_filter_closure";
529
530        post_thread_message(WM_USER);
531
532        let running_filter_closure = Cell::new(false);
533        MessageLoop::run(|msg_loop, msg| {
534            if running_filter_closure.replace(true) {
535                msg_loop.quit_when_idle();
536            }
537
538            if msg.hwnd.is_null() && msg.message == WM_USER {
539                unsafe {
540                    MessageBoxA(
541                        ptr::null_mut(),
542                        ptr::null_mut(),
543                        window_name.as_ptr() as _,
544                        0,
545                    );
546                }
547            }
548
549            running_filter_closure.set(false);
550            FilterResult::Forward
551        });
552    }
553}