Skip to main content

singe_cuda/
library.rs

1use std::{ffi::CString, mem::ManuallyDrop, ptr, sync::Arc};
2
3use singe_cuda_sys::driver;
4
5use crate::{
6    context::Context,
7    error::{Error, Result},
8    graph::{ExecutableGraph, Graph, GraphNode},
9    kernel::{self, LibraryKernelHandle},
10    module::{KernelFunction, KernelLaunchArgs, LaunchConfig, Module},
11    try_ffi,
12    types::{DeviceFunction, FunctionAttribute, FunctionCache},
13};
14
15#[derive(Debug)]
16pub struct Library {
17    handle: driver::CUlibrary,
18    ctx: Arc<Context>,
19}
20
21#[derive(Debug, Clone, Copy)]
22pub struct LibraryGlobal<'a> {
23    ptr: *mut (),
24    size: usize,
25    _library: &'a Library,
26}
27
28#[derive(Debug, Clone, Copy)]
29pub struct LibraryKernel<'a> {
30    handle: driver::CUkernel,
31    library: &'a Library,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub struct KernelParamInfo {
36    pub offset: usize,
37    pub size: usize,
38}
39
40impl Library {
41    pub unsafe fn from_raw(handle: driver::CUlibrary, ctx: Arc<Context>) -> Result<Self> {
42        if handle.is_null() {
43            return Err(Error::NullHandle);
44        }
45
46        Ok(Self { handle, ctx })
47    }
48
49    /// Returns the handle of the kernel with the given name located in this library.
50    /// If kernel handle is not found, the call returns [`crate::error::Status::NotFound`].
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if `name` contains an interior NUL byte, if the CUDA
55    /// context cannot be bound, if CUDA Driver cannot find the kernel, or if it
56    /// returns a null handle.
57    pub fn kernel(&self, name: &str) -> Result<LibraryKernel<'_>> {
58        let c_name = CString::new(name)?;
59        let mut handle = ptr::null_mut();
60        self.ctx.bind()?;
61        unsafe {
62            try_ffi!(driver::cuLibraryGetKernel(
63                &raw mut handle,
64                self.handle,
65                c_name.as_ptr(),
66            ))?;
67        }
68        if handle.is_null() {
69            return Err(Error::NullHandle);
70        }
71        Ok(LibraryKernel {
72            handle,
73            library: self,
74        })
75    }
76
77    /// Returns the number of kernels in this library.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if the CUDA context cannot be bound or if CUDA Driver
82    /// cannot report the kernel count.
83    pub fn kernel_count(&self) -> Result<usize> {
84        let mut count = 0;
85        self.ctx.bind()?;
86        unsafe {
87            try_ffi!(driver::cuLibraryGetKernelCount(&raw mut count, self.handle))?;
88        }
89        Ok(count as usize)
90    }
91
92    /// Returns the module handle associated with the current context located in this library.
93    /// If module handle is not found, the call returns [`crate::error::Status::NotFound`].
94    ///
95    /// # Errors
96    ///
97    /// Returns an error if the CUDA context cannot be bound, if CUDA Driver
98    /// cannot find the module, or if it returns a null handle.
99    pub fn module(&self) -> Result<Module> {
100        let mut handle = ptr::null_mut();
101        self.ctx.bind()?;
102        unsafe {
103            try_ffi!(driver::cuLibraryGetModule(&raw mut handle, self.handle))?;
104        }
105        if handle.is_null() {
106            return Err(Error::NullHandle);
107        }
108        Ok(unsafe { Module::from_borrowed_raw(handle, Arc::clone(&self.ctx)) })
109    }
110
111    /// Returns the base pointer and size of the global with the given name for the requested library and the current context.
112    /// If no global for the requested name exists, the call returns [`crate::error::Status::NotFound`].
113    ///
114    /// # Errors
115    ///
116    /// Returns an error if `name` contains an interior NUL byte, if the CUDA
117    /// context cannot be bound, or if CUDA Driver cannot find the global.
118    pub fn global(&self, name: &str) -> Result<LibraryGlobal<'_>> {
119        let c_name = CString::new(name)?;
120        let mut ptr = 0;
121        let mut size = 0;
122        self.ctx.bind()?;
123        unsafe {
124            try_ffi!(driver::cuLibraryGetGlobal(
125                &raw mut ptr,
126                &raw mut size,
127                self.handle,
128                c_name.as_ptr(),
129            ))?;
130        }
131        Ok(LibraryGlobal {
132            ptr: ptr as *mut (),
133            size: size as usize,
134            _library: self,
135        })
136    }
137
138    /// Returns the base pointer and size of the managed memory with the given name for the requested library.
139    /// If no managed memory with the requested name exists, the call returns [`crate::error::Status::NotFound`].
140    /// Managed memory for the library is shared across devices and is registered when the library is loaded into at least one context.
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if `name` contains an interior NUL byte, if the CUDA
145    /// context cannot be bound, or if CUDA Driver cannot find the managed
146    /// allocation.
147    pub fn managed(&self, name: &str) -> Result<LibraryGlobal<'_>> {
148        let c_name = CString::new(name)?;
149        let mut ptr = 0;
150        let mut size = 0;
151        self.ctx.bind()?;
152        unsafe {
153            try_ffi!(driver::cuLibraryGetManaged(
154                &raw mut ptr,
155                &raw mut size,
156                self.handle,
157                c_name.as_ptr(),
158            ))?;
159        }
160        Ok(LibraryGlobal {
161            ptr: ptr as *mut (),
162            size: size as usize,
163            _library: self,
164        })
165    }
166
167    /// Returns the pointer to the unified function named by `symbol`.
168    /// If no unified function with that name exists, the call returns [`crate::error::Status::NotFound`].
169    /// If no device in the system supports unified function pointers, the call may return [`crate::error::Status::NotFound`].
170    ///
171    /// # Errors
172    ///
173    /// Returns an error if `symbol` contains an interior NUL byte, if the CUDA
174    /// context cannot be bound, or if CUDA Driver cannot find the unified
175    /// function.
176    pub fn unified_function(&self, symbol: &str) -> Result<*mut ()> {
177        let c_symbol = CString::new(symbol)?;
178        let mut ptr = ptr::null_mut();
179        self.ctx.bind()?;
180        unsafe {
181            try_ffi!(driver::cuLibraryGetUnifiedFunction(
182                &raw mut ptr,
183                self.handle,
184                c_symbol.as_ptr(),
185            ))?;
186        }
187        if ptr.is_null() {
188            return Err(Error::NullHandle);
189        }
190        Ok(ptr.cast())
191    }
192
193    pub const fn as_raw(&self) -> driver::CUlibrary {
194        self.handle
195    }
196
197    /// Consumes the library and returns the raw CUDA library handle without
198    /// unloading it.
199    ///
200    /// The caller becomes responsible for eventually unloading the returned
201    /// handle with CUDA.
202    pub fn into_raw(self) -> driver::CUlibrary {
203        let library = ManuallyDrop::new(self);
204        library.handle
205    }
206}
207
208impl Drop for Library {
209    fn drop(&mut self) {
210        if let Err(err) = self.ctx.bind() {
211            #[cfg(debug_assertions)]
212            eprintln!("failed to bind context before unloading library: {err}");
213            return;
214        }
215
216        unsafe {
217            if let Err(err) = try_ffi!(driver::cuLibraryUnload(self.handle)) {
218                #[cfg(debug_assertions)]
219                eprintln!("failed to unload cuda library: {err}");
220            }
221        }
222    }
223}
224
225impl LibraryGlobal<'_> {
226    pub const fn as_ptr(&self) -> *mut () {
227        self.ptr
228    }
229
230    pub const fn byte_len(&self) -> usize {
231        self.size
232    }
233}
234
235impl LibraryKernel<'_> {
236    pub fn name(&self) -> Result<String> {
237        kernel::name::<LibraryKernelHandle>(self.library.ctx.as_ref(), self.handle)
238    }
239
240    /// Returns the device function handle for this kernel and the current context.
241    /// If the handle is not found, the call returns [`crate::error::Status::NotFound`].
242    ///
243    /// # Errors
244    ///
245    /// Returns an error if the CUDA context cannot be bound, if CUDA Driver
246    /// cannot find the function, or if it returns a null handle.
247    pub fn function(&self) -> Result<DeviceFunction> {
248        self.library.ctx.bind()?;
249        let mut handle = ptr::null_mut();
250        unsafe {
251            try_ffi!(driver::cuKernelGetFunction(&raw mut handle, self.handle))?;
252        }
253        if handle.is_null() {
254            return Err(Error::NullHandle);
255        }
256        Ok(unsafe { DeviceFunction::from_raw(handle) })
257    }
258
259    /// Adds this kernel to `graph` as a kernel node.
260    ///
261    /// # Safety
262    ///
263    /// The caller must ensure every pointer value passed through `params`
264    /// remains valid for every graph instantiation, update, and launch that can
265    /// execute the created node. Mutable pointer arguments must remain
266    /// exclusive for the work ordered by those launches.
267    pub unsafe fn add_to_graph<'a, P>(
268        &self,
269        graph: &mut Graph,
270        dependencies: &[GraphNode],
271        config: &LaunchConfig,
272        params: P,
273    ) -> Result<GraphNode>
274    where
275        P: KernelLaunchArgs<'a>,
276    {
277        let function = self.function()?;
278        let module = self.library.module()?;
279        let function = unsafe { KernelFunction::from_raw(function, &module) };
280        unsafe { function.add_to_graph(graph, dependencies, config, params) }
281    }
282
283    /// Updates this kernel's parameters in an executable graph node.
284    ///
285    /// # Safety
286    ///
287    /// The caller must ensure every pointer value passed through `params`
288    /// remains valid for every future launch that can execute `node`. Mutable
289    /// pointer arguments must remain exclusive for the work ordered by those
290    /// launches.
291    pub unsafe fn set_graph_node_params<'a, P>(
292        &self,
293        executable: &mut ExecutableGraph,
294        node: GraphNode,
295        config: &LaunchConfig,
296        params: P,
297    ) -> Result<()>
298    where
299        P: KernelLaunchArgs<'a>,
300    {
301        let function = self.function()?;
302        let module = self.library.module()?;
303        let function = unsafe { KernelFunction::from_raw(function, &module) };
304        unsafe { function.set_graph_node_params(executable, node, config, params) }
305    }
306
307    pub fn attribute(&self, attribute: FunctionAttribute) -> Result<i32> {
308        kernel::attribute::<LibraryKernelHandle>(self.library.ctx.as_ref(), self.handle, attribute)
309    }
310
311    pub fn set_attribute(&self, attribute: FunctionAttribute, value: i32) -> Result<()> {
312        kernel::set_attribute::<LibraryKernelHandle>(
313            self.library.ctx.as_ref(),
314            self.handle,
315            attribute,
316            value,
317        )
318    }
319
320    /// Sets the preferred cache configuration for this kernel on devices where L1 cache and shared memory use the same hardware resources.
321    /// This setting is only a preference.
322    /// The driver uses the requested configuration if possible, but it may choose a different configuration if required to execute the kernel.
323    /// This per-kernel setting overrides any context-wide preference set via [`sys::cuCtxSetCacheConfig`](singe_cuda_sys::driver::cuCtxSetCacheConfig).
324    ///
325    /// Attributes set using [`sys::cuFuncSetCacheConfig`](singe_cuda_sys::driver::cuFuncSetCacheConfig) override this preference regardless of call order.
326    ///
327    /// This setting does nothing on devices where the size of the L1 cache and shared memory are fixed.
328    ///
329    /// Launching a kernel with a different preference than the most recent preference setting may insert a device-side synchronization point.
330    ///
331    /// The supported cache configurations are:
332    ///
333    /// * [`FunctionCache::PreferNone`]: no preference for shared memory or L1 (default)
334    /// * [`FunctionCache::PreferShared`]: prefer larger shared memory and smaller L1 cache
335    /// * [`FunctionCache::PreferL1`]: prefer larger L1 cache and smaller shared memory
336    /// * [`FunctionCache::PreferEqual`]: prefer equal sized L1 cache and shared memory
337    ///
338    /// This has stricter locking requirements than its legacy counterpart [`sys::cuFuncSetCacheConfig`](singe_cuda_sys::driver::cuFuncSetCacheConfig) because the setting has device-wide semantics.
339    /// If multiple threads try to set a configuration on the same device simultaneously, the final cache configuration depends on OS scheduler interleaving and memory consistency.
340    ///
341    /// # Errors
342    ///
343    /// Returns an error if the CUDA context cannot be bound or if CUDA Driver
344    /// rejects the cache configuration.
345    pub fn set_cache_config(&self, config: FunctionCache) -> Result<()> {
346        self.library.ctx.bind()?;
347        unsafe {
348            try_ffi!(driver::cuKernelSetCacheConfig(
349                self.handle,
350                config.into(),
351                self.library.ctx.device().id() as _,
352            ))?;
353        }
354        Ok(())
355    }
356
357    /// Queries the kernel parameter at the given index, returning the offset and size where the parameter resides in the device-side parameter layout.
358    /// Use this information to update kernel node parameters from the device. The index must be less than the number of parameters that the kernel takes.
359    ///
360    /// # Errors
361    ///
362    /// Returns an error if the library context cannot be bound, `index` is not a valid kernel
363    /// parameter index, CUDA cannot query the parameter layout, or a previous asynchronous launch
364    /// reported an error.
365    pub fn param_info(&self, index: usize) -> Result<KernelParamInfo> {
366        self.library.ctx.bind()?;
367        let mut offset = 0;
368        let mut size = 0;
369        unsafe {
370            try_ffi!(driver::cuKernelGetParamInfo(
371                self.handle,
372                index as _,
373                &raw mut offset,
374                &raw mut size,
375            ))?;
376        }
377        Ok(KernelParamInfo {
378            offset: offset as usize,
379            size: size as usize,
380        })
381    }
382
383    pub const fn as_raw(&self) -> driver::CUkernel {
384        self.handle
385    }
386}