windows_future/
async_spawn.rs

1use super::*;
2use std::ffi::c_void;
3use std::sync::Mutex;
4
5struct State<T: Async> {
6    result: Option<Result<T::Output>>,
7    completed: Option<T::CompletedHandler>,
8    completed_assigned: bool,
9}
10
11impl<T: Async> State<T> {
12    fn status(&self) -> AsyncStatus {
13        match &self.result {
14            None => AsyncStatus::Started,
15            Some(Ok(_)) => AsyncStatus::Completed,
16            Some(Err(_)) => AsyncStatus::Error,
17        }
18    }
19
20    fn error_code(&self) -> HRESULT {
21        match &self.result {
22            Some(Err(error)) => error.code(),
23            _ => HRESULT(0),
24        }
25    }
26
27    fn get_results(&self) -> Result<T::Output> {
28        match &self.result {
29            Some(result) => result.clone(),
30            None => Err(Error::from_hresult(HRESULT(0x8000000Eu32 as i32))), // E_ILLEGAL_METHOD_CALL
31        }
32    }
33}
34
35struct SyncState<T: Async>(Mutex<State<T>>);
36
37impl<T: Async> SyncState<T> {
38    fn new() -> Self {
39        Self(Mutex::new(State {
40            result: None,
41            completed: None,
42            completed_assigned: false,
43        }))
44    }
45
46    fn status(&self) -> AsyncStatus {
47        self.0.lock().unwrap().status()
48    }
49
50    fn error_code(&self) -> HRESULT {
51        self.0.lock().unwrap().error_code()
52    }
53
54    fn get_results(&self) -> Result<T::Output> {
55        self.0.lock().unwrap().get_results()
56    }
57
58    fn set_completed(&self, sender: &T, handler: Ref<'_, T::CompletedHandler>) -> Result<()> {
59        let mut guard = self.0.lock().unwrap();
60
61        if guard.completed_assigned {
62            Err(Error::from_hresult(HRESULT(0x80000018u32 as i32))) // E_ILLEGAL_DELEGATE_ASSIGNMENT
63        } else {
64            guard.completed_assigned = true;
65            let status = guard.status();
66            let handler = handler.ok()?;
67
68            if status == AsyncStatus::Started {
69                guard.completed = Some(handler.clone());
70            } else {
71                drop(guard);
72                sender.invoke_completed(handler, status);
73            }
74
75            Ok(())
76        }
77    }
78
79    fn spawn<F>(&self, sender: &T, f: F)
80    where
81        F: FnOnce() -> Result<T::Output> + Send + 'static,
82    {
83        let result = f();
84        let mut guard = self.0.lock().unwrap();
85        debug_assert!(guard.result.is_none());
86        guard.result = Some(result);
87        let status = guard.status();
88        let completed = guard.completed.take();
89
90        drop(guard);
91
92        if let Some(completed) = completed {
93            sender.invoke_completed(&completed, status);
94        }
95    }
96}
97
98unsafe impl<T: Async> Send for SyncState<T> {}
99
100#[implement(IAsyncAction, IAsyncInfo)]
101struct Action(SyncState<IAsyncAction>);
102
103#[implement(IAsyncOperation<T>, IAsyncInfo)]
104struct Operation<T>(SyncState<IAsyncOperation<T>>)
105where
106    T: RuntimeType + 'static;
107
108#[implement(IAsyncActionWithProgress<P>, IAsyncInfo)]
109struct ActionWithProgress<P>(SyncState<IAsyncActionWithProgress<P>>)
110where
111    P: RuntimeType + 'static;
112
113#[implement(IAsyncOperationWithProgress<T, P>, IAsyncInfo)]
114struct OperationWithProgress<T, P>(SyncState<IAsyncOperationWithProgress<T, P>>)
115where
116    T: RuntimeType + 'static,
117    P: RuntimeType + 'static;
118
119impl IAsyncInfo_Impl for Action_Impl {
120    fn Id(&self) -> Result<u32> {
121        Ok(1)
122    }
123    fn Status(&self) -> Result<AsyncStatus> {
124        Ok(self.0.status())
125    }
126    fn ErrorCode(&self) -> Result<HRESULT> {
127        Ok(self.0.error_code())
128    }
129    fn Cancel(&self) -> Result<()> {
130        Ok(())
131    }
132    fn Close(&self) -> Result<()> {
133        Ok(())
134    }
135}
136
137impl<T: RuntimeType> IAsyncInfo_Impl for Operation_Impl<T> {
138    fn Id(&self) -> Result<u32> {
139        Ok(1)
140    }
141    fn Status(&self) -> Result<AsyncStatus> {
142        Ok(self.0.status())
143    }
144    fn ErrorCode(&self) -> Result<HRESULT> {
145        Ok(self.0.error_code())
146    }
147    fn Cancel(&self) -> Result<()> {
148        Ok(())
149    }
150    fn Close(&self) -> Result<()> {
151        Ok(())
152    }
153}
154
155impl<P: RuntimeType> IAsyncInfo_Impl for ActionWithProgress_Impl<P> {
156    fn Id(&self) -> Result<u32> {
157        Ok(1)
158    }
159    fn Status(&self) -> Result<AsyncStatus> {
160        Ok(self.0.status())
161    }
162    fn ErrorCode(&self) -> Result<HRESULT> {
163        Ok(self.0.error_code())
164    }
165    fn Cancel(&self) -> Result<()> {
166        Ok(())
167    }
168    fn Close(&self) -> Result<()> {
169        Ok(())
170    }
171}
172
173impl<T: RuntimeType, P: RuntimeType> IAsyncInfo_Impl for OperationWithProgress_Impl<T, P> {
174    fn Id(&self) -> Result<u32> {
175        Ok(1)
176    }
177    fn Status(&self) -> Result<AsyncStatus> {
178        Ok(self.0.status())
179    }
180    fn ErrorCode(&self) -> Result<HRESULT> {
181        Ok(self.0.error_code())
182    }
183    fn Cancel(&self) -> Result<()> {
184        Ok(())
185    }
186    fn Close(&self) -> Result<()> {
187        Ok(())
188    }
189}
190
191impl IAsyncAction_Impl for Action_Impl {
192    fn SetCompleted(&self, handler: Ref<'_, AsyncActionCompletedHandler>) -> Result<()> {
193        self.0.set_completed(&self.as_interface(), handler)
194    }
195    fn Completed(&self) -> Result<AsyncActionCompletedHandler> {
196        Err(Error::empty())
197    }
198    fn GetResults(&self) -> Result<()> {
199        self.0.get_results()
200    }
201}
202
203impl<T: RuntimeType> IAsyncOperation_Impl<T> for Operation_Impl<T> {
204    fn SetCompleted(&self, handler: Ref<'_, AsyncOperationCompletedHandler<T>>) -> Result<()> {
205        self.0.set_completed(&self.as_interface(), handler)
206    }
207    fn Completed(&self) -> Result<AsyncOperationCompletedHandler<T>> {
208        Err(Error::empty())
209    }
210    fn GetResults(&self) -> Result<T> {
211        self.0.get_results()
212    }
213}
214
215impl<P: RuntimeType> IAsyncActionWithProgress_Impl<P> for ActionWithProgress_Impl<P> {
216    fn SetCompleted(
217        &self,
218        handler: Ref<'_, AsyncActionWithProgressCompletedHandler<P>>,
219    ) -> Result<()> {
220        self.0.set_completed(&self.as_interface(), handler)
221    }
222    fn Completed(&self) -> Result<AsyncActionWithProgressCompletedHandler<P>> {
223        Err(Error::empty())
224    }
225    fn GetResults(&self) -> Result<()> {
226        self.0.get_results()
227    }
228    fn SetProgress(&self, _: Ref<'_, AsyncActionProgressHandler<P>>) -> Result<()> {
229        Ok(())
230    }
231    fn Progress(&self) -> Result<AsyncActionProgressHandler<P>> {
232        Err(Error::empty())
233    }
234}
235
236impl<T: RuntimeType, P: RuntimeType> IAsyncOperationWithProgress_Impl<T, P>
237    for OperationWithProgress_Impl<T, P>
238{
239    fn SetCompleted(
240        &self,
241        handler: Ref<'_, AsyncOperationWithProgressCompletedHandler<T, P>>,
242    ) -> Result<()> {
243        self.0.set_completed(&self.as_interface(), handler)
244    }
245    fn Completed(&self) -> Result<AsyncOperationWithProgressCompletedHandler<T, P>> {
246        Err(Error::empty())
247    }
248    fn GetResults(&self) -> Result<T> {
249        self.0.get_results()
250    }
251    fn SetProgress(&self, _: Ref<'_, AsyncOperationProgressHandler<T, P>>) -> Result<()> {
252        Ok(())
253    }
254    fn Progress(&self) -> Result<AsyncOperationProgressHandler<T, P>> {
255        Err(Error::empty())
256    }
257}
258
259impl IAsyncAction {
260    /// Creates an `IAsyncAction` that waits for the closure to execute on the Windows thread pool.
261    pub fn spawn<F>(f: F) -> Self
262    where
263        F: FnOnce() -> Result<()> + Send + 'static,
264    {
265        let object = ComObject::new(Action(SyncState::new()));
266        let interface = object.to_interface();
267
268        spawn(move || {
269            object.0.spawn(&object.as_interface(), f);
270        });
271
272        interface
273    }
274}
275
276impl<T: RuntimeType> IAsyncOperation<T> {
277    /// Creates an `IAsyncOperation<T>` that waits for the closure to execute on the Windows thread pool.
278    pub fn spawn<F>(f: F) -> Self
279    where
280        F: FnOnce() -> Result<T> + Send + 'static,
281    {
282        let object = ComObject::new(Operation(SyncState::new()));
283        let interface = object.to_interface();
284
285        spawn(move || {
286            object.0.spawn(&object.as_interface(), f);
287        });
288
289        interface
290    }
291}
292
293impl<P: RuntimeType> IAsyncActionWithProgress<P> {
294    /// Creates an `IAsyncActionWithProgress<P>` that waits for the closure to execute on the Windows thread pool.
295    pub fn spawn<F>(f: F) -> Self
296    where
297        F: FnOnce() -> Result<()> + Send + 'static,
298    {
299        let object = ComObject::new(ActionWithProgress(SyncState::new()));
300        let interface = object.to_interface();
301
302        spawn(move || {
303            object.0.spawn(&object.as_interface(), f);
304        });
305
306        interface
307    }
308}
309
310impl<T: RuntimeType, P: RuntimeType> IAsyncOperationWithProgress<T, P> {
311    /// Creates an `IAsyncOperationWithProgress<T, P>` that waits for the closure to execute on the Windows thread pool.
312    pub fn spawn<F>(f: F) -> Self
313    where
314        F: FnOnce() -> Result<T> + Send + 'static,
315    {
316        let object = ComObject::new(OperationWithProgress(SyncState::new()));
317        let interface = object.to_interface();
318
319        spawn(move || {
320            object.0.spawn(&object.as_interface(), f);
321        });
322
323        interface
324    }
325}
326
327fn spawn<F: FnOnce() + Send + 'static>(f: F) {
328    type PTP_SIMPLE_CALLBACK =
329        unsafe extern "system" fn(instance: *const c_void, context: *const c_void);
330    windows_link::link!("kernel32.dll" "system" fn TrySubmitThreadpoolCallback(callback: PTP_SIMPLE_CALLBACK, context: *const c_void, environment: *const c_void) -> i32);
331
332    unsafe extern "system" fn callback<F: FnOnce() + Send + 'static>(
333        _: *const c_void,
334        callback: *const c_void,
335    ) {
336        unsafe {
337            Box::from_raw(callback as *mut F)();
338        }
339    }
340
341    unsafe {
342        if TrySubmitThreadpoolCallback(
343            callback::<F>,
344            Box::into_raw(Box::new(f)) as _,
345            core::ptr::null(),
346        ) == 0
347        {
348            panic!("allocation failed");
349        }
350    }
351}