1#![doc = include_str!("../README.md")]
2
3pub mod util;
4
5use std::{
6 any::Any,
7 cell::Cell,
8 future::Future,
9 mem::{ManuallyDrop, MaybeUninit},
10 panic::{self, AssertUnwindSafe},
11 pin::{pin, Pin},
12 ptr::NonNull,
13 task::{Context, Poll, Waker},
14};
15
16use async_task::Runnable;
17use util::{Window, WindowType};
18use windows::Win32::{
19 Foundation::{LPARAM, LRESULT, WPARAM},
20 UI::WindowsAndMessaging::*,
21};
22
23use crate::util::MsgFilterHook;
24
25const MSG_ID_WAKE: u32 = WM_USER;
26
27thread_local! {
28 static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
29 static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
30 if msg.msg == MSG_ID_WAKE {
31 let runnable = unsafe {
32 let runnable_ptr = NonNull::new_unchecked(msg.lparam.0 as *mut _);
33 Runnable::<()>::from_raw(runnable_ptr)
34 };
35 if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
36 PANIC_PAYLOAD.set(Some(panic_payload));
37 }
38 Some(LRESULT(0))
39 } else {
40 None
41 }
42 })
43 .unwrap();
44}
45
46pub struct JoinHandle<T> {
51 task: ManuallyDrop<async_task::Task<T>>,
52}
53
54impl<T> Drop for JoinHandle<T> {
56 fn drop(&mut self) {
57 let task = unsafe { ManuallyDrop::take(&mut self.task) };
58 task.detach();
59 }
60}
61
62impl<T> Future for JoinHandle<T> {
63 type Output = T;
64
65 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
66 pin!(&mut *self.task).poll(cx)
67 }
68}
69
70unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
71 let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
72
73 let (runnable, task) = unsafe {
77 async_task::spawn_unchecked(future, move |runnable: Runnable| {
78 let _ = PostMessageW(
79 Some(hwnd),
80 MSG_ID_WAKE,
81 WPARAM(0),
82 LPARAM(runnable.into_raw().as_ptr() as _),
83 );
84 })
85 };
86
87 runnable.schedule();
89
90 JoinHandle {
91 task: ManuallyDrop::new(task),
92 }
93}
94
95pub fn spawn_local<T: 'static>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
113 unsafe { spawn_unchecked_lifetime(future) }
115}
116
117pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
130 let msg_loop = &MessageLoop::new();
131
132 let task = unsafe {
136 spawn_unchecked_lifetime(async move {
137 let result = future.await;
138 msg_loop.quit();
139 result
140 })
141 };
142
143 msg_loop.run_loop(|_| FilterResult::Forward);
144
145 poll_ready(task).expect("received unexpected quit message")
146}
147
148fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
149 let future = pin!(future);
150 match future.poll(&mut Context::from_waker(Waker::noop())) {
151 Poll::Ready(result) => Ok(result),
152 Poll::Pending => Err(()),
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
158pub enum FilterResult {
159 Forward,
161
162 Drop,
164}
165
166pub struct MessageLoop {
172 quit: Cell<bool>,
173}
174
175impl MessageLoop {
176 fn new() -> Self {
177 Self {
178 quit: Cell::new(false),
179 }
180 }
181
182 fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
183 let executor_hwnd = EXECUTOR_WINDOW.with(|ew| ew.hwnd());
184
185 while !self.quit.get() {
186 unsafe {
187 let mut msg = MaybeUninit::uninit();
188 if GetMessageW(msg.as_mut_ptr(), None, 0, 0).0 == 0 {
189 return;
190 }
191 let msg = msg.assume_init();
192
193 let is_wake_message = msg.hwnd == executor_hwnd && msg.message == MSG_ID_WAKE;
195 if is_wake_message || filter(&msg) == FilterResult::Forward {
196 let _ = TranslateMessage(&msg);
197 DispatchMessageW(&msg);
198 }
199
200 if let Some(panic_payload) = PANIC_PAYLOAD.take() {
201 panic::resume_unwind(panic_payload)
202 }
203 }
204 }
205 }
206
207 pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
229 let msg_loop = MessageLoop::new();
230
231 let _hook = unsafe {
237 MsgFilterHook::register(|msg| {
238 panic::catch_unwind(AssertUnwindSafe(|| {
239 let filter_result = filter(&msg_loop, msg);
240 if msg_loop.quit.get() {
244 let _ = PostMessageW(Some(msg.hwnd), WM_QUIT, WPARAM(0), LPARAM(0));
245 }
246 filter_result == FilterResult::Drop
247 }))
248 .unwrap_or_else(|payload| {
249 PANIC_PAYLOAD.with(|panic_payload| {
250 panic_payload.set(Some(payload));
251 });
252 let _ = PostMessageW(Some(msg.hwnd), WM_QUIT, WPARAM(0), LPARAM(0));
254 false
255 })
256 })
257 };
258 msg_loop.run_loop(|msg| filter(&msg_loop, msg));
259 }
260
261 pub fn quit(&self) {
263 self.quit.set(true);
264 }
265
266 pub fn quit_when_idle(&self) {
268 unsafe { PostQuitMessage(0) };
269 }
270}
271
272#[cfg(test)]
273mod test {
274 use std::future::poll_fn;
275
276 use windows::core::{w, PCWSTR};
277 use windows::Win32::Foundation::HWND;
278
279 use super::*;
280
281 fn post_thread_message(msg: u32) {
282 let _ = unsafe { PostMessageW(None, msg, WPARAM(0), LPARAM(0)) };
283 }
284
285 #[test]
286 #[should_panic]
287 fn panic_in_dispatcher() {
288 post_thread_message(WM_USER);
289 MessageLoop::run(|_, _| panic!());
290 }
291
292 #[test]
293 fn message_loop_quit() {
294 for i in 0..10 {
295 post_thread_message(WM_USER + i);
296 }
297 MessageLoop::run(|msg_loop, msg| {
298 assert_eq!(msg.message, WM_USER);
301 msg_loop.quit();
302 FilterResult::Drop
303 });
304 }
305
306 #[test]
307 fn message_loop_quit_when_idle() {
308 for i in 0..10 {
309 post_thread_message(WM_USER + i);
310 }
311 let expected_msg = Cell::new(0);
312 MessageLoop::run(|msg_loop, msg| {
313 assert_eq!(msg.message, WM_USER + expected_msg.get());
314 expected_msg.set(expected_msg.get() + 1);
315 msg_loop.quit_when_idle();
316 FilterResult::Drop
317 });
318 assert_eq!(expected_msg.get(), 10);
319 }
320
321 #[test]
322 fn nested_block_on() {
323 let count: Cell<usize> = Cell::new(0);
324
325 block_on(async {
326 assert_eq!(count.get(), 0);
327 count.set(count.get() + 1);
328
329 block_on(async {
330 assert_eq!(count.get(), 1);
331 count.set(count.get() + 1);
332 });
333
334 assert_eq!(count.get(), 2);
335 count.set(count.get() + 1);
336 });
337
338 assert_eq!(count.get(), 3);
339 }
340
341 #[test]
342 #[should_panic]
343 fn nested_message_loop() {
344 post_thread_message(WM_USER);
345 MessageLoop::run(|_, _| {
346 MessageLoop::run(|_, _| FilterResult::Drop);
347 FilterResult::Drop
348 });
349 }
350
351 async fn yield_now() {
352 let mut yielded = false;
353 poll_fn(|cx| {
354 if yielded {
355 Poll::Ready(())
356 } else {
357 yielded = true;
358 cx.waker().wake_by_ref();
359 Poll::Pending
360 }
361 })
362 .await;
363 }
364
365 #[test]
366 fn nested_message_loop_block_on() {
367 let inner_executed = Cell::new(false);
368
369 post_thread_message(WM_USER);
370 MessageLoop::run(|msg_loop, _| {
371 block_on(async {
372 inner_executed.set(true);
373 });
374 msg_loop.quit();
375 FilterResult::Forward
376 });
377
378 assert!(inner_executed.get());
379 }
380
381 #[test]
382 fn nested_message_loop_block_on_quit() {
383 post_thread_message(WM_USER);
384 MessageLoop::run(|msg_loop, _| {
385 block_on(async {
386 msg_loop.quit();
387 });
388 FilterResult::Forward
389 });
390 }
391
392 fn window_by_name(name: PCWSTR) -> HWND {
393 unsafe { FindWindowW(None, name) }.unwrap_or_default()
394 }
395
396 #[test]
397 fn running_spawned_with_modal_dialog() {
398 let window_name = w!("running_spawned_with_modal_dialog");
401
402 let task = spawn_local(async move {
403 while window_by_name(window_name).0.is_null() {
405 yield_now().await;
406 }
407
408 for _ in 0..10 {
410 yield_now().await;
411 }
412
413 unsafe {
415 SendMessageW(window_by_name(window_name), WM_CLOSE, Some(WPARAM(0)), Some(LPARAM(0)));
416 }
417 });
418
419 block_on(async {
420 unsafe {
421 MessageBoxW(
422 None,
423 PCWSTR::null(),
424 window_name,
425 MESSAGEBOX_STYLE(0),
426 );
427 }
428 task.await;
429 });
430 }
431
432 #[test]
436 #[should_panic]
437 fn reenter_filter_closure_panic() {
438 let window_name = w!("reenter_filter_closure");
441
442 post_thread_message(WM_USER);
443
444 let running_filter_closure = Cell::new(false);
445 MessageLoop::run(|_, msg| {
446 assert!(
447 !running_filter_closure.replace(true),
448 "Filter closure reentered"
449 );
450
451 if msg.hwnd.0.is_null() && msg.message == WM_USER {
452 unsafe {
453 MessageBoxW(
454 None,
455 PCWSTR::null(),
456 window_name,
457 MESSAGEBOX_STYLE(0),
458 );
459 }
460 }
461
462 running_filter_closure.set(false);
463 FilterResult::Forward
464 });
465 }
466
467 #[test]
468 fn reenter_filter_closure_quit() {
469 let window_name = w!("reenter_filter_closure");
472
473 post_thread_message(WM_USER);
474
475 let running_filter_closure = Cell::new(false);
476 MessageLoop::run(|msg_loop, msg| {
477 if running_filter_closure.replace(true) {
478 msg_loop.quit();
479 }
480
481 if msg.hwnd.0.is_null() && msg.message == WM_USER {
482 unsafe {
483 MessageBoxW(
484 None,
485 PCWSTR::null(),
486 window_name,
487 MESSAGEBOX_STYLE(0),
488 );
489 }
490 }
491
492 running_filter_closure.set(false);
493 FilterResult::Forward
494 });
495 }
496
497 #[test]
498 fn message_loop_with_modal_dialog() {
499 let window_name = w!("message_loop_with_modal_dialog");
502
503 spawn_local(async move {
504 unsafe {
505 MessageBoxW(
506 None,
507 PCWSTR::null(),
508 window_name,
509 MESSAGEBOX_STYLE(0),
510 );
511 }
512 });
513
514 spawn_local(async move {
515 assert!(!window_by_name(window_name).0.is_null());
517
518 for i in 0..10 {
519 post_thread_message(WM_USER + i);
520 yield_now().await;
521 }
522
523 unsafe { SendMessageW(window_by_name(window_name), WM_CLOSE, Some(WPARAM(0)), Some(LPARAM(0))) };
525 });
526
527 let expected_msg = Cell::new(0);
528 MessageLoop::run(|msg_loop, msg| {
529 if msg.hwnd.0.is_null() && msg.message >= WM_USER {
530 assert_eq!(msg.message, WM_USER + expected_msg.get());
531 expected_msg.set(expected_msg.get() + 1);
532 msg_loop.quit_when_idle();
533 FilterResult::Drop
534 } else {
535 FilterResult::Forward
536 }
537 });
538 assert_eq!(expected_msg.get(), 10);
539 }
540
541 #[test]
542 fn reenter_filter_closure_quit_when_idle() {
543 let window_name = w!("reenter_filter_closure");
546
547 post_thread_message(WM_USER);
548
549 let running_filter_closure = Cell::new(false);
550 MessageLoop::run(|msg_loop, msg| {
551 if running_filter_closure.replace(true) {
552 msg_loop.quit_when_idle();
553 }
554
555 if msg.hwnd.0.is_null() && msg.message == WM_USER {
556 unsafe {
557 MessageBoxW(
558 None,
559 PCWSTR::null(),
560 window_name,
561 MESSAGEBOX_STYLE(0),
562 );
563 }
564 }
565
566 running_filter_closure.set(false);
567 FilterResult::Forward
568 });
569 }
570
571 #[test]
572 fn disallow_wake_message_filtering() {
573 let msg_loop = MessageLoop::new();
574 let msg_loop = Box::leak(Box::new(msg_loop));
575
576 let custom_wnd = Window::new(WindowType::MessageOnly, (), |_, msg| {
578 assert_ne!(msg.msg, MSG_ID_WAKE);
579 None
580 })
581 .unwrap();
582 unsafe {
583 let _ = PostMessageW(Some(custom_wnd.hwnd()), MSG_ID_WAKE, WPARAM(0), LPARAM(0));
584 }
585
586 spawn_local(async {
589 yield_now().await;
590 yield_now().await;
591 yield_now().await;
592 msg_loop.quit();
593 });
594
595 msg_loop.run_loop(|msg| {
596 if msg.message == MSG_ID_WAKE {
598 assert_ne!(msg.hwnd, EXECUTOR_WINDOW.with(|ew| ew.hwnd()));
599 FilterResult::Drop
600 } else {
601 FilterResult::Forward
602 }
603 });
604 }
605}