1#![doc = include_str!("../README.md")]
2
3pub mod util;
4
5use std::{
6 any::Any,
7 cell::Cell,
8 future::Future,
9 mem::{ManuallyDrop, MaybeUninit},
10 panic::{self, AssertUnwindSafe},
11 pin::{pin, Pin},
12 ptr::{self, NonNull},
13 task::{Context, Poll, Waker},
14};
15
16use async_task::Runnable;
17use util::{Window, WindowType};
18use windows_sys::Win32::UI::WindowsAndMessaging::*;
19
20use crate::util::MsgFilterHook;
21
22const MSG_ID_WAKE: u32 = WM_USER;
23
24thread_local! {
25 static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
26 static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
27 if msg.msg == MSG_ID_WAKE {
28 let runnable = unsafe {
29 let runnable_ptr = NonNull::new_unchecked(msg.lparam as *mut _);
30 Runnable::<()>::from_raw(runnable_ptr)
31 };
32 if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
33 PANIC_PAYLOAD.set(Some(panic_payload));
34 }
35 Some(0)
36 } else {
37 None
38 }
39 })
40 .unwrap();
41}
42
43pub struct JoinHandle<T> {
48 task: ManuallyDrop<async_task::Task<T>>,
49}
50
51impl<T> Drop for JoinHandle<T> {
53 fn drop(&mut self) {
54 let task = unsafe { ManuallyDrop::take(&mut self.task) };
55 task.detach();
56 }
57}
58
59impl<T> Future for JoinHandle<T> {
60 type Output = T;
61
62 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63 pin!(&mut *self.task).poll(cx)
64 }
65}
66
67unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
68 let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
69
70 let (runnable, task) = unsafe {
74 async_task::spawn_unchecked(future, move |runnable: Runnable| {
75 PostMessageA(hwnd, MSG_ID_WAKE, 0, runnable.into_raw().as_ptr() as _);
76 })
77 };
78
79 runnable.schedule();
81
82 JoinHandle {
83 task: ManuallyDrop::new(task),
84 }
85}
86
87pub fn spawn_local<T: 'static>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
105 unsafe { spawn_unchecked_lifetime(future) }
107}
108
109pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
122 let msg_loop = &MessageLoop::new();
123
124 let task = unsafe {
128 spawn_unchecked_lifetime(async move {
129 let result = future.await;
130 msg_loop.quit();
131 result
132 })
133 };
134
135 msg_loop.run_loop(|_| FilterResult::Forward);
136
137 poll_ready(task).expect("received unexpected quit message")
138}
139
140fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
141 let future = pin!(future);
142 match future.poll(&mut Context::from_waker(Waker::noop())) {
143 Poll::Ready(result) => Ok(result),
144 Poll::Pending => Err(()),
145 }
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150pub enum FilterResult {
151 Forward,
153
154 Drop,
156}
157
158pub struct MessageLoop {
164 quit: Cell<bool>,
165}
166
167impl MessageLoop {
168 fn new() -> Self {
169 Self {
170 quit: Cell::new(false),
171 }
172 }
173
174 fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
175 let executor_hwnd = EXECUTOR_WINDOW.with(|ew| ew.hwnd());
176
177 while !self.quit.get() {
178 unsafe {
179 let mut msg = MaybeUninit::uninit();
180 if GetMessageA(msg.as_mut_ptr(), ptr::null_mut(), 0, 0) == 0 {
181 return;
182 }
183 let msg = msg.assume_init();
184
185 let is_wake_message = msg.hwnd == executor_hwnd && msg.message == MSG_ID_WAKE;
187 if is_wake_message || filter(&msg) == FilterResult::Forward {
188 TranslateMessage(&msg);
189 DispatchMessageA(&msg);
190 }
191
192 if let Some(panic_payload) = PANIC_PAYLOAD.take() {
193 panic::resume_unwind(panic_payload)
194 }
195 }
196 }
197 }
198
199 pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
221 let msg_loop = MessageLoop::new();
222
223 let _hook = unsafe {
229 MsgFilterHook::register(|msg| {
230 panic::catch_unwind(AssertUnwindSafe(|| {
231 let filter_result = filter(&msg_loop, msg);
232 if msg_loop.quit.get() {
236 PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
237 }
238 filter_result == FilterResult::Drop
239 }))
240 .unwrap_or_else(|payload| {
241 PANIC_PAYLOAD.with(|panic_payload| {
242 panic_payload.set(Some(payload));
243 });
244 PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
246 false
247 })
248 })
249 };
250 msg_loop.run_loop(|msg| filter(&msg_loop, msg));
251 }
252
253 pub fn quit(&self) {
255 self.quit.set(true);
256 }
257
258 pub fn quit_when_idle(&self) {
260 unsafe { PostQuitMessage(0) };
261 }
262}
263
264#[cfg(test)]
265mod test {
266 use std::{ffi::CStr, future::poll_fn};
267
268 use windows_sys::Win32::Foundation::HWND;
269
270 use super::*;
271
272 fn post_thread_message(msg: u32) {
273 unsafe { PostMessageA(ptr::null_mut(), msg, 0, 0) };
274 }
275
276 #[test]
277 #[should_panic]
278 fn panic_in_dispatcher() {
279 post_thread_message(WM_USER);
280 MessageLoop::run(|_, _| panic!());
281 }
282
283 #[test]
284 fn message_loop_quit() {
285 for i in 0..10 {
286 post_thread_message(WM_USER + i);
287 }
288 MessageLoop::run(|msg_loop, msg| {
289 assert_eq!(msg.message, WM_USER);
292 msg_loop.quit();
293 FilterResult::Drop
294 });
295 }
296
297 #[test]
298 fn message_loop_quit_when_idle() {
299 for i in 0..10 {
300 post_thread_message(WM_USER + i);
301 }
302 let expected_msg = Cell::new(0);
303 MessageLoop::run(|msg_loop, msg| {
304 assert_eq!(msg.message, WM_USER + expected_msg.get());
305 expected_msg.set(expected_msg.get() + 1);
306 msg_loop.quit_when_idle();
307 FilterResult::Drop
308 });
309 assert_eq!(expected_msg.get(), 10);
310 }
311
312 #[test]
313 fn nested_block_on() {
314 let count: Cell<usize> = Cell::new(0);
315
316 block_on(async {
317 assert_eq!(count.get(), 0);
318 count.set(count.get() + 1);
319
320 block_on(async {
321 assert_eq!(count.get(), 1);
322 count.set(count.get() + 1);
323 });
324
325 assert_eq!(count.get(), 2);
326 count.set(count.get() + 1);
327 });
328
329 assert_eq!(count.get(), 3);
330 }
331
332 #[test]
333 #[should_panic]
334 fn nested_message_loop() {
335 post_thread_message(WM_USER);
336 MessageLoop::run(|_, _| {
337 MessageLoop::run(|_, _| FilterResult::Drop);
338 FilterResult::Drop
339 });
340 }
341
342 async fn yield_now() {
343 let mut yielded = false;
344 poll_fn(|cx| {
345 if yielded {
346 Poll::Ready(())
347 } else {
348 yielded = true;
349 cx.waker().wake_by_ref();
350 Poll::Pending
351 }
352 })
353 .await;
354 }
355
356 #[test]
357 fn nested_message_loop_block_on() {
358 let inner_executed = Cell::new(false);
359
360 post_thread_message(WM_USER);
361 MessageLoop::run(|msg_loop, _| {
362 block_on(async {
363 inner_executed.set(true);
364 });
365 msg_loop.quit();
366 FilterResult::Forward
367 });
368
369 assert!(inner_executed.get());
370 }
371
372 #[test]
373 fn nested_message_loop_block_on_quit() {
374 post_thread_message(WM_USER);
375 MessageLoop::run(|msg_loop, _| {
376 block_on(async {
377 msg_loop.quit();
378 });
379 FilterResult::Forward
380 });
381 }
382
383 fn window_by_name(name: &CStr) -> HWND {
384 unsafe { FindWindowA(ptr::null_mut(), name.as_ptr() as _) }
385 }
386
387 #[test]
388 fn running_spawned_with_modal_dialog() {
389 let window_name = c"running_spawned_with_modal_dialog";
392
393 let task = spawn_local(async {
394 while window_by_name(window_name).is_null() {
396 yield_now().await;
397 }
398
399 for _ in 0..10 {
401 yield_now().await;
402 }
403
404 unsafe {
406 SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0);
407 }
408 });
409
410 block_on(async {
411 unsafe {
412 MessageBoxA(
413 ptr::null_mut(),
414 ptr::null_mut(),
415 window_name.as_ptr() as _,
416 0,
417 );
418 }
419 task.await;
420 });
421 }
422
423 #[test]
427 #[should_panic]
428 fn reenter_filter_closure_panic() {
429 let window_name = c"reenter_filter_closure";
432
433 post_thread_message(WM_USER);
434
435 let running_filter_closure = Cell::new(false);
436 MessageLoop::run(|_, msg| {
437 assert!(
438 !running_filter_closure.replace(true),
439 "Filter closure reentered"
440 );
441
442 if msg.hwnd.is_null() && msg.message == WM_USER {
443 unsafe {
444 MessageBoxA(
445 ptr::null_mut(),
446 ptr::null_mut(),
447 window_name.as_ptr() as _,
448 0,
449 );
450 }
451 }
452
453 running_filter_closure.set(false);
454 FilterResult::Forward
455 });
456 }
457
458 #[test]
459 fn reenter_filter_closure_quit() {
460 let window_name = c"reenter_filter_closure";
463
464 post_thread_message(WM_USER);
465
466 let running_filter_closure = Cell::new(false);
467 MessageLoop::run(|msg_loop, msg| {
468 if running_filter_closure.replace(true) {
469 msg_loop.quit();
470 }
471
472 if msg.hwnd.is_null() && msg.message == WM_USER {
473 unsafe {
474 MessageBoxA(
475 ptr::null_mut(),
476 ptr::null_mut(),
477 window_name.as_ptr() as _,
478 0,
479 );
480 }
481 }
482
483 running_filter_closure.set(false);
484 FilterResult::Forward
485 });
486 }
487
488 #[test]
489 fn message_loop_with_modal_dialog() {
490 let window_name = c"message_loop_with_modal_dialog";
493
494 spawn_local(async {
495 unsafe {
496 MessageBoxA(
497 ptr::null_mut(),
498 ptr::null_mut(),
499 window_name.as_ptr() as _,
500 0,
501 );
502 }
503 });
504
505 spawn_local(async {
506 assert!(!window_by_name(window_name).is_null());
508
509 for i in 0..10 {
510 post_thread_message(WM_USER + i);
511 yield_now().await;
512 }
513
514 unsafe { SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0) };
516 });
517
518 let expected_msg = Cell::new(0);
519 MessageLoop::run(|msg_loop, msg| {
520 if msg.hwnd.is_null() && msg.message >= WM_USER {
521 assert_eq!(msg.message, WM_USER + expected_msg.get());
522 expected_msg.set(expected_msg.get() + 1);
523 msg_loop.quit_when_idle();
524 FilterResult::Drop
525 } else {
526 FilterResult::Forward
527 }
528 });
529 assert_eq!(expected_msg.get(), 10);
530 }
531
532 #[test]
533 fn reenter_filter_closure_quit_when_idle() {
534 let window_name = c"reenter_filter_closure";
537
538 post_thread_message(WM_USER);
539
540 let running_filter_closure = Cell::new(false);
541 MessageLoop::run(|msg_loop, msg| {
542 if running_filter_closure.replace(true) {
543 msg_loop.quit_when_idle();
544 }
545
546 if msg.hwnd.is_null() && msg.message == WM_USER {
547 unsafe {
548 MessageBoxA(
549 ptr::null_mut(),
550 ptr::null_mut(),
551 window_name.as_ptr() as _,
552 0,
553 );
554 }
555 }
556
557 running_filter_closure.set(false);
558 FilterResult::Forward
559 });
560 }
561
562 #[test]
563 fn disallow_wake_message_filtering() {
564 let msg_loop = MessageLoop::new();
565 let msg_loop = Box::leak(Box::new(msg_loop));
566
567 let custom_wnd = Window::new(WindowType::MessageOnly, (), |_, msg| {
569 assert_ne!(msg.msg, MSG_ID_WAKE);
570 None
571 })
572 .unwrap();
573 unsafe {
574 PostMessageA(custom_wnd.hwnd(), MSG_ID_WAKE, 0, 0);
575 }
576
577 spawn_local(async {
580 yield_now().await;
581 yield_now().await;
582 yield_now().await;
583 msg_loop.quit();
584 });
585
586 msg_loop.run_loop(|msg| {
587 if msg.message == MSG_ID_WAKE {
589 assert_ne!(msg.hwnd, EXECUTOR_WINDOW.with(|ew| ew.hwnd()));
590 FilterResult::Drop
591 } else {
592 FilterResult::Forward
593 }
594 });
595 }
596}