1use std::ptr::NonNull;
4use std::sync::{
5 atomic::{AtomicBool, Ordering},
6 Arc,
7};
8
9use crossbeam_queue::ArrayQueue;
10use cudarc::driver::{
11 sys::{CUevent, CUevent_flags, CUresult, CUstream, CUstream_flags, CUstream_st},
12 CudaContext,
13};
14use vyre_driver::{backend::private, BackendError, PendingDispatch};
15
16use crate::backend::telemetry::CudaTelemetry;
17use crate::backend::{cuda_check, DispatchAllocations, HostTransferAllocations, ResidentUseGuard};
18
19#[derive(Debug)]
21pub(crate) struct CudaStream {
22 raw: CUstream,
23}
24
25unsafe impl Send for CudaStream {}
26unsafe impl Sync for CudaStream {}
27
28impl CudaStream {
29 pub(crate) fn non_blocking() -> Result<Self, BackendError> {
31 let raw = create_non_blocking_raw_stream("cuStreamCreate")?;
32 Ok(Self { raw: raw.as_ptr() })
33 }
34
35 #[must_use]
37 pub(crate) fn raw(&self) -> CUstream {
38 self.raw
39 }
40
41 pub(crate) fn synchronize(&self) -> Result<(), BackendError> {
43 synchronize_raw_stream(self.raw, "cuStreamSynchronize")
44 }
45}
46
47pub(crate) fn create_non_blocking_raw_stream(
50 label: &'static str,
51) -> Result<NonNull<CUstream_st>, BackendError> {
52 let mut raw = std::ptr::null_mut();
53 unsafe {
56 cuda_check(
57 cudarc::driver::sys::cuStreamCreate(
58 &mut raw,
59 CUstream_flags::CU_STREAM_NON_BLOCKING as u32,
60 ),
61 label,
62 )?;
63 }
64 NonNull::new(raw).ok_or_else(|| BackendError::DispatchFailed {
65 code: None,
66 message: format!(
67 "{label} returned a null stream after reporting success. Fix: update the CUDA driver or disable the CUDA path using this stream."
68 ),
69 })
70}
71
72pub(crate) fn destroy_raw_stream(stream: CUstream, label: &'static str) {
73 if stream.is_null() {
74 return;
75 }
76 unsafe {
79 let result = cudarc::driver::sys::cuStreamDestroy_v2(stream);
80 if result != CUresult::CUDA_SUCCESS {
81 tracing::error!(
82 "Fix: {label} failed during CUDA stream drop with {result:?}; ensure pending work is synchronized before dropping dispatch resources."
83 );
84 }
85 }
86}
87
88pub(crate) fn query_raw_stream_ready(
91 stream: CUstream,
92 label: &'static str,
93) -> Result<bool, BackendError> {
94 if stream.is_null() {
95 return Err(BackendError::InvalidProgram {
96 fix: format!(
97 "Fix: {label} received a null CUDA stream; use a backend-owned non-blocking stream instead of querying CUDA's legacy default stream."
98 ),
99 });
100 }
101 let result = unsafe { cudarc::driver::sys::cuStreamQuery(stream) };
104 match result {
105 CUresult::CUDA_SUCCESS => Ok(true),
106 CUresult::CUDA_ERROR_NOT_READY => Ok(false),
107 other => cuda_check(other, label).map(|()| true),
108 }
109}
110
111pub(crate) fn synchronize_raw_stream(
114 stream: CUstream,
115 label: &'static str,
116) -> Result<(), BackendError> {
117 if stream.is_null() {
118 return Err(BackendError::InvalidProgram {
119 fix: format!(
120 "Fix: {label} received a null CUDA stream; use a backend-owned non-blocking stream instead of the legacy default stream."
121 ),
122 });
123 }
124 unsafe { cuda_check(cudarc::driver::sys::cuStreamSynchronize(stream), label) }
127}
128
129impl Drop for CudaStream {
130 fn drop(&mut self) {
131 destroy_raw_stream(self.raw, "cuStreamDestroy_v2");
132 }
133}
134
135#[derive(Debug)]
137pub(crate) struct CudaEvent {
138 raw: CUevent,
139}
140
141unsafe impl Send for CudaEvent {}
142unsafe impl Sync for CudaEvent {}
143
144impl CudaEvent {
145 pub(crate) fn completion() -> Result<Self, BackendError> {
147 let raw = create_raw_event(
148 CUevent_flags::CU_EVENT_DISABLE_TIMING as u32,
149 "cuEventCreate",
150 )?;
151 Ok(Self { raw })
152 }
153
154 pub(crate) fn timing() -> Result<Self, BackendError> {
156 let raw = create_raw_event(0, "cuEventCreate")?;
157 Ok(Self { raw })
158 }
159
160 pub(crate) fn record(&self, stream: CUstream) -> Result<(), BackendError> {
162 if self.raw.is_null() {
163 return Err(BackendError::InvalidProgram {
164 fix: "Fix: cuEventRecord received a null CUDA event; acquire a backend-owned event before recording completion.".to_string(),
165 });
166 }
167 if stream.is_null() {
168 return Err(BackendError::InvalidProgram {
169 fix: "Fix: cuEventRecord received a null CUDA stream; record events on a backend-owned non-blocking stream instead of CUDA's legacy default stream.".to_string(),
170 });
171 }
172 unsafe {
175 cuda_check(
176 cudarc::driver::sys::cuEventRecord(self.raw, stream),
177 "cuEventRecord",
178 )
179 }
180 }
181
182 pub(crate) fn query_ready(&self) -> Result<bool, BackendError> {
184 if self.raw.is_null() {
185 return Err(BackendError::InvalidProgram {
186 fix: "Fix: cuEventQuery received a null CUDA event; pending dispatches must own a recorded completion event before readiness polling.".to_string(),
187 });
188 }
189 let result = unsafe { cudarc::driver::sys::cuEventQuery(self.raw) };
192 match result {
193 CUresult::CUDA_SUCCESS => Ok(true),
194 CUresult::CUDA_ERROR_NOT_READY => Ok(false),
195 other => cuda_check(other, "cuEventQuery").map(|()| true),
196 }
197 }
198
199 pub(crate) fn synchronize(&self) -> Result<(), BackendError> {
201 if self.raw.is_null() {
202 return Err(BackendError::InvalidProgram {
203 fix: "Fix: cuEventSynchronize received a null CUDA event; pending dispatches must own a recorded completion event before synchronization.".to_string(),
204 });
205 }
206 unsafe {
209 cuda_check(
210 cudarc::driver::sys::cuEventSynchronize(self.raw),
211 "cuEventSynchronize",
212 )
213 }
214 }
215
216 pub(crate) fn elapsed_time_ns(&self, end: &CudaEvent) -> Result<u64, BackendError> {
218 if self.raw.is_null() || end.raw.is_null() {
219 return Err(BackendError::InvalidProgram {
220 fix: "Fix: cuEventElapsedTime received a null CUDA timing event; record both timing events before reading elapsed time.".to_string(),
221 });
222 }
223 let mut elapsed_ms = 0.0f32;
224 unsafe {
227 cuda_check(
228 cudarc::driver::sys::cuEventElapsedTime(
229 (&mut elapsed_ms) as *mut f32,
230 self.raw,
231 end.raw,
232 ),
233 "cuEventElapsedTime",
234 )?;
235 }
236 let elapsed_ns = f64::from(elapsed_ms) * 1_000_000.0;
237 if !elapsed_ns.is_finite() || elapsed_ns < 0.0 || elapsed_ns > u64::MAX as f64 {
238 return Err(BackendError::InvalidProgram {
239 fix: format!(
240 "Fix: CUDA event elapsed time {elapsed_ms} ms cannot fit u64 nanoseconds; inspect CUDA event timing and split the dispatch before telemetry overflows."
241 ),
242 });
243 }
244 crate::numeric::CUDA_NUMERIC.rounded_f64_to_u64(elapsed_ns, "event elapsed nanoseconds")
245 }
246}
247
248impl Drop for CudaEvent {
249 fn drop(&mut self) {
250 destroy_raw_event(self.raw, "cuEventDestroy_v2");
251 }
252}
253
254fn create_raw_event(flags: u32, label: &'static str) -> Result<CUevent, BackendError> {
255 let mut raw = std::ptr::null_mut();
256 unsafe {
259 cuda_check(cudarc::driver::sys::cuEventCreate(&mut raw, flags), label)?;
260 }
261 if raw.is_null() {
262 return Err(BackendError::DispatchFailed {
263 code: None,
264 message: format!(
265 "{label} returned a null event after reporting success. Fix: update the CUDA driver or disable event-backed CUDA dispatch for this device."
266 ),
267 });
268 }
269 Ok(raw)
270}
271
272fn destroy_raw_event(event: CUevent, label: &'static str) {
273 if event.is_null() {
274 return;
275 }
276 unsafe {
279 let result = cudarc::driver::sys::cuEventDestroy_v2(event);
280 if result != CUresult::CUDA_SUCCESS {
281 tracing::error!(
282 "Fix: {label} failed during CUDA event drop with {result:?}; ensure pending work is synchronized before dropping dispatch resources."
283 );
284 }
285 }
286}
287
288#[derive(Debug)]
290pub(crate) struct CudaLaunchResourcePool {
291 streams: ArrayQueue<CudaStream>,
292 events: ArrayQueue<CudaEvent>,
293 timing_events: ArrayQueue<CudaEvent>,
294}
295
296#[derive(Clone, Copy, Debug, Eq, PartialEq)]
298pub struct CudaLaunchResourceCounts {
299 pub streams: usize,
301 pub completion_events: usize,
303 pub timing_events: usize,
305}
306
307#[derive(Debug)]
309pub(crate) struct CudaLaunchResourceLease {
310 pool: Arc<CudaLaunchResourcePool>,
311 stream: Option<CudaStream>,
312 timing_events: Option<(CudaEvent, CudaEvent)>,
313}
314
315#[derive(Debug)]
317pub(crate) struct CudaTimingEventPairLease {
318 pool: Arc<CudaLaunchResourcePool>,
319 timing_events: Option<(CudaEvent, CudaEvent)>,
320}
321
322impl CudaTimingEventPairLease {
323 pub(crate) fn acquire(pool: Arc<CudaLaunchResourcePool>) -> Result<Self, BackendError> {
324 let timing_events = pool.acquire_timing_event_pair()?;
325 Ok(Self {
326 pool,
327 timing_events: Some(timing_events),
328 })
329 }
330
331 pub(crate) fn events(&self) -> Result<&(CudaEvent, CudaEvent), BackendError> {
332 self.timing_events
333 .as_ref()
334 .ok_or_else(|| BackendError::InvalidProgram {
335 fix: "Fix: CUDA timing event pair lease was already consumed; acquire a fresh timing lease before recording graph replay events.".to_string(),
336 })
337 }
338}
339
340impl Drop for CudaTimingEventPairLease {
341 fn drop(&mut self) {
342 if let Some((start, end)) = self.timing_events.take() {
343 self.pool.release_timing_event(start);
344 self.pool.release_timing_event(end);
345 }
346 }
347}
348
349impl CudaLaunchResourceLease {
350 pub(crate) fn acquire(
351 pool: Arc<CudaLaunchResourcePool>,
352 capture_timing: bool,
353 ) -> Result<Self, BackendError> {
354 let stream = pool.acquire_stream()?;
355 let timing_events = if capture_timing {
356 match pool.acquire_timing_event_pair() {
357 Ok(pair) => Some(pair),
358 Err(error) => {
359 pool.release_stream(stream);
360 return Err(error);
361 }
362 }
363 } else {
364 None
365 };
366 Ok(Self {
367 pool,
368 stream: Some(stream),
369 timing_events,
370 })
371 }
372
373 pub(crate) fn stream_raw(&self) -> Result<CUstream, BackendError> {
374 self.stream
375 .as_ref()
376 .map(CudaStream::raw)
377 .ok_or_else(|| BackendError::InvalidProgram {
378 fix: "Fix: CUDA launch resource lease stream was already consumed; acquire a fresh launch-resource lease before enqueueing CUDA work.".to_string(),
379 })
380 }
381
382 pub(crate) fn timing_events(&self) -> Result<Option<&(CudaEvent, CudaEvent)>, BackendError> {
383 if self.stream.is_none() {
384 return Err(BackendError::InvalidProgram {
385 fix: "Fix: CUDA launch resource lease timing events were queried after the stream was consumed; query timing events before transferring the lease into a pending dispatch.".to_string(),
386 });
387 }
388 Ok(self.timing_events.as_ref())
389 }
390
391 pub(crate) fn into_parts(
392 mut self,
393 ) -> Result<(CudaStream, Option<(CudaEvent, CudaEvent)>), BackendError> {
394 let stream = self.stream.take().ok_or_else(|| BackendError::InvalidProgram {
395 fix: "Fix: CUDA launch resource lease stream was already consumed; pending dispatch ownership cannot be built twice from the same lease.".to_string(),
396 })?;
397 let timing_events = self.timing_events.take();
398 Ok((stream, timing_events))
399 }
400}
401
402impl Drop for CudaLaunchResourceLease {
403 fn drop(&mut self) {
404 if let Some((start, end)) = self.timing_events.take() {
405 self.pool.release_timing_event(start);
406 self.pool.release_timing_event(end);
407 }
408 if let Some(stream) = self.stream.take() {
409 self.pool.release_stream(stream);
410 }
411 }
412}
413
414impl CudaLaunchResourcePool {
415 pub(crate) fn new(max_cached: usize) -> Self {
416 let max_cached = max_cached.max(1);
417 Self {
418 streams: ArrayQueue::new(max_cached),
419 events: ArrayQueue::new(max_cached),
420 timing_events: ArrayQueue::new(max_cached),
421 }
422 }
423
424 pub(crate) fn acquire_stream(&self) -> Result<CudaStream, BackendError> {
425 if let Some(stream) = self.streams.pop() {
426 return Ok(stream);
427 }
428 CudaStream::non_blocking()
429 }
430
431 pub(crate) fn acquire_event(&self) -> Result<CudaEvent, BackendError> {
432 if let Some(event) = self.events.pop() {
433 return Ok(event);
434 }
435 CudaEvent::completion()
436 }
437
438 pub(crate) fn acquire_timing_event(&self) -> Result<CudaEvent, BackendError> {
439 if let Some(event) = self.timing_events.pop() {
440 return Ok(event);
441 }
442 CudaEvent::timing()
443 }
444
445 pub(crate) fn acquire_timing_event_pair(&self) -> Result<(CudaEvent, CudaEvent), BackendError> {
446 let start = self.acquire_timing_event()?;
447 match self.acquire_timing_event() {
448 Ok(end) => Ok((start, end)),
449 Err(error) => {
450 self.release_timing_event(start);
451 Err(error)
452 }
453 }
454 }
455
456 pub(crate) fn release_stream(&self, stream: CudaStream) {
457 if let Err(stream) = self.streams.push(stream) {
458 drop(stream);
459 }
460 }
461
462 pub(crate) fn release_event(&self, event: CudaEvent) {
463 if let Err(event) = self.events.push(event) {
464 drop(event);
465 }
466 }
467
468 pub(crate) fn release_timing_event(&self, event: CudaEvent) {
469 if let Err(event) = self.timing_events.push(event) {
470 drop(event);
471 }
472 }
473
474 pub(crate) fn cached_counts(&self) -> Result<(usize, usize), BackendError> {
475 Ok((self.streams.len(), self.events.len()))
476 }
477
478 pub(crate) fn cached_counts_detailed(&self) -> Result<CudaLaunchResourceCounts, BackendError> {
479 Ok(CudaLaunchResourceCounts {
480 streams: self.streams.len(),
481 completion_events: self.events.len(),
482 timing_events: self.timing_events.len(),
483 })
484 }
485
486 pub(crate) fn clear(&self) -> Result<(), BackendError> {
487 while self.streams.pop().is_some() {}
488 while self.events.pop().is_some() {}
489 while self.timing_events.pop().is_some() {}
490 Ok(())
491 }
492}
493
494#[derive(Debug)]
496pub(crate) struct CudaPendingDispatch {
497 ctx: Arc<CudaContext>,
498 pool: Arc<CudaLaunchResourcePool>,
499 event: Option<CudaEvent>,
500 stream: Option<CudaStream>,
501 allocations: Option<DispatchAllocations>,
502 resident_use: Option<ResidentUseGuard>,
503 host_transfers: Option<HostTransferAllocations>,
504 outputs: Vec<Vec<u8>>,
505 timing_start: Option<CudaEvent>,
506 timing_end: Option<CudaEvent>,
507 ready_device_ns: Option<u64>,
508 telemetry: Arc<CudaTelemetry>,
509 completed: AtomicBool,
510}
511
512
513impl CudaPendingDispatch {
514 pub(crate) fn new_ready(
516 ctx: Arc<CudaContext>,
517 pool: Arc<CudaLaunchResourcePool>,
518 outputs: Vec<Vec<u8>>,
519 telemetry: Arc<CudaTelemetry>,
520 ) -> Self {
521 Self {
522 ctx,
523 pool,
524 event: None,
525 stream: None,
526 allocations: None,
527 resident_use: None,
528 host_transfers: None,
529 outputs,
530 timing_start: None,
531 timing_end: None,
532 ready_device_ns: None,
533 telemetry,
534 completed: AtomicBool::new(true),
535 }
536 }
537
538 pub(crate) fn new_ready_timed(
540 ctx: Arc<CudaContext>,
541 pool: Arc<CudaLaunchResourcePool>,
542 outputs: Vec<Vec<u8>>,
543 device_ns: Option<u64>,
544 telemetry: Arc<CudaTelemetry>,
545 ) -> Self {
546 Self {
547 ctx,
548 pool,
549 event: None,
550 stream: None,
551 allocations: None,
552 resident_use: None,
553 host_transfers: None,
554 outputs,
555 timing_start: None,
556 timing_end: None,
557 ready_device_ns: device_ns,
558 telemetry,
559 completed: AtomicBool::new(true),
560 }
561 }
562
563 #[allow(clippy::too_many_arguments)]
568 pub(crate) fn new_resident_batch_pending(
569 ctx: Arc<CudaContext>,
570 pool: Arc<CudaLaunchResourcePool>,
571 event: CudaEvent,
572 stream: CudaStream,
573 allocations: DispatchAllocations,
574 resident_use: ResidentUseGuard,
575 host_transfers: HostTransferAllocations,
576 telemetry: Arc<CudaTelemetry>,
577 ) -> Self {
578 Self::new(
579 ctx,
580 pool,
581 event,
582 stream,
583 allocations,
584 Some(resident_use),
585 Some(host_transfers),
586 Vec::new(),
587 telemetry,
588 )
589 }
590
591 #[allow(clippy::too_many_arguments)]
593 pub(crate) fn new(
594 ctx: Arc<CudaContext>,
595 pool: Arc<CudaLaunchResourcePool>,
596 event: CudaEvent,
597 stream: CudaStream,
598 allocations: DispatchAllocations,
599 resident_use: Option<ResidentUseGuard>,
600 host_transfers: Option<HostTransferAllocations>,
601 outputs: Vec<Vec<u8>>,
602 telemetry: Arc<CudaTelemetry>,
603 ) -> Self {
604 Self {
605 ctx,
606 pool,
607 event: Some(event),
608 stream: Some(stream),
609 allocations: Some(allocations),
610 resident_use,
611 host_transfers,
612 outputs,
613 timing_start: None,
614 timing_end: None,
615 ready_device_ns: None,
616 telemetry,
617 completed: AtomicBool::new(false),
618 }
619 }
620
621 #[allow(clippy::too_many_arguments)]
623 pub(crate) fn new_with_timing(
624 ctx: Arc<CudaContext>,
625 pool: Arc<CudaLaunchResourcePool>,
626 event: CudaEvent,
627 stream: CudaStream,
628 allocations: DispatchAllocations,
629 resident_use: Option<ResidentUseGuard>,
630 host_transfers: Option<HostTransferAllocations>,
631 outputs: Vec<Vec<u8>>,
632 timing_start: CudaEvent,
633 timing_end: CudaEvent,
634 telemetry: Arc<CudaTelemetry>,
635 ) -> Self {
636 Self {
637 ctx,
638 pool,
639 event: Some(event),
640 stream: Some(stream),
641 allocations: Some(allocations),
642 resident_use,
643 host_transfers,
644 outputs,
645 timing_start: Some(timing_start),
646 timing_end: Some(timing_end),
647 ready_device_ns: None,
648 telemetry,
649 completed: AtomicBool::new(false),
650 }
651 }
652
653 fn bind_context(&self) -> Result<(), BackendError> {
654 self.ctx
655 .bind_to_thread()
656 .map_err(|e| BackendError::DispatchFailed {
657 code: None,
658 message: format!("CUDA context bind failed: {e}"),
659 })
660 }
661
662 fn synchronize(&self) -> Result<(), BackendError> {
663 if self.completed.load(Ordering::Acquire) {
664 return Ok(());
665 }
666 self.bind_context()?;
667 let event = self
668 .event
669 .as_ref()
670 .ok_or_else(|| BackendError::DispatchFailed {
671 code: None,
672 message: "CUDA pending dispatch completion event was already released".to_string(),
673 })?;
674 event.synchronize()?;
675 self.telemetry.record_sync_point();
676 self.completed.store(true, Ordering::Release);
677 Ok(())
678 }
679
680 fn release_launch_resources(&mut self) {
681 if let Some(event) = self.event.take() {
682 self.pool.release_event(event);
683 }
684 if let Some(event) = self.timing_start.take() {
685 self.pool.release_timing_event(event);
686 }
687 if let Some(event) = self.timing_end.take() {
688 self.pool.release_timing_event(event);
689 }
690 if let Some(stream) = self.stream.take() {
691 self.pool.release_stream(stream);
692 }
693 }
694
695 pub(crate) fn await_timed_result(
697 mut self,
698 ) -> Result<(Vec<Vec<u8>>, Option<u64>), BackendError> {
699 self.synchronize()?;
700 let device_ns = match self.ready_device_ns.take() {
701 Some(device_ns) => Some(device_ns),
702 None => match (self.timing_start.as_ref(), self.timing_end.as_ref()) {
703 (Some(start), Some(end)) => Some(start.elapsed_time_ns(end)?),
704 _ => None,
705 },
706 };
707 self.release_launch_resources();
708 self.allocations.take();
709 self.resident_use.take();
710 let outputs = self.collect_outputs()?;
711 self.host_transfers.take();
712 Ok((outputs, device_ns))
713 }
714
715 fn collect_outputs(&mut self) -> Result<Vec<Vec<u8>>, BackendError> {
716 if let Some(transfers) = self.host_transfers.as_ref() {
717 let mut outputs = std::mem::take(&mut self.outputs);
718 transfers.collect_outputs_into(&mut outputs)?;
719 Ok(outputs)
720 } else {
721 Ok(std::mem::take(&mut self.outputs))
722 }
723 }
724
725 fn collect_outputs_into(&mut self, outputs: &mut Vec<Vec<u8>>) -> Result<(), BackendError> {
726 if let Some(transfers) = self.host_transfers.as_ref() {
727 transfers.collect_outputs_into(outputs)?;
728 } else {
729 vyre_driver::replace_output_buffers_preserving_slots(
730 std::mem::take(&mut self.outputs),
731 outputs,
732 );
733 }
734 Ok(())
735 }
736}
737
738impl private::Sealed for CudaPendingDispatch {}
739
740impl PendingDispatch for CudaPendingDispatch {
741 fn is_ready(&self) -> bool {
742 if self.completed.load(Ordering::Acquire) {
743 return true;
744 }
745 if self.bind_context().is_err() {
746 return false;
747 }
748 let Some(event) = self.event.as_ref() else {
749 return true;
750 };
751 let ready = match event.query_ready() {
752 Ok(ready) => ready,
753 Err(error) => {
754 tracing::error!(
755 "Fix: CUDA pending dispatch readiness query failed: {error}. Await the dispatch to surface synchronization failure details."
756 );
757 false
758 }
759 };
760 if ready {
761 self.completed.store(true, Ordering::Release);
762 }
763 ready
764 }
765
766 fn await_result(mut self: Box<Self>) -> Result<Vec<Vec<u8>>, BackendError> {
767 self.synchronize()?;
768 self.release_launch_resources();
769 self.allocations.take();
770 self.resident_use.take();
771 let outputs = self.collect_outputs()?;
772 self.host_transfers.take();
773 Ok(outputs)
774 }
775
776 fn await_result_into(
777 mut self: Box<Self>,
778 outputs: &mut Vec<Vec<u8>>,
779 ) -> Result<(), BackendError> {
780 self.synchronize()?;
781 self.release_launch_resources();
782 self.allocations.take();
783 self.resident_use.take();
784 self.collect_outputs_into(outputs)?;
785 self.host_transfers.take();
786 Ok(())
787 }
788}
789
790#[cfg(test)]
791mod tests {
792 use super::{query_raw_stream_ready, synchronize_raw_stream, CudaLaunchResourcePool};
793
794 #[test]
795 fn launch_resource_leases_do_not_panic_on_consumed_state() {
796 let source = include_str!("stream.rs");
797 assert!(
798 !source.contains(concat!(".expect", "(\"Fix: CUDA launch resource lease stream was already consumed")),
799 "Fix: CUDA launch resource leases must return typed backend errors when consumed twice, not panic."
800 );
801 assert!(
802 !source.contains(concat!(".expect", "(\"Fix: CUDA timing event pair lease was already consumed")),
803 "Fix: CUDA graph replay timing leases must return typed backend errors when consumed twice, not panic."
804 );
805 }
806
807 #[test]
808 fn launch_resource_counts_include_timing_events() {
809 let pool = CudaLaunchResourcePool::new(8);
810 let counts = pool
811 .cached_counts_detailed()
812 .expect("Fix: empty launch resource pool counts should be readable");
813
814 assert_eq!(counts.streams, 0);
815 assert_eq!(counts.completion_events, 0);
816 assert_eq!(counts.timing_events, 0);
817
818 let source = include_str!("stream.rs");
819 assert!(
820 source.contains("pub struct CudaLaunchResourceCounts")
821 && source.contains("pub timing_events: usize")
822 && source.contains("cached_counts_detailed"),
823 "Fix: CUDA launch-resource telemetry must expose timing-event cache pressure, not just streams and completion events."
824 );
825 }
826
827 #[test]
828 fn raw_stream_sync_rejects_null_default_stream() {
829 let err = synchronize_raw_stream(std::ptr::null_mut(), "unit sync")
830 .expect_err("Fix: raw stream sync must reject the legacy null stream");
831 assert!(
832 err.to_string().contains("null CUDA stream"),
833 "raw sync diagnostic must explain the default-stream hazard: {err}"
834 );
835 }
836
837 #[test]
838 fn raw_stream_query_rejects_null_default_stream() {
839 let err = query_raw_stream_ready(std::ptr::null_mut(), "unit query")
840 .expect_err("Fix: raw stream query must reject the legacy null stream");
841 assert!(
842 err.to_string().contains("null CUDA stream"),
843 "raw query diagnostic must explain the default-stream hazard: {err}"
844 );
845 }
846
847 #[test]
848 fn event_record_rejects_null_event_before_ffi() {
849 let event = super::CudaEvent {
850 raw: std::ptr::null_mut(),
851 };
852 let err = event
853 .record(std::ptr::null_mut())
854 .expect_err("Fix: event recording must reject invalid event handles before FFI");
855 assert!(
856 err.to_string().contains("null CUDA event"),
857 "event record diagnostic must explain the null-event hazard: {err}"
858 );
859 }
860
861 #[test]
862 fn event_record_rejects_null_default_stream_before_ffi() {
863 let event = std::mem::ManuallyDrop::new(super::CudaEvent {
864 raw: std::ptr::NonNull::<cudarc::driver::sys::CUevent_st>::dangling().as_ptr(),
865 });
866 let err = event
867 .record(std::ptr::null_mut())
868 .expect_err("Fix: event recording must reject CUDA's legacy null stream before FFI");
869 assert!(
870 err.to_string().contains("null CUDA stream"),
871 "event record diagnostic must explain the default-stream hazard: {err}"
872 );
873 }
874
875 #[test]
876 fn event_query_and_sync_reject_null_event_before_ffi() {
877 let event = super::CudaEvent {
878 raw: std::ptr::null_mut(),
879 };
880 let query_err = event
881 .query_ready()
882 .expect_err("Fix: event readiness query must reject null events before FFI");
883 assert!(
884 query_err.to_string().contains("null CUDA event"),
885 "event query diagnostic must explain the null-event hazard: {query_err}"
886 );
887
888 let sync_err = event
889 .synchronize()
890 .expect_err("Fix: event synchronize must reject null events before FFI");
891 assert!(
892 sync_err.to_string().contains("null CUDA event"),
893 "event sync diagnostic must explain the null-event hazard: {sync_err}"
894 );
895 }
896
897 #[test]
898 fn event_elapsed_time_rejects_null_timing_event_before_ffi() {
899 let event = super::CudaEvent {
900 raw: std::ptr::null_mut(),
901 };
902 let err = event
903 .elapsed_time_ns(&event)
904 .expect_err("Fix: elapsed timing must reject null events before FFI");
905 assert!(
906 err.to_string().contains("null CUDA timing event"),
907 "event elapsed diagnostic must explain the null-event hazard: {err}"
908 );
909 }
910
911 #[test]
912 fn stream_lifecycle_ffi_is_single_sourced_for_graph_capture() {
913 let stream = include_str!("stream.rs");
914 let cuda_graph = include_str!("backend/cuda_graph.rs");
915 let create_ffi = concat!("cudarc::driver::sys::", "cuStreamCreate(");
916 let destroy_ffi = concat!("cudarc::driver::sys::", "cuStreamDestroy_v2(");
917
918 assert_eq!(
919 stream.matches(create_ffi).count(),
920 1,
921 "Fix: raw CUDA stream creation must stay behind create_non_blocking_raw_stream."
922 );
923 assert_eq!(
924 stream.matches(destroy_ffi).count(),
925 1,
926 "Fix: raw CUDA stream destruction must stay behind destroy_raw_stream."
927 );
928 assert_eq!(
929 cuda_graph.matches(create_ffi).count() + cuda_graph.matches(destroy_ffi).count(),
930 0,
931 "Fix: cudaGraph capture must use the shared stream lifecycle helpers instead of direct stream FFI."
932 );
933 assert!(
934 stream.contains("fn create_non_blocking_raw_stream(")
935 && stream.contains("returned a null stream after reporting success")
936 && cuda_graph.contains("create_non_blocking_raw_stream"),
937 "Fix: shared CUDA stream creation must reject null-success handles and be used by cudaGraph."
938 );
939 }
940
941 #[test]
942 fn event_lifecycle_ffi_is_single_sourced() {
943 let stream = include_str!("stream.rs");
944 let create_ffi = concat!("cudarc::driver::sys::", "cuEventCreate(");
945 let destroy_ffi = concat!("cudarc::driver::sys::", "cuEventDestroy_v2(");
946
947 assert_eq!(
948 stream.matches(create_ffi).count(),
949 1,
950 "Fix: raw CUDA event creation must stay behind create_raw_event."
951 );
952 assert_eq!(
953 stream.matches(destroy_ffi).count(),
954 1,
955 "Fix: raw CUDA event destruction must stay behind destroy_raw_event."
956 );
957 assert!(
958 stream.contains("fn create_raw_event(")
959 && stream.contains("returned a null event after reporting success")
960 && stream.contains("fn destroy_raw_event(")
961 && stream.contains("CudaEvent::completion")
962 && stream.contains("CudaEvent::timing"),
963 "Fix: CUDA event lifecycle must use shared create/destroy helpers with null-success validation."
964 );
965 }
966
967 #[test]
968 fn graph_replay_uses_shared_stream_query_helper() {
969 let stream = include_str!("stream.rs");
970 let graph_replay = include_str!("backend/cuda_graph_replay.rs");
971 let query_ffi = concat!("cudarc::driver::sys::", "cuStreamQuery(");
972
973 assert_eq!(
974 stream.matches(query_ffi).count(),
975 1,
976 "Fix: raw CUDA stream query must stay behind query_raw_stream_ready."
977 );
978 assert_eq!(
979 graph_replay.matches(query_ffi).count(),
980 0,
981 "Fix: CUDA graph replay must use query_raw_stream_ready instead of raw cuStreamQuery."
982 );
983 assert!(
984 graph_replay.contains("query_raw_stream_ready")
985 && stream.contains("fn query_raw_stream_ready("),
986 "Fix: graph replay polling must use the shared stream query helper."
987 );
988 }
989}
990
991
992impl Drop for CudaPendingDispatch {
993 fn drop(&mut self) {
994 if !self.completed.load(Ordering::Acquire) {
995 if let Err(error) = self.ctx.bind_to_thread() {
996 tracing::error!(
997 "Fix: failed to bind CUDA context while dropping pending dispatch: {error}. Dispatch completion could not be forced."
998 );
999 }
1000 if let Some(stream) = self.stream.as_ref() {
1001 if let Err(error) = stream.synchronize() {
1002 tracing::error!(
1003 "Fix: failed to synchronize CUDA stream while dropping pending dispatch: {error}. Dispatch completion state may be stale."
1004 );
1005 } else {
1006 self.telemetry.record_sync_point();
1007 }
1008 }
1009 self.completed.store(true, Ordering::Release);
1010 }
1011 self.release_launch_resources();
1012 self.allocations.take();
1013 self.resident_use.take();
1014 self.host_transfers.take();
1015 }
1016}
1017