Skip to main content

wintf_winmsg_executor/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod util;
4
5use std::{
6    any::Any,
7    cell::Cell,
8    future::Future,
9    mem::{ManuallyDrop, MaybeUninit},
10    panic::{self, AssertUnwindSafe},
11    pin::{pin, Pin},
12    ptr::{self, NonNull},
13    task::{Context, Poll, Waker},
14};
15
16use async_task::Runnable;
17use util::{Window, WindowType};
18use windows_sys::Win32::UI::WindowsAndMessaging::*;
19
20use crate::util::MsgFilterHook;
21
22const MSG_ID_WAKE: u32 = WM_USER;
23
24thread_local! {
25    static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
26    static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
27        if msg.msg == MSG_ID_WAKE {
28            let runnable = unsafe {
29                let runnable_ptr = NonNull::new_unchecked(msg.lparam as *mut _);
30                Runnable::<()>::from_raw(runnable_ptr)
31            };
32            if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
33                PANIC_PAYLOAD.set(Some(panic_payload));
34            }
35            Some(0)
36        } else {
37            None
38        }
39    })
40    .unwrap();
41}
42
43/// An owned permission to join on a task (await its termination).
44///
45/// If a `JoinHandle` is dropped, the task continues running in the
46/// background and its return value is lost.
47pub struct JoinHandle<T> {
48    task: ManuallyDrop<async_task::Task<T>>,
49}
50
51// Keep the task running when dropped.
52impl<T> Drop for JoinHandle<T> {
53    fn drop(&mut self) {
54        let task = unsafe { ManuallyDrop::take(&mut self.task) };
55        task.detach();
56    }
57}
58
59impl<T> Future for JoinHandle<T> {
60    type Output = T;
61
62    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63        pin!(&mut *self.task).poll(cx)
64    }
65}
66
67unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
68    let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
69
70    // SAFETY: The `future` does not need to be `Send` because the thread that
71    // receives the runnable is our own, meaning the runnable is also dropped
72    // on the original thread.
73    let (runnable, task) = unsafe {
74        async_task::spawn_unchecked(future, move |runnable: Runnable| {
75            PostMessageA(hwnd, MSG_ID_WAKE, 0, runnable.into_raw().as_ptr() as _);
76        })
77    };
78
79    // Trigger initial poll.
80    runnable.schedule();
81
82    JoinHandle {
83        task: ManuallyDrop::new(task),
84    }
85}
86
87/// Spawns a new future on the current thread.
88///
89/// This function may be used to spawn tasks when the message loop is not running.
90/// The provided future starts running once the message loop is entered with
91/// [`block_on()`] or [`MessageLoop::run()`].
92///
93/// # Examples
94///
95/// This example is a compile-time test to ensure that we only accept `'static`
96/// return types to prevent <https://github.com/rust-lang/rust/issues/84366>.
97///
98/// ```compile_fail
99/// fn test_fn<'a>() {
100///     let closure = || -> &'a str { "" };
101///     wintf_winmsg_executor::spawn_local(async move { closure() });
102/// }
103/// ```
104pub fn spawn_local<T: 'static>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
105    // SAFETY: future is `'static`
106    unsafe { spawn_unchecked_lifetime(future) }
107}
108
109/// Runs a future to completion on the calling thread's message loop.
110///
111/// Runs the provided future on the current thread, blocking until it completes.
112/// Any tasks spawned from the same thread using [`spawn_local()`] also run concurrently.
113/// Note that all spawned tasks are suspended after [`block_on()`] returns.
114/// Calling [`block_on()`] again resumes the spawned tasks.
115///
116/// # Panics
117///
118/// Panics if the message loop quits before the future completes.
119/// This can happen when the future or any spawned task calls the
120/// `PostQuitMessage()` WinAPI function.
121pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
122    let msg_loop = &MessageLoop::new();
123
124    // Wrap the future so it quits the message loop when finished.
125    // SAFETY: All borrowed variables outlive the task itself because we only
126    // return from this function after the task has finished.
127    let task = unsafe {
128        spawn_unchecked_lifetime(async move {
129            let result = future.await;
130            msg_loop.quit();
131            result
132        })
133    };
134
135    msg_loop.run_loop(|_| FilterResult::Forward);
136
137    poll_ready(task).expect("received unexpected quit message")
138}
139
140fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
141    let future = pin!(future);
142    match future.poll(&mut Context::from_waker(Waker::noop())) {
143        Poll::Ready(result) => Ok(result),
144        Poll::Pending => Err(()),
145    }
146}
147
148/// Return value of the filter closure passed to [`MessageLoop::run`].
149#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150pub enum FilterResult {
151    /// The message is forwarded to the window procedure.
152    Forward,
153
154    /// The message is dropped and not forwarded to the window procedure.
155    Drop,
156}
157
158/// Abstract representation of a message loop.
159///
160/// Not directly constructible. Use [`MessageLoop::run`] to create a message
161/// loop. The message loop struct is used to control message loop behavior
162/// by passing it as an argument to the filter closure of [`MessageLoop::run`].
163pub struct MessageLoop {
164    quit: Cell<bool>,
165}
166
167impl MessageLoop {
168    fn new() -> Self {
169        Self {
170            quit: Cell::new(false),
171        }
172    }
173
174    fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
175        let executor_hwnd = EXECUTOR_WINDOW.with(|ew| ew.hwnd());
176
177        while !self.quit.get() {
178            unsafe {
179                let mut msg = MaybeUninit::uninit();
180                if GetMessageA(msg.as_mut_ptr(), ptr::null_mut(), 0, 0) == 0 {
181                    return;
182                }
183                let msg = msg.assume_init();
184
185                // Do not allow the filter to drop our wake messages.
186                let is_wake_message = msg.hwnd == executor_hwnd && msg.message == MSG_ID_WAKE;
187                if is_wake_message || filter(&msg) == FilterResult::Forward {
188                    TranslateMessage(&msg);
189                    DispatchMessageA(&msg);
190                }
191
192                if let Some(panic_payload) = PANIC_PAYLOAD.take() {
193                    panic::resume_unwind(panic_payload)
194                }
195            }
196        }
197    }
198
199    /// Runs the message loop with a filter closure to inspect and drop messages before
200    /// they are dispatched to their respective window procedures.
201    ///
202    /// Use the [`FilterResult`] return value to control how the message is handled.
203    /// The first argument to the filter closure is the [`MessageLoop`] struct itself,
204    /// which can be used to quit the message loop.
205    ///
206    /// Like [`block_on()`], this function runs any tasks spawned from the same thread
207    /// using [`spawn_local()`]. All tasks are suspended when this function returns.
208    ///
209    /// Installs a [`WH_MSGFILTER`] hook to allow inspection of messages while modal
210    /// windows are open.
211    ///
212    /// # Panics and Reentrancy
213    ///
214    /// Panics if called from within another [`MessageLoop::run()`] filter closure.
215    ///
216    /// A call to [`block_on()`] from within the filter closure creates a nested message
217    /// loop, which causes the filter closure to be re-entered when a modal window is open.
218    ///
219    /// [`WH_MSGFILTER`]: (https://learn.microsoft.com/en-us/windows/win32/winmsg/about-hooks#wh_msgfilter-and-wh_sysmsgfilter)
220    pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
221        let msg_loop = MessageLoop::new();
222
223        // Any modal window (i.e. a right-click menu) blocks the main message loop
224        // and dispatches messages internally. To keep the executor running use a
225        // hook to get access to modal windows' internal message loop.
226        // SAFETY: The Drop implementation of MsgFilterHook unregisters the hook,
227        // ensuring that dispatchers will not be called after the end of the scope.
228        let _hook = unsafe {
229            MsgFilterHook::register(|msg| {
230                panic::catch_unwind(AssertUnwindSafe(|| {
231                    let filter_result = filter(&msg_loop, msg);
232                    // When `quit()` is called, it has no real effect because we
233                    // are running in a modal loop. Post a quit message to exit
234                    // the modal message loop to store the panic payload.
235                    if msg_loop.quit.get() {
236                        PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
237                    }
238                    filter_result == FilterResult::Drop
239                }))
240                .unwrap_or_else(|payload| {
241                    PANIC_PAYLOAD.with(|panic_payload| {
242                        panic_payload.set(Some(payload));
243                    });
244                    // Also exit the modal loop ASAP when a panic occurs.
245                    PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
246                    false
247                })
248            })
249        };
250        msg_loop.run_loop(|msg| filter(&msg_loop, msg));
251    }
252
253    /// Quits the message loop as soon as possible.
254    pub fn quit(&self) {
255        self.quit.set(true);
256    }
257
258    /// Quits the message loop when there are no more messages to process.
259    pub fn quit_when_idle(&self) {
260        unsafe { PostQuitMessage(0) };
261    }
262}
263
264#[cfg(test)]
265mod test {
266    use std::{ffi::CStr, future::poll_fn};
267
268    use windows_sys::Win32::Foundation::HWND;
269
270    use super::*;
271
272    fn post_thread_message(msg: u32) {
273        unsafe { PostMessageA(ptr::null_mut(), msg, 0, 0) };
274    }
275
276    #[test]
277    #[should_panic]
278    fn panic_in_dispatcher() {
279        post_thread_message(WM_USER);
280        MessageLoop::run(|_, _| panic!());
281    }
282
283    #[test]
284    fn message_loop_quit() {
285        for i in 0..10 {
286            post_thread_message(WM_USER + i);
287        }
288        MessageLoop::run(|msg_loop, msg| {
289            // This is the only message we observe because we quit the
290            // loop right after it is received.
291            assert_eq!(msg.message, WM_USER);
292            msg_loop.quit();
293            FilterResult::Drop
294        });
295    }
296
297    #[test]
298    fn message_loop_quit_when_idle() {
299        for i in 0..10 {
300            post_thread_message(WM_USER + i);
301        }
302        let expected_msg = Cell::new(0);
303        MessageLoop::run(|msg_loop, msg| {
304            assert_eq!(msg.message, WM_USER + expected_msg.get());
305            expected_msg.set(expected_msg.get() + 1);
306            msg_loop.quit_when_idle();
307            FilterResult::Drop
308        });
309        assert_eq!(expected_msg.get(), 10);
310    }
311
312    #[test]
313    fn nested_block_on() {
314        let count: Cell<usize> = Cell::new(0);
315
316        block_on(async {
317            assert_eq!(count.get(), 0);
318            count.set(count.get() + 1);
319
320            block_on(async {
321                assert_eq!(count.get(), 1);
322                count.set(count.get() + 1);
323            });
324
325            assert_eq!(count.get(), 2);
326            count.set(count.get() + 1);
327        });
328
329        assert_eq!(count.get(), 3);
330    }
331
332    #[test]
333    #[should_panic]
334    fn nested_message_loop() {
335        post_thread_message(WM_USER);
336        MessageLoop::run(|_, _| {
337            MessageLoop::run(|_, _| FilterResult::Drop);
338            FilterResult::Drop
339        });
340    }
341
342    async fn yield_now() {
343        let mut yielded = false;
344        poll_fn(|cx| {
345            if yielded {
346                Poll::Ready(())
347            } else {
348                yielded = true;
349                cx.waker().wake_by_ref();
350                Poll::Pending
351            }
352        })
353        .await;
354    }
355
356    #[test]
357    fn nested_message_loop_block_on() {
358        let inner_executed = Cell::new(false);
359
360        post_thread_message(WM_USER);
361        MessageLoop::run(|msg_loop, _| {
362            block_on(async {
363                inner_executed.set(true);
364            });
365            msg_loop.quit();
366            FilterResult::Forward
367        });
368
369        assert!(inner_executed.get());
370    }
371
372    #[test]
373    fn nested_message_loop_block_on_quit() {
374        post_thread_message(WM_USER);
375        MessageLoop::run(|msg_loop, _| {
376            block_on(async {
377                msg_loop.quit();
378            });
379            FilterResult::Forward
380        });
381    }
382
383    fn window_by_name(name: &CStr) -> HWND {
384        unsafe { FindWindowA(ptr::null_mut(), name.as_ptr() as _) }
385    }
386
387    #[test]
388    fn running_spawned_with_modal_dialog() {
389        // The window name must be unique for each test because cargo runs tests
390        // in parallel and we do not want to close the window of another test.
391        let window_name = c"running_spawned_with_modal_dialog";
392
393        let task = spawn_local(async {
394            // Wait for modal window to be open.
395            while window_by_name(window_name).is_null() {
396                yield_now().await;
397            }
398
399            // Do some async work with modal dialog open.
400            for _ in 0..10 {
401                yield_now().await;
402            }
403
404            // Close the modal window.
405            unsafe {
406                SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0);
407            }
408        });
409
410        block_on(async {
411            unsafe {
412                MessageBoxA(
413                    ptr::null_mut(),
414                    ptr::null_mut(),
415                    window_name.as_ptr() as _,
416                    0,
417                );
418            }
419            task.await;
420        });
421    }
422
423    // This test does not actually expect the library to panic.
424    // The panic is rather an convenient way to signal if the filter closure is
425    // reentered (which is the expected behaviour).
426    #[test]
427    #[should_panic]
428    fn reenter_filter_closure_panic() {
429        // The window name must be unique for each test because cargo runs tests
430        // in parallel and we do not want to close the window of another test.
431        let window_name = c"reenter_filter_closure";
432
433        post_thread_message(WM_USER);
434
435        let running_filter_closure = Cell::new(false);
436        MessageLoop::run(|_, msg| {
437            assert!(
438                !running_filter_closure.replace(true),
439                "Filter closure reentered"
440            );
441
442            if msg.hwnd.is_null() && msg.message == WM_USER {
443                unsafe {
444                    MessageBoxA(
445                        ptr::null_mut(),
446                        ptr::null_mut(),
447                        window_name.as_ptr() as _,
448                        0,
449                    );
450                }
451            }
452
453            running_filter_closure.set(false);
454            FilterResult::Forward
455        });
456    }
457
458    #[test]
459    fn reenter_filter_closure_quit() {
460        // The window name must be unique for each test because cargo runs tests
461        // in parallel and we do not want to close the window of another test.
462        let window_name = c"reenter_filter_closure";
463
464        post_thread_message(WM_USER);
465
466        let running_filter_closure = Cell::new(false);
467        MessageLoop::run(|msg_loop, msg| {
468            if running_filter_closure.replace(true) {
469                msg_loop.quit();
470            }
471
472            if msg.hwnd.is_null() && msg.message == WM_USER {
473                unsafe {
474                    MessageBoxA(
475                        ptr::null_mut(),
476                        ptr::null_mut(),
477                        window_name.as_ptr() as _,
478                        0,
479                    );
480                }
481            }
482
483            running_filter_closure.set(false);
484            FilterResult::Forward
485        });
486    }
487
488    #[test]
489    fn message_loop_with_modal_dialog() {
490        // The window name must be unique for each test because cargo runs tests
491        // in parallel and we do not want to close the window of another test.
492        let window_name = c"message_loop_with_modal_dialog";
493
494        spawn_local(async {
495            unsafe {
496                MessageBoxA(
497                    ptr::null_mut(),
498                    ptr::null_mut(),
499                    window_name.as_ptr() as _,
500                    0,
501                );
502            }
503        });
504
505        spawn_local(async {
506            // Check if modal window is actually open.
507            assert!(!window_by_name(window_name).is_null());
508
509            for i in 0..10 {
510                post_thread_message(WM_USER + i);
511                yield_now().await;
512            }
513
514            // Close modal window again.
515            unsafe { SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0) };
516        });
517
518        let expected_msg = Cell::new(0);
519        MessageLoop::run(|msg_loop, msg| {
520            if msg.hwnd.is_null() && msg.message >= WM_USER {
521                assert_eq!(msg.message, WM_USER + expected_msg.get());
522                expected_msg.set(expected_msg.get() + 1);
523                msg_loop.quit_when_idle();
524                FilterResult::Drop
525            } else {
526                FilterResult::Forward
527            }
528        });
529        assert_eq!(expected_msg.get(), 10);
530    }
531
532    #[test]
533    fn reenter_filter_closure_quit_when_idle() {
534        // The window name must be unique for each test because cargo runs tests
535        // in parallel and we do not want to close the window of another test.
536        let window_name = c"reenter_filter_closure";
537
538        post_thread_message(WM_USER);
539
540        let running_filter_closure = Cell::new(false);
541        MessageLoop::run(|msg_loop, msg| {
542            if running_filter_closure.replace(true) {
543                msg_loop.quit_when_idle();
544            }
545
546            if msg.hwnd.is_null() && msg.message == WM_USER {
547                unsafe {
548                    MessageBoxA(
549                        ptr::null_mut(),
550                        ptr::null_mut(),
551                        window_name.as_ptr() as _,
552                        0,
553                    );
554                }
555            }
556
557            running_filter_closure.set(false);
558            FilterResult::Forward
559        });
560    }
561
562    #[test]
563    fn disallow_wake_message_filtering() {
564        let msg_loop = MessageLoop::new();
565        let msg_loop = Box::leak(Box::new(msg_loop));
566
567        // `MSG_ID_WAKE` message for the custom should be filtered by the run loop filter below.
568        let custom_wnd = Window::new(WindowType::MessageOnly, (), |_, msg| {
569            assert_ne!(msg.msg, MSG_ID_WAKE);
570            None
571        })
572        .unwrap();
573        unsafe {
574            PostMessageA(custom_wnd.hwnd(), MSG_ID_WAKE, 0, 0);
575        }
576
577        // Spawn a task to ensure that the executor window also has a wake message,
578        // which must not be filtered.
579        spawn_local(async {
580            yield_now().await;
581            yield_now().await;
582            yield_now().await;
583            msg_loop.quit();
584        });
585
586        msg_loop.run_loop(|msg| {
587            // This test is to ensure that this callback is not even called for internal wake messages.
588            if msg.message == MSG_ID_WAKE {
589                assert_ne!(msg.hwnd, EXECUTOR_WINDOW.with(|ew| ew.hwnd()));
590                FilterResult::Drop
591            } else {
592                FilterResult::Forward
593            }
594        });
595    }
596}