1#![doc = include_str!("../README.md")]
2
3pub mod util;
4
5use std::{
6 any::Any,
7 cell::Cell,
8 future::Future,
9 mem::{ManuallyDrop, MaybeUninit},
10 panic::{self, AssertUnwindSafe},
11 pin::{pin, Pin},
12 ptr::{self, NonNull},
13 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
14};
15
16use async_task::Runnable;
17use util::{Window, WindowType};
18use windows_sys::Win32::UI::WindowsAndMessaging::*;
19
20use crate::util::MsgFilterHook;
21
22const MSG_ID_WAKE: u32 = WM_USER;
23
24thread_local! {
25 static PANIC_PAYLOAD: Cell<Option<Box<dyn Any + Send + 'static>>> = const { Cell::new(None) };
26 static EXECUTOR_WINDOW: Window<()> = Window::new(WindowType::MessageOnly, (), |_, msg| {
27 if msg.msg == MSG_ID_WAKE {
28 let runnable = unsafe {
29 let runnable_ptr = NonNull::new_unchecked(msg.lparam as *mut _);
30 Runnable::<()>::from_raw(runnable_ptr)
31 };
32 if let Err(panic_payload) = panic::catch_unwind(|| runnable.run()) {
33 PANIC_PAYLOAD.set(Some(panic_payload));
34 }
35 Some(0)
36 } else {
37 None
38 }
39 })
40 .unwrap();
41}
42
43pub struct JoinHandle<T> {
48 task: ManuallyDrop<async_task::Task<T>>,
49}
50
51impl<T> Drop for JoinHandle<T> {
53 fn drop(&mut self) {
54 let task = unsafe { ManuallyDrop::take(&mut self.task) };
55 task.detach();
56 }
57}
58
59impl<T> Future for JoinHandle<T> {
60 type Output = T;
61
62 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63 pin!(&mut *self.task).poll(cx)
64 }
65}
66
67unsafe fn spawn_unchecked_lifetime<T>(future: impl Future<Output = T>) -> JoinHandle<T> {
68 let hwnd = EXECUTOR_WINDOW.with(|w| w.hwnd());
69
70 let (runnable, task) = unsafe {
74 async_task::spawn_unchecked(future, move |runnable: Runnable| {
75 PostMessageA(hwnd, MSG_ID_WAKE, 0, runnable.into_raw().as_ptr() as _);
76 })
77 };
78
79 runnable.schedule();
81
82 JoinHandle {
83 task: ManuallyDrop::new(task),
84 }
85}
86
87pub fn spawn_local<T>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
93 unsafe { spawn_unchecked_lifetime(future) }
95}
96
97pub fn block_on<'a, T: 'a>(future: impl Future<Output = T> + 'a) -> T {
110 let msg_loop = &MessageLoop::new();
111
112 let task = unsafe {
116 spawn_unchecked_lifetime(async move {
117 let result = future.await;
118 msg_loop.quit();
119 result
120 })
121 };
122
123 msg_loop.run_loop(|_| FilterResult::Forward);
124
125 poll_ready(task).expect("received unexpected quit message")
126}
127
128fn poll_ready<T>(future: impl Future<Output = T>) -> Result<T, ()> {
129 const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
131 |_| RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE),
132 |_| (),
133 |_| (),
134 |_| (),
135 );
136 let noop_waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
137 let future = pin!(future);
138 if let Poll::Ready(result) = future.poll(&mut Context::from_waker(&noop_waker)) {
139 Ok(result)
140 } else {
141 Err(())
142 }
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum FilterResult {
148 Forward,
150
151 Drop,
153}
154
155pub struct MessageLoop {
161 quit: Cell<bool>,
162}
163
164impl MessageLoop {
165 fn new() -> Self {
166 Self {
167 quit: Cell::new(false),
168 }
169 }
170
171 fn run_loop(&self, filter: impl Fn(&MSG) -> FilterResult) {
172 while !self.quit.get() {
173 unsafe {
174 let mut msg = MaybeUninit::uninit();
175 if GetMessageA(msg.as_mut_ptr(), ptr::null_mut(), 0, 0) == 0 {
176 return;
177 }
178 let msg = msg.assume_init();
179
180 if filter(&msg) == FilterResult::Forward {
181 TranslateMessage(&msg);
182 DispatchMessageA(&msg);
183 }
184 if let Some(panic_payload) = PANIC_PAYLOAD.take() {
185 panic::resume_unwind(panic_payload)
186 }
187 }
188 }
189 }
190
191 pub fn run(filter: impl Fn(&MessageLoop, &MSG) -> FilterResult) {
218 let msg_loop = MessageLoop::new();
219
220 let _hook = unsafe {
226 MsgFilterHook::register(|msg| {
227 panic::catch_unwind(AssertUnwindSafe(|| {
228 let filter_result = filter(&msg_loop, msg);
229 if msg_loop.quit.get() {
233 PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
234 }
235 filter_result == FilterResult::Drop
236 }))
237 .unwrap_or_else(|payload| {
238 PANIC_PAYLOAD.with(|panic_payload| {
239 panic_payload.set(Some(payload));
240 });
241 PostMessageA(msg.hwnd, WM_QUIT, 0, 0);
243 false
244 })
245 })
246 };
247 msg_loop.run_loop(|msg| filter(&msg_loop, msg));
248 }
249
250 pub fn quit(&self) {
252 self.quit.set(true);
253 }
254
255 pub fn quit_when_idle(&self) {
257 unsafe { PostQuitMessage(0) };
258 }
259}
260
261#[cfg(test)]
262mod test {
263 use std::{ffi::CStr, future::poll_fn};
264
265 use windows_sys::Win32::Foundation::HWND;
266
267 use super::*;
268
269 fn post_thread_message(msg: u32) {
270 unsafe { PostMessageA(ptr::null_mut(), msg, 0, 0) };
271 }
272
273 #[test]
274 #[should_panic]
275 fn panic_in_dispatcher() {
276 post_thread_message(WM_USER);
277 MessageLoop::run(|_, _| panic!());
278 }
279
280 #[test]
281 fn message_loop_quit() {
282 for i in 0..10 {
283 post_thread_message(WM_USER + i);
284 }
285 MessageLoop::run(|msg_loop, msg| {
286 assert_eq!(msg.message, WM_USER);
289 msg_loop.quit();
290 FilterResult::Drop
291 });
292 }
293
294 #[test]
295 fn message_loop_quit_when_idle() {
296 for i in 0..10 {
297 post_thread_message(WM_USER + i);
298 }
299 let expected_msg = Cell::new(0);
300 MessageLoop::run(|msg_loop, msg| {
301 assert_eq!(msg.message, WM_USER + expected_msg.get());
302 expected_msg.set(expected_msg.get() + 1);
303 msg_loop.quit_when_idle();
304 FilterResult::Drop
305 });
306 assert_eq!(expected_msg.get(), 10);
307 }
308
309 #[test]
310 fn nested_block_on() {
311 let count: Cell<usize> = Cell::new(0);
312
313 block_on(async {
314 assert_eq!(count.get(), 0);
315 count.set(count.get() + 1);
316
317 block_on(async {
318 assert_eq!(count.get(), 1);
319 count.set(count.get() + 1);
320 });
321
322 assert_eq!(count.get(), 2);
323 count.set(count.get() + 1);
324 });
325
326 assert_eq!(count.get(), 3);
327 }
328
329 #[test]
330 #[should_panic]
331 fn nested_message_loop() {
332 post_thread_message(WM_USER);
333 MessageLoop::run(|_, _| {
334 MessageLoop::run(|_, _| FilterResult::Drop);
335 FilterResult::Drop
336 });
337 }
338
339 async fn yield_now() {
340 let mut yielded = false;
341 poll_fn(|cx| {
342 if yielded {
343 Poll::Ready(())
344 } else {
345 yielded = true;
346 cx.waker().wake_by_ref();
347 Poll::Pending
348 }
349 })
350 .await;
351 }
352
353 #[test]
354 fn nested_message_loop_block_on() {
355 let inner_executed = Cell::new(false);
356
357 post_thread_message(WM_USER);
358 MessageLoop::run(|msg_loop, _| {
359 block_on(async {
360 inner_executed.set(true);
361 });
362 msg_loop.quit();
363 FilterResult::Forward
364 });
365
366 assert!(inner_executed.get());
367 }
368
369 #[test]
370 fn nested_message_loop_block_on_quit() {
371 post_thread_message(WM_USER);
372 MessageLoop::run(|msg_loop, _| {
373 block_on(async {
374 msg_loop.quit();
375 });
376 FilterResult::Forward
377 });
378 }
379
380 fn window_by_name(name: &CStr) -> HWND {
381 unsafe { FindWindowA(ptr::null_mut(), name.as_ptr() as _) }
382 }
383
384 #[test]
385 fn running_spawned_with_modal_dialog() {
386 let window_name = c"running_spawned_with_modal_dialog";
389
390 let task = spawn_local(async {
391 while window_by_name(window_name).is_null() {
393 yield_now().await;
394 }
395
396 for _ in 0..10 {
398 yield_now().await;
399 }
400
401 unsafe {
403 SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0);
404 }
405 });
406
407 block_on(async {
408 unsafe {
409 MessageBoxA(
410 ptr::null_mut(),
411 ptr::null_mut(),
412 window_name.as_ptr() as _,
413 0,
414 );
415 }
416 task.await;
417 });
418 }
419
420 #[test]
424 #[should_panic]
425 fn reenter_filter_closure_panic() {
426 let window_name = c"reenter_filter_closure";
429
430 post_thread_message(WM_USER);
431
432 let running_filter_closure = Cell::new(false);
433 MessageLoop::run(|_, msg| {
434 assert!(
435 !running_filter_closure.replace(true),
436 "Filter closure reentered"
437 );
438
439 if msg.hwnd.is_null() && msg.message == WM_USER {
440 unsafe {
441 MessageBoxA(
442 ptr::null_mut(),
443 ptr::null_mut(),
444 window_name.as_ptr() as _,
445 0,
446 );
447 }
448 }
449
450 running_filter_closure.set(false);
451 FilterResult::Forward
452 });
453 }
454
455 #[test]
456 fn reenter_filter_closure_quit() {
457 let window_name = c"reenter_filter_closure";
460
461 post_thread_message(WM_USER);
462
463 let running_filter_closure = Cell::new(false);
464 MessageLoop::run(|msg_loop, msg| {
465 if running_filter_closure.replace(true) {
466 msg_loop.quit();
467 }
468
469 if msg.hwnd.is_null() && msg.message == WM_USER {
470 unsafe {
471 MessageBoxA(
472 ptr::null_mut(),
473 ptr::null_mut(),
474 window_name.as_ptr() as _,
475 0,
476 );
477 }
478 }
479
480 running_filter_closure.set(false);
481 FilterResult::Forward
482 });
483 }
484
485 #[test]
486 fn message_loop_with_modal_dialog() {
487 let window_name = c"message_loop_with_modal_dialog";
490
491 spawn_local(async {
492 unsafe {
493 MessageBoxA(
494 ptr::null_mut(),
495 ptr::null_mut(),
496 window_name.as_ptr() as _,
497 0,
498 );
499 }
500 });
501
502 spawn_local(async {
503 assert!(!window_by_name(window_name).is_null());
505
506 for i in 0..10 {
507 post_thread_message(WM_USER + i);
508 yield_now().await;
509 }
510
511 unsafe { SendMessageA(window_by_name(window_name), WM_CLOSE, 0, 0) };
513 });
514
515 let expected_msg = Cell::new(0);
516 MessageLoop::run(|msg_loop, msg| {
517 if msg.hwnd.is_null() && msg.message >= WM_USER {
518 assert_eq!(msg.message, WM_USER + expected_msg.get());
519 expected_msg.set(expected_msg.get() + 1);
520 msg_loop.quit_when_idle();
521 FilterResult::Drop
522 } else {
523 FilterResult::Forward
524 }
525 });
526 assert_eq!(expected_msg.get(), 10);
527 }
528
529 #[test]
530 fn reenter_filter_closure_quit_when_idle() {
531 let window_name = c"reenter_filter_closure";
534
535 post_thread_message(WM_USER);
536
537 let running_filter_closure = Cell::new(false);
538 MessageLoop::run(|msg_loop, msg| {
539 if running_filter_closure.replace(true) {
540 msg_loop.quit_when_idle();
541 }
542
543 if msg.hwnd.is_null() && msg.message == WM_USER {
544 unsafe {
545 MessageBoxA(
546 ptr::null_mut(),
547 ptr::null_mut(),
548 window_name.as_ptr() as _,
549 0,
550 );
551 }
552 }
553
554 running_filter_closure.set(false);
555 FilterResult::Forward
556 });
557 }
558}