1#![doc = include_str!("../README.md")]
2
3pub mod util;
4
5use std::{
6 any::Any,
7 cell::Cell,
8 future::Future,
9 mem::{ManuallyDrop, MaybeUninit},
10 panic::{self, AssertUnwindSafe},
11 pin::{pin, Pin},
12 ptr::{self, NonNull},
13 task::{Context, Poll, Waker},
14};
15
16use async_task::Runnable;
17use util::{Window, WindowType};
18use windows_sys::Win32::UI::WindowsAndMessaging::*;
19
20use crate::util::MsgFilterHook;
21
22const MSG_ID_WAKE: u32 = WM_USER;
23
24thread_local! {
25 static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
26 static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
27 if msg.msg == MSG_ID_WAKE {
28 let runnable = unsafe {
29 let runnable_ptr = NonNull::new_unchecked(msg.lparam as *mut _);
30 Runnable::<()>::from_raw(runnable_ptr)
31 };
32 if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
33 PANIC_PAYLOAD.set(Some(panic_payload));
34 }
35 Some(0)
36 } else {
37 None
38 }
39 })
40 .unwrap();
41}
42
43pub struct JoinHandle<T> {
48 task: ManuallyDrop<async_task::Task<T>>,
49}
50
51impl<T> Drop for JoinHandle<T> {
53 fn drop(&mut self) {
54 let task = unsafe { ManuallyDrop::take(&mut self.task) };
55 task.detach();
56 }
57}
58
59impl<T> Future for JoinHandle<T> {
60 type Output = T;
61
62 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63 pin!(&mut *self.task).poll(cx)
64 }
65}
66
67unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
68 let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
69
70 let (runnable, task) = unsafe {
74 async_task::spawn_unchecked(future, move |runnable: Runnable| {
75 PostMessageA(hwnd, MSG_ID_WAKE, 0, runnable.into_raw().as_ptr() as _);
76 })
77 };
78
79 runnable.schedule();
81
82 JoinHandle {
83 task: ManuallyDrop::new(task),
84 }
85}
86
87pub fn spawn_local<T>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
93 unsafe { spawn_unchecked_lifetime(future) }
95}
96
97pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
110 let msg_loop = &MessageLoop::new();
111
112 let task = unsafe {
116 spawn_unchecked_lifetime(async move {
117 let result = future.await;
118 msg_loop.quit();
119 result
120 })
121 };
122
123 msg_loop.run_loop(|_| FilterResult::Forward);
124
125 poll_ready(task).expect("received unexpected quit message")
126}
127
128fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
129 let future = pin!(future);
130 match future.poll(&mut Context::from_waker(Waker::noop())) {
131 Poll::Ready(result) => Ok(result),
132 Poll::Pending => Err(()),
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum FilterResult {
139 Forward,
141
142 Drop,
144}
145
146pub struct MessageLoop {
152 quit: Cell<bool>,
153}
154
155impl MessageLoop {
156 fn new() -> Self {
157 Self {
158 quit: Cell::new(false),
159 }
160 }
161
162 fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
163 let executor_hwnd = EXECUTOR_WINDOW.with(|ew| ew.hwnd());
164
165 while !self.quit.get() {
166 unsafe {
167 let mut msg = MaybeUninit::uninit();
168 if GetMessageA(msg.as_mut_ptr(), ptr::null_mut(), 0, 0) == 0 {
169 return;
170 }
171 let msg = msg.assume_init();
172
173 let is_wake_message = msg.hwnd == executor_hwnd && msg.message == MSG_ID_WAKE;
175 if is_wake_message || filter(&msg) == FilterResult::Forward {
176 TranslateMessage(&msg);
177 DispatchMessageA(&msg);
178 }
179
180 if let Some(panic_payload) = PANIC_PAYLOAD.take() {
181 panic::resume_unwind(panic_payload)
182 }
183 }
184 }
185 }
186
187 pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
214 let msg_loop = MessageLoop::new();
215
216 let _hook = unsafe {
222 MsgFilterHook::register(|msg| {
223 panic::catch_unwind(AssertUnwindSafe(|| {
224 let filter_result = filter(&msg_loop, msg);
225 if msg_loop.quit.get() {
229 PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
230 }
231 filter_result == FilterResult::Drop
232 }))
233 .unwrap_or_else(|payload| {
234 PANIC_PAYLOAD.with(|panic_payload| {
235 panic_payload.set(Some(payload));
236 });
237 PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
239 false
240 })
241 })
242 };
243 msg_loop.run_loop(|msg| filter(&msg_loop, msg));
244 }
245
246 pub fn quit(&self) {
248 self.quit.set(true);
249 }
250
251 pub fn quit_when_idle(&self) {
253 unsafe { PostQuitMessage(0) };
254 }
255}
256
257#[cfg(test)]
258mod test {
259 use std::{ffi::CStr, future::poll_fn};
260
261 use windows_sys::Win32::Foundation::HWND;
262
263 use super::*;
264
265 fn post_thread_message(msg: u32) {
266 unsafe { PostMessageA(ptr::null_mut(), msg, 0, 0) };
267 }
268
269 #[test]
270 #[should_panic]
271 fn panic_in_dispatcher() {
272 post_thread_message(WM_USER);
273 MessageLoop::run(|_, _| panic!());
274 }
275
276 #[test]
277 fn message_loop_quit() {
278 for i in 0..10 {
279 post_thread_message(WM_USER + i);
280 }
281 MessageLoop::run(|msg_loop, msg| {
282 assert_eq!(msg.message, WM_USER);
285 msg_loop.quit();
286 FilterResult::Drop
287 });
288 }
289
290 #[test]
291 fn message_loop_quit_when_idle() {
292 for i in 0..10 {
293 post_thread_message(WM_USER + i);
294 }
295 let expected_msg = Cell::new(0);
296 MessageLoop::run(|msg_loop, msg| {
297 assert_eq!(msg.message, WM_USER + expected_msg.get());
298 expected_msg.set(expected_msg.get() + 1);
299 msg_loop.quit_when_idle();
300 FilterResult::Drop
301 });
302 assert_eq!(expected_msg.get(), 10);
303 }
304
305 #[test]
306 fn nested_block_on() {
307 let count: Cell<usize> = Cell::new(0);
308
309 block_on(async {
310 assert_eq!(count.get(), 0);
311 count.set(count.get() + 1);
312
313 block_on(async {
314 assert_eq!(count.get(), 1);
315 count.set(count.get() + 1);
316 });
317
318 assert_eq!(count.get(), 2);
319 count.set(count.get() + 1);
320 });
321
322 assert_eq!(count.get(), 3);
323 }
324
325 #[test]
326 #[should_panic]
327 fn nested_message_loop() {
328 post_thread_message(WM_USER);
329 MessageLoop::run(|_, _| {
330 MessageLoop::run(|_, _| FilterResult::Drop);
331 FilterResult::Drop
332 });
333 }
334
335 async fn yield_now() {
336 let mut yielded = false;
337 poll_fn(|cx| {
338 if yielded {
339 Poll::Ready(())
340 } else {
341 yielded = true;
342 cx.waker().wake_by_ref();
343 Poll::Pending
344 }
345 })
346 .await;
347 }
348
349 #[test]
350 fn nested_message_loop_block_on() {
351 let inner_executed = Cell::new(false);
352
353 post_thread_message(WM_USER);
354 MessageLoop::run(|msg_loop, _| {
355 block_on(async {
356 inner_executed.set(true);
357 });
358 msg_loop.quit();
359 FilterResult::Forward
360 });
361
362 assert!(inner_executed.get());
363 }
364
365 #[test]
366 fn nested_message_loop_block_on_quit() {
367 post_thread_message(WM_USER);
368 MessageLoop::run(|msg_loop, _| {
369 block_on(async {
370 msg_loop.quit();
371 });
372 FilterResult::Forward
373 });
374 }
375
376 fn window_by_name(name: &CStr) -> HWND {
377 unsafe { FindWindowA(ptr::null_mut(), name.as_ptr() as _) }
378 }
379
380 #[test]
381 fn running_spawned_with_modal_dialog() {
382 let window_name = c"running_spawned_with_modal_dialog";
385
386 let task = spawn_local(async {
387 while window_by_name(window_name).is_null() {
389 yield_now().await;
390 }
391
392 for _ in 0..10 {
394 yield_now().await;
395 }
396
397 unsafe {
399 SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0);
400 }
401 });
402
403 block_on(async {
404 unsafe {
405 MessageBoxA(
406 ptr::null_mut(),
407 ptr::null_mut(),
408 window_name.as_ptr() as _,
409 0,
410 );
411 }
412 task.await;
413 });
414 }
415
416 #[test]
420 #[should_panic]
421 fn reenter_filter_closure_panic() {
422 let window_name = c"reenter_filter_closure";
425
426 post_thread_message(WM_USER);
427
428 let running_filter_closure = Cell::new(false);
429 MessageLoop::run(|_, msg| {
430 assert!(
431 !running_filter_closure.replace(true),
432 "Filter closure reentered"
433 );
434
435 if msg.hwnd.is_null() && msg.message == WM_USER {
436 unsafe {
437 MessageBoxA(
438 ptr::null_mut(),
439 ptr::null_mut(),
440 window_name.as_ptr() as _,
441 0,
442 );
443 }
444 }
445
446 running_filter_closure.set(false);
447 FilterResult::Forward
448 });
449 }
450
451 #[test]
452 fn reenter_filter_closure_quit() {
453 let window_name = c"reenter_filter_closure";
456
457 post_thread_message(WM_USER);
458
459 let running_filter_closure = Cell::new(false);
460 MessageLoop::run(|msg_loop, msg| {
461 if running_filter_closure.replace(true) {
462 msg_loop.quit();
463 }
464
465 if msg.hwnd.is_null() && msg.message == WM_USER {
466 unsafe {
467 MessageBoxA(
468 ptr::null_mut(),
469 ptr::null_mut(),
470 window_name.as_ptr() as _,
471 0,
472 );
473 }
474 }
475
476 running_filter_closure.set(false);
477 FilterResult::Forward
478 });
479 }
480
481 #[test]
482 fn message_loop_with_modal_dialog() {
483 let window_name = c"message_loop_with_modal_dialog";
486
487 spawn_local(async {
488 unsafe {
489 MessageBoxA(
490 ptr::null_mut(),
491 ptr::null_mut(),
492 window_name.as_ptr() as _,
493 0,
494 );
495 }
496 });
497
498 spawn_local(async {
499 assert!(!window_by_name(window_name).is_null());
501
502 for i in 0..10 {
503 post_thread_message(WM_USER + i);
504 yield_now().await;
505 }
506
507 unsafe { SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0) };
509 });
510
511 let expected_msg = Cell::new(0);
512 MessageLoop::run(|msg_loop, msg| {
513 if msg.hwnd.is_null() && msg.message >= WM_USER {
514 assert_eq!(msg.message, WM_USER + expected_msg.get());
515 expected_msg.set(expected_msg.get() + 1);
516 msg_loop.quit_when_idle();
517 FilterResult::Drop
518 } else {
519 FilterResult::Forward
520 }
521 });
522 assert_eq!(expected_msg.get(), 10);
523 }
524
525 #[test]
526 fn reenter_filter_closure_quit_when_idle() {
527 let window_name = c"reenter_filter_closure";
530
531 post_thread_message(WM_USER);
532
533 let running_filter_closure = Cell::new(false);
534 MessageLoop::run(|msg_loop, msg| {
535 if running_filter_closure.replace(true) {
536 msg_loop.quit_when_idle();
537 }
538
539 if msg.hwnd.is_null() && msg.message == WM_USER {
540 unsafe {
541 MessageBoxA(
542 ptr::null_mut(),
543 ptr::null_mut(),
544 window_name.as_ptr() as _,
545 0,
546 );
547 }
548 }
549
550 running_filter_closure.set(false);
551 FilterResult::Forward
552 });
553 }
554
555 #[test]
556 fn disallow_wake_message_filtering() {
557 let msg_loop = MessageLoop::new();
558 let msg_loop = Box::leak(Box::new(msg_loop));
559
560 let custom_wnd = Window::new(WindowType::MessageOnly, (), |_, msg| {
562 assert_ne!(msg.msg, MSG_ID_WAKE);
563 None
564 })
565 .unwrap();
566 unsafe {
567 PostMessageA(custom_wnd.hwnd(), MSG_ID_WAKE, 0, 0);
568 }
569
570 spawn_local(async {
573 yield_now().await;
574 yield_now().await;
575 yield_now().await;
576 msg_loop.quit();
577 });
578
579 msg_loop.run_loop(|msg| {
580 if msg.message == MSG_ID_WAKE {
582 assert_ne!(msg.hwnd, EXECUTOR_WINDOW.with(|ew| ew.hwnd()));
583 FilterResult::Drop
584 } else {
585 FilterResult::Forward
586 }
587 });
588 }
589}