1use super::{CudaError, CudaEvent};
2use slop_alloc::mem::{CopyDirection, CopyError, DeviceMemory};
3use slop_alloc::{AllocError, Allocator};
4use sp1_gpu_sys::runtime::{
5 cuda_event_record, cuda_free_async, cuda_launch_host_function, cuda_malloc_async,
6 cuda_mem_copy_device_to_device_async, cuda_mem_copy_device_to_host_async,
7 cuda_mem_copy_host_to_device_async, cuda_mem_set_async, cuda_stream_create,
8 cuda_stream_destroy, cuda_stream_query, cuda_stream_synchronize, cuda_stream_wait_event,
9 CudaStreamHandle, Dim3, KernelPtr, DEFAULT_STREAM,
10};
11use std::{
12 alloc::Layout,
13 ffi::c_void,
14 future::{Future, IntoFuture},
15 ops::Deref,
16 pin::Pin,
17 ptr::{self, NonNull},
18 sync::{Arc, Mutex},
19 task::{Context, Poll, Waker},
20 time::Duration,
21};
22use tokio::time::Interval;
23
24pub(crate) const INTERVAL_MS: u64 = 2000;
25
26#[derive(Debug, PartialEq, Eq, Hash)]
27#[repr(transparent)]
28pub struct CudaStream(pub(crate) CudaStreamHandle);
29
30unsafe impl Send for CudaStream {}
31unsafe impl Sync for CudaStream {}
32
33impl Drop for CudaStream {
34 fn drop(&mut self) {
35 if self.0 != unsafe { DEFAULT_STREAM } {
36 CudaError::result_from_ffi(unsafe { cuda_stream_destroy(self.0) }).unwrap();
38 }
39 }
40}
41
42impl CudaStream {
43 #[inline]
44 pub(crate) fn create() -> Result<Self, CudaError> {
45 let mut ptr = CudaStreamHandle(ptr::null_mut());
46 CudaError::result_from_ffi(unsafe {
47 cuda_stream_create(&mut ptr as *mut CudaStreamHandle)
48 })?;
49 Ok(Self(ptr))
50 }
51
52 #[inline]
56 unsafe fn launch_host_fn(
57 &self,
58 host_fn: Option<unsafe extern "C" fn(*mut c_void)>,
59 data: *const c_void,
60 ) -> Result<(), CudaError> {
61 CudaError::result_from_ffi(unsafe { cuda_launch_host_function(self.0, host_fn, data) })
62 }
63
64 #[inline]
70 pub unsafe fn launch_kernel(
71 &self,
72 kernel: KernelPtr,
73 grid_dim: impl Into<Dim3>,
74 block_dim: impl Into<Dim3>,
75 args: &[*mut c_void],
76 shared_mem: usize,
77 ) -> Result<(), CudaError> {
78 CudaError::result_from_ffi(sp1_gpu_sys::runtime::cuda_launch_kernel(
79 kernel,
80 grid_dim.into(),
81 block_dim.into(),
82 args.as_ptr() as *mut *mut c_void,
83 shared_mem,
84 self.0,
85 ))
86 }
87
88 #[inline]
89 fn query(&self) -> Result<(), CudaError> {
90 CudaError::result_from_ffi(unsafe { cuda_stream_query(self.0) })
91 }
92
93 #[inline]
94 fn record(&self, event: &CudaEvent) -> Result<(), CudaError> {
95 CudaError::result_from_ffi(unsafe { cuda_event_record(event.0, self.0) })
96 }
97
98 #[inline]
103 unsafe fn wait(&self, event: &CudaEvent) -> Result<(), CudaError> {
104 CudaError::result_from_ffi(cuda_stream_wait_event(self.0, event.0))
105 }
106
107 #[inline]
108 fn synchronize(&self) -> Result<(), CudaError> {
109 CudaError::result_from_ffi(unsafe { cuda_stream_synchronize(self.0) })
110 }
111}
112
113impl Default for CudaStream {
114 fn default() -> Self {
115 Self(unsafe { DEFAULT_STREAM })
116 }
117}
118
119struct CallbackState<S> {
121 task: Option<S>,
123 done: bool,
124 result: Result<(), CudaError>,
125 waker: Option<Waker>,
126}
127
128pub struct StreamCallbackFuture<S> {
134 shared: Arc<Mutex<CallbackState<S>>>,
135 interval: Pin<Box<Interval>>,
136}
137
138pub trait StreamRef {
148 unsafe fn stream(&self) -> &CudaStream;
149
150 #[inline]
154 unsafe fn launch_host_fn_uncheked(
155 &self,
156 host_fn: Option<unsafe extern "C" fn(*mut c_void)>,
157 data: *const c_void,
158 ) -> Result<(), CudaError> {
159 self.stream().launch_host_fn(host_fn, data)
160 }
161
162 #[inline]
163 unsafe fn query(&self) -> Result<(), CudaError> {
164 self.stream().query()
165 }
166
167 #[inline]
168 unsafe fn record_unchecked(&self, event: &CudaEvent) -> Result<(), CudaError> {
169 self.stream().record(event)
170 }
171
172 #[inline]
177 unsafe fn wait_unchecked(&self, event: &CudaEvent) -> Result<(), CudaError> {
178 self.stream().wait(event)
179 }
180
181 #[inline]
182 unsafe fn stream_synchronize(&self) -> Result<(), CudaError> {
183 self.stream().synchronize()
184 }
185}
186
187impl StreamRef for CudaStream {
188 #[inline]
189 unsafe fn stream(&self) -> &CudaStream {
190 self
191 }
192}
193
194impl<S> StreamRef for Arc<S>
195where
196 S: StreamRef + ?Sized,
197{
198 #[inline]
199 unsafe fn stream(&self) -> &CudaStream {
200 self.as_ref().stream()
201 }
202}
203
204impl<S> StreamCallbackFuture<S> {
205 pub fn new(task: S) -> Self
208 where
209 S: StreamRef,
210 {
211 let shared = Arc::new(Mutex::new(CallbackState {
213 task: None,
214 done: false,
215 result: Ok(()),
216 waker: None,
217 }));
218
219 let ptr = Arc::into_raw(shared.clone()) as *mut c_void;
222
223 let launch_result = unsafe { task.stream().launch_host_fn(Some(waker_callback::<S>), ptr) };
227
228 shared.lock().unwrap().task = Some(task);
229
230 if let Err(e) = launch_result {
231 let mut state = shared.lock().unwrap();
232 state.result = Err(e);
233 state.done = true;
234 }
235
236 let interval = Box::pin(tokio::time::interval(Duration::from_millis(INTERVAL_MS)));
237
238 Self { shared, interval }
239 }
240}
241
242unsafe extern "C" fn waker_callback<S>(user_data: *mut c_void)
243where
244 S: StreamRef,
245{
246 let shared = Arc::<Mutex<CallbackState<S>>>::from_raw(user_data as *const _);
248 let mut state = shared.lock().unwrap();
249
250 state.done = true;
252
253 if let Some(ref waker) = state.waker {
255 waker.wake_by_ref();
256 }
257}
258
259impl<S> Future for StreamCallbackFuture<S>
260where
261 S: StreamRef,
262{
263 type Output = Result<(), CudaError>;
264
265 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
266 let mut state = self.shared.lock().unwrap();
267
268 if state.done {
270 return Poll::Ready(state.result);
272 }
273
274 match unsafe { state.task.as_ref().unwrap().stream().query() } {
276 Ok(()) => {
277 state.done = true;
278 state.result = Ok(());
279 return Poll::Ready(Ok(()));
280 }
281 Err(CudaError::NotReady) => {
282 }
284 Err(e) => {
285 state.done = true;
287 state.result = Err(e);
288 return Poll::Ready(Err(e));
289 }
290 }
291
292 state.waker = Some(cx.waker().clone());
294 drop(state);
295
296 match self.interval.as_mut().poll_tick(cx) {
298 Poll::Ready(_) => {
299 cx.waker().wake_by_ref();
301 Poll::Pending
302 }
303 Poll::Pending => {
304 Poll::Pending
306 }
307 }
308 }
309}
310
311impl IntoFuture for CudaStream {
312 type Output = Result<(), CudaError>;
313 type IntoFuture = StreamCallbackFuture<Self>;
314
315 fn into_future(self) -> Self::IntoFuture {
316 StreamCallbackFuture::new(self)
317 }
318}
319
320unsafe impl Allocator for CudaStream {
321 #[inline]
322 unsafe fn allocate(&self, layout: Layout) -> Result<ptr::NonNull<[u8]>, AllocError> {
323 let mut ptr: *mut c_void = ptr::null_mut();
324 unsafe {
325 CudaError::result_from_ffi(cuda_malloc_async(
326 &mut ptr as *mut *mut c_void,
327 layout.size(),
328 self.0,
329 ))
330 .map_err(|_| AllocError)?;
331 };
332 let ptr = ptr as *mut u8;
333 Ok(NonNull::slice_from_raw_parts(NonNull::new_unchecked(ptr), layout.size()))
334 }
335
336 #[inline]
337 unsafe fn deallocate(&self, ptr: NonNull<u8>, _layout: Layout) {
338 unsafe {
339 CudaError::result_from_ffi(cuda_free_async(ptr.as_ptr() as *mut c_void, self.0))
340 .unwrap()
341 }
342 }
343}
344
345impl DeviceMemory for CudaStream {
346 #[inline]
347 unsafe fn copy_nonoverlapping(
348 &self,
349 src: *const u8,
350 dst: *mut u8,
351 size: usize,
352 direction: CopyDirection,
353 ) -> Result<(), CopyError> {
354 let maybe_err = match direction {
355 CopyDirection::HostToDevice => cuda_mem_copy_host_to_device_async(
356 dst as *mut c_void,
357 src as *const c_void,
358 size,
359 self.0,
360 ),
361 CopyDirection::DeviceToHost => cuda_mem_copy_device_to_host_async(
362 dst as *mut c_void,
363 src as *const c_void,
364 size,
365 self.0,
366 ),
367 CopyDirection::DeviceToDevice => cuda_mem_copy_device_to_device_async(
368 dst as *mut c_void,
369 src as *const c_void,
370 size,
371 self.0,
372 ),
373 };
374 CudaError::result_from_ffi(maybe_err).map_err(|_| CopyError)
375 }
376
377 #[inline]
378 unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
379 unsafe {
380 CudaError::result_from_ffi(cuda_mem_set_async(dst as *mut c_void, value, size, self.0))
381 .map_err(|_| CopyError)
382 }
383 }
384}
385
386#[derive(Debug, PartialEq, Eq, Hash)]
387pub struct UnsafeCudaStream(CudaStream);
388
389impl UnsafeCudaStream {
390 #[allow(dead_code)]
391 pub fn create() -> Result<Self, CudaError> {
392 Ok(Self(CudaStream::create()?))
393 }
394}
395
396impl Deref for UnsafeCudaStream {
397 type Target = CudaStream;
398
399 fn deref(&self) -> &Self::Target {
400 &self.0
401 }
402}
403
404impl StreamRef for UnsafeCudaStream {
405 #[inline]
406 unsafe fn stream(&self) -> &CudaStream {
407 &self.0
408 }
409}
410
411impl IntoFuture for UnsafeCudaStream {
412 type Output = Result<(), CudaError>;
413 type IntoFuture = StreamCallbackFuture<Self>;
414
415 fn into_future(self) -> Self::IntoFuture {
416 StreamCallbackFuture::new(self)
417 }
418}
419
420unsafe impl Allocator for UnsafeCudaStream {
421 #[inline]
422 unsafe fn allocate(&self, layout: Layout) -> Result<ptr::NonNull<[u8]>, AllocError> {
423 self.0.allocate(layout)
424 }
425
426 #[inline]
427 unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
428 self.0.deallocate(ptr, layout)
429 }
430}
431
432impl DeviceMemory for UnsafeCudaStream {
433 #[inline]
434 unsafe fn copy_nonoverlapping(
435 &self,
436 src: *const u8,
437 dst: *mut u8,
438 size: usize,
439 direction: CopyDirection,
440 ) -> Result<(), CopyError> {
441 self.0.copy_nonoverlapping(src, dst, size, direction)
442 }
443
444 #[inline]
445 unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
446 self.0.write_bytes(dst, value, size)
447 }
448}