1#![doc = include_str!("../README.md")]
2
3pub mod util;
4
5use std::{
6 any::Any,
7 cell::Cell,
8 future::Future,
9 mem::{ManuallyDrop, MaybeUninit},
10 panic::{self, AssertUnwindSafe},
11 pin::{pin, Pin},
12 ptr::{self, NonNull},
13 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
14};
15
16use async_task::Runnable;
17use util::{Window, WindowType};
18use windows_sys::Win32::UI::WindowsAndMessaging::*;
19
20use crate::util::MsgFilterHook;
21
22const MSG_ID_WAKE: u32 = WM_USER;
23
24thread_local! {
25 static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
26 static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
27 if msg.msg == MSG_ID_WAKE {
28 let runnable = unsafe {
29 let runnable_ptr = NonNull::new_unchecked(msg.lparam as *mut _);
30 Runnable::<()>::from_raw(runnable_ptr)
31 };
32 if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
33 PANIC_PAYLOAD.set(Some(panic_payload));
34 }
35 Some(0)
36 } else {
37 None
38 }
39 })
40 .unwrap();
41}
42
43pub struct JoinHandle<T> {
48 task: ManuallyDrop<async_task::Task<T>>,
49}
50
51impl<T> Drop for JoinHandle<T> {
53 fn drop(&mut self) {
54 let task = unsafe { ManuallyDrop::take(&mut self.task) };
55 task.detach();
56 }
57}
58
59impl<T> Future for JoinHandle<T> {
60 type Output = T;
61
62 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63 pin!(&mut *self.task).poll(cx)
64 }
65}
66
67unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
68 let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
69
70 let (runnable, task) = unsafe {
74 async_task::spawn_unchecked(future, move |runnable: Runnable| {
75 PostMessageA(hwnd, MSG_ID_WAKE, 0, runnable.into_raw().as_ptr() as _);
76 })
77 };
78
79 runnable.schedule();
81
82 JoinHandle {
83 task: ManuallyDrop::new(task),
84 }
85}
86
87pub fn spawn_local<T>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
93 unsafe { spawn_unchecked_lifetime(future) }
95}
96
97pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
110 let msg_loop = &MessageLoop::new();
111
112 let task = unsafe {
116 spawn_unchecked_lifetime(async move {
117 let result = future.await;
118 msg_loop.quit();
119 result
120 })
121 };
122
123 msg_loop.run_loop(|_| FilterResult::Forward);
124
125 poll_ready(task).expect("received unexpected quit message")
126}
127
128fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
129 const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
131 |_| RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE),
132 |_| (),
133 |_| (),
134 |_| (),
135 );
136 let noop_waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
137 let future = pin!(future);
138 if let Poll::Ready(result) = future.poll(&mut Context::from_waker(&noop_waker)) {
139 Ok(result)
140 } else {
141 Err(())
142 }
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum FilterResult {
148 Forward,
150
151 Drop,
153}
154
155pub struct MessageLoop {
161 quit: Cell<bool>,
162}
163
164impl MessageLoop {
165 fn new() -> Self {
166 Self {
167 quit: Cell::new(false),
168 }
169 }
170
171 fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
172 while !self.quit.get() {
173 unsafe {
174 let mut msg = MaybeUninit::uninit();
175 if GetMessageA(msg.as_mut_ptr(), ptr::null_mut(), 0, 0) == 0 {
176 return;
177 }
178 let msg = msg.assume_init();
179
180 if filter(&msg) == FilterResult::Forward {
181 TranslateMessage(&msg);
182 DispatchMessageA(&msg);
183 }
184 if let Some(panic_payload) = PANIC_PAYLOAD.take() {
185 panic::resume_unwind(panic_payload)
186 }
187 }
188 }
189 }
190
191 pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
216 let msg_loop = MessageLoop::new();
217
218 let _hook = unsafe {
224 MsgFilterHook::register(|msg| {
225 panic::catch_unwind(AssertUnwindSafe(|| {
226 let filter_result = filter(&msg_loop, msg);
227 if msg_loop.quit.get() {
231 PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
232 }
233 filter_result == FilterResult::Drop
234 }))
235 .unwrap_or_else(|payload| {
236 PANIC_PAYLOAD.with(|panic_payload| {
237 panic_payload.set(Some(payload));
238 });
239 PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
241 false
242 })
243 })
244 };
245 msg_loop.run_loop(|msg| filter(&msg_loop, msg));
246 }
247
248 pub fn quit(&self) {
250 self.quit.set(true);
251 }
252
253 pub fn quit_when_idle(&self) {
255 unsafe { PostQuitMessage(0) };
256 }
257}
258
259#[cfg(test)]
260mod test {
261 use std::{ffi::CStr, future::poll_fn};
262
263 use windows_sys::Win32::Foundation::HWND;
264
265 use super::*;
266
267 fn post_thread_message(msg: u32) {
268 unsafe { PostMessageA(ptr::null_mut(), msg, 0, 0) };
269 }
270
271 #[test]
272 #[should_panic]
273 fn panic_in_dispatcher() {
274 post_thread_message(WM_USER);
275 MessageLoop::run(|_, _| panic!());
276 }
277
278 #[test]
279 fn message_loop_quit() {
280 for i in 0..10 {
281 post_thread_message(WM_USER + i);
282 }
283 MessageLoop::run(|msg_loop, msg| {
284 assert_eq!(msg.message, WM_USER);
287 msg_loop.quit();
288 FilterResult::Drop
289 });
290 }
291
292 #[test]
293 fn message_loop_quit_when_idle() {
294 for i in 0..10 {
295 post_thread_message(WM_USER + i);
296 }
297 let expected_msg = Cell::new(0);
298 MessageLoop::run(|msg_loop, msg| {
299 assert_eq!(msg.message, WM_USER + expected_msg.get());
300 expected_msg.set(expected_msg.get() + 1);
301 msg_loop.quit_when_idle();
302 FilterResult::Drop
303 });
304 assert_eq!(expected_msg.get(), 10);
305 }
306
307 #[test]
308 fn nested_block_on() {
309 let count: Cell<usize> = Cell::new(0);
310
311 block_on(async {
312 assert_eq!(count.get(), 0);
313 count.set(count.get() + 1);
314
315 block_on(async {
316 assert_eq!(count.get(), 1);
317 count.set(count.get() + 1);
318 });
319
320 assert_eq!(count.get(), 2);
321 count.set(count.get() + 1);
322 });
323
324 assert_eq!(count.get(), 3);
325 }
326
327 #[test]
328 #[should_panic]
329 fn nested_message_loop() {
330 post_thread_message(WM_USER);
331 MessageLoop::run(|_, _| {
332 MessageLoop::run(|_, _| FilterResult::Drop);
333 FilterResult::Drop
334 });
335 }
336
337 async fn yield_now() {
338 let mut yielded = false;
339 poll_fn(|cx| {
340 if yielded {
341 Poll::Ready(())
342 } else {
343 yielded = true;
344 cx.waker().wake_by_ref();
345 Poll::Pending
346 }
347 })
348 .await;
349 }
350
351 #[test]
352 fn nested_message_loop_block_on() {
353 let inner_executed = Cell::new(false);
354
355 post_thread_message(WM_USER);
356 MessageLoop::run(|msg_loop, _| {
357 block_on(async {
358 inner_executed.set(true);
359 });
360 msg_loop.quit();
361 FilterResult::Forward
362 });
363
364 assert!(inner_executed.get());
365 }
366
367 #[test]
368 fn nested_message_loop_block_on_quit() {
369 post_thread_message(WM_USER);
370 MessageLoop::run(|msg_loop, _| {
371 block_on(async {
372 msg_loop.quit();
373 });
374 FilterResult::Forward
375 });
376 }
377
378 fn window_by_name(name: &CStr) -> HWND {
379 unsafe { FindWindowA(ptr::null_mut(), name.as_ptr() as _) }
380 }
381
382 #[test]
383 fn running_spawned_with_modal_dialog() {
384 let window_name = c"running_spawned_with_modal_dialog";
387
388 let task = spawn_local(async {
389 while window_by_name(window_name).is_null() {
391 yield_now().await;
392 }
393
394 for _ in 0..10 {
396 yield_now().await;
397 }
398
399 unsafe {
401 SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0);
402 }
403 });
404
405 block_on(async {
406 unsafe {
407 MessageBoxA(
408 ptr::null_mut(),
409 ptr::null_mut(),
410 window_name.as_ptr() as _,
411 0,
412 );
413 }
414 task.await;
415 });
416 }
417
418 #[test]
419 #[should_panic]
420 fn reenter_filter_closure_panic() {
421 let window_name = c"reenter_filter_closure";
424
425 post_thread_message(WM_USER);
426
427 let running_filter_closure = Cell::new(false);
428 MessageLoop::run(|_, msg| {
429 assert!(
430 !running_filter_closure.replace(true),
431 "Filter closure reentered"
432 );
433
434 if msg.hwnd.is_null() && msg.message == WM_USER {
435 unsafe {
436 MessageBoxA(
437 ptr::null_mut(),
438 ptr::null_mut(),
439 window_name.as_ptr() as _,
440 0,
441 );
442 }
443 }
444
445 running_filter_closure.set(false);
446 FilterResult::Forward
447 });
448 }
449
450 #[test]
451 fn reenter_filter_closure_quit() {
452 let window_name = c"reenter_filter_closure";
455
456 post_thread_message(WM_USER);
457
458 let running_filter_closure = Cell::new(false);
459 MessageLoop::run(|msg_loop, msg| {
460 if running_filter_closure.replace(true) {
461 msg_loop.quit();
462 }
463
464 if msg.hwnd.is_null() && msg.message == WM_USER {
465 unsafe {
466 MessageBoxA(
467 ptr::null_mut(),
468 ptr::null_mut(),
469 window_name.as_ptr() as _,
470 0,
471 );
472 }
473 }
474
475 running_filter_closure.set(false);
476 FilterResult::Forward
477 });
478 }
479
480 #[test]
481 fn message_loop_with_modal_dialog() {
482 let window_name = c"message_loop_with_modal_dialog";
485
486 spawn_local(async {
487 unsafe {
488 MessageBoxA(
489 ptr::null_mut(),
490 ptr::null_mut(),
491 window_name.as_ptr() as _,
492 0,
493 );
494 }
495 });
496
497 spawn_local(async {
498 assert!(!window_by_name(window_name).is_null());
500
501 for i in 0..10 {
502 post_thread_message(WM_USER + i);
503 yield_now().await;
504 }
505
506 unsafe { SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0) };
508 });
509
510 let expected_msg = Cell::new(0);
511 MessageLoop::run(|msg_loop, msg| {
512 if msg.hwnd.is_null() && msg.message >= WM_USER {
513 assert_eq!(msg.message, WM_USER + expected_msg.get());
514 expected_msg.set(expected_msg.get() + 1);
515 msg_loop.quit_when_idle();
516 FilterResult::Drop
517 } else {
518 FilterResult::Forward
519 }
520 });
521 assert_eq!(expected_msg.get(), 10);
522 }
523
524 #[test]
525 fn reenter_filter_closure_quit_when_idle() {
526 let window_name = c"reenter_filter_closure";
529
530 post_thread_message(WM_USER);
531
532 let running_filter_closure = Cell::new(false);
533 MessageLoop::run(|msg_loop, msg| {
534 if running_filter_closure.replace(true) {
535 msg_loop.quit_when_idle();
536 }
537
538 if msg.hwnd.is_null() && msg.message == WM_USER {
539 unsafe {
540 MessageBoxA(
541 ptr::null_mut(),
542 ptr::null_mut(),
543 window_name.as_ptr() as _,
544 0,
545 );
546 }
547 }
548
549 running_filter_closure.set(false);
550 FilterResult::Forward
551 });
552 }
553}