1use std::{
2 alloc::Layout,
3 ffi::c_void,
4 future::{Future, IntoFuture},
5 mem::MaybeUninit,
6 ops::Deref,
7 pin::Pin,
8 ptr::{self, NonNull},
9 sync::{
10 atomic::{AtomicUsize, Ordering},
11 Arc, OnceLock, Weak,
12 },
13 task::{Context, Poll},
14 time::Duration,
15};
16
17use futures::{future::MapOkOrElse, TryFutureExt};
18use pin_project::pin_project;
19use slop_alloc::{
20 mem::{CopyDirection, CopyError, DeviceMemory},
21 AllocError, Allocator, Backend, Buffer, Slice,
22};
23use slop_futures::queue::{AcquireWorkerError, TryAcquireWorkerError, Worker, WorkerQueue};
24use sp1_gpu_sys::runtime::{
25 cuda_device_get_default_mem_pool, cuda_mem_pool_set_release_threshold, CudaDevice, CudaMemPool,
26 CudaStreamHandle, Dim3, KernelPtr,
27};
28use thiserror::Error;
29use tokio::{sync::oneshot, task::JoinHandle};
30
31use crate::{DeviceCopy, ToDevice};
32
33use super::{
34 stream::{StreamRef, INTERVAL_MS},
35 sync::CudaSend,
36 CudaError, CudaEvent, CudaStream, IntoDevice, StreamCallbackFuture,
37};
38
39const DEFAULT_NUM_TASKS: usize = 64;
40
41static GLOBAL_TASK_POOL: OnceLock<Arc<TaskPool>> = OnceLock::new();
42
43static POOL_ID: AtomicUsize = AtomicUsize::new(0);
44
45pub struct TaskPoolBuilder {
46 device: CudaDevice,
47 mem_release_threshold: u64,
48 capacity: Option<usize>,
49}
50
51pub(crate) fn global_task_pool() -> &'static Arc<TaskPool> {
52 GLOBAL_TASK_POOL.get_or_init(|| Arc::new(TaskPoolBuilder::new().build().unwrap()))
53}
54
55pub struct SpawnHandle<T> {
56 handle: JoinHandle<Result<T, CudaError>>,
57}
58
59impl<T> SpawnHandle<T> {
60 pub fn abort(&self) {
61 self.handle.abort();
62 }
63}
64
65#[derive(Debug, Error)]
66pub enum SpawnError {
67 #[error("join handle panicked with error: {0}")]
68 JoinError(#[from] tokio::task::JoinError),
69 #[error("cuda error: {0}")]
70 CudaError(#[from] CudaError),
71 #[error("failed to acquire a task from the pool")]
72 TaskSpawnError(#[from] TaskSpawnError),
73}
74
75fn map_ok_value<T>(e: Result<T, CudaError>) -> Result<T, SpawnError> {
76 e.map_err(SpawnError::CudaError)
77}
78
79fn map_err_value<T>(e: tokio::task::JoinError) -> Result<T, SpawnError> {
80 Err(SpawnError::JoinError(e))
81}
82
83impl<T> IntoFuture for SpawnHandle<T> {
84 type Output = Result<T, SpawnError>;
85
86 type IntoFuture = MapOkOrElse<
87 JoinHandle<Result<T, CudaError>>,
88 fn(Result<T, CudaError>) -> Result<T, SpawnError>,
89 fn(tokio::task::JoinError) -> Result<T, SpawnError>,
90 >;
91
92 fn into_future(self) -> Self::IntoFuture {
93 self.handle.map_ok_or_else(map_err_value, map_ok_value)
94 }
95}
96
97pub fn spawn<F, Fut>(f: F) -> SpawnHandle<Fut::Output>
98where
99 F: FnOnce(TaskScope) -> Fut + Send + 'static,
100 Fut: Future + Send + 'static,
101 Fut::Output: Send + 'static,
102{
103 let pool = global_task_pool();
104 pool.spawn(f)
105}
106
107pub async fn run_in_place<F, Fut, R>(f: F) -> TaskHandle<R>
111where
112 F: FnOnce(TaskScope) -> Fut,
113 Fut: Future<Output = R>,
114{
115 let pool = global_task_pool();
116 pool.run(f).await
117}
118
119pub fn run_sync_in_place<F, R>(f: F) -> Result<R, CudaError>
123where
124 F: FnOnce(TaskScope) -> R,
125{
126 let pool = global_task_pool();
127 pool.run_sync(f)
128}
129
130#[derive(Debug, Clone, Error)]
131pub enum TaskPoolBuildError {
132 #[error("failed to create CUDA stream: {0}")]
133 StreamCreationFailed(CudaError),
134
135 #[error("failed to create CUDA event: {0}")]
136 EventCreationFailed(CudaError),
137
138 #[error("failed to push task back into pool")]
139 PushTaskFailed,
140}
141
142#[derive(Debug, Clone, Error)]
143pub enum GlobalTaskPoolBuildError {
144 #[error("failed to build global task pool")]
145 BuildFailed(#[from] TaskPoolBuildError),
146 #[error("global task pool already initialized")]
147 AlreadyInitialized,
148}
149
150impl TaskPoolBuilder {
151 pub fn new() -> Self {
152 Self { capacity: None, device: CudaDevice(0), mem_release_threshold: u64::MAX }
153 }
154
155 pub fn num_tasks(mut self, num_tasks: usize) -> Self {
156 self.capacity = Some(num_tasks);
157 self
158 }
159
160 pub fn device(mut self, device: CudaDevice) -> Self {
161 assert!(device.0 == 0, "only device 0 is supported at the moment");
162 self.device = device;
163 self
164 }
165
166 pub fn mem_release_threshold(mut self, threshold: u64) -> Self {
172 self.mem_release_threshold = threshold;
173 self
174 }
175
176 fn allocate_new_id(&self) -> usize {
177 let id = POOL_ID.fetch_add(1, Ordering::Relaxed);
178 if id > usize::MAX / 2 {
179 std::process::abort();
180 }
181 id
182 }
183
184 pub fn build(self) -> Result<TaskPool, TaskPoolBuildError> {
185 let id = self.allocate_new_id();
186 let num_tasks = self.capacity.unwrap_or(DEFAULT_NUM_TASKS);
187
188 unsafe {
190 let mut mem_pool = CudaMemPool(ptr::null_mut());
191 CudaError::result_from_ffi(cuda_device_get_default_mem_pool(
192 &mut mem_pool,
193 self.device,
194 ))
195 .unwrap();
196 CudaError::result_from_ffi(cuda_mem_pool_set_release_threshold(
197 mem_pool,
198 self.mem_release_threshold,
199 ))
200 .unwrap();
201 };
202
203 let mut tasks = Vec::with_capacity(num_tasks);
204 for (i, _) in (0..num_tasks).enumerate() {
205 let stream = CudaStream::create().map_err(TaskPoolBuildError::StreamCreationFailed)?;
206 let end_event = CudaEvent::create().map_err(TaskPoolBuildError::EventCreationFailed)?;
207 tasks.push(Task { owner_id: id, id: i, stream, end_event });
208 }
209 let inner = Arc::new(WorkerQueue::new(tasks));
210
211 Ok(TaskPool { inner })
212 }
213
214 pub fn build_global(self) -> Result<(), GlobalTaskPoolBuildError> {
215 let pool = self.build()?;
216 GLOBAL_TASK_POOL
217 .set(Arc::new(pool))
218 .map_err(|_| GlobalTaskPoolBuildError::AlreadyInitialized)
219 }
220}
221
222impl Default for TaskPoolBuilder {
223 fn default() -> Self {
224 Self::new()
225 }
226}
227
228#[derive(Debug, Clone)]
229pub struct TaskPool {
230 inner: Arc<WorkerQueue<Task>>,
231}
232
233struct OwnedTask {
234 inner: Worker<Task>,
235}
236
237impl std::fmt::Debug for OwnedTask {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 write!(f, "OwnedTask {{ inner: {:?} }}", self.inner.deref())
240 }
241}
242
243#[derive(Debug, Error)]
244#[error("failed to acquire a task from the pool")]
245pub enum TaskSpawnError {
246 AcquireError(#[from] AcquireWorkerError),
247}
248
249#[derive(Debug, Error)]
250#[error("failed to acquire a task from the pool")]
251pub enum TrySpawnError {
252 TryAcquireError(#[from] TryAcquireWorkerError),
253}
254
255impl TaskPool {
256 async fn task(inner: Arc<WorkerQueue<Task>>) -> Result<OwnedTask, TaskSpawnError> {
257 let worker = inner.clone().pop().await.map_err(TaskSpawnError::AcquireError)?;
258 Ok(OwnedTask { inner: worker })
259 }
260
261 fn try_task(inner: Arc<WorkerQueue<Task>>) -> Result<OwnedTask, TrySpawnError> {
262 let worker = inner.clone().try_pop().map_err(TrySpawnError::TryAcquireError)?;
263 Ok(OwnedTask { inner: worker })
264 }
265
266 pub fn spawn<F, Fut>(&self, f: F) -> SpawnHandle<Fut::Output>
270 where
271 F: FnOnce(TaskScope) -> Fut + Send + 'static,
272 Fut: Future + Send + 'static,
273 Fut::Output: Send + 'static,
274 {
275 let queue = self.inner.clone();
276 let handle = tokio::spawn(async move {
277 let task = TaskPool::task(queue).await.expect("failed to acquire a task from the pool");
278 task.run(f).await.await
279 });
280 SpawnHandle { handle }
281 }
282
283 pub fn spawn_blocking<F, R>(&self, f: F) -> SpawnHandle<R>
284 where
285 F: FnOnce(TaskScope) -> R + Send + 'static,
286 R: Send + 'static,
287 {
288 let queue = self.inner.clone();
289 let handle = tokio::task::spawn_blocking(move || {
290 let task = TaskPool::try_task(queue).expect("failed to acquire a task from the pool");
291 let task = Arc::new(task);
292 task.run_sync(f)
293 });
294 SpawnHandle { handle }
295 }
296
297 pub async fn run<F, Fut, R>(&self, f: F) -> TaskHandle<R>
301 where
302 F: FnOnce(TaskScope) -> Fut,
303 Fut: Future<Output = R>,
304 {
305 let queue = self.inner.clone();
306 let task = TaskPool::task(queue).await.expect("failed to acquire a task from the pool");
307 task.run(f).await
308 }
309
310 pub fn run_sync<F, R>(&self, f: F) -> Result<R, CudaError>
311 where
312 F: FnOnce(TaskScope) -> R,
313 {
314 let queue = self.inner.clone();
315 let task = TaskPool::try_task(queue).expect("failed to acquire a task from the pool");
316 let task = Arc::new(task);
317 task.run_sync(f)
318 }
319}
320
321#[derive(Debug)]
322pub struct TaskScope(Weak<OwnedTask>);
323
324impl Clone for TaskScope {
325 fn clone(&self) -> Self {
326 TaskScope(self.0.clone())
327 }
328}
329
330impl Deref for TaskScope {
331 type Target = Task;
332
333 #[inline]
334 fn deref(&self) -> &Self::Target {
335 unsafe { &(*self.0.as_ptr()).inner }
336 }
337}
338
339unsafe impl Backend for TaskScope {}
340
341unsafe extern "C" fn sleep(ptr: *mut c_void) {
342 let time = unsafe { Box::from_raw(ptr as *mut Duration) };
343 std::thread::sleep(*time);
344}
345
346unsafe extern "C" fn sync_host(ptr: *mut c_void) {
347 let tx = unsafe { Box::from_raw(ptr as *mut oneshot::Sender<bool>) };
348 tx.send(true).unwrap();
349}
350
351impl TaskScope {
352 #[inline]
362 pub fn alloc<T>(&self, capacity: usize) -> Buffer<T, Self> {
363 Buffer::with_capacity_in(capacity, self.clone())
364 }
365
366 #[inline]
368 pub fn try_alloc<T>(
369 &self,
370 capacity: usize,
371 ) -> Result<Buffer<T, Self>, slop_alloc::TryReserveError> {
372 Buffer::try_with_capacity_in(capacity, self.clone())
373 }
374
375 #[inline]
383 pub unsafe fn launch_host_fn(
384 &self,
385 host_fn: unsafe extern "C" fn(*mut c_void),
386 data: *mut c_void,
387 ) -> Result<(), CudaError> {
388 self.launch_host_fn_uncheked(Some(host_fn), data)
389 }
390
391 pub unsafe fn launch_kernel(
399 &self,
400 kernel: KernelPtr,
401 grid_dim: impl Into<Dim3>,
402 block_dim: impl Into<Dim3>,
403 args: &[*mut c_void],
404 shared_mem: usize,
405 ) -> Result<(), CudaError> {
406 self.stream().launch_kernel(kernel, grid_dim, block_dim, args, shared_mem)
407 }
408
409 pub fn sleep(&self, time: Duration) {
414 let time_ptr = Box::into_raw(Box::new(time));
415 unsafe {
416 self.launch_host_fn(sleep, time_ptr as *mut c_void).unwrap();
417 }
418 }
419
420 pub unsafe fn copy<T: DeviceCopy>(
425 &self,
426 dst: &mut Slice<T, Self>,
427 src: &Slice<T, Self>,
428 ) -> Result<(), CopyError> {
429 dst.copy_from_slice(src, self)
430 }
431
432 pub async fn synchronize(&self) -> Result<(), CudaError> {
437 let (tx, mut rx) = oneshot::channel::<bool>();
438 let mut interval = tokio::time::interval(Duration::from_millis(INTERVAL_MS));
439
440 let tx = Box::new(tx);
442 let tx_ptr = Box::into_raw(tx);
443 unsafe {
444 self.launch_host_fn(sync_host, tx_ptr as *mut c_void)?;
445 }
446
447 loop {
450 tokio::select! {
451 _ = interval.tick() => {
452 match unsafe { self.stream().query() } {
453 Ok(()) => {break;}
454 Err(CudaError::NotReady) => {}
455 Err(e) => {
456 return Err(e);
457 }
458
459 }
460 }
461 _ = &mut rx => {
462 break;
463 }
464 }
465 }
466
467 Ok(())
468 }
469
470 #[inline]
474 unsafe fn join(self, parent: &TaskScope) -> Result<(), CudaError> {
475 parent.stream.wait_unchecked(&self.end_event)
476 }
477
478 #[inline]
480 pub fn into_device<T: IntoDevice>(&self, data: T) -> Result<T::Output, CopyError> {
481 T::into_device_in(data, self)
482 }
483
484 #[inline]
485 pub fn to_device<T: ToDevice>(&self, data: &T) -> Result<T::Output, CopyError> {
486 T::to_device_in(data, self)
487 }
488
489 #[inline]
494 pub fn synchronize_blocking(&self) -> Result<(), CudaError> {
495 unsafe { self.stream_synchronize() }
497 }
498
499 pub unsafe fn handle(&self) -> CudaStreamHandle {
501 self.stream.0
502 }
503
504 pub fn owner(&self) -> TaskPool {
505 TaskPool { inner: self.0.upgrade().unwrap().inner.owner().clone() }
506 }
507
508 fn owner_queue(&self) -> Arc<WorkerQueue<Task>> {
509 self.0.upgrade().unwrap().inner.owner().clone()
510 }
511
512 pub fn spawn<F, Fut>(&self, f: F) -> SpawnHandle<Fut::Output>
517 where
518 F: FnOnce(TaskScope) -> Fut + Send + 'static,
519 Fut: Future + Send + 'static,
520 Fut::Output: CudaSend + 'static,
521 {
522 let parent = self.clone();
523 let handle = tokio::spawn(async move { parent.run_in_place(f).await });
524 SpawnHandle { handle }
525 }
526
527 pub async fn run_in_place<F, Fut>(&self, f: F) -> Result<Fut::Output, CudaError>
532 where
533 F: FnOnce(TaskScope) -> Fut,
534 Fut: Future,
535 Fut::Output: CudaSend,
536 {
537 let parent = self.clone();
538 let task = TaskPool::task(parent.owner_queue()).await.unwrap();
539 unsafe {
540 parent.stream.record_unchecked(&task.inner.end_event)?;
544 task.inner.stream.wait_unchecked(&task.inner.end_event)?
545 };
546 let handle = task.run(f).await;
547 handle.join(&parent)
548 }
549}
550
551impl StreamRef for TaskScope {
552 #[inline]
553 unsafe fn stream(&self) -> &CudaStream {
554 &self.stream
555 }
556}
557
558#[derive(Debug)]
559pub struct Task {
560 pub(crate) owner_id: usize,
561 pub(crate) id: usize,
562 pub(crate) stream: CudaStream,
563 end_event: CudaEvent,
564}
565
566impl PartialEq for Task {
567 fn eq(&self, other: &Self) -> bool {
568 self.owner_id == other.owner_id && self.id == other.id
569 }
570}
571
572impl Eq for Task {}
573
574impl StreamRef for Task {
575 #[inline]
576 unsafe fn stream(&self) -> &CudaStream {
577 &self.stream
578 }
579}
580
581impl Drop for Task {
582 fn drop(&mut self) {
583 unsafe {
584 self.end_event.query().expect("attempting to drop a task that did not finish");
585 self.stream.query().expect("attempting to drop a task that did not finish");
586 }
587 }
588}
589
590impl IntoFuture for Task {
591 type Output = Result<(), CudaError>;
592 type IntoFuture = StreamCallbackFuture<Self>;
593
594 fn into_future(self) -> Self::IntoFuture {
595 StreamCallbackFuture::new(self)
596 }
597}
598
599unsafe impl Allocator for TaskScope {
600 #[inline]
601 unsafe fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
602 self.stream.allocate(layout)
603 }
604
605 #[inline]
606 unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
607 self.stream.deallocate(ptr, layout)
609 }
610}
611
612impl DeviceMemory for TaskScope {
613 #[inline]
614 unsafe fn copy_nonoverlapping(
615 &self,
616 src: *const u8,
617 dst: *mut u8,
618 size: usize,
619 direction: CopyDirection,
620 ) -> Result<(), CopyError> {
621 self.stream.copy_nonoverlapping(src, dst, size, direction)
622 }
623
624 #[inline]
625 unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
626 self.stream.write_bytes(dst, value, size)
627 }
628}
629
630impl OwnedTask {
673 fn is_finished(&self) -> Result<bool, CudaError> {
674 self.inner.end_event.query().map(|()| true).or_else(|e| match e {
675 CudaError::NotReady => Ok(false),
676 e => Err(e),
677 })
678 }
679
680 async fn run<F, Fut, R>(self, f: F) -> TaskHandle<R>
681 where
682 F: FnOnce(TaskScope) -> Fut,
683 Fut: Future<Output = R>,
684 {
685 let strong_ptr = Arc::new(self);
686 let scope = TaskScope(Arc::downgrade(&strong_ptr));
687 let value = f(scope.clone()).await;
688 unsafe { scope.stream.record_unchecked(&scope.end_event).unwrap() };
689 TaskHandle { task: strong_ptr, scope, value }
690 }
691
692 fn run_sync<F, R>(self: Arc<Self>, f: F) -> Result<R, CudaError>
693 where
694 F: FnOnce(TaskScope) -> R,
695 {
696 let scope = TaskScope(Arc::downgrade(&self));
697 let output = f(scope.clone());
698 unsafe {
699 scope.stream.record_unchecked(&scope.end_event)?;
700 scope.end_event.synchronize()?;
701 };
702 Ok(output)
703 }
704}
705
706impl StreamRef for OwnedTask {
707 #[inline]
708 unsafe fn stream(&self) -> &CudaStream {
709 self.inner.stream()
710 }
711}
712
713impl IntoFuture for TaskScope {
714 type Output = Result<(), CudaError>;
715 type IntoFuture = StreamCallbackFuture<Self>;
716
717 fn into_future(self) -> Self::IntoFuture {
718 StreamCallbackFuture::new(self)
719 }
720}
721
722pub struct TaskHandle<T> {
723 task: Arc<OwnedTask>,
724 scope: TaskScope,
725 value: T,
726}
727
728impl<T> TaskHandle<T> {
729 pub fn join(self, parent: &TaskScope) -> Result<T, CudaError>
730 where
731 T: CudaSend,
732 {
733 unsafe {
736 self.scope.join(parent)?;
737 let value = self.value.send_to_scope(parent);
738 Ok(value)
740 }
741 }
742
743 pub fn is_finished(&self) -> Result<bool, CudaError> {
744 self.task.is_finished()
745 }
746}
747
748#[pin_project]
749pub struct StreamHandleFuture<T> {
750 #[pin]
751 callback: StreamCallbackFuture<Arc<OwnedTask>>,
752 value: MaybeUninit<T>,
753}
754
755impl<T> Future for StreamHandleFuture<T> {
756 type Output = Result<T, CudaError>;
757
758 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
759 let this = self.project();
760 this.callback.poll(cx).map(|res| {
761 res.map(|_| {
762 let uinit = MaybeUninit::uninit();
763 let ret = std::mem::replace(this.value, uinit);
764 unsafe { ret.assume_init() }
767 })
768 })
769 }
770}
771
772impl<T> IntoFuture for TaskHandle<T> {
773 type Output = Result<T, CudaError>;
774 type IntoFuture = StreamHandleFuture<T>;
775
776 #[inline]
777 fn into_future(self) -> Self::IntoFuture {
778 StreamHandleFuture {
779 callback: StreamCallbackFuture::new(self.task),
780 value: MaybeUninit::new(self.value),
781 }
782 }
783}
784
785#[cfg(test)]
786mod tests {
787
788 use crate::TaskPoolBuilder;
789
790 #[tokio::test]
791 async fn test_global_task_pool() {
792 crate::spawn(|_| async {}).await.unwrap();
793 }
794
795 #[tokio::test]
796 async fn test_local_pool() {
797 let num_workers = 10;
798 let num_callers = 100;
799 let pool = TaskPoolBuilder::new().num_tasks(num_workers).build().unwrap();
800
801 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
802 let mut handles = Vec::new();
803 for _ in 0..num_callers {
804 let pool = pool.clone();
805 let tx = tx.clone();
806 let handle = pool.spawn(|_| async move {
807 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
808 tx.send(true).unwrap();
809 });
810
811 handles.push(handle);
812 }
813 drop(tx);
814
815 let mut count = 0;
816 while let Some(flag) = rx.recv().await {
817 assert!(flag);
818 count += 1;
819 }
820
821 for handle in handles {
822 handle.await.unwrap();
823 }
824
825 assert_eq!(count, num_callers);
826 }
827}