Skip to main content

singe_cuda/
future.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::{
5        Arc, Mutex,
6        atomic::{AtomicBool, Ordering},
7    },
8    task::{Context as TaskContext, Poll, Waker},
9};
10
11use crate::{
12    error::Result,
13    event::Event,
14    graph::ExecutableGraph,
15    stream::{Stream, StreamScope},
16};
17
18#[derive(Debug)]
19struct CompletionState {
20    complete: AtomicBool,
21    waker: Mutex<Option<Waker>>,
22}
23
24impl CompletionState {
25    fn create() -> Self {
26        Self {
27            complete: AtomicBool::new(false),
28            waker: Mutex::new(None),
29        }
30    }
31
32    fn signal(&self) {
33        self.complete.store(true, Ordering::Release);
34        if let Some(waker) = self.waker.lock().expect("completion waker poisoned").take() {
35            waker.wake();
36        }
37    }
38}
39
40#[derive(Debug)]
41struct CheckedCompletionState {
42    inner: Mutex<CheckedCompletionInner>,
43}
44
45#[derive(Debug)]
46struct CheckedCompletionInner {
47    result: Option<Result<()>>,
48    waker: Option<Waker>,
49}
50
51impl CheckedCompletionState {
52    fn create() -> Self {
53        Self {
54            inner: Mutex::new(CheckedCompletionInner {
55                result: None,
56                waker: None,
57            }),
58        }
59    }
60
61    fn signal(&self, result: Result<()>) {
62        let waker = {
63            let mut inner = self
64                .inner
65                .lock()
66                .expect("checked completion waker poisoned");
67            inner.result = Some(result);
68            inner.waker.take()
69        };
70
71        if let Some(waker) = waker {
72            waker.wake();
73        }
74    }
75}
76
77#[derive(Debug)]
78pub struct StreamFuture {
79    state: Arc<CompletionState>,
80    _stream: Stream,
81}
82
83impl Future for StreamFuture {
84    type Output = ();
85
86    fn poll(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Self::Output> {
87        if self.state.complete.load(Ordering::Acquire) {
88            return Poll::Ready(());
89        }
90
91        {
92            let mut waker = self.state.waker.lock().expect("completion waker poisoned");
93            *waker = Some(cx.waker().clone());
94        }
95
96        if self.state.complete.load(Ordering::Acquire) {
97            Poll::Ready(())
98        } else {
99            Poll::Pending
100        }
101    }
102}
103
104impl Unpin for StreamFuture {}
105
106#[derive(Debug)]
107pub struct CheckedStreamFuture {
108    state: Arc<CheckedCompletionState>,
109    _stream: Stream,
110}
111
112impl Future for CheckedStreamFuture {
113    type Output = Result<()>;
114
115    fn poll(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Self::Output> {
116        let mut inner = self
117            .state
118            .inner
119            .lock()
120            .expect("checked completion waker poisoned");
121
122        if let Some(result) = inner.result.take() {
123            return Poll::Ready(result);
124        }
125
126        inner.waker = Some(cx.waker().clone());
127        Poll::Pending
128    }
129}
130
131impl Unpin for CheckedStreamFuture {}
132
133#[derive(Debug)]
134pub struct CudaFuture<T> {
135    completion: CheckedStreamFuture,
136    output: Option<T>,
137}
138
139impl<T> Future for CudaFuture<T> {
140    type Output = Result<T>;
141
142    fn poll(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Self::Output> {
143        match Pin::new(&mut self.completion).poll(cx) {
144            Poll::Ready(Ok(())) => {
145                let output = self
146                    .output
147                    .take()
148                    .expect("cuda future output already consumed");
149                Poll::Ready(Ok(output))
150            }
151            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
152            Poll::Pending => Poll::Pending,
153        }
154    }
155}
156
157impl<T> Unpin for CudaFuture<T> {}
158
159impl Stream {
160    /// Returns a future that resolves when a host function enqueued at the current end of this stream runs.
161    ///
162    /// This is a notification primitive.
163    /// It does not report asynchronous CUDA errors after registration.
164    /// Use [`Stream::checked_completion_future`] or [`Stream::synchronize_async`] when the result must include CUDA status.
165    pub fn completion_future(&self) -> Result<StreamFuture> {
166        self.ensure_not_capturing_for_future()?;
167
168        let state = Arc::new(CompletionState::create());
169        let callback_state = Arc::clone(&state);
170        self.launch_host_func(move || callback_state.signal())?;
171
172        Ok(StreamFuture {
173            state,
174            _stream: self.clone(),
175        })
176    }
177
178    /// Returns a future that resolves with CUDA's asynchronous stream status.
179    ///
180    /// This uses CUDA's stream callback status path and is therefore rejected while stream capture is active.
181    pub fn checked_completion_future(&self) -> Result<CheckedStreamFuture> {
182        self.ensure_not_capturing_for_future()?;
183
184        let state = Arc::new(CheckedCompletionState::create());
185        let callback_state = Arc::clone(&state);
186        self.add_callback(move |result| callback_state.signal(result))?;
187
188        Ok(CheckedStreamFuture {
189            state,
190            _stream: self.clone(),
191        })
192    }
193
194    pub async fn synchronize_async(&self) -> Result<()> {
195        self.checked_completion_future()?.await
196    }
197
198    pub fn enqueue_async<T, F>(&self, f: F) -> Result<CudaFuture<T>>
199    where
200        F: FnOnce(&Stream) -> Result<T>,
201    {
202        let output = f(self)?;
203        let completion = self.checked_completion_future()?;
204        Ok(CudaFuture {
205            completion,
206            output: Some(output),
207        })
208    }
209}
210
211impl<'scope, 'env> StreamScope<'scope, 'env> {
212    pub fn completion_future(&self) -> Result<StreamFuture> {
213        self.stream().completion_future()
214    }
215
216    pub fn checked_completion_future(&self) -> Result<CheckedStreamFuture> {
217        self.stream().checked_completion_future()
218    }
219
220    pub async fn synchronize_async(&self) -> Result<()> {
221        self.stream().synchronize_async().await
222    }
223}
224
225impl Event {
226    pub fn completion_future_on(&self, stream: &Stream) -> Result<StreamFuture> {
227        stream.wait_event(self)?;
228        stream.completion_future()
229    }
230
231    pub fn checked_completion_future_on(&self, stream: &Stream) -> Result<CheckedStreamFuture> {
232        stream.wait_event(self)?;
233        stream.checked_completion_future()
234    }
235
236    pub async fn synchronize_async_on(&self, stream: &Stream) -> Result<()> {
237        self.checked_completion_future_on(stream)?.await
238    }
239}
240
241impl ExecutableGraph {
242    pub fn launch_async(&self, stream: &Stream) -> Result<CheckedStreamFuture> {
243        self.launch(stream)?;
244        stream.checked_completion_future()
245    }
246
247    pub async fn launch_and_wait(&self, stream: &Stream) -> Result<()> {
248        self.launch_async(stream)?.await
249    }
250
251    pub fn upload_async(&self, stream: &Stream) -> Result<CheckedStreamFuture> {
252        self.upload(stream)?;
253        stream.checked_completion_future()
254    }
255
256    pub async fn upload_and_wait(&self, stream: &Stream) -> Result<()> {
257        self.upload_async(stream)?.await
258    }
259}
260
261#[cfg(all(test, feature = "testing"))]
262mod tests {
263    use crate::{
264        error::{Error, Result, Status},
265        event::EventRecordFlags,
266        memory::DeviceMemory,
267        stream::StreamCaptureMode,
268        testing,
269    };
270
271    #[tokio::test]
272    async fn stream_future_resolves_after_empty_stream() -> Result<()> {
273        let (_lock, ctx) = testing::bootstrap()?;
274        let stream = ctx.create_stream()?;
275
276        stream.completion_future()?.await;
277        stream.checked_completion_future()?.await?;
278        stream.synchronize_async().await?;
279
280        Ok(())
281    }
282
283    #[tokio::test]
284    async fn stream_synchronize_async_waits_for_memset() -> Result<()> {
285        let (_lock, ctx) = testing::bootstrap()?;
286        let stream = ctx.create_stream()?;
287        let mut device = DeviceMemory::<u8>::zeroes(16)?;
288
289        unsafe {
290            device.set_value_async_unchecked(7, &stream)?;
291        }
292        stream.synchronize_async().await?;
293
294        let mut host = vec![0; 16];
295        device.copy_to_host(&mut host)?;
296        assert_eq!(host, vec![7; 16]);
297
298        Ok(())
299    }
300
301    #[tokio::test]
302    async fn dropping_pending_future_does_not_cancel_stream_work() -> Result<()> {
303        let (_lock, ctx) = testing::bootstrap()?;
304        let stream = ctx.create_stream()?;
305        let mut device = DeviceMemory::<u8>::zeroes(8)?;
306
307        unsafe {
308            device.set_value_async_unchecked(11, &stream)?;
309        }
310        let future = stream.checked_completion_future()?;
311        drop(future);
312
313        stream.synchronize()?;
314        let mut host = vec![0; 8];
315        device.copy_to_host(&mut host)?;
316        assert_eq!(host, vec![11; 8]);
317
318        Ok(())
319    }
320
321    #[tokio::test]
322    async fn stream_future_registration_is_rejected_during_capture() -> Result<()> {
323        let (_lock, ctx) = testing::bootstrap()?;
324        let stream = ctx.create_stream()?;
325
326        stream.begin_capture(StreamCaptureMode::Relaxed)?;
327        let error = stream.completion_future().unwrap_err();
328        drop(stream.end_capture());
329
330        assert!(matches!(
331            error,
332            Error::Cuda {
333                code: Status::StreamCaptureUnsupported,
334                ..
335            }
336        ));
337
338        Ok(())
339    }
340
341    #[tokio::test]
342    async fn event_future_orders_work_across_streams() -> Result<()> {
343        let (_lock, ctx) = testing::bootstrap()?;
344        let stream_a = ctx.create_stream()?;
345        let stream_b = ctx.create_stream()?;
346        let event = ctx.create_event()?;
347        let mut device = DeviceMemory::<u8>::zeroes(4)?;
348
349        unsafe {
350            device.set_value_async_unchecked(5, &stream_a)?;
351        }
352        event.record(&stream_a, EventRecordFlags::DEFAULT)?;
353        event.synchronize_async_on(&stream_b).await?;
354
355        let mut host = vec![0; 4];
356        device.copy_to_host(&mut host)?;
357        assert_eq!(host, vec![5; 4]);
358
359        Ok(())
360    }
361
362    #[tokio::test]
363    async fn graph_launch_async_waits_for_launch_completion() -> Result<()> {
364        let (_lock, ctx) = testing::bootstrap()?;
365        let stream = ctx.create_stream()?;
366        let mut graph = ctx.create_graph()?;
367
368        graph.add_empty_node(&[])?;
369        let executable = graph.instantiate()?;
370        executable.launch_async(&stream).unwrap().await?;
371
372        Ok(())
373    }
374}