Skip to main content

vyre_driver_cuda/
stream.rs

1//! CUDA stream/event ownership and pending-dispatch handles.
2
3use std::ptr::NonNull;
4use std::sync::{
5    atomic::{AtomicBool, Ordering},
6    Arc,
7};
8
9use crossbeam_queue::ArrayQueue;
10use cudarc::driver::{
11    sys::{CUevent, CUevent_flags, CUresult, CUstream, CUstream_flags, CUstream_st},
12    CudaContext,
13};
14use vyre_driver::{backend::private, BackendError, PendingDispatch};
15
16use crate::backend::telemetry::CudaTelemetry;
17use crate::backend::{cuda_check, DispatchAllocations, HostTransferAllocations, ResidentUseGuard};
18
19/// RAII owner for a CUDA stream.
20#[derive(Debug)]
21pub(crate) struct CudaStream {
22    raw: CUstream,
23}
24
25unsafe impl Send for CudaStream {}
26unsafe impl Sync for CudaStream {}
27
28impl CudaStream {
29    /// Create a non-blocking CUDA stream.
30    pub(crate) fn non_blocking() -> Result<Self, BackendError> {
31        let raw = create_non_blocking_raw_stream("cuStreamCreate")?;
32        Ok(Self { raw: raw.as_ptr() })
33    }
34
35    /// Raw CUDA stream handle.
36    #[must_use]
37    pub(crate) fn raw(&self) -> CUstream {
38        self.raw
39    }
40
41    /// Block until stream work has completed.
42    pub(crate) fn synchronize(&self) -> Result<(), BackendError> {
43        synchronize_raw_stream(self.raw, "cuStreamSynchronize")
44    }
45}
46
47/// Create a non-blocking raw CUDA stream and reject impossible null-success
48/// driver responses before callers can accidentally fall back to stream 0.
49pub(crate) fn create_non_blocking_raw_stream(
50    label: &'static str,
51) -> Result<NonNull<CUstream_st>, BackendError> {
52    let mut raw = std::ptr::null_mut();
53    // SAFETY: raw is a valid CUDA stream out-pointer; cuda_check converts
54    // non-success CUresult values into BackendError.
55    unsafe {
56        cuda_check(
57            cudarc::driver::sys::cuStreamCreate(
58                &mut raw,
59                CUstream_flags::CU_STREAM_NON_BLOCKING as u32,
60            ),
61            label,
62        )?;
63    }
64    NonNull::new(raw).ok_or_else(|| BackendError::DispatchFailed {
65        code: None,
66        message: format!(
67            "{label} returned a null stream after reporting success. Fix: update the CUDA driver or disable the CUDA path using this stream."
68        ),
69    })
70}
71
72pub(crate) fn destroy_raw_stream(stream: CUstream, label: &'static str) {
73    if stream.is_null() {
74        return;
75    }
76    // SAFETY: stream is a CUDA stream handle owned by the caller; destroy is
77    // best-effort because this function is used from Drop paths.
78    unsafe {
79        let result = cudarc::driver::sys::cuStreamDestroy_v2(stream);
80        if result != CUresult::CUDA_SUCCESS {
81            tracing::error!(
82                "Fix: {label} failed during CUDA stream drop with {result:?}; ensure pending work is synchronized before dropping dispatch resources."
83            );
84        }
85    }
86}
87
88/// Query a raw CUDA stream without falling back to CUDA's legacy null-stream
89/// semantics.
90pub(crate) fn query_raw_stream_ready(
91    stream: CUstream,
92    label: &'static str,
93) -> Result<bool, BackendError> {
94    if stream.is_null() {
95        return Err(BackendError::InvalidProgram {
96            fix: format!(
97                "Fix: {label} received a null CUDA stream; use a backend-owned non-blocking stream instead of querying CUDA's legacy default stream."
98            ),
99        });
100    }
101    // SAFETY: CUDA validates the opaque stream handle and reports readiness
102    // through CUresult.
103    let result = unsafe { cudarc::driver::sys::cuStreamQuery(stream) };
104    match result {
105        CUresult::CUDA_SUCCESS => Ok(true),
106        CUresult::CUDA_ERROR_NOT_READY => Ok(false),
107        other => cuda_check(other, label).map(|()| true),
108    }
109}
110
111/// Synchronize a raw CUDA stream without ever falling through to the legacy
112/// null-stream global fence.
113pub(crate) fn synchronize_raw_stream(
114    stream: CUstream,
115    label: &'static str,
116) -> Result<(), BackendError> {
117    if stream.is_null() {
118        return Err(BackendError::InvalidProgram {
119            fix: format!(
120                "Fix: {label} received a null CUDA stream; use a backend-owned non-blocking stream instead of the legacy default stream."
121            ),
122        });
123    }
124    // SAFETY: CUDA validates the opaque stream handle and returns a CUresult;
125    // `cuda_check` converts non-success into a typed backend error.
126    unsafe { cuda_check(cudarc::driver::sys::cuStreamSynchronize(stream), label) }
127}
128
129impl Drop for CudaStream {
130    fn drop(&mut self) {
131        destroy_raw_stream(self.raw, "cuStreamDestroy_v2");
132    }
133}
134
135/// RAII owner for a CUDA event used as the completion fence.
136#[derive(Debug)]
137pub(crate) struct CudaEvent {
138    raw: CUevent,
139}
140
141unsafe impl Send for CudaEvent {}
142unsafe impl Sync for CudaEvent {}
143
144impl CudaEvent {
145    /// Create a timing-disabled CUDA event.
146    pub(crate) fn completion() -> Result<Self, BackendError> {
147        let raw = create_raw_event(
148            CUevent_flags::CU_EVENT_DISABLE_TIMING as u32,
149            "cuEventCreate",
150        )?;
151        Ok(Self { raw })
152    }
153
154    /// Create a CUDA event with timing enabled.
155    pub(crate) fn timing() -> Result<Self, BackendError> {
156        let raw = create_raw_event(0, "cuEventCreate")?;
157        Ok(Self { raw })
158    }
159
160    /// Record this event onto a stream.
161    pub(crate) fn record(&self, stream: CUstream) -> Result<(), BackendError> {
162        if self.raw.is_null() {
163            return Err(BackendError::InvalidProgram {
164                fix: "Fix: cuEventRecord received a null CUDA event; acquire a backend-owned event before recording completion.".to_string(),
165            });
166        }
167        if stream.is_null() {
168            return Err(BackendError::InvalidProgram {
169                fix: "Fix: cuEventRecord received a null CUDA stream; record events on a backend-owned non-blocking stream instead of CUDA's legacy default stream.".to_string(),
170            });
171        }
172        // SAFETY: stream / event handles are owned by &self; cuStream*/cuEvent* calls
173        // operate on those owned handles and the result is checked via cuda_check.
174        unsafe {
175            cuda_check(
176                cudarc::driver::sys::cuEventRecord(self.raw, stream),
177                "cuEventRecord",
178            )
179        }
180    }
181
182    /// Return whether all prior work in the stream has completed.
183    pub(crate) fn query_ready(&self) -> Result<bool, BackendError> {
184        if self.raw.is_null() {
185            return Err(BackendError::InvalidProgram {
186                fix: "Fix: cuEventQuery received a null CUDA event; pending dispatches must own a recorded completion event before readiness polling.".to_string(),
187            });
188        }
189        // SAFETY: event handle is owned by &self and non-null. CUDA reports
190        // readiness or a typed driver error via CUresult.
191        let result = unsafe { cudarc::driver::sys::cuEventQuery(self.raw) };
192        match result {
193            CUresult::CUDA_SUCCESS => Ok(true),
194            CUresult::CUDA_ERROR_NOT_READY => Ok(false),
195            other => cuda_check(other, "cuEventQuery").map(|()| true),
196        }
197    }
198
199    /// Block until the event completes.
200    pub(crate) fn synchronize(&self) -> Result<(), BackendError> {
201        if self.raw.is_null() {
202            return Err(BackendError::InvalidProgram {
203                fix: "Fix: cuEventSynchronize received a null CUDA event; pending dispatches must own a recorded completion event before synchronization.".to_string(),
204            });
205        }
206        // SAFETY: stream / event handles are owned by &self; cuStream*/cuEvent* calls
207        // operate on those owned handles and the result is checked via cuda_check.
208        unsafe {
209            cuda_check(
210                cudarc::driver::sys::cuEventSynchronize(self.raw),
211                "cuEventSynchronize",
212            )
213        }
214    }
215
216    /// Elapsed time between two timing-enabled events, in nanoseconds.
217    pub(crate) fn elapsed_time_ns(&self, end: &CudaEvent) -> Result<u64, BackendError> {
218        if self.raw.is_null() || end.raw.is_null() {
219            return Err(BackendError::InvalidProgram {
220                fix: "Fix: cuEventElapsedTime received a null CUDA timing event; record both timing events before reading elapsed time.".to_string(),
221            });
222        }
223        let mut elapsed_ms = 0.0f32;
224        // SAFETY: both events are owned, valid CUDA event handles. CUDA returns an
225        // error if either event was not recorded or timing was disabled.
226        unsafe {
227            cuda_check(
228                cudarc::driver::sys::cuEventElapsedTime(
229                    (&mut elapsed_ms) as *mut f32,
230                    self.raw,
231                    end.raw,
232                ),
233                "cuEventElapsedTime",
234            )?;
235        }
236        let elapsed_ns = f64::from(elapsed_ms) * 1_000_000.0;
237        if !elapsed_ns.is_finite() || elapsed_ns < 0.0 || elapsed_ns > u64::MAX as f64 {
238            return Err(BackendError::InvalidProgram {
239                fix: format!(
240                    "Fix: CUDA event elapsed time {elapsed_ms} ms cannot fit u64 nanoseconds; inspect CUDA event timing and split the dispatch before telemetry overflows."
241                ),
242            });
243        }
244        crate::numeric::CUDA_NUMERIC.rounded_f64_to_u64(elapsed_ns, "event elapsed nanoseconds")
245    }
246}
247
248impl Drop for CudaEvent {
249    fn drop(&mut self) {
250        destroy_raw_event(self.raw, "cuEventDestroy_v2");
251    }
252}
253
254fn create_raw_event(flags: u32, label: &'static str) -> Result<CUevent, BackendError> {
255    let mut raw = std::ptr::null_mut();
256    // SAFETY: raw is a valid CUDA event out-pointer; cuda_check converts
257    // non-success CUresult values into BackendError.
258    unsafe {
259        cuda_check(cudarc::driver::sys::cuEventCreate(&mut raw, flags), label)?;
260    }
261    if raw.is_null() {
262        return Err(BackendError::DispatchFailed {
263            code: None,
264            message: format!(
265                "{label} returned a null event after reporting success. Fix: update the CUDA driver or disable event-backed CUDA dispatch for this device."
266            ),
267        });
268    }
269    Ok(raw)
270}
271
272fn destroy_raw_event(event: CUevent, label: &'static str) {
273    if event.is_null() {
274        return;
275    }
276    // SAFETY: event is a CUDA event handle owned by the caller; destroy is
277    // best-effort because this function is used from Drop paths.
278    unsafe {
279        let result = cudarc::driver::sys::cuEventDestroy_v2(event);
280        if result != CUresult::CUDA_SUCCESS {
281            tracing::error!(
282                "Fix: {label} failed during CUDA event drop with {result:?}; ensure pending work is synchronized before dropping dispatch resources."
283            );
284        }
285    }
286}
287
288/// Cached CUDA launch resources for repeated dispatches.
289#[derive(Debug)]
290pub(crate) struct CudaLaunchResourcePool {
291    streams: ArrayQueue<CudaStream>,
292    events: ArrayQueue<CudaEvent>,
293    timing_events: ArrayQueue<CudaEvent>,
294}
295
296/// Cached CUDA launch-resource counts retained for dispatch reuse.
297#[derive(Clone, Copy, Debug, Eq, PartialEq)]
298pub struct CudaLaunchResourceCounts {
299    /// Cached non-blocking CUDA streams.
300    pub streams: usize,
301    /// Cached completion-fence CUDA events.
302    pub completion_events: usize,
303    /// Cached timing-enabled CUDA events used by graph replay telemetry.
304    pub timing_events: usize,
305}
306
307/// Owned lease for launch resources before they are transferred into a pending dispatch.
308#[derive(Debug)]
309pub(crate) struct CudaLaunchResourceLease {
310    pool: Arc<CudaLaunchResourcePool>,
311    stream: Option<CudaStream>,
312    timing_events: Option<(CudaEvent, CudaEvent)>,
313}
314
315/// Owned lease for a timing-event pair used outside normal launch-resource ownership.
316#[derive(Debug)]
317pub(crate) struct CudaTimingEventPairLease {
318    pool: Arc<CudaLaunchResourcePool>,
319    timing_events: Option<(CudaEvent, CudaEvent)>,
320}
321
322impl CudaTimingEventPairLease {
323    pub(crate) fn acquire(pool: Arc<CudaLaunchResourcePool>) -> Result<Self, BackendError> {
324        let timing_events = pool.acquire_timing_event_pair()?;
325        Ok(Self {
326            pool,
327            timing_events: Some(timing_events),
328        })
329    }
330
331    pub(crate) fn events(&self) -> Result<&(CudaEvent, CudaEvent), BackendError> {
332        self.timing_events
333            .as_ref()
334            .ok_or_else(|| BackendError::InvalidProgram {
335                fix: "Fix: CUDA timing event pair lease was already consumed; acquire a fresh timing lease before recording graph replay events.".to_string(),
336            })
337    }
338}
339
340impl Drop for CudaTimingEventPairLease {
341    fn drop(&mut self) {
342        if let Some((start, end)) = self.timing_events.take() {
343            self.pool.release_timing_event(start);
344            self.pool.release_timing_event(end);
345        }
346    }
347}
348
349impl CudaLaunchResourceLease {
350    pub(crate) fn acquire(
351        pool: Arc<CudaLaunchResourcePool>,
352        capture_timing: bool,
353    ) -> Result<Self, BackendError> {
354        let stream = pool.acquire_stream()?;
355        let timing_events = if capture_timing {
356            match pool.acquire_timing_event_pair() {
357                Ok(pair) => Some(pair),
358                Err(error) => {
359                    pool.release_stream(stream);
360                    return Err(error);
361                }
362            }
363        } else {
364            None
365        };
366        Ok(Self {
367            pool,
368            stream: Some(stream),
369            timing_events,
370        })
371    }
372
373    pub(crate) fn stream_raw(&self) -> Result<CUstream, BackendError> {
374        self.stream
375            .as_ref()
376            .map(CudaStream::raw)
377            .ok_or_else(|| BackendError::InvalidProgram {
378                fix: "Fix: CUDA launch resource lease stream was already consumed; acquire a fresh launch-resource lease before enqueueing CUDA work.".to_string(),
379            })
380    }
381
382    pub(crate) fn timing_events(&self) -> Result<Option<&(CudaEvent, CudaEvent)>, BackendError> {
383        if self.stream.is_none() {
384            return Err(BackendError::InvalidProgram {
385                fix: "Fix: CUDA launch resource lease timing events were queried after the stream was consumed; query timing events before transferring the lease into a pending dispatch.".to_string(),
386            });
387        }
388        Ok(self.timing_events.as_ref())
389    }
390
391    pub(crate) fn into_parts(
392        mut self,
393    ) -> Result<(CudaStream, Option<(CudaEvent, CudaEvent)>), BackendError> {
394        let stream = self.stream.take().ok_or_else(|| BackendError::InvalidProgram {
395            fix: "Fix: CUDA launch resource lease stream was already consumed; pending dispatch ownership cannot be built twice from the same lease.".to_string(),
396        })?;
397        let timing_events = self.timing_events.take();
398        Ok((stream, timing_events))
399    }
400}
401
402impl Drop for CudaLaunchResourceLease {
403    fn drop(&mut self) {
404        if let Some((start, end)) = self.timing_events.take() {
405            self.pool.release_timing_event(start);
406            self.pool.release_timing_event(end);
407        }
408        if let Some(stream) = self.stream.take() {
409            self.pool.release_stream(stream);
410        }
411    }
412}
413
414impl CudaLaunchResourcePool {
415    pub(crate) fn new(max_cached: usize) -> Self {
416        let max_cached = max_cached.max(1);
417        Self {
418            streams: ArrayQueue::new(max_cached),
419            events: ArrayQueue::new(max_cached),
420            timing_events: ArrayQueue::new(max_cached),
421        }
422    }
423
424    pub(crate) fn acquire_stream(&self) -> Result<CudaStream, BackendError> {
425        if let Some(stream) = self.streams.pop() {
426            return Ok(stream);
427        }
428        CudaStream::non_blocking()
429    }
430
431    pub(crate) fn acquire_event(&self) -> Result<CudaEvent, BackendError> {
432        if let Some(event) = self.events.pop() {
433            return Ok(event);
434        }
435        CudaEvent::completion()
436    }
437
438    pub(crate) fn acquire_timing_event(&self) -> Result<CudaEvent, BackendError> {
439        if let Some(event) = self.timing_events.pop() {
440            return Ok(event);
441        }
442        CudaEvent::timing()
443    }
444
445    pub(crate) fn acquire_timing_event_pair(&self) -> Result<(CudaEvent, CudaEvent), BackendError> {
446        let start = self.acquire_timing_event()?;
447        match self.acquire_timing_event() {
448            Ok(end) => Ok((start, end)),
449            Err(error) => {
450                self.release_timing_event(start);
451                Err(error)
452            }
453        }
454    }
455
456    pub(crate) fn release_stream(&self, stream: CudaStream) {
457        if let Err(stream) = self.streams.push(stream) {
458            drop(stream);
459        }
460    }
461
462    pub(crate) fn release_event(&self, event: CudaEvent) {
463        if let Err(event) = self.events.push(event) {
464            drop(event);
465        }
466    }
467
468    pub(crate) fn release_timing_event(&self, event: CudaEvent) {
469        if let Err(event) = self.timing_events.push(event) {
470            drop(event);
471        }
472    }
473
474    pub(crate) fn cached_counts(&self) -> Result<(usize, usize), BackendError> {
475        Ok((self.streams.len(), self.events.len()))
476    }
477
478    pub(crate) fn cached_counts_detailed(&self) -> Result<CudaLaunchResourceCounts, BackendError> {
479        Ok(CudaLaunchResourceCounts {
480            streams: self.streams.len(),
481            completion_events: self.events.len(),
482            timing_events: self.timing_events.len(),
483        })
484    }
485
486    pub(crate) fn clear(&self) -> Result<(), BackendError> {
487        while self.streams.pop().is_some() {}
488        while self.events.pop().is_some() {}
489        while self.timing_events.pop().is_some() {}
490        Ok(())
491    }
492}
493
494/// CUDA-backed pending dispatch whose result is fenced by a CUDA event.
495#[derive(Debug)]
496pub(crate) struct CudaPendingDispatch {
497    ctx: Arc<CudaContext>,
498    pool: Arc<CudaLaunchResourcePool>,
499    event: Option<CudaEvent>,
500    stream: Option<CudaStream>,
501    allocations: Option<DispatchAllocations>,
502    resident_use: Option<ResidentUseGuard>,
503    host_transfers: Option<HostTransferAllocations>,
504    outputs: Vec<Vec<u8>>,
505    timing_start: Option<CudaEvent>,
506    timing_end: Option<CudaEvent>,
507    ready_device_ns: Option<u64>,
508    telemetry: Arc<CudaTelemetry>,
509    completed: AtomicBool,
510}
511
512
513impl CudaPendingDispatch {
514    /// Build an already-completed pending dispatch.
515    pub(crate) fn new_ready(
516        ctx: Arc<CudaContext>,
517        pool: Arc<CudaLaunchResourcePool>,
518        outputs: Vec<Vec<u8>>,
519        telemetry: Arc<CudaTelemetry>,
520    ) -> Self {
521        Self {
522            ctx,
523            pool,
524            event: None,
525            stream: None,
526            allocations: None,
527            resident_use: None,
528            host_transfers: None,
529            outputs,
530            timing_start: None,
531            timing_end: None,
532            ready_device_ns: None,
533            telemetry,
534            completed: AtomicBool::new(true),
535        }
536    }
537
538    /// Build an already-completed pending dispatch with measured device time.
539    pub(crate) fn new_ready_timed(
540        ctx: Arc<CudaContext>,
541        pool: Arc<CudaLaunchResourcePool>,
542        outputs: Vec<Vec<u8>>,
543        device_ns: Option<u64>,
544        telemetry: Arc<CudaTelemetry>,
545    ) -> Self {
546        Self {
547            ctx,
548            pool,
549            event: None,
550            stream: None,
551            allocations: None,
552            resident_use: None,
553            host_transfers: None,
554            outputs,
555            timing_start: None,
556            timing_end: None,
557            ready_device_ns: device_ns,
558            telemetry,
559            completed: AtomicBool::new(true),
560        }
561    }
562
563    /// Build a pending resident batch dispatch with no host output slots.
564    ///
565    /// Resident batch readback uses caller-owned resident handles; the pending
566    /// dispatch only fences parameter uploads and kernel launches.
567    #[allow(clippy::too_many_arguments)]
568    pub(crate) fn new_resident_batch_pending(
569        ctx: Arc<CudaContext>,
570        pool: Arc<CudaLaunchResourcePool>,
571        event: CudaEvent,
572        stream: CudaStream,
573        allocations: DispatchAllocations,
574        resident_use: ResidentUseGuard,
575        host_transfers: HostTransferAllocations,
576        telemetry: Arc<CudaTelemetry>,
577    ) -> Self {
578        Self::new(
579            ctx,
580            pool,
581            event,
582            stream,
583            allocations,
584            Some(resident_use),
585            Some(host_transfers),
586            Vec::new(),
587            telemetry,
588        )
589    }
590
591    /// Build a pending dispatch after all GPU work has been enqueued.
592    #[allow(clippy::too_many_arguments)]
593    pub(crate) fn new(
594        ctx: Arc<CudaContext>,
595        pool: Arc<CudaLaunchResourcePool>,
596        event: CudaEvent,
597        stream: CudaStream,
598        allocations: DispatchAllocations,
599        resident_use: Option<ResidentUseGuard>,
600        host_transfers: Option<HostTransferAllocations>,
601        outputs: Vec<Vec<u8>>,
602        telemetry: Arc<CudaTelemetry>,
603    ) -> Self {
604        Self {
605            ctx,
606            pool,
607            event: Some(event),
608            stream: Some(stream),
609            allocations: Some(allocations),
610            resident_use,
611            host_transfers,
612            outputs,
613            timing_start: None,
614            timing_end: None,
615            ready_device_ns: None,
616            telemetry,
617            completed: AtomicBool::new(false),
618        }
619    }
620
621    /// Build a pending dispatch with timing-enabled start/end events.
622    #[allow(clippy::too_many_arguments)]
623    pub(crate) fn new_with_timing(
624        ctx: Arc<CudaContext>,
625        pool: Arc<CudaLaunchResourcePool>,
626        event: CudaEvent,
627        stream: CudaStream,
628        allocations: DispatchAllocations,
629        resident_use: Option<ResidentUseGuard>,
630        host_transfers: Option<HostTransferAllocations>,
631        outputs: Vec<Vec<u8>>,
632        timing_start: CudaEvent,
633        timing_end: CudaEvent,
634        telemetry: Arc<CudaTelemetry>,
635    ) -> Self {
636        Self {
637            ctx,
638            pool,
639            event: Some(event),
640            stream: Some(stream),
641            allocations: Some(allocations),
642            resident_use,
643            host_transfers,
644            outputs,
645            timing_start: Some(timing_start),
646            timing_end: Some(timing_end),
647            ready_device_ns: None,
648            telemetry,
649            completed: AtomicBool::new(false),
650        }
651    }
652
653    fn bind_context(&self) -> Result<(), BackendError> {
654        self.ctx
655            .bind_to_thread()
656            .map_err(|e| BackendError::DispatchFailed {
657                code: None,
658                message: format!("CUDA context bind failed: {e}"),
659            })
660    }
661
662    fn synchronize(&self) -> Result<(), BackendError> {
663        if self.completed.load(Ordering::Acquire) {
664            return Ok(());
665        }
666        self.bind_context()?;
667        let event = self
668            .event
669            .as_ref()
670            .ok_or_else(|| BackendError::DispatchFailed {
671                code: None,
672                message: "CUDA pending dispatch completion event was already released".to_string(),
673            })?;
674        event.synchronize()?;
675        self.telemetry.record_sync_point();
676        self.completed.store(true, Ordering::Release);
677        Ok(())
678    }
679
680    fn release_launch_resources(&mut self) {
681        if let Some(event) = self.event.take() {
682            self.pool.release_event(event);
683        }
684        if let Some(event) = self.timing_start.take() {
685            self.pool.release_timing_event(event);
686        }
687        if let Some(event) = self.timing_end.take() {
688            self.pool.release_timing_event(event);
689        }
690        if let Some(stream) = self.stream.take() {
691            self.pool.release_stream(stream);
692        }
693    }
694
695    /// Await completion and return output buffers plus device elapsed time.
696    pub(crate) fn await_timed_result(
697        mut self,
698    ) -> Result<(Vec<Vec<u8>>, Option<u64>), BackendError> {
699        self.synchronize()?;
700        let device_ns = match self.ready_device_ns.take() {
701            Some(device_ns) => Some(device_ns),
702            None => match (self.timing_start.as_ref(), self.timing_end.as_ref()) {
703                (Some(start), Some(end)) => Some(start.elapsed_time_ns(end)?),
704                _ => None,
705            },
706        };
707        self.release_launch_resources();
708        self.allocations.take();
709        self.resident_use.take();
710        let outputs = self.collect_outputs()?;
711        self.host_transfers.take();
712        Ok((outputs, device_ns))
713    }
714
715    fn collect_outputs(&mut self) -> Result<Vec<Vec<u8>>, BackendError> {
716        if let Some(transfers) = self.host_transfers.as_ref() {
717            let mut outputs = std::mem::take(&mut self.outputs);
718            transfers.collect_outputs_into(&mut outputs)?;
719            Ok(outputs)
720        } else {
721            Ok(std::mem::take(&mut self.outputs))
722        }
723    }
724
725    fn collect_outputs_into(&mut self, outputs: &mut Vec<Vec<u8>>) -> Result<(), BackendError> {
726        if let Some(transfers) = self.host_transfers.as_ref() {
727            transfers.collect_outputs_into(outputs)?;
728        } else {
729            vyre_driver::replace_output_buffers_preserving_slots(
730                std::mem::take(&mut self.outputs),
731                outputs,
732            );
733        }
734        Ok(())
735    }
736}
737
738impl private::Sealed for CudaPendingDispatch {}
739
740impl PendingDispatch for CudaPendingDispatch {
741    fn is_ready(&self) -> bool {
742        if self.completed.load(Ordering::Acquire) {
743            return true;
744        }
745        if self.bind_context().is_err() {
746            return false;
747        }
748        let Some(event) = self.event.as_ref() else {
749            return true;
750        };
751        let ready = match event.query_ready() {
752            Ok(ready) => ready,
753            Err(error) => {
754                tracing::error!(
755                    "Fix: CUDA pending dispatch readiness query failed: {error}. Await the dispatch to surface synchronization failure details."
756                );
757                false
758            }
759        };
760        if ready {
761            self.completed.store(true, Ordering::Release);
762        }
763        ready
764    }
765
766    fn await_result(mut self: Box<Self>) -> Result<Vec<Vec<u8>>, BackendError> {
767        self.synchronize()?;
768        self.release_launch_resources();
769        self.allocations.take();
770        self.resident_use.take();
771        let outputs = self.collect_outputs()?;
772        self.host_transfers.take();
773        Ok(outputs)
774    }
775
776    fn await_result_into(
777        mut self: Box<Self>,
778        outputs: &mut Vec<Vec<u8>>,
779    ) -> Result<(), BackendError> {
780        self.synchronize()?;
781        self.release_launch_resources();
782        self.allocations.take();
783        self.resident_use.take();
784        self.collect_outputs_into(outputs)?;
785        self.host_transfers.take();
786        Ok(())
787    }
788}
789
790#[cfg(test)]
791mod tests {
792    use super::{query_raw_stream_ready, synchronize_raw_stream, CudaLaunchResourcePool};
793
794    #[test]
795    fn launch_resource_leases_do_not_panic_on_consumed_state() {
796        let source = include_str!("stream.rs");
797        assert!(
798            !source.contains(concat!(".expect", "(\"Fix: CUDA launch resource lease stream was already consumed")),
799            "Fix: CUDA launch resource leases must return typed backend errors when consumed twice, not panic."
800        );
801        assert!(
802            !source.contains(concat!(".expect", "(\"Fix: CUDA timing event pair lease was already consumed")),
803            "Fix: CUDA graph replay timing leases must return typed backend errors when consumed twice, not panic."
804        );
805    }
806
807    #[test]
808    fn launch_resource_counts_include_timing_events() {
809        let pool = CudaLaunchResourcePool::new(8);
810        let counts = pool
811            .cached_counts_detailed()
812            .expect("Fix: empty launch resource pool counts should be readable");
813
814        assert_eq!(counts.streams, 0);
815        assert_eq!(counts.completion_events, 0);
816        assert_eq!(counts.timing_events, 0);
817
818        let source = include_str!("stream.rs");
819        assert!(
820            source.contains("pub struct CudaLaunchResourceCounts")
821                && source.contains("pub timing_events: usize")
822                && source.contains("cached_counts_detailed"),
823            "Fix: CUDA launch-resource telemetry must expose timing-event cache pressure, not just streams and completion events."
824        );
825    }
826
827    #[test]
828    fn raw_stream_sync_rejects_null_default_stream() {
829        let err = synchronize_raw_stream(std::ptr::null_mut(), "unit sync")
830            .expect_err("Fix: raw stream sync must reject the legacy null stream");
831        assert!(
832            err.to_string().contains("null CUDA stream"),
833            "raw sync diagnostic must explain the default-stream hazard: {err}"
834        );
835    }
836
837    #[test]
838    fn raw_stream_query_rejects_null_default_stream() {
839        let err = query_raw_stream_ready(std::ptr::null_mut(), "unit query")
840            .expect_err("Fix: raw stream query must reject the legacy null stream");
841        assert!(
842            err.to_string().contains("null CUDA stream"),
843            "raw query diagnostic must explain the default-stream hazard: {err}"
844        );
845    }
846
847    #[test]
848    fn event_record_rejects_null_event_before_ffi() {
849        let event = super::CudaEvent {
850            raw: std::ptr::null_mut(),
851        };
852        let err = event
853            .record(std::ptr::null_mut())
854            .expect_err("Fix: event recording must reject invalid event handles before FFI");
855        assert!(
856            err.to_string().contains("null CUDA event"),
857            "event record diagnostic must explain the null-event hazard: {err}"
858        );
859    }
860
861    #[test]
862    fn event_record_rejects_null_default_stream_before_ffi() {
863        let event = std::mem::ManuallyDrop::new(super::CudaEvent {
864            raw: std::ptr::NonNull::<cudarc::driver::sys::CUevent_st>::dangling().as_ptr(),
865        });
866        let err = event
867            .record(std::ptr::null_mut())
868            .expect_err("Fix: event recording must reject CUDA's legacy null stream before FFI");
869        assert!(
870            err.to_string().contains("null CUDA stream"),
871            "event record diagnostic must explain the default-stream hazard: {err}"
872        );
873    }
874
875    #[test]
876    fn event_query_and_sync_reject_null_event_before_ffi() {
877        let event = super::CudaEvent {
878            raw: std::ptr::null_mut(),
879        };
880        let query_err = event
881            .query_ready()
882            .expect_err("Fix: event readiness query must reject null events before FFI");
883        assert!(
884            query_err.to_string().contains("null CUDA event"),
885            "event query diagnostic must explain the null-event hazard: {query_err}"
886        );
887
888        let sync_err = event
889            .synchronize()
890            .expect_err("Fix: event synchronize must reject null events before FFI");
891        assert!(
892            sync_err.to_string().contains("null CUDA event"),
893            "event sync diagnostic must explain the null-event hazard: {sync_err}"
894        );
895    }
896
897    #[test]
898    fn event_elapsed_time_rejects_null_timing_event_before_ffi() {
899        let event = super::CudaEvent {
900            raw: std::ptr::null_mut(),
901        };
902        let err = event
903            .elapsed_time_ns(&event)
904            .expect_err("Fix: elapsed timing must reject null events before FFI");
905        assert!(
906            err.to_string().contains("null CUDA timing event"),
907            "event elapsed diagnostic must explain the null-event hazard: {err}"
908        );
909    }
910
911    #[test]
912    fn stream_lifecycle_ffi_is_single_sourced_for_graph_capture() {
913        let stream = include_str!("stream.rs");
914        let cuda_graph = include_str!("backend/cuda_graph.rs");
915        let create_ffi = concat!("cudarc::driver::sys::", "cuStreamCreate(");
916        let destroy_ffi = concat!("cudarc::driver::sys::", "cuStreamDestroy_v2(");
917
918        assert_eq!(
919            stream.matches(create_ffi).count(),
920            1,
921            "Fix: raw CUDA stream creation must stay behind create_non_blocking_raw_stream."
922        );
923        assert_eq!(
924            stream.matches(destroy_ffi).count(),
925            1,
926            "Fix: raw CUDA stream destruction must stay behind destroy_raw_stream."
927        );
928        assert_eq!(
929            cuda_graph.matches(create_ffi).count() + cuda_graph.matches(destroy_ffi).count(),
930            0,
931            "Fix: cudaGraph capture must use the shared stream lifecycle helpers instead of direct stream FFI."
932        );
933        assert!(
934            stream.contains("fn create_non_blocking_raw_stream(")
935                && stream.contains("returned a null stream after reporting success")
936                && cuda_graph.contains("create_non_blocking_raw_stream"),
937            "Fix: shared CUDA stream creation must reject null-success handles and be used by cudaGraph."
938        );
939    }
940
941    #[test]
942    fn event_lifecycle_ffi_is_single_sourced() {
943        let stream = include_str!("stream.rs");
944        let create_ffi = concat!("cudarc::driver::sys::", "cuEventCreate(");
945        let destroy_ffi = concat!("cudarc::driver::sys::", "cuEventDestroy_v2(");
946
947        assert_eq!(
948            stream.matches(create_ffi).count(),
949            1,
950            "Fix: raw CUDA event creation must stay behind create_raw_event."
951        );
952        assert_eq!(
953            stream.matches(destroy_ffi).count(),
954            1,
955            "Fix: raw CUDA event destruction must stay behind destroy_raw_event."
956        );
957        assert!(
958            stream.contains("fn create_raw_event(")
959                && stream.contains("returned a null event after reporting success")
960                && stream.contains("fn destroy_raw_event(")
961                && stream.contains("CudaEvent::completion")
962                && stream.contains("CudaEvent::timing"),
963            "Fix: CUDA event lifecycle must use shared create/destroy helpers with null-success validation."
964        );
965    }
966
967    #[test]
968    fn graph_replay_uses_shared_stream_query_helper() {
969        let stream = include_str!("stream.rs");
970        let graph_replay = include_str!("backend/cuda_graph_replay.rs");
971        let query_ffi = concat!("cudarc::driver::sys::", "cuStreamQuery(");
972
973        assert_eq!(
974            stream.matches(query_ffi).count(),
975            1,
976            "Fix: raw CUDA stream query must stay behind query_raw_stream_ready."
977        );
978        assert_eq!(
979            graph_replay.matches(query_ffi).count(),
980            0,
981            "Fix: CUDA graph replay must use query_raw_stream_ready instead of raw cuStreamQuery."
982        );
983        assert!(
984            graph_replay.contains("query_raw_stream_ready")
985                && stream.contains("fn query_raw_stream_ready("),
986            "Fix: graph replay polling must use the shared stream query helper."
987        );
988    }
989}
990
991
992impl Drop for CudaPendingDispatch {
993    fn drop(&mut self) {
994        if !self.completed.load(Ordering::Acquire) {
995            if let Err(error) = self.ctx.bind_to_thread() {
996                tracing::error!(
997                    "Fix: failed to bind CUDA context while dropping pending dispatch: {error}. Dispatch completion could not be forced."
998                );
999            }
1000            if let Some(stream) = self.stream.as_ref() {
1001                if let Err(error) = stream.synchronize() {
1002                    tracing::error!(
1003                        "Fix: failed to synchronize CUDA stream while dropping pending dispatch: {error}. Dispatch completion state may be stale."
1004                    );
1005                } else {
1006                    self.telemetry.record_sync_point();
1007                }
1008            }
1009            self.completed.store(true, Ordering::Release);
1010        }
1011        self.release_launch_resources();
1012        self.allocations.take();
1013        self.resident_use.take();
1014        self.host_transfers.take();
1015    }
1016}
1017