1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
use std::{
    cell::OnceCell, fmt::Debug, hash::Hash, io, mem::MaybeUninit, num::NonZeroU32, ops::Range, ptr,
};

use winapi::{
    shared::{
        minwindef::{DWORD, FALSE},
        windef::{HWINEVENTHOOK, HWND},
    },
    um::{
        winnt::LONG,
        winuser::{GetWindowThreadProcessId, SetWinEventHook, UnhookWinEvent},
    },
};

use crate::{
    message_loop::run_dummy_message_loop, raw_event, RawWindowEvent, RawWindowHandle, WindowEvent,
};

thread_local! {
    static HOOK_EVENT_TX: OnceCell<(tokio::sync::mpsc::UnboundedSender<WindowEvent>, EventPredicate)> = OnceCell::new();
}

extern "system" fn win_event_hook_callback(
    hook: HWINEVENTHOOK,
    event: DWORD,
    h_wnd: HWND,
    id_object: LONG,
    id_child: LONG,
    event_thread: DWORD,
    event_time: DWORD,
) {
    HOOK_EVENT_TX.with(|event_tx| {
        if let Some((event_tx, predicate)) = event_tx.get() {
            let event = WindowEvent::from_raw(RawWindowEvent {
                hook_handle: hook,
                event_id: event,
                window_handle: h_wnd,
                object_id: id_object,
                child_id: id_child,
                thread_id: event_thread,
                timestamp: event_time,
            });
            if predicate(&event) {
                let _ = event_tx.send(event);
            }
        }
    });
}

#[derive(Debug)]
/// A hook with a message loop that listens for window events.
pub struct WindowEventHook {
    abort_tx: tokio::sync::oneshot::Sender<()>,
    event_thread: async_thread::JoinHandle<Result<(), std::io::Error>>,
}

impl WindowEventHook {
    /// Creates a new [`WindowEventHook`] that listens to the events matching the given filter.
    /// An [`WindowEvent`] is sent to the given `event_tx` whenever an event matching the filter is received.
    /// A dedicated event loop thread is created to be able to receive events.
    pub async fn hook(
        filter: EventFilter,
        event_tx: tokio::sync::mpsc::UnboundedSender<WindowEvent>,
    ) -> Result<Self, io::Error> {
        let (handle_tx, handle_rx) = tokio::sync::oneshot::channel();
        let (abort_tx, abort_rx) = tokio::sync::oneshot::channel();

        let event_thread = async_thread::Builder::new()
            .name("WindowEventHook EventLoop".to_string())
            .spawn(move || {
                let mut flags = HookFlags::OUT_OF_CONTEXT;
                if filter.skip_own_process {
                    flags |= HookFlags::SKIP_OWN_PROCESS;
                }
                if filter.skip_own_thread {
                    flags |= HookFlags::SKIP_OWN_THREAD;
                }

                let hook = unsafe {
                    SetWinEventHook(
                        filter.events.0 as _,
                        filter.events.1 as _,
                        ptr::null_mut(),
                        Some(win_event_hook_callback),
                        filter.process_id,
                        filter.thread_id,
                        flags.bits(),
                    )
                };

                let hook_result = if !hook.is_null() {
                    HOOK_EVENT_TX.with(|tx| {
                        tx.set((event_tx, filter.predicate.get()))
                            .map_err(|_| ())
                            .unwrap();
                    });
                    Ok(())
                } else {
                    Err(io::Error::last_os_error())
                };

                handle_tx.send(hook_result).unwrap();

                run_dummy_message_loop(abort_rx).unwrap();

                let unhook_result = unsafe { UnhookWinEvent(hook) };
                if unhook_result == FALSE {
                    Err(io::Error::last_os_error())
                } else {
                    Ok(())
                }
            })
            .unwrap();

        handle_rx.await.unwrap().map(|_| Self {
            abort_tx,
            event_thread,
        })
    }

    /// Unhook this hook and stop the event loop.
    pub async fn unhook(self) -> Result<(), io::Error> {
        self.abort_tx.send(()).unwrap();
        self.event_thread.join().await.unwrap()
    }
}

/// A filter for window events.
pub type EventPredicate = fn(&WindowEvent) -> bool;

#[derive(Clone, Copy)]
struct EventPredicateHolder(Option<EventPredicate>);

impl EventPredicateHolder {
    fn new() -> Self {
        Self(None)
    }
    fn set(&mut self, predicate: EventPredicate) {
        self.0 = Some(predicate);
    }
    fn get(&self) -> EventPredicate {
        self.0.unwrap_or(|_| true)
    }
}

