Skip to main content

windows_erg/wait/
mod.rs

1//! Shared wait-object primitives used across modules.
2
3use std::borrow::Cow;
4use std::sync::Arc;
5use std::time::Duration;
6
7use windows::Win32::Foundation::{GetLastError, HANDLE, WAIT_FAILED, WAIT_OBJECT_0, WAIT_TIMEOUT};
8use windows::Win32::System::Threading::{
9    CreateEventW, ResetEvent, SetEvent, WaitForMultipleObjects, WaitForSingleObject,
10};
11use windows::core::PCWSTR;
12
13use crate::error::{InvalidParameterError, OtherError};
14use crate::utils::{OwnedHandle, to_utf16_nul};
15use crate::{Error, Result};
16
17const MAX_WAIT_OBJECTS: usize = 64;
18
19/// Wait object used to coordinate or interrupt long-running operations.
20#[derive(Debug, Clone)]
21pub struct Wait {
22    inner: Arc<OwnedHandle>,
23}
24
25impl Wait {
26    /// Create a new wait-handle-backed event object.
27    pub fn new(manual_reset: bool, initial_state: bool) -> Result<Self> {
28        Self::create_event(manual_reset, initial_state, PCWSTR::null(), "create")
29    }
30
31    /// Create a named wait event for inter-process synchronization.
32    pub fn named(name: &str, manual_reset: bool, initial_state: bool) -> Result<Self> {
33        let name_wide = to_utf16_nul(name);
34        Self::create_event(
35            manual_reset,
36            initial_state,
37            PCWSTR(name_wide.as_ptr()),
38            "create_named",
39        )
40    }
41
42    /// Create a manual-reset wait handle event.
43    pub fn manual_reset(initial_state: bool) -> Result<Self> {
44        Self::new(true, initial_state)
45    }
46
47    /// Create an auto-reset wait handle event.
48    pub fn auto_reset(initial_state: bool) -> Result<Self> {
49        Self::new(false, initial_state)
50    }
51
52    pub(crate) fn from_handle_borrowed(handle: HANDLE) -> Self {
53        Self {
54            inner: Arc::new(OwnedHandle::borrowed(handle)),
55        }
56    }
57
58    /// Return the underlying Win32 handle.
59    pub fn raw_handle(&self) -> HANDLE {
60        self.inner.raw()
61    }
62
63    /// Clone this wait handle with shared internal ownership.
64    ///
65    /// This does not duplicate the underlying OS handle.
66    pub fn try_clone(&self) -> Result<Self> {
67        Ok(self.clone())
68    }
69
70    /// Check if this wait handle is currently signaled without blocking.
71    pub fn is_signaled(&self) -> Result<bool> {
72        self.wait_timeout(Duration::ZERO)
73    }
74
75    /// Wait indefinitely until this handle is signaled.
76    pub fn wait(&self) -> Result<()> {
77        let wait_result = unsafe { WaitForSingleObject(self.inner.raw(), u32::MAX) };
78        if wait_result == WAIT_OBJECT_0 {
79            return Ok(());
80        }
81        Err(wait_error("wait"))
82    }
83
84    /// Wait until this handle is signaled or the timeout elapses.
85    ///
86    /// Returns `Ok(true)` if signaled, `Ok(false)` on timeout.
87    pub fn wait_timeout(&self, timeout: Duration) -> Result<bool> {
88        let wait_result =
89            unsafe { WaitForSingleObject(self.inner.raw(), duration_to_wait_ms(timeout)) };
90        if wait_result == WAIT_OBJECT_0 {
91            return Ok(true);
92        }
93        if wait_result == WAIT_TIMEOUT {
94            return Ok(false);
95        }
96        if wait_result == WAIT_FAILED {
97            return Err(wait_error("wait_timeout"));
98        }
99        Err(wait_error("wait_timeout"))
100    }
101
102    /// Wait until any handle in `handles` is signaled.
103    ///
104    /// Returns the index of the signaled handle.
105    pub fn wait_any(handles: &[&Self]) -> Result<usize> {
106        let raw_handles = collect_raw_handles(handles)?;
107        let wait_result = unsafe { WaitForMultipleObjects(&raw_handles, false, u32::MAX) };
108        decode_wait_any_result(wait_result, raw_handles.len(), "wait_any")?.ok_or_else(|| {
109            Error::Other(OtherError::new(Cow::Borrowed(
110                "wait handle operation 'wait_any' timed out unexpectedly",
111            )))
112        })
113    }
114
115    /// Wait until any handle in `handles` is signaled or timeout elapses.
116    ///
117    /// Returns `Ok(Some(index))` for a signaled handle, `Ok(None)` on timeout.
118    pub fn wait_any_timeout(handles: &[&Self], timeout: Duration) -> Result<Option<usize>> {
119        let raw_handles = collect_raw_handles(handles)?;
120        let wait_result =
121            unsafe { WaitForMultipleObjects(&raw_handles, false, duration_to_wait_ms(timeout)) };
122        decode_wait_any_result(wait_result, raw_handles.len(), "wait_any_timeout")
123    }
124
125    /// Wait until all handles in `handles` are signaled.
126    pub fn wait_all(handles: &[&Self]) -> Result<()> {
127        let raw_handles = collect_raw_handles(handles)?;
128        let wait_result = unsafe { WaitForMultipleObjects(&raw_handles, true, u32::MAX) };
129        if wait_result == WAIT_OBJECT_0 {
130            return Ok(());
131        }
132        Err(wait_error("wait_all"))
133    }
134
135    /// Wait until all handles in `handles` are signaled or timeout elapses.
136    ///
137    /// Returns `Ok(true)` when all are signaled, `Ok(false)` on timeout.
138    pub fn wait_all_timeout(handles: &[&Self], timeout: Duration) -> Result<bool> {
139        let raw_handles = collect_raw_handles(handles)?;
140        let wait_result =
141            unsafe { WaitForMultipleObjects(&raw_handles, true, duration_to_wait_ms(timeout)) };
142        if wait_result == WAIT_OBJECT_0 {
143            return Ok(true);
144        }
145        if wait_result == WAIT_TIMEOUT {
146            return Ok(false);
147        }
148        if wait_result == WAIT_FAILED {
149            return Err(wait_error("wait_all_timeout"));
150        }
151        Err(wait_error("wait_all_timeout"))
152    }
153
154    /// Signal the wait handle.
155    pub fn set(&self) -> Result<()> {
156        unsafe { SetEvent(self.inner.raw()) }.map_err(|_| wait_error("set"))?;
157        Ok(())
158    }
159
160    /// Reset the wait handle to unsignaled state.
161    pub fn reset(&self) -> Result<()> {
162        unsafe { ResetEvent(self.inner.raw()) }.map_err(|_| wait_error("reset"))?;
163        Ok(())
164    }
165
166    fn create_event(
167        manual_reset: bool,
168        initial_state: bool,
169        name: PCWSTR,
170        operation: &'static str,
171    ) -> Result<Self> {
172        let handle = unsafe { CreateEventW(None, manual_reset, initial_state, name) };
173
174        let handle = handle.map_err(|_| {
175            let code = unsafe { GetLastError().0 as i32 };
176            Error::Other(OtherError::new(Cow::Owned(format!(
177                "wait handle operation '{}' failed (error code: 0x{:08X})",
178                operation, code
179            ))))
180        })?;
181
182        Ok(Self {
183            inner: Arc::new(OwnedHandle::new(handle)),
184        })
185    }
186}
187
188fn collect_raw_handles(handles: &[&Wait]) -> Result<Vec<HANDLE>> {
189    if handles.is_empty() {
190        return Err(Error::InvalidParameter(InvalidParameterError::new(
191            "handles",
192            "at least one wait handle is required",
193        )));
194    }
195
196    if handles.len() > MAX_WAIT_OBJECTS {
197        return Err(Error::InvalidParameter(InvalidParameterError::new(
198            "handles",
199            "at most 64 wait handles are supported",
200        )));
201    }
202
203    Ok(handles.iter().map(|h| h.raw_handle()).collect())
204}
205
206fn decode_wait_any_result(
207    wait_result: windows::Win32::Foundation::WAIT_EVENT,
208    handle_count: usize,
209    operation: &'static str,
210) -> Result<Option<usize>> {
211    if wait_result == WAIT_TIMEOUT {
212        return Ok(None);
213    }
214    if wait_result == WAIT_FAILED {
215        return Err(wait_error(operation));
216    }
217
218    let result = wait_result.0;
219    let base = WAIT_OBJECT_0.0;
220    let end = base + handle_count as u32;
221    if result >= base && result < end {
222        return Ok(Some((result - base) as usize));
223    }
224
225    Err(wait_error(operation))
226}
227
228fn duration_to_wait_ms(timeout: Duration) -> u32 {
229    timeout.as_millis().min(u32::MAX as u128) as u32
230}
231
232fn wait_error(operation: &'static str) -> Error {
233    let code = unsafe { GetLastError().0 as i32 };
234    Error::Other(OtherError::new(Cow::Owned(format!(
235        "wait handle operation '{}' failed (error code: 0x{:08X})",
236        operation, code
237    ))))
238}
239
240#[cfg(test)]
241mod tests {
242    use super::Wait;
243    use std::thread;
244    use std::time::Duration;
245
246    #[test]
247    fn wait_timeout_reports_timeout_then_signal() {
248        let wait = Wait::manual_reset(false).expect("wait handle create");
249        let timed_out = wait
250            .wait_timeout(Duration::from_millis(10))
251            .expect("wait timeout should not fail");
252        assert!(!timed_out);
253
254        wait.set().expect("set should succeed");
255        let signaled = wait
256            .wait_timeout(Duration::from_millis(10))
257            .expect("wait timeout should not fail");
258        assert!(signaled);
259    }
260
261    #[test]
262    fn cloned_wait_handle_synchronizes_threads() {
263        let wait = Wait::manual_reset(false).expect("wait handle create");
264        let signaler = wait.try_clone().expect("clone should succeed");
265
266        let worker = thread::spawn(move || {
267            thread::sleep(Duration::from_millis(20));
268            signaler.set().expect("set should succeed");
269        });
270
271        wait.wait().expect("wait should succeed after signal");
272        worker.join().expect("worker should not panic");
273    }
274
275    #[test]
276    fn wait_any_reports_signaled_index() {
277        let wait_a = Wait::manual_reset(false).expect("wait handle create");
278        let wait_b = Wait::manual_reset(false).expect("wait handle create");
279        wait_b.set().expect("set should succeed");
280
281        let index = Wait::wait_any(&[&wait_a, &wait_b]).expect("wait_any should succeed");
282        assert_eq!(index, 1);
283    }
284}