windows_future/
async_spawn.rs1use super::*;
2use std::ffi::c_void;
3use std::sync::Mutex;
4
5struct State<T: Async> {
6 result: Option<Result<T::Output>>,
7 completed: Option<T::CompletedHandler>,
8 completed_assigned: bool,
9}
10
11impl<T: Async> State<T> {
12 fn status(&self) -> AsyncStatus {
13 match &self.result {
14 None => AsyncStatus::Started,
15 Some(Ok(_)) => AsyncStatus::Completed,
16 Some(Err(_)) => AsyncStatus::Error,
17 }
18 }
19
20 fn error_code(&self) -> HRESULT {
21 match &self.result {
22 Some(Err(error)) => error.code(),
23 _ => HRESULT(0),
24 }
25 }
26
27 fn get_results(&self) -> Result<T::Output> {
28 match &self.result {
29 Some(result) => result.clone(),
30 None => Err(Error::from_hresult(HRESULT(0x8000000Eu32 as i32))), }
32 }
33}
34
35struct SyncState<T: Async>(Mutex<State<T>>);
36
37impl<T: Async> SyncState<T> {
38 fn new() -> Self {
39 Self(Mutex::new(State {
40 result: None,
41 completed: None,
42 completed_assigned: false,
43 }))
44 }
45
46 fn status(&self) -> AsyncStatus {
47 self.0.lock().unwrap().status()
48 }
49
50 fn error_code(&self) -> HRESULT {
51 self.0.lock().unwrap().error_code()
52 }
53
54 fn get_results(&self) -> Result<T::Output> {
55 self.0.lock().unwrap().get_results()
56 }
57
58 fn set_completed(&self, sender: &T, handler: Ref<'_, T::CompletedHandler>) -> Result<()> {
59 let mut guard = self.0.lock().unwrap();
60
61 if guard.completed_assigned {
62 Err(Error::from_hresult(HRESULT(0x80000018u32 as i32))) } else {
64 guard.completed_assigned = true;
65 let status = guard.status();
66 let handler = handler.ok()?;
67
68 if status == AsyncStatus::Started {
69 guard.completed = Some(handler.clone());
70 } else {
71 drop(guard);
72 sender.invoke_completed(handler, status);
73 }
74
75 Ok(())
76 }
77 }
78
79 fn spawn<F>(&self, sender: &T, f: F)
80 where
81 F: FnOnce() -> Result<T::Output> + Send + 'static,
82 {
83 let result = f();
84 let mut guard = self.0.lock().unwrap();
85 debug_assert!(guard.result.is_none());
86 guard.result = Some(result);
87 let status = guard.status();
88 let completed = guard.completed.take();
89
90 drop(guard);
91
92 if let Some(completed) = completed {
93 sender.invoke_completed(&completed, status);
94 }
95 }
96}
97
98unsafe impl<T: Async> Send for SyncState<T> {}
99
100#[implement(IAsyncAction, IAsyncInfo)]
101struct Action(SyncState<IAsyncAction>);
102
103#[implement(IAsyncOperation<T>, IAsyncInfo)]
104struct Operation<T>(SyncState<IAsyncOperation<T>>)
105where
106 T: RuntimeType + 'static;
107
108#[implement(IAsyncActionWithProgress<P>, IAsyncInfo)]
109struct ActionWithProgress<P>(SyncState<IAsyncActionWithProgress<P>>)
110where
111 P: RuntimeType + 'static;
112
113#[implement(IAsyncOperationWithProgress<T, P>, IAsyncInfo)]
114struct OperationWithProgress<T, P>(SyncState<IAsyncOperationWithProgress<T, P>>)
115where
116 T: RuntimeType + 'static,
117 P: RuntimeType + 'static;
118
119impl IAsyncInfo_Impl for Action_Impl {
120 fn Id(&self) -> Result<u32> {
121 Ok(1)
122 }
123 fn Status(&self) -> Result<AsyncStatus> {
124 Ok(self.0.status())
125 }
126 fn ErrorCode(&self) -> Result<HRESULT> {
127 Ok(self.0.error_code())
128 }
129 fn Cancel(&self) -> Result<()> {
130 Ok(())
131 }
132 fn Close(&self) -> Result<()> {
133 Ok(())
134 }
135}
136
137impl<T: RuntimeType> IAsyncInfo_Impl for Operation_Impl<T> {
138 fn Id(&self) -> Result<u32> {
139 Ok(1)
140 }
141 fn Status(&self) -> Result<AsyncStatus> {
142 Ok(self.0.status())
143 }
144 fn ErrorCode(&self) -> Result<HRESULT> {
145 Ok(self.0.error_code())
146 }
147 fn Cancel(&self) -> Result<()> {
148 Ok(())
149 }
150 fn Close(&self) -> Result<()> {
151 Ok(())
152 }
153}
154
155impl<P: RuntimeType> IAsyncInfo_Impl for ActionWithProgress_Impl<P> {
156 fn Id(&self) -> Result<u32> {
157 Ok(1)
158 }
159 fn Status(&self) -> Result<AsyncStatus> {
160 Ok(self.0.status())
161 }
162 fn ErrorCode(&self) -> Result<HRESULT> {
163 Ok(self.0.error_code())
164 }
165 fn Cancel(&self) -> Result<()> {
166 Ok(())
167 }
168 fn Close(&self) -> Result<()> {
169 Ok(())
170 }
171}
172
173impl<T: RuntimeType, P: RuntimeType> IAsyncInfo_Impl for OperationWithProgress_Impl<T, P> {
174 fn Id(&self) -> Result<u32> {
175 Ok(1)
176 }
177 fn Status(&self) -> Result<AsyncStatus> {
178 Ok(self.0.status())
179 }
180 fn ErrorCode(&self) -> Result<HRESULT> {
181 Ok(self.0.error_code())
182 }
183 fn Cancel(&self) -> Result<()> {
184 Ok(())
185 }
186 fn Close(&self) -> Result<()> {
187 Ok(())
188 }
189}
190
191impl IAsyncAction_Impl for Action_Impl {
192 fn SetCompleted(&self, handler: Ref<'_, AsyncActionCompletedHandler>) -> Result<()> {
193 self.0.set_completed(&self.as_interface(), handler)
194 }
195 fn Completed(&self) -> Result<AsyncActionCompletedHandler> {
196 Err(Error::empty())
197 }
198 fn GetResults(&self) -> Result<()> {
199 self.0.get_results()
200 }
201}
202
203impl<T: RuntimeType> IAsyncOperation_Impl<T> for Operation_Impl<T> {
204 fn SetCompleted(&self, handler: Ref<'_, AsyncOperationCompletedHandler<T>>) -> Result<()> {
205 self.0.set_completed(&self.as_interface(), handler)
206 }
207 fn Completed(&self) -> Result<AsyncOperationCompletedHandler<T>> {
208 Err(Error::empty())
209 }
210 fn GetResults(&self) -> Result<T> {
211 self.0.get_results()
212 }
213}
214
215impl<P: RuntimeType> IAsyncActionWithProgress_Impl<P> for ActionWithProgress_Impl<P> {
216 fn SetCompleted(
217 &self,
218 handler: Ref<'_, AsyncActionWithProgressCompletedHandler<P>>,
219 ) -> Result<()> {
220 self.0.set_completed(&self.as_interface(), handler)
221 }
222 fn Completed(&self) -> Result<AsyncActionWithProgressCompletedHandler<P>> {
223 Err(Error::empty())
224 }
225 fn GetResults(&self) -> Result<()> {
226 self.0.get_results()
227 }
228 fn SetProgress(&self, _: Ref<'_, AsyncActionProgressHandler<P>>) -> Result<()> {
229 Ok(())
230 }
231 fn Progress(&self) -> Result<AsyncActionProgressHandler<P>> {
232 Err(Error::empty())
233 }
234}
235
236impl<T: RuntimeType, P: RuntimeType> IAsyncOperationWithProgress_Impl<T, P>
237 for OperationWithProgress_Impl<T, P>
238{
239 fn SetCompleted(
240 &self,
241 handler: Ref<'_, AsyncOperationWithProgressCompletedHandler<T, P>>,
242 ) -> Result<()> {
243 self.0.set_completed(&self.as_interface(), handler)
244 }
245 fn Completed(&self) -> Result<AsyncOperationWithProgressCompletedHandler<T, P>> {
246 Err(Error::empty())
247 }
248 fn GetResults(&self) -> Result<T> {
249 self.0.get_results()
250 }
251 fn SetProgress(&self, _: Ref<'_, AsyncOperationProgressHandler<T, P>>) -> Result<()> {
252 Ok(())
253 }
254 fn Progress(&self) -> Result<AsyncOperationProgressHandler<T, P>> {
255 Err(Error::empty())
256 }
257}
258
259impl IAsyncAction {
260 pub fn spawn<F>(f: F) -> Self
262 where
263 F: FnOnce() -> Result<()> + Send + 'static,
264 {
265 let object = ComObject::new(Action(SyncState::new()));
266 let interface = object.to_interface();
267
268 spawn(move || {
269 object.0.spawn(&object.as_interface(), f);
270 });
271
272 interface
273 }
274}
275
276impl<T: RuntimeType> IAsyncOperation<T> {
277 pub fn spawn<F>(f: F) -> Self
279 where
280 F: FnOnce() -> Result<T> + Send + 'static,
281 {
282 let object = ComObject::new(Operation(SyncState::new()));
283 let interface = object.to_interface();
284
285 spawn(move || {
286 object.0.spawn(&object.as_interface(), f);
287 });
288
289 interface
290 }
291}
292
293impl<P: RuntimeType> IAsyncActionWithProgress<P> {
294 pub fn spawn<F>(f: F) -> Self
296 where
297 F: FnOnce() -> Result<()> + Send + 'static,
298 {
299 let object = ComObject::new(ActionWithProgress(SyncState::new()));
300 let interface = object.to_interface();
301
302 spawn(move || {
303 object.0.spawn(&object.as_interface(), f);
304 });
305
306 interface
307 }
308}
309
310impl<T: RuntimeType, P: RuntimeType> IAsyncOperationWithProgress<T, P> {
311 pub fn spawn<F>(f: F) -> Self
313 where
314 F: FnOnce() -> Result<T> + Send + 'static,
315 {
316 let object = ComObject::new(OperationWithProgress(SyncState::new()));
317 let interface = object.to_interface();
318
319 spawn(move || {
320 object.0.spawn(&object.as_interface(), f);
321 });
322
323 interface
324 }
325}
326
327fn spawn<F: FnOnce() + Send + 'static>(f: F) {
328 type PTP_SIMPLE_CALLBACK =
329 unsafe extern "system" fn(instance: *const c_void, context: *const c_void);
330 windows_link::link!("kernel32.dll" "system" fn TrySubmitThreadpoolCallback(callback: PTP_SIMPLE_CALLBACK, context: *const c_void, environment: *const c_void) -> i32);
331
332 unsafe extern "system" fn callback<F: FnOnce() + Send + 'static>(
333 _: *const c_void,
334 callback: *const c_void,
335 ) {
336 unsafe {
337 Box::from_raw(callback as *mut F)();
338 }
339 }
340
341 unsafe {
342 if TrySubmitThreadpoolCallback(
343 callback::<F>,
344 Box::into_raw(Box::new(f)) as _,
345 core::ptr::null(),
346 ) == 0
347 {
348 panic!("allocation failed");
349 }
350 }
351}