impl Debug for EventPredicateHolder {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{:?}", self.0.as_ref().map(|_| "some predicate"))
    }
}

#[derive(Debug, Clone, Copy)]
/// A filter for window events to be received.
pub struct EventFilter {
    events: (i32, i32), // Range<i32> is not Copy
    skip_own_thread: bool,
    skip_own_process: bool,
    thread_id: u32,
    process_id: u32,
    predicate: EventPredicateHolder,
}

impl Default for EventFilter {
    fn default() -> Self {
        Self {
            events: (raw_event::MIN, raw_event::MAX),
            skip_own_thread: true,
            skip_own_process: false,
            thread_id: 0,
            process_id: 0,
            predicate: EventPredicateHolder::new(),
        }
    }
}

impl EventFilter {
    /// Set the event that should be received by the hook.
    #[must_use]
    pub fn event(self, event: i32) -> Self {
        self.events(Range {
            start: event,
            end: event,
        })
    }

    /// Set the range of events that should be received by the hook.
    #[must_use]
    pub fn events(mut self, events: Range<i32>) -> Self {
        self.events = (events.start, events.end);
        self
    }

    /// If true, the event will be skipped if it is generated by the current process.
    #[must_use]
    pub fn skip_own_process(mut self, value: bool) -> Self {
        self.skip_own_process = value;
        self
    }

    /// If true, the event will be skipped if it is generated by the event loop thread.
    #[must_use]
    pub fn skip_own_thread(mut self, value: bool) -> Self {
        self.skip_own_thread = value;
        self
    }

    /// Only include events from the process with the given id.
    #[must_use]
    pub fn process(mut self, id: NonZeroU32) -> Self {
        self.process_id = id.get();
        self
    }

    /// Only include events from the thread with the given id.
    #[must_use]
    pub fn thread(mut self, id: NonZeroU32) -> Self {
        self.thread_id = id.get();
        self
    }

    /// Only include events from the thread and process of the given window.
    #[must_use]
    pub fn window(self, window: RawWindowHandle) -> Self {
        let window_info = WindowThreadProcess::from_handle(window);
        self.thread(window_info.thread).process(window_info.process)
    }

    /// Include events from all processes.
    /// To receive events from the current process `skip_own_process` must be set to false.
    #[must_use]
    pub fn all_processes(mut self) -> Self {
        self.process_id = 0;
        self
    }

    /// Include events from all threads.
    /// To receive events from the event loop thread `skip_own_thread` must be set to false.
    #[must_use]
    pub fn all_threads(mut self) -> Self {
        self.thread_id = 0;
        self
    }

    /// Set the predicate that determines whether an event should be accepted.
    #[must_use]
    pub fn predicate(mut self, predicate: EventPredicate) -> EventFilter {
        self.predicate.set(predicate);
        self
    }
}

struct WindowThreadProcess {
    process: NonZeroU32,
    thread: NonZeroU32,
}

impl WindowThreadProcess {
    fn new(process: NonZeroU32, thread: NonZeroU32) -> Self {
        Self { process, thread }
    }

    fn from_handle(window: RawWindowHandle) -> Self {
        let mut process_id = MaybeUninit::uninit();
        let thread_id = unsafe { GetWindowThreadProcessId(window, process_id.as_mut_ptr()) };
        let process_id = unsafe { process_id.assume_init() };
        Self::new(
            NonZeroU32::new(process_id).unwrap(),
            NonZeroU32::new(thread_id).unwrap(),
        )
    }
}

bitflags::bitflags! {
    struct HookFlags: u32 {
        /// The callback function is not mapped into the address space of the process that generates the event.
        /// Because the hook function is called across process boundaries, the system must queue events.
        /// Although this method is asynchronous, events are guaranteed to be in sequential order.
        const OUT_OF_CONTEXT = 0x0000;

        /// Prevents this instance of the hook from receiving the events that are generated by the thread that is registering this hook.
        const SKIP_OWN_THREAD = 0x0001;

        /// Prevents this instance of the hook from receiving the events that are generated by threads in this process.
        /// This flag does not prevent threads from generating events.
        const SKIP_OWN_PROCESS = 0x0002;

        /// The DLL that contains the callback function is mapped into the address space of the process that generates the event.
        /// With this flag, the system sends event notifications to the callback function as they occur.
        /// The hook function must be in a DLL when this flag is specified.
        /// This flag has no effect when both the calling process and the generating process are not 32-bit or 64-bit processes, or when the generating process is a console application.
        const IN_CONTEXT = 0x0004;
    }
}

impl Default for HookFlags {
    fn default() -> Self {
        Self::OUT_OF_CONTEXT
    }
}