1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{
5 Arc, Mutex,
6 atomic::{AtomicBool, Ordering},
7 },
8 task::{Context as TaskContext, Poll, Waker},
9};
10
11use crate::{
12 error::Result,
13 event::Event,
14 graph::ExecutableGraph,
15 stream::{Stream, StreamScope},
16};
17
18#[derive(Debug)]
19struct CompletionState {
20 complete: AtomicBool,
21 waker: Mutex<Option<Waker>>,
22}
23
24impl CompletionState {
25 fn create() -> Self {
26 Self {
27 complete: AtomicBool::new(false),
28 waker: Mutex::new(None),
29 }
30 }
31
32 fn signal(&self) {
33 self.complete.store(true, Ordering::Release);
34 if let Some(waker) = self.waker.lock().expect("completion waker poisoned").take() {
35 waker.wake();
36 }
37 }
38}
39
40#[derive(Debug)]
41struct CheckedCompletionState {
42 inner: Mutex<CheckedCompletionInner>,
43}
44
45#[derive(Debug)]
46struct CheckedCompletionInner {
47 result: Option<Result<()>>,
48 waker: Option<Waker>,
49}
50
51impl CheckedCompletionState {
52 fn create() -> Self {
53 Self {
54 inner: Mutex::new(CheckedCompletionInner {
55 result: None,
56 waker: None,
57 }),
58 }
59 }
60
61 fn signal(&self, result: Result<()>) {
62 let waker = {
63 let mut inner = self
64 .inner
65 .lock()
66 .expect("checked completion waker poisoned");
67 inner.result = Some(result);
68 inner.waker.take()
69 };
70
71 if let Some(waker) = waker {
72 waker.wake();
73 }
74 }
75}
76
77#[derive(Debug)]
78pub struct StreamFuture {
79 state: Arc<CompletionState>,
80 _stream: Stream,
81}
82
83impl Future for StreamFuture {
84 type Output = ();
85
86 fn poll(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Self::Output> {
87 if self.state.complete.load(Ordering::Acquire) {
88 return Poll::Ready(());
89 }
90
91 {
92 let mut waker = self.state.waker.lock().expect("completion waker poisoned");
93 *waker = Some(cx.waker().clone());
94 }
95
96 if self.state.complete.load(Ordering::Acquire) {
97 Poll::Ready(())
98 } else {
99 Poll::Pending
100 }
101 }
102}
103
104impl Unpin for StreamFuture {}
105
106#[derive(Debug)]
107pub struct CheckedStreamFuture {
108 state: Arc<CheckedCompletionState>,
109 _stream: Stream,
110}
111
112impl Future for CheckedStreamFuture {
113 type Output = Result<()>;
114
115 fn poll(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Self::Output> {
116 let mut inner = self
117 .state
118 .inner
119 .lock()
120 .expect("checked completion waker poisoned");
121
122 if let Some(result) = inner.result.take() {
123 return Poll::Ready(result);
124 }
125
126 inner.waker = Some(cx.waker().clone());
127 Poll::Pending
128 }
129}
130
131impl Unpin for CheckedStreamFuture {}
132
133#[derive(Debug)]
134pub struct CudaFuture<T> {
135 completion: CheckedStreamFuture,
136 output: Option<T>,
137}
138
139impl<T> Future for CudaFuture<T> {
140 type Output = Result<T>;
141
142 fn poll(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Self::Output> {
143 match Pin::new(&mut self.completion).poll(cx) {
144 Poll::Ready(Ok(())) => {
145 let output = self
146 .output
147 .take()
148 .expect("cuda future output already consumed");
149 Poll::Ready(Ok(output))
150 }
151 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
152 Poll::Pending => Poll::Pending,
153 }
154 }
155}
156
157impl<T> Unpin for CudaFuture<T> {}
158
159impl Stream {
160 pub fn completion_future(&self) -> Result<StreamFuture> {
166 self.ensure_not_capturing_for_future()?;
167
168 let state = Arc::new(CompletionState::create());
169 let callback_state = Arc::clone(&state);
170 self.launch_host_func(move || callback_state.signal())?;
171
172 Ok(StreamFuture {
173 state,
174 _stream: self.clone(),
175 })
176 }
177
178 pub fn checked_completion_future(&self) -> Result<CheckedStreamFuture> {
182 self.ensure_not_capturing_for_future()?;
183
184 let state = Arc::new(CheckedCompletionState::create());
185 let callback_state = Arc::clone(&state);
186 self.add_callback(move |result| callback_state.signal(result))?;
187
188 Ok(CheckedStreamFuture {
189 state,
190 _stream: self.clone(),
191 })
192 }
193
194 pub async fn synchronize_async(&self) -> Result<()> {
195 self.checked_completion_future()?.await
196 }
197
198 pub fn enqueue_async<T, F>(&self, f: F) -> Result<CudaFuture<T>>
199 where
200 F: FnOnce(&Stream) -> Result<T>,
201 {
202 let output = f(self)?;
203 let completion = self.checked_completion_future()?;
204 Ok(CudaFuture {
205 completion,
206 output: Some(output),
207 })
208 }
209}
210
211impl<'scope, 'env> StreamScope<'scope, 'env> {
212 pub fn completion_future(&self) -> Result<StreamFuture> {
213 self.stream().completion_future()
214 }
215
216 pub fn checked_completion_future(&self) -> Result<CheckedStreamFuture> {
217 self.stream().checked_completion_future()
218 }
219
220 pub async fn synchronize_async(&self) -> Result<()> {
221 self.stream().synchronize_async().await
222 }
223}
224
225impl Event {
226 pub fn completion_future_on(&self, stream: &Stream) -> Result<StreamFuture> {
227 stream.wait_event(self)?;
228 stream.completion_future()
229 }
230
231 pub fn checked_completion_future_on(&self, stream: &Stream) -> Result<CheckedStreamFuture> {
232 stream.wait_event(self)?;
233 stream.checked_completion_future()
234 }
235
236 pub async fn synchronize_async_on(&self, stream: &Stream) -> Result<()> {
237 self.checked_completion_future_on(stream)?.await
238 }
239}
240
241impl ExecutableGraph {
242 pub fn launch_async(&self, stream: &Stream) -> Result<CheckedStreamFuture> {
243 self.launch(stream)?;
244 stream.checked_completion_future()
245 }
246
247 pub async fn launch_and_wait(&self, stream: &Stream) -> Result<()> {
248 self.launch_async(stream)?.await
249 }
250
251 pub fn upload_async(&self, stream: &Stream) -> Result<CheckedStreamFuture> {
252 self.upload(stream)?;
253 stream.checked_completion_future()
254 }
255
256 pub async fn upload_and_wait(&self, stream: &Stream) -> Result<()> {
257 self.upload_async(stream)?.await
258 }
259}
260
261#[cfg(all(test, feature = "testing"))]
262mod tests {
263 use crate::{
264 error::{Error, Result, Status},
265 event::EventRecordFlags,
266 memory::DeviceMemory,
267 stream::StreamCaptureMode,
268 testing,
269 };
270
271 #[tokio::test]
272 async fn stream_future_resolves_after_empty_stream() -> Result<()> {
273 let (_lock, ctx) = testing::bootstrap()?;
274 let stream = ctx.create_stream()?;
275
276 stream.completion_future()?.await;
277 stream.checked_completion_future()?.await?;
278 stream.synchronize_async().await?;
279
280 Ok(())
281 }
282
283 #[tokio::test]
284 async fn stream_synchronize_async_waits_for_memset() -> Result<()> {
285 let (_lock, ctx) = testing::bootstrap()?;
286 let stream = ctx.create_stream()?;
287 let mut device = DeviceMemory::<u8>::zeroes(16)?;
288
289 unsafe {
290 device.set_value_async_unchecked(7, &stream)?;
291 }
292 stream.synchronize_async().await?;
293
294 let mut host = vec![0; 16];
295 device.copy_to_host(&mut host)?;
296 assert_eq!(host, vec![7; 16]);
297
298 Ok(())
299 }
300
301 #[tokio::test]
302 async fn dropping_pending_future_does_not_cancel_stream_work() -> Result<()> {
303 let (_lock, ctx) = testing::bootstrap()?;
304 let stream = ctx.create_stream()?;
305 let mut device = DeviceMemory::<u8>::zeroes(8)?;
306
307 unsafe {
308 device.set_value_async_unchecked(11, &stream)?;
309 }
310 let future = stream.checked_completion_future()?;
311 drop(future);
312
313 stream.synchronize()?;
314 let mut host = vec![0; 8];
315 device.copy_to_host(&mut host)?;
316 assert_eq!(host, vec![11; 8]);
317
318 Ok(())
319 }
320
321 #[tokio::test]
322 async fn stream_future_registration_is_rejected_during_capture() -> Result<()> {
323 let (_lock, ctx) = testing::bootstrap()?;
324 let stream = ctx.create_stream()?;
325
326 stream.begin_capture(StreamCaptureMode::Relaxed)?;
327 let error = stream.completion_future().unwrap_err();
328 drop(stream.end_capture());
329
330 assert!(matches!(
331 error,
332 Error::Cuda {
333 code: Status::StreamCaptureUnsupported,
334 ..
335 }
336 ));
337
338 Ok(())
339 }
340
341 #[tokio::test]
342 async fn event_future_orders_work_across_streams() -> Result<()> {
343 let (_lock, ctx) = testing::bootstrap()?;
344 let stream_a = ctx.create_stream()?;
345 let stream_b = ctx.create_stream()?;
346 let event = ctx.create_event()?;
347 let mut device = DeviceMemory::<u8>::zeroes(4)?;
348
349 unsafe {
350 device.set_value_async_unchecked(5, &stream_a)?;
351 }
352 event.record(&stream_a, EventRecordFlags::DEFAULT)?;
353 event.synchronize_async_on(&stream_b).await?;
354
355 let mut host = vec![0; 4];
356 device.copy_to_host(&mut host)?;
357 assert_eq!(host, vec![5; 4]);
358
359 Ok(())
360 }
361
362 #[tokio::test]
363 async fn graph_launch_async_waits_for_launch_completion() -> Result<()> {
364 let (_lock, ctx) = testing::bootstrap()?;
365 let stream = ctx.create_stream()?;
366 let mut graph = ctx.create_graph()?;
367
368 graph.add_empty_node(&[])?;
369 let executable = graph.instantiate()?;
370 executable.launch_async(&stream).unwrap().await?;
371
372 Ok(())
373 }
374}