1use std::borrow::Cow;
4use std::sync::Arc;
5use std::time::Duration;
6
7use windows::Win32::Foundation::{GetLastError, HANDLE, WAIT_FAILED, WAIT_OBJECT_0, WAIT_TIMEOUT};
8use windows::Win32::System::Threading::{
9 CreateEventW, ResetEvent, SetEvent, WaitForMultipleObjects, WaitForSingleObject,
10};
11use windows::core::PCWSTR;
12
13use crate::error::{InvalidParameterError, OtherError};
14use crate::utils::{OwnedHandle, to_utf16_nul};
15use crate::{Error, Result};
16
17const MAX_WAIT_OBJECTS: usize = 64;
18
19#[derive(Debug, Clone)]
21pub struct Wait {
22 inner: Arc<OwnedHandle>,
23}
24
25impl Wait {
26 pub fn new(manual_reset: bool, initial_state: bool) -> Result<Self> {
28 Self::create_event(manual_reset, initial_state, PCWSTR::null(), "create")
29 }
30
31 pub fn named(name: &str, manual_reset: bool, initial_state: bool) -> Result<Self> {
33 let name_wide = to_utf16_nul(name);
34 Self::create_event(
35 manual_reset,
36 initial_state,
37 PCWSTR(name_wide.as_ptr()),
38 "create_named",
39 )
40 }
41
42 pub fn manual_reset(initial_state: bool) -> Result<Self> {
44 Self::new(true, initial_state)
45 }
46
47 pub fn auto_reset(initial_state: bool) -> Result<Self> {
49 Self::new(false, initial_state)
50 }
51
52 pub(crate) fn from_handle_borrowed(handle: HANDLE) -> Self {
53 Self {
54 inner: Arc::new(OwnedHandle::borrowed(handle)),
55 }
56 }
57
58 pub fn raw_handle(&self) -> HANDLE {
60 self.inner.raw()
61 }
62
63 pub fn try_clone(&self) -> Result<Self> {
67 Ok(self.clone())
68 }
69
70 pub fn is_signaled(&self) -> Result<bool> {
72 self.wait_timeout(Duration::ZERO)
73 }
74
75 pub fn wait(&self) -> Result<()> {
77 let wait_result = unsafe { WaitForSingleObject(self.inner.raw(), u32::MAX) };
78 if wait_result == WAIT_OBJECT_0 {
79 return Ok(());
80 }
81 Err(wait_error("wait"))
82 }
83
84 pub fn wait_timeout(&self, timeout: Duration) -> Result<bool> {
88 let wait_result =
89 unsafe { WaitForSingleObject(self.inner.raw(), duration_to_wait_ms(timeout)) };
90 if wait_result == WAIT_OBJECT_0 {
91 return Ok(true);
92 }
93 if wait_result == WAIT_TIMEOUT {
94 return Ok(false);
95 }
96 if wait_result == WAIT_FAILED {
97 return Err(wait_error("wait_timeout"));
98 }
99 Err(wait_error("wait_timeout"))
100 }
101
102 pub fn wait_any(handles: &[&Self]) -> Result<usize> {
106 let raw_handles = collect_raw_handles(handles)?;
107 let wait_result = unsafe { WaitForMultipleObjects(&raw_handles, false, u32::MAX) };
108 decode_wait_any_result(wait_result, raw_handles.len(), "wait_any")?.ok_or_else(|| {
109 Error::Other(OtherError::new(Cow::Borrowed(
110 "wait handle operation 'wait_any' timed out unexpectedly",
111 )))
112 })
113 }
114
115 pub fn wait_any_timeout(handles: &[&Self], timeout: Duration) -> Result<Option<usize>> {
119 let raw_handles = collect_raw_handles(handles)?;
120 let wait_result =
121 unsafe { WaitForMultipleObjects(&raw_handles, false, duration_to_wait_ms(timeout)) };
122 decode_wait_any_result(wait_result, raw_handles.len(), "wait_any_timeout")
123 }
124
125 pub fn wait_all(handles: &[&Self]) -> Result<()> {
127 let raw_handles = collect_raw_handles(handles)?;
128 let wait_result = unsafe { WaitForMultipleObjects(&raw_handles, true, u32::MAX) };
129 if wait_result == WAIT_OBJECT_0 {
130 return Ok(());
131 }
132 Err(wait_error("wait_all"))
133 }
134
135 pub fn wait_all_timeout(handles: &[&Self], timeout: Duration) -> Result<bool> {
139 let raw_handles = collect_raw_handles(handles)?;
140 let wait_result =
141 unsafe { WaitForMultipleObjects(&raw_handles, true, duration_to_wait_ms(timeout)) };
142 if wait_result == WAIT_OBJECT_0 {
143 return Ok(true);
144 }
145 if wait_result == WAIT_TIMEOUT {
146 return Ok(false);
147 }
148 if wait_result == WAIT_FAILED {
149 return Err(wait_error("wait_all_timeout"));
150 }
151 Err(wait_error("wait_all_timeout"))
152 }
153
154 pub fn set(&self) -> Result<()> {
156 unsafe { SetEvent(self.inner.raw()) }.map_err(|_| wait_error("set"))?;
157 Ok(())
158 }
159
160 pub fn reset(&self) -> Result<()> {
162 unsafe { ResetEvent(self.inner.raw()) }.map_err(|_| wait_error("reset"))?;
163 Ok(())
164 }
165
166 fn create_event(
167 manual_reset: bool,
168 initial_state: bool,
169 name: PCWSTR,
170 operation: &'static str,
171 ) -> Result<Self> {
172 let handle = unsafe { CreateEventW(None, manual_reset, initial_state, name) };
173
174 let handle = handle.map_err(|_| {
175 let code = unsafe { GetLastError().0 as i32 };
176 Error::Other(OtherError::new(Cow::Owned(format!(
177 "wait handle operation '{}' failed (error code: 0x{:08X})",
178 operation, code
179 ))))
180 })?;
181
182 Ok(Self {
183 inner: Arc::new(OwnedHandle::new(handle)),
184 })
185 }
186}
187
188fn collect_raw_handles(handles: &[&Wait]) -> Result<Vec<HANDLE>> {
189 if handles.is_empty() {
190 return Err(Error::InvalidParameter(InvalidParameterError::new(
191 "handles",
192 "at least one wait handle is required",
193 )));
194 }
195
196 if handles.len() > MAX_WAIT_OBJECTS {
197 return Err(Error::InvalidParameter(InvalidParameterError::new(
198 "handles",
199 "at most 64 wait handles are supported",
200 )));
201 }
202
203 Ok(handles.iter().map(|h| h.raw_handle()).collect())
204}
205
206fn decode_wait_any_result(
207 wait_result: windows::Win32::Foundation::WAIT_EVENT,
208 handle_count: usize,
209 operation: &'static str,
210) -> Result<Option<usize>> {
211 if wait_result == WAIT_TIMEOUT {
212 return Ok(None);
213 }
214 if wait_result == WAIT_FAILED {
215 return Err(wait_error(operation));
216 }
217
218 let result = wait_result.0;
219 let base = WAIT_OBJECT_0.0;
220 let end = base + handle_count as u32;
221 if result >= base && result < end {
222 return Ok(Some((result - base) as usize));
223 }
224
225 Err(wait_error(operation))
226}
227
228fn duration_to_wait_ms(timeout: Duration) -> u32 {
229 timeout.as_millis().min(u32::MAX as u128) as u32
230}
231
232fn wait_error(operation: &'static str) -> Error {
233 let code = unsafe { GetLastError().0 as i32 };
234 Error::Other(OtherError::new(Cow::Owned(format!(
235 "wait handle operation '{}' failed (error code: 0x{:08X})",
236 operation, code
237 ))))
238}
239
240#[cfg(test)]
241mod tests {
242 use super::Wait;
243 use std::thread;
244 use std::time::Duration;
245
246 #[test]
247 fn wait_timeout_reports_timeout_then_signal() {
248 let wait = Wait::manual_reset(false).expect("wait handle create");
249 let timed_out = wait
250 .wait_timeout(Duration::from_millis(10))
251 .expect("wait timeout should not fail");
252 assert!(!timed_out);
253
254 wait.set().expect("set should succeed");
255 let signaled = wait
256 .wait_timeout(Duration::from_millis(10))
257 .expect("wait timeout should not fail");
258 assert!(signaled);
259 }
260
261 #[test]
262 fn cloned_wait_handle_synchronizes_threads() {
263 let wait = Wait::manual_reset(false).expect("wait handle create");
264 let signaler = wait.try_clone().expect("clone should succeed");
265
266 let worker = thread::spawn(move || {
267 thread::sleep(Duration::from_millis(20));
268 signaler.set().expect("set should succeed");
269 });
270
271 wait.wait().expect("wait should succeed after signal");
272 worker.join().expect("worker should not panic");
273 }
274
275 #[test]
276 fn wait_any_reports_signaled_index() {
277 let wait_a = Wait::manual_reset(false).expect("wait handle create");
278 let wait_b = Wait::manual_reset(false).expect("wait handle create");
279 wait_b.set().expect("set should succeed");
280
281 let index = Wait::wait_any(&[&wait_a, &wait_b]).expect("wait_any should succeed");
282 assert_eq!(index, 1);
283 }
284}