Skip to main content

wintf_winmsg_executor/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod util;
4
5use std::{
6    any::Any,
7    cell::Cell,
8    future::Future,
9    mem::{ManuallyDrop, MaybeUninit},
10    panic::{self, AssertUnwindSafe},
11    pin::{pin, Pin},
12    ptr::NonNull,
13    task::{Context, Poll, Waker},
14};
15
16use async_task::Runnable;
17use util::{Window, WindowType};
18use windows::Win32::{
19    Foundation::{LPARAM, LRESULT, WPARAM},
20    UI::WindowsAndMessaging::*,
21};
22
23use crate::util::MsgFilterHook;
24
25const MSG_ID_WAKE: u32 = WM_USER;
26
27thread_local! {
28    static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
29    static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
30        if msg.msg == MSG_ID_WAKE {
31            let runnable = unsafe {
32                let runnable_ptr = NonNull::new_unchecked(msg.lparam.0 as *mut _);
33                Runnable::<()>::from_raw(runnable_ptr)
34            };
35            if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
36                PANIC_PAYLOAD.set(Some(panic_payload));
37            }
38            Some(LRESULT(0))
39        } else {
40            None
41        }
42    })
43    .unwrap();
44}
45
46/// An owned permission to join on a task (await its termination).
47///
48/// If a `JoinHandle` is dropped, the task continues running in the
49/// background and its return value is lost.
50pub struct JoinHandle<T> {
51    task: ManuallyDrop<async_task::Task<T>>,
52}
53
54// Keep the task running when dropped.
55impl<T> Drop for JoinHandle<T> {
56    fn drop(&mut self) {
57        let task = unsafe { ManuallyDrop::take(&mut self.task) };
58        task.detach();
59    }
60}
61
62impl<T> Future for JoinHandle<T> {
63    type Output = T;
64
65    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
66        pin!(&mut *self.task).poll(cx)
67    }
68}
69
70unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
71    let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
72
73    // SAFETY: The `future` does not need to be `Send` because the thread that
74    // receives the runnable is our own, meaning the runnable is also dropped
75    // on the original thread.
76    let (runnable, task) = unsafe {
77        async_task::spawn_unchecked(future, move |runnable: Runnable| {
78            let _ = PostMessageW(
79                Some(hwnd),
80                MSG_ID_WAKE,
81                WPARAM(0),
82                LPARAM(runnable.into_raw().as_ptr() as _),
83            );
84        })
85    };
86
87    // Trigger initial poll.
88    runnable.schedule();
89
90    JoinHandle {
91        task: ManuallyDrop::new(task),
92    }
93}
94
95/// Spawns a new future on the current thread.
96///
97/// This function may be used to spawn tasks when the message loop is not running.
98/// The provided future starts running once the message loop is entered with
99/// [`block_on()`] or [`MessageLoop::run()`].
100///
101/// # Examples
102///
103/// This example is a compile-time test to ensure that we only accept `'static`
104/// return types to prevent <https://github.com/rust-lang/rust/issues/84366>.
105///
106/// ```compile_fail
107/// fn test_fn<'a>() {
108///     let closure = || -> &'a str { "" };
109///     wintf_winmsg_executor::spawn_local(async move { closure() });
110/// }
111/// ```
112pub fn spawn_local<T: 'static>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
113    // SAFETY: future is `'static`
114    unsafe { spawn_unchecked_lifetime(future) }
115}
116
117/// Runs a future to completion on the calling thread's message loop.
118///
119/// Runs the provided future on the current thread, blocking until it completes.
120/// Any tasks spawned from the same thread using [`spawn_local()`] also run concurrently.
121/// Note that all spawned tasks are suspended after [`block_on()`] returns.
122/// Calling [`block_on()`] again resumes the spawned tasks.
123///
124/// # Panics
125///
126/// Panics if the message loop quits before the future completes.
127/// This can happen when the future or any spawned task calls the
128/// `PostQuitMessage()` WinAPI function.
129pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
130    let msg_loop = &MessageLoop::new();
131
132    // Wrap the future so it quits the message loop when finished.
133    // SAFETY: All borrowed variables outlive the task itself because we only
134    // return from this function after the task has finished.
135    let task = unsafe {
136        spawn_unchecked_lifetime(async move {
137            let result = future.await;
138            msg_loop.quit();
139            result
140        })
141    };
142
143    msg_loop.run_loop(|_| FilterResult::Forward);
144
145    poll_ready(task).expect("received unexpected quit message")
146}
147
148fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
149    let future = pin!(future);
150    match future.poll(&mut Context::from_waker(Waker::noop())) {
151        Poll::Ready(result) => Ok(result),
152        Poll::Pending => Err(()),
153    }
154}
155
156/// Return value of the filter closure passed to [`MessageLoop::run`].
157#[derive(Debug, Clone, Copy, PartialEq, Eq)]
158pub enum FilterResult {
159    /// The message is forwarded to the window procedure.
160    Forward,
161
162    /// The message is dropped and not forwarded to the window procedure.
163    Drop,
164}
165
166/// Abstract representation of a message loop.
167///
168/// Not directly constructible. Use [`MessageLoop::run`] to create a message
169/// loop. The message loop struct is used to control message loop behavior
170/// by passing it as an argument to the filter closure of [`MessageLoop::run`].
171pub struct MessageLoop {
172    quit: Cell<bool>,
173}
174
175impl MessageLoop {
176    fn new() -> Self {
177        Self {
178            quit: Cell::new(false),
179        }
180    }
181
182    fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
183        let executor_hwnd = EXECUTOR_WINDOW.with(|ew| ew.hwnd());
184
185        while !self.quit.get() {
186            unsafe {
187                let mut msg = MaybeUninit::uninit();
188                if GetMessageW(msg.as_mut_ptr(), None, 0, 0).0 == 0 {
189                    return;
190                }
191                let msg = msg.assume_init();
192
193                // Do not allow the filter to drop our wake messages.
194                let is_wake_message = msg.hwnd == executor_hwnd && msg.message == MSG_ID_WAKE;
195                if is_wake_message || filter(&msg) == FilterResult::Forward {
196                    let _ = TranslateMessage(&msg);
197                    DispatchMessageW(&msg);
198                }
199
200                if let Some(panic_payload) = PANIC_PAYLOAD.take() {
201                    panic::resume_unwind(panic_payload)
202                }
203            }
204        }
205    }
206
207    /// Runs the message loop with a filter closure to inspect and drop messages before
208    /// they are dispatched to their respective window procedures.
209    ///
210    /// Use the [`FilterResult`] return value to control how the message is handled.
211    /// The first argument to the filter closure is the [`MessageLoop`] struct itself,
212    /// which can be used to quit the message loop.
213    ///
214    /// Like [`block_on()`], this function runs any tasks spawned from the same thread
215    /// using [`spawn_local()`]. All tasks are suspended when this function returns.
216    ///
217    /// Installs a [`WH_MSGFILTER`] hook to allow inspection of messages while modal
218    /// windows are open.
219    ///
220    /// # Panics and Reentrancy
221    ///
222    /// Panics if called from within another [`MessageLoop::run()`] filter closure.
223    ///
224    /// A call to [`block_on()`] from within the filter closure creates a nested message
225    /// loop, which causes the filter closure to be re-entered when a modal window is open.
226    ///
227    /// [`WH_MSGFILTER`]: (https://learn.microsoft.com/en-us/windows/win32/winmsg/about-hooks#wh_msgfilter-and-wh_sysmsgfilter)
228    pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
229        let msg_loop = MessageLoop::new();
230
231        // Any modal window (i.e. a right-click menu) blocks the main message loop
232        // and dispatches messages internally. To keep the executor running use a
233        // hook to get access to modal windows' internal message loop.
234        // SAFETY: The Drop implementation of MsgFilterHook unregisters the hook,
235        // ensuring that dispatchers will not be called after the end of the scope.
236        let _hook = unsafe {
237            MsgFilterHook::register(|msg| {
238                panic::catch_unwind(AssertUnwindSafe(|| {
239                    let filter_result = filter(&msg_loop, msg);
240                    // When `quit()` is called, it has no real effect because we
241                    // are running in a modal loop. Post a quit message to exit
242                    // the modal message loop to store the panic payload.
243                    if msg_loop.quit.get() {
244                        let _ = PostMessageW(Some(msg.hwnd), WM_QUIT, WPARAM(0), LPARAM(0));
245                    }
246                    filter_result == FilterResult::Drop
247                }))
248                .unwrap_or_else(|payload| {
249                    PANIC_PAYLOAD.with(|panic_payload| {
250                        panic_payload.set(Some(payload));
251                    });
252                    // Also exit the modal loop ASAP when a panic occurs.
253                    let _ = PostMessageW(Some(msg.hwnd), WM_QUIT, WPARAM(0), LPARAM(0));
254                    false
255                })
256            })
257        };
258        msg_loop.run_loop(|msg| filter(&msg_loop, msg));
259    }
260
261    /// Quits the message loop as soon as possible.
262    pub fn quit(&self) {
263        self.quit.set(true);
264    }
265
266    /// Quits the message loop when there are no more messages to process.
267    pub fn quit_when_idle(&self) {
268        unsafe { PostQuitMessage(0) };
269    }
270}
271
272#[cfg(test)]
273mod test {
274    use std::future::poll_fn;
275
276    use windows::core::{w, PCWSTR};
277    use windows::Win32::Foundation::HWND;
278
279    use super::*;
280
281    fn post_thread_message(msg: u32) {
282        let _ = unsafe { PostMessageW(None, msg, WPARAM(0), LPARAM(0)) };
283    }
284
285    #[test]
286    #[should_panic]
287    fn panic_in_dispatcher() {
288        post_thread_message(WM_USER);
289        MessageLoop::run(|_, _| panic!());
290    }
291
292    #[test]
293    fn message_loop_quit() {
294        for i in 0..10 {
295            post_thread_message(WM_USER + i);
296        }
297        MessageLoop::run(|msg_loop, msg| {
298            // This is the only message we observe because we quit the
299            // loop right after it is received.
300            assert_eq!(msg.message, WM_USER);
301            msg_loop.quit();
302            FilterResult::Drop
303        });
304    }
305
306    #[test]
307    fn message_loop_quit_when_idle() {
308        for i in 0..10 {
309            post_thread_message(WM_USER + i);
310        }
311        let expected_msg = Cell::new(0);
312        MessageLoop::run(|msg_loop, msg| {
313            assert_eq!(msg.message, WM_USER + expected_msg.get());
314            expected_msg.set(expected_msg.get() + 1);
315            msg_loop.quit_when_idle();
316            FilterResult::Drop
317        });
318        assert_eq!(expected_msg.get(), 10);
319    }
320
321    #[test]
322    fn nested_block_on() {
323        let count: Cell<usize> = Cell::new(0);
324
325        block_on(async {
326            assert_eq!(count.get(), 0);
327            count.set(count.get() + 1);
328
329            block_on(async {
330                assert_eq!(count.get(), 1);
331                count.set(count.get() + 1);
332            });
333
334            assert_eq!(count.get(), 2);
335            count.set(count.get() + 1);
336        });
337
338        assert_eq!(count.get(), 3);
339    }
340
341    #[test]
342    #[should_panic]
343    fn nested_message_loop() {
344        post_thread_message(WM_USER);
345        MessageLoop::run(|_, _| {
346            MessageLoop::run(|_, _| FilterResult::Drop);
347            FilterResult::Drop
348        });
349    }
350
351    async fn yield_now() {
352        let mut yielded = false;
353        poll_fn(|cx| {
354            if yielded {
355                Poll::Ready(())
356            } else {
357                yielded = true;
358                cx.waker().wake_by_ref();
359                Poll::Pending
360            }
361        })
362        .await;
363    }
364
365    #[test]
366    fn nested_message_loop_block_on() {
367        let inner_executed = Cell::new(false);
368
369        post_thread_message(WM_USER);
370        MessageLoop::run(|msg_loop, _| {
371            block_on(async {
372                inner_executed.set(true);
373            });
374            msg_loop.quit();
375            FilterResult::Forward
376        });
377
378        assert!(inner_executed.get());
379    }
380
381    #[test]
382    fn nested_message_loop_block_on_quit() {
383        post_thread_message(WM_USER);
384        MessageLoop::run(|msg_loop, _| {
385            block_on(async {
386                msg_loop.quit();
387            });
388            FilterResult::Forward
389        });
390    }
391
392    fn window_by_name(name: PCWSTR) -> HWND {
393        unsafe { FindWindowW(None, name) }.unwrap_or_default()
394    }
395
396    #[test]
397    fn running_spawned_with_modal_dialog() {
398        // The window name must be unique for each test because cargo runs tests
399        // in parallel and we do not want to close the window of another test.
400        let window_name = w!("running_spawned_with_modal_dialog");
401
402        let task = spawn_local(async move {
403            // Wait for modal window to be open.
404            while window_by_name(window_name).0.is_null() {
405                yield_now().await;
406            }
407
408            // Do some async work with modal dialog open.
409            for _ in 0..10 {
410                yield_now().await;
411            }
412
413            // Close the modal window.
414            unsafe {
415                SendMessageW(window_by_name(window_name), WM_CLOSE, Some(WPARAM(0)), Some(LPARAM(0)));
416            }
417        });
418
419        block_on(async {
420            unsafe {
421                MessageBoxW(
422                    None,
423                    PCWSTR::null(),
424                    window_name,
425                    MESSAGEBOX_STYLE(0),
426                );
427            }
428            task.await;
429        });
430    }
431
432    // This test does not actually expect the library to panic.
433    // The panic is rather an convenient way to signal if the filter closure is
434    // reentered (which is the expected behaviour).
435    #[test]
436    #[should_panic]
437    fn reenter_filter_closure_panic() {
438        // The window name must be unique for each test because cargo runs tests
439        // in parallel and we do not want to close the window of another test.
440        let window_name = w!("reenter_filter_closure");
441
442        post_thread_message(WM_USER);
443
444        let running_filter_closure = Cell::new(false);
445        MessageLoop::run(|_, msg| {
446            assert!(
447                !running_filter_closure.replace(true),
448                "Filter closure reentered"
449            );
450
451            if msg.hwnd.0.is_null() && msg.message == WM_USER {
452                unsafe {
453                    MessageBoxW(
454                        None,
455                        PCWSTR::null(),
456                        window_name,
457                        MESSAGEBOX_STYLE(0),
458                    );
459                }
460            }
461
462            running_filter_closure.set(false);
463            FilterResult::Forward
464        });
465    }
466
467    #[test]
468    fn reenter_filter_closure_quit() {
469        // The window name must be unique for each test because cargo runs tests
470        // in parallel and we do not want to close the window of another test.
471        let window_name = w!("reenter_filter_closure");
472
473        post_thread_message(WM_USER);
474
475        let running_filter_closure = Cell::new(false);
476        MessageLoop::run(|msg_loop, msg| {
477            if running_filter_closure.replace(true) {
478                msg_loop.quit();
479            }
480
481            if msg.hwnd.0.is_null() && msg.message == WM_USER {
482                unsafe {
483                    MessageBoxW(
484                        None,
485                        PCWSTR::null(),
486                        window_name,
487                        MESSAGEBOX_STYLE(0),
488                    );
489                }
490            }
491
492            running_filter_closure.set(false);
493            FilterResult::Forward
494        });
495    }
496
497    #[test]
498    fn message_loop_with_modal_dialog() {
499        // The window name must be unique for each test because cargo runs tests
500        // in parallel and we do not want to close the window of another test.
501        let window_name = w!("message_loop_with_modal_dialog");
502
503        spawn_local(async move {
504            unsafe {
505                MessageBoxW(
506                    None,
507                    PCWSTR::null(),
508                    window_name,
509                    MESSAGEBOX_STYLE(0),
510                );
511            }
512        });
513
514        spawn_local(async move {
515            // Check if modal window is actually open.
516            assert!(!window_by_name(window_name).0.is_null());
517
518            for i in 0..10 {
519                post_thread_message(WM_USER + i);
520                yield_now().await;
521            }
522
523            // Close modal window again.
524            unsafe { SendMessageW(window_by_name(window_name), WM_CLOSE, Some(WPARAM(0)), Some(LPARAM(0))) };
525        });
526
527        let expected_msg = Cell::new(0);
528        MessageLoop::run(|msg_loop, msg| {
529            if msg.hwnd.0.is_null() && msg.message >= WM_USER {
530                assert_eq!(msg.message, WM_USER + expected_msg.get());
531                expected_msg.set(expected_msg.get() + 1);
532                msg_loop.quit_when_idle();
533                FilterResult::Drop
534            } else {
535                FilterResult::Forward
536            }
537        });
538        assert_eq!(expected_msg.get(), 10);
539    }
540
541    #[test]
542    fn reenter_filter_closure_quit_when_idle() {
543        // The window name must be unique for each test because cargo runs tests
544        // in parallel and we do not want to close the window of another test.
545        let window_name = w!("reenter_filter_closure");
546
547        post_thread_message(WM_USER);
548
549        let running_filter_closure = Cell::new(false);
550        MessageLoop::run(|msg_loop, msg| {
551            if running_filter_closure.replace(true) {
552                msg_loop.quit_when_idle();
553            }
554
555            if msg.hwnd.0.is_null() && msg.message == WM_USER {
556                unsafe {
557                    MessageBoxW(
558                        None,
559                        PCWSTR::null(),
560                        window_name,
561                        MESSAGEBOX_STYLE(0),
562                    );
563                }
564            }
565
566            running_filter_closure.set(false);
567            FilterResult::Forward
568        });
569    }
570
571    #[test]
572    fn disallow_wake_message_filtering() {
573        let msg_loop = MessageLoop::new();
574        let msg_loop = Box::leak(Box::new(msg_loop));
575
576        // `MSG_ID_WAKE` message for the custom should be filtered by the run loop filter below.
577        let custom_wnd = Window::new(WindowType::MessageOnly, (), |_, msg| {
578            assert_ne!(msg.msg, MSG_ID_WAKE);
579            None
580        })
581        .unwrap();
582        unsafe {
583            let _ = PostMessageW(Some(custom_wnd.hwnd()), MSG_ID_WAKE, WPARAM(0), LPARAM(0));
584        }
585
586        // Spawn a task to ensure that the executor window also has a wake message,
587        // which must not be filtered.
588        spawn_local(async {
589            yield_now().await;
590            yield_now().await;
591            yield_now().await;
592            msg_loop.quit();
593        });
594
595        msg_loop.run_loop(|msg| {
596            // This test is to ensure that this callback is not even called for internal wake messages.
597            if msg.message == MSG_ID_WAKE {
598                assert_ne!(msg.hwnd, EXECUTOR_WINDOW.with(|ew| ew.hwnd()));
599                FilterResult::Drop
600            } else {
601                FilterResult::Forward
602            }
603        });
604    }
605}