Skip to main content

singe_cuda/
context.rs

1use std::{ffi::CString, mem, ptr, sync::Arc};
2
3use singe_cuda_sys::driver;
4
5use crate::{
6    device::Device,
7    error::{Error, Result},
8    graph::Graph,
9    jit::JitOptions,
10    library::Library,
11    module::{Module, ModuleImage},
12    nvrtc::{self, CompilationArtifact, OutputKind},
13    try_ffi,
14    types::Limit,
15};
16
17bitflags::bitflags! {
18    /// Context creation flags.
19    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20    pub struct ContextFlags: u32 {
21        const SCHEDULE_AUTO = driver::CUctx_flags::CU_CTX_SCHED_AUTO as _;
22        const SCHEDULE_SPIN = driver::CUctx_flags::CU_CTX_SCHED_SPIN as _;
23        const SCHEDULE_YIELD = driver::CUctx_flags::CU_CTX_SCHED_YIELD as _;
24        const SCHEDULE_BLOCKING_SYNC = driver::CUctx_flags::CU_CTX_SCHED_BLOCKING_SYNC as _;
25        const MAP_HOST = driver::CUctx_flags::CU_CTX_MAP_HOST as _;
26        const LOCAL_MEMORY_RESIZE_TO_MAX = driver::CUctx_flags::CU_CTX_LMEM_RESIZE_TO_MAX as _;
27        const COREDUMP_ENABLE = driver::CUctx_flags::CU_CTX_COREDUMP_ENABLE as _;
28        const USER_COREDUMP_ENABLE = driver::CUctx_flags::CU_CTX_USER_COREDUMP_ENABLE as _;
29        const SYNC_MEMORY_OPERATIONS = driver::CUctx_flags::CU_CTX_SYNC_MEMOPS as _;
30    }
31}
32
33/// A shared CUDA driver context.
34///
35/// Unlike cuBLAS, cuDNN, cuFFT, and similar library handles, a CUDA context is
36/// the underlying execution environment for a device. It is intended to be
37/// shared by streams, modules, libraries, events, allocations, and higher-level
38/// library wrappers.
39///
40/// This type is therefore reference-counted by returning [`Arc<Self>`] from the
41/// constructors, and it remains `Send + Sync`. Shared references do not mutate
42/// Rust-visible state on the [`Context`] object itself; methods such as `bind`
43/// update the calling thread's current CUDA context in the driver.
44///
45/// Prefer one long-lived context per device and share it across dependent CUDA
46/// objects instead of creating many short-lived contexts.
47#[derive(Debug)]
48pub struct Context {
49    handle: driver::CUcontext,
50    device: Device,
51    ownership: ContextOwnership,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55enum ContextOwnership {
56    Created,
57    Primary,
58}
59
60#[non_exhaustive]
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
62pub enum RawContextOwnership {
63    Created,
64    Primary,
65}
66
67impl From<RawContextOwnership> for ContextOwnership {
68    fn from(value: RawContextOwnership) -> Self {
69        match value {
70            RawContextOwnership::Created => Self::Created,
71            RawContextOwnership::Primary => Self::Primary,
72        }
73    }
74}
75
76impl From<ContextOwnership> for RawContextOwnership {
77    fn from(value: ContextOwnership) -> Self {
78        match value {
79            ContextOwnership::Created => Self::Created,
80            ContextOwnership::Primary => Self::Primary,
81        }
82    }
83}
84
85impl Context {
86    pub fn create() -> Result<Arc<Self>> {
87        Self::create_with_flags(ContextFlags::empty())
88    }
89
90    pub fn create_with_flags(flags: ContextFlags) -> Result<Arc<Self>> {
91        let device = Device::current()?;
92        Self::create_for_device_with_flags(device, flags)
93    }
94
95    pub fn create_for_device(device: Device) -> Result<Arc<Self>> {
96        Self::create_for_device_with_flags(device, ContextFlags::empty())
97    }
98
99    pub fn create_for_device_with_flags(device: Device, flags: ContextFlags) -> Result<Arc<Self>> {
100        unsafe {
101            try_ffi!(driver::cuInit(0))?;
102
103            let mut handle = ptr::null_mut();
104            try_ffi!(driver::cuCtxCreate_v4(
105                &raw mut handle,
106                ptr::null_mut(), // CUctxCreateParams
107                flags.bits(),
108                device.id() as _,
109            ))?;
110
111            if handle.is_null() {
112                return Err(Error::NullHandle);
113            }
114
115            Ok(Arc::new(Self {
116                handle,
117                device,
118                ownership: ContextOwnership::Created,
119            }))
120        }
121    }
122
123    pub fn retain_primary_for_device(device: Device) -> Result<Arc<Self>> {
124        unsafe {
125            try_ffi!(driver::cuInit(0))?;
126
127            let mut handle = ptr::null_mut();
128            try_ffi!(driver::cuDevicePrimaryCtxRetain(
129                &raw mut handle,
130                device.id() as _,
131            ))?;
132
133            if handle.is_null() {
134                return Err(Error::NullHandle);
135            }
136
137            try_ffi!(driver::cuCtxSetCurrent(handle))?;
138
139            Ok(Arc::new(Self {
140                handle,
141                device,
142                ownership: ContextOwnership::Primary,
143            }))
144        }
145    }
146
147    /// Binds this CUDA context to the calling CPU thread.
148    ///
149    /// The "current context" is thread-local driver state. Calling this method
150    /// does not mutate the Rust [`Context`] value itself; it makes this context
151    /// current for subsequent CUDA driver and interoperating runtime calls on
152    /// the current host thread.
153    ///
154    /// # Errors
155    ///
156    /// Returns an error if CUDA Driver cannot query or set the current context.
157    pub fn bind(&self) -> Result<()> {
158        unsafe {
159            let mut current_ctx = ptr::null_mut();
160            try_ffi!(driver::cuCtxGetCurrent(&raw mut current_ctx))?;
161            if current_ctx == self.as_raw() {
162                return Ok(());
163            }
164            try_ffi!(driver::cuCtxSetCurrent(self.as_raw()))?;
165        }
166        Ok(())
167    }
168
169    /// Loads the corresponding module from the given image into the current context.
170    /// The image may be a cubin or fatbin as output by **nvcc**, or a NUL-terminated PTX string, either as output by **nvcc** or hand-written, or Tile IR data.
171    ///
172    /// # Errors
173    ///
174    /// Returns an error if the context cannot be bound, CUDA cannot load the module, or a
175    /// previous asynchronous launch reported an error.
176    pub fn load_module(self: &Arc<Self>, image: &ModuleImage<'_>) -> Result<Module> {
177        self.bind()?;
178
179        unsafe {
180            let mut module_handle = ptr::null_mut();
181            try_ffi!(driver::cuModuleLoadData(
182                &raw mut module_handle,
183                image.as_ptr() as _,
184            ))?;
185            if module_handle.is_null() {
186                return Err(Error::NullHandle);
187            }
188            Module::from_raw(module_handle, Arc::clone(self))
189        }
190    }
191
192    /// Creates an empty CUDA graph associated with this context.
193    ///
194    /// Prefer this over [`RawGraph::create`](crate::graph::RawGraph::create)
195    /// for ordinary Singe code. The returned graph carries its context
196    /// association into instantiated executable graphs, allowing launches and
197    /// uploads to reject streams from another context before calling CUDA.
198    ///
199    /// # Errors
200    ///
201    /// Returns an error if the context cannot be bound or CUDA cannot create the graph.
202    pub fn create_graph(self: &Arc<Self>) -> Result<Graph> {
203        Graph::create_in_context(Arc::clone(self))
204    }
205
206    pub fn unload_module(self: &Arc<Self>, module: Module) -> Result<()> {
207        drop(module);
208        Ok(())
209    }
210
211    /// Loads the corresponding module from the given image into the current context.
212    /// The image may be a cubin or fatbin as output by **nvcc**, or a NUL-terminated PTX string, either as output by **nvcc** or hand-written, or Tile IR data.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if the context cannot be bound, CUDA cannot load the module, JIT options
217    /// are rejected, or a previous asynchronous launch reported an error.
218    pub fn load_module_with_options(
219        self: &Arc<Self>,
220        image: &ModuleImage<'_>,
221        mut jit_options: JitOptions<'_>,
222    ) -> Result<Module> {
223        self.bind()?;
224
225        let mut jit_options = jit_options.build();
226        unsafe {
227            let mut module_handle = ptr::null_mut();
228            try_ffi!(driver::cuModuleLoadDataEx(
229                &raw mut module_handle,
230                image.as_ptr() as _,
231                jit_options.names.len() as _,
232                jit_options.names.as_mut_ptr() as _,
233                jit_options.values.as_mut_ptr() as _,
234            ))?;
235            if module_handle.is_null() {
236                return Err(Error::NullHandle);
237            }
238            Module::from_raw(module_handle, Arc::clone(self))
239        }
240    }
241
242    pub fn load_nvrtc_module(
243        self: &Arc<Self>,
244        program: &nvrtc::Program,
245        output: OutputKind,
246    ) -> Result<Module> {
247        self.load_nvrtc_module_with_options(program, output, JitOptions::default())
248    }
249
250    pub fn load_nvrtc_module_with_options(
251        self: &Arc<Self>,
252        program: &nvrtc::Program,
253        output: OutputKind,
254        jit_options: JitOptions<'_>,
255    ) -> Result<Module> {
256        let image = module_loadable_image(program.artifact(output)?)?;
257        self.load_module_with_options(&image, jit_options)
258    }
259
260    pub fn load_library(self: &Arc<Self>, image: &ModuleImage<'_>) -> Result<Library> {
261        self.load_library_with_options(image, JitOptions::default())
262    }
263
264    /// Loads the corresponding library from the given image based on the application defined library loading mode:
265    ///
266    /// * If module loading is set to EAGER by the environment variables described in "Module loading", the library is loaded eagerly into all contexts at the time of the call and future contexts at the time of creation until the library
267    ///   is unloaded with [`sys::cuLibraryUnload`](singe_cuda_sys::driver::cuLibraryUnload).
268    /// * If the environment variables are set to LAZY, the library is not immediately loaded into existing contexts and is loaded only when a function is needed for that context,
269    ///   such as a kernel launch.
270    ///
271    /// These environment variables are described in the CUDA programming guide under the "CUDA environment variables" section.
272    ///
273    /// The code may be a cubin or fatbin emitted by **nvcc**, a NUL-terminated PTX string emitted by **nvcc** or written by hand, or Tile IR data.
274    /// A fatbin must also contain relocatable code when doing separate compilation.
275    ///
276    /// If the library contains managed variables and no device in the system supports them, this call returns [`crate::error::Status::NotSupported`].
277    pub fn load_library_with_options(
278        self: &Arc<Self>,
279        image: &ModuleImage<'_>,
280        mut jit_options: JitOptions<'_>,
281    ) -> Result<Library> {
282        self.bind()?;
283
284        let mut jit_options = jit_options.build();
285        let mut handle = ptr::null_mut();
286        unsafe {
287            try_ffi!(driver::cuLibraryLoadData(
288                &raw mut handle,
289                image.as_ptr() as _,
290                jit_options.names.as_mut_ptr() as _,
291                jit_options.values.as_mut_ptr() as _,
292                jit_options.names.len() as _,
293                ptr::null_mut(),
294                ptr::null_mut(),
295                0,
296            ))?;
297        }
298        if handle.is_null() {
299            return Err(Error::NullHandle);
300        }
301        unsafe { Library::from_raw(handle, Arc::clone(self)) }
302    }
303
304    pub fn load_nvrtc_library(
305        self: &Arc<Self>,
306        program: &nvrtc::Program,
307        output: OutputKind,
308    ) -> Result<Library> {
309        self.load_nvrtc_library_with_options(program, output, JitOptions::default())
310    }
311
312    pub fn load_nvrtc_library_with_options(
313        self: &Arc<Self>,
314        program: &nvrtc::Program,
315        output: OutputKind,
316        jit_options: JitOptions<'_>,
317    ) -> Result<Library> {
318        let image = library_loadable_image(program.artifact(output)?)?;
319        self.load_library_with_options(&image, jit_options)
320    }
321
322    /// Loads the corresponding library from the given file based on the application defined library loading mode:
323    ///
324    /// * If module loading is set to EAGER by the environment variables described in "Module loading", the library is loaded eagerly into all contexts at the time of the call and future contexts at the time of creation until the library
325    ///   is unloaded with [`sys::cuLibraryUnload`](singe_cuda_sys::driver::cuLibraryUnload).
326    /// * If the environment variables are set to LAZY, the library is not immediately loaded into existing contexts and is loaded only when a function is needed for that context,
327    ///   such as a kernel launch.
328    ///
329    /// These environment variables are described in the CUDA programming guide under the "CUDA environment variables" section.
330    ///
331    /// The file must be a cubin emitted by **nvcc**, a PTX file emitted by **nvcc** or written by hand, a fatbin emitted by **nvcc** or written by hand, or a Tile IR file.
332    /// A fatbin must also contain relocatable code when doing separate compilation.
333    ///
334    /// If the library contains managed variables and no device in the system supports them, this call returns [`crate::error::Status::NotSupported`].
335    ///
336    /// # Errors
337    ///
338    /// Returns an error if this context cannot be bound, if `path` contains an
339    /// interior NUL byte, or if CUDA Driver cannot load the library.
340    pub fn load_library_from_file(self: &Arc<Self>, path: &str) -> Result<Library> {
341        self.bind()?;
342        let path = CString::new(path)?;
343        let mut handle = ptr::null_mut();
344        unsafe {
345            try_ffi!(driver::cuLibraryLoadFromFile(
346                &raw mut handle,
347                path.as_ptr(),
348                ptr::null_mut(),
349                ptr::null_mut(),
350                0,
351                ptr::null_mut(),
352                ptr::null_mut(),
353                0,
354            ))?;
355        }
356        if handle.is_null() {
357            return Err(Error::NullHandle);
358        }
359        unsafe { Library::from_raw(handle, Arc::clone(self)) }
360    }
361
362    /// Blocks until the current context has completed all preceding requested tasks.
363    /// If the current context is the primary context, child contexts that have been created are also synchronized.
364    /// [`Context::synchronize`] returns an error if one of the preceding tasks failed.
365    /// If the context was created with [`ContextFlags::SCHEDULE_BLOCKING_SYNC`], the CPU thread blocks until the GPU context has finished its work.
366    ///
367    /// # Errors
368    ///
369    /// Returns an error if the context cannot be bound, a preceding task failed, or a previous
370    /// asynchronous launch reported an error.
371    pub fn synchronize(&self) -> Result<()> {
372        self.bind()?;
373        unsafe {
374            try_ffi!(driver::cuCtxSynchronize())?;
375        }
376        Ok(())
377    }
378
379    /// Returns the flags of the current context.
380    /// See [`ContextFlags`] for flag values.
381    ///
382    /// # Errors
383    ///
384    /// Returns an error if the context cannot be bound, CUDA cannot query the flags, or a
385    /// previous asynchronous launch reported an error.
386    pub fn flags(&self) -> Result<ContextFlags> {
387        self.bind()?;
388        unsafe {
389            let mut flags = 0;
390            try_ffi!(driver::cuCtxGetFlags(&raw mut flags))?;
391            Ok(ContextFlags::from_bits_truncate(flags))
392        }
393    }
394
395    /// Returns the current size of limit.
396    /// The supported [`Limit`] values are:
397    ///
398    /// * [`Limit::StackSize`]: stack size in bytes of each GPU thread.
399    /// * [`Limit::PrintfFifoSize`]: size in bytes of the FIFO used by the `printf()` device system call.
400    /// * [`Limit::MallocHeapSize`]: size in bytes of the heap used by the `malloc()` and `free()` device system calls.
401    /// * [`Limit::DevRuntimeSyncDepth`]: maximum grid depth at which a thread can issue the device runtime call [`Device::synchronize`] to wait on child grid launches to complete.
402    /// * [`Limit::DevRuntimePendingLaunchCount`]: maximum number of outstanding device runtime launches that can be made from this context.
403    /// * [`Limit::MaxL2FetchGranularity`]: L2 cache fetch granularity.
404    /// * [`Limit::PersistingL2CacheSize`]: persisting L2 cache size in bytes.
405    ///
406    /// # Errors
407    ///
408    /// Returns an error if the context cannot be bound, `limit` is unsupported, CUDA cannot query
409    /// the limit, or a previous asynchronous launch reported an error.
410    pub fn limit(&self, limit: Limit) -> Result<usize> {
411        self.bind()?;
412        unsafe {
413            let mut value = 0;
414            try_ffi!(driver::cuCtxGetLimit(&raw mut value, limit.into()))?;
415            Ok(value as usize)
416        }
417    }
418
419    /// Setting limit to value is a request by the application to update the current limit maintained by the context.
420    /// The driver may modify the requested value to meet hardware requirements, such as clamping to minimum or maximum values or rounding up to the nearest element size.
421    /// Use [`Context::limit`] to query the effective value.
422    ///
423    /// Setting each [`Limit`] has its own restrictions.
424    ///
425    /// * [`Limit::StackSize`] controls the stack size in bytes of each GPU thread.
426    ///   The driver automatically increases the per-thread stack size for each
427    ///   kernel launch as needed.
428    ///   This size is not reset back to the original value after each launch.
429    ///   Setting this value will take
430    ///   effect immediately, and if necessary, the device will block until all preceding requested tasks are complete.
431    ///
432    /// * [`Limit::PrintfFifoSize`] controls the size in bytes of the FIFO used by the `printf()` device system call.
433    ///   Configure [`Limit::PrintfFifoSize`] before launching any kernel that uses the `printf()` device system call; otherwise [`crate::error::Status::InvalidValue`] is returned.
434    ///
435    /// * [`Limit::MallocHeapSize`] controls the size in bytes of the heap used by the `malloc()` and `free()` device system calls.
436    ///   Configure [`Limit::MallocHeapSize`] before launching any kernel that uses the `malloc()` or `free()` device system calls; otherwise [`crate::error::Status::InvalidValue`] is returned.
437    ///
438    /// * [`Limit::DevRuntimeSyncDepth`] controls the maximum nesting depth of a grid at which a thread can safely call [`Device::synchronize`].
439    ///   Setting this limit must be performed before any launch of a kernel that uses the device runtime and calls [`Device::synchronize`] above the default sync depth, two levels of grids.
440    ///   Calls to [`Device::synchronize`] fail if this limit is violated.
441    ///   This limit can be set smaller than the default or up to the maximum launch depth of 24.
442    ///   Additional sync-depth levels require the driver to reserve large amounts of device memory that can no longer be used for application allocations.
443    ///   If these reservations of device memory fail, [`Context::set_limit`] returns [`crate::error::Status::OutOfMemory`], and the limit can be reset to a lower value.
444    ///   This limit is only applicable to devices of compute capability &lt; 9.0.
445    ///   Setting this limit on devices of other compute capability versions returns [`crate::error::Status::UnsupportedLimit`].
446    ///
447    /// * [`Limit::DevRuntimePendingLaunchCount`] controls the maximum number of outstanding device runtime launches that can be made from the current context.
448    ///   A grid is outstanding from launch until it is known to have completed.
449    ///   Device runtime launches that violate this limit fail.
450    ///   If a module using the device runtime needs more pending launches than the default 2048 launches, this limit can be increased.
451    ///   Sustaining additional pending launches requires the driver to reserve larger amounts of device memory up front, which can no longer be used for allocations.
452    ///   If these reservations fail, [`Context::set_limit`] returns [`crate::error::Status::OutOfMemory`], and the limit can be reset to a lower value.
453    ///   This limit is only applicable to devices of compute capability 3.5 and higher.
454    ///   Attempting to set this limit on devices of compute capability less than 3.5 returns [`crate::error::Status::UnsupportedLimit`].
455    ///
456    /// * [`Limit::MaxL2FetchGranularity`] controls the L2 cache fetch granularity.
457    ///   Values can range from 0B to 128B.
458    ///   Performance hint that may be ignored or clamped depending on the platform.
459    ///
460    /// * [`Limit::PersistingL2CacheSize`] controls size in bytes available for persisting L2 cache.
461    ///   Performance hint that may be ignored or clamped depending on the platform.
462    ///
463    /// # Errors
464    ///
465    /// Returns an error if the context cannot be bound, `limit` is unsupported, CUDA rejects the
466    /// requested value, or a previous asynchronous launch reported an error.
467    pub fn set_limit(&self, limit: Limit, value: usize) -> Result<()> {
468        self.bind()?;
469        unsafe {
470            try_ffi!(driver::cuCtxSetLimit(limit.into(), value as _))?;
471        }
472        Ok(())
473    }
474
475    pub const fn device(&self) -> Device {
476        self.device
477    }
478
479    pub const fn as_raw(&self) -> driver::CUcontext {
480        self.handle
481    }
482
483    /// Takes ownership of a raw CUDA context.
484    ///
485    /// # Safety
486    ///
487    /// `handle` must be a valid CUDA context for `device`, and no other Rust
488    /// wrapper may own the same release responsibility. `ownership` must match
489    /// how the context should be released: created contexts are destroyed with
490    /// `cuCtxDestroy`, while primary contexts are released with
491    /// `cuDevicePrimaryCtxRelease`.
492    pub unsafe fn from_raw(
493        handle: driver::CUcontext,
494        device: Device,
495        ownership: RawContextOwnership,
496    ) -> Result<Arc<Self>> {
497        if handle.is_null() {
498            return Err(Error::NullHandle);
499        }
500
501        Ok(Arc::new(Self {
502            handle,
503            device,
504            ownership: ownership.into(),
505        }))
506    }
507
508    /// Transfers ownership of the raw CUDA context to the caller.
509    ///
510    /// The caller becomes responsible for releasing the returned context
511    /// according to the returned ownership mode.
512    pub fn into_raw_parts(self) -> (driver::CUcontext, Device, RawContextOwnership) {
513        let raw = (self.handle, self.device, self.ownership.into());
514        mem::forget(self);
515        raw
516    }
517}
518
519// CUDA driver contexts are shared execution environments, not per-thread
520// library handles. The Rust wrapper only stores the raw context pointer and the
521// owning device, while current-context selection is maintained by CUDA as
522// thread-local driver state.
523unsafe impl Send for Context {}
524unsafe impl Sync for Context {}
525
526impl Drop for Context {
527    fn drop(&mut self) {
528        unsafe {
529            let result = match self.ownership {
530                ContextOwnership::Created => try_ffi!(driver::cuCtxDestroy_v2(self.handle)),
531                ContextOwnership::Primary => {
532                    try_ffi!(driver::cuDevicePrimaryCtxRelease_v2(self.device.id() as _))
533                }
534            };
535
536            if let Err(err) = result {
537                #[cfg(debug_assertions)]
538                eprintln!("failed to destroy CUDA context wrapper: {err}");
539            }
540        }
541    }
542}
543
544impl PartialEq for Context {
545    fn eq(&self, other: &Self) -> bool {
546        self.as_raw() == other.as_raw()
547    }
548}
549
550impl Eq for Context {}
551
552fn module_loadable_image(artifact: CompilationArtifact) -> Result<ModuleImage<'static>> {
553    match artifact {
554        CompilationArtifact::Ptx(image) | CompilationArtifact::Cubin(image) => Ok(image),
555        CompilationArtifact::LtoIr(_) | CompilationArtifact::OptixIr(_) => Err(Error::InvalidValue),
556    }
557}
558
559fn library_loadable_image(artifact: CompilationArtifact) -> Result<ModuleImage<'static>> {
560    match artifact {
561        CompilationArtifact::Ptx(image) | CompilationArtifact::Cubin(image) => Ok(image),
562        CompilationArtifact::LtoIr(_) | CompilationArtifact::OptixIr(_) => Err(Error::InvalidValue),
563    }
564}