Skip to main content

singe_cuda/
external_memory.rs

1//! Safe ownership wrappers for CUDA external memory imports.
2
3use std::{
4    ffi::c_void,
5    marker::PhantomData,
6    mem::{self, size_of},
7    ptr,
8};
9
10use num_enum::{IntoPrimitive, TryFromPrimitive};
11use singe_core::impl_enum_conversion;
12use singe_cuda_sys::driver;
13
14use crate::{
15    error::{Error, Result},
16    module::{KernelParameters, PushKernelArg},
17    try_ffi,
18    view::{DeviceRepr, DeviceSlice, DeviceSliceMut},
19};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
22#[repr(u32)]
23#[non_exhaustive]
24pub enum ExternalMemoryHandleType {
25    /// Handle is an opaque file descriptor.
26    OpaqueFileDescriptor =
27        driver::CUexternalMemoryHandleType_enum::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD as _,
28    /// Handle is an opaque shared NT handle.
29    OpaqueWin32 =
30        driver::CUexternalMemoryHandleType_enum::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 as _,
31    /// Handle is an opaque, globally shared handle.
32    OpaqueWin32Kmt =
33        driver::CUexternalMemoryHandleType_enum::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT
34            as _,
35    /// Handle is a dma_buf file descriptor.
36    DmaBufferFileDescriptor =
37        driver::CUexternalMemoryHandleType_enum::CU_EXTERNAL_MEMORY_HANDLE_TYPE_DMABUF_FD as _,
38}
39
40impl_enum_conversion!(driver::CUexternalMemoryHandleType, ExternalMemoryHandleType);
41
42bitflags::bitflags! {
43    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44    pub struct ExternalMemoryFlags: u32 {
45        const DEDICATED = driver::CUDA_EXTERNAL_MEMORY_DEDICATED;
46    }
47}
48
49#[derive(Debug)]
50pub struct ExternalMemory {
51    handle: driver::CUexternalMemory,
52    size: usize,
53}
54
55#[derive(Debug)]
56pub struct MappedBuffer<'a, T: DeviceRepr> {
57    ptr: *mut T,
58    length: usize,
59    // The mapped buffer must not outlive the imported external-memory object it
60    // was derived from. CUDA still requires the mapped pointer to be freed
61    // separately with cuMemFree.
62    _memory: PhantomData<&'a ExternalMemory>,
63}
64
65impl ExternalMemory {
66    /// Imports an opaque file descriptor as CUDA external memory.
67    ///
68    /// # Safety
69    ///
70    /// `fd` must reference a valid external memory object of `size` bytes.
71    /// CUDA takes ownership of the descriptor on a successful import.
72    #[cfg(unix)]
73    pub unsafe fn import_opaque_file_descriptor(
74        fd: std::os::fd::RawFd,
75        size: usize,
76        flags: ExternalMemoryFlags,
77    ) -> Result<Self> {
78        let mut desc = handle_desc(ExternalMemoryHandleType::OpaqueFileDescriptor, size, flags)?;
79        desc.handle.fd = fd;
80        unsafe { Self::import(&desc, size) }
81    }
82
83    /// Imports an opaque Win32 shared handle as CUDA external memory.
84    ///
85    /// # Safety
86    ///
87    /// `handle` must be a valid shared handle for a memory object of `size`
88    /// bytes and must remain valid according to CUDA's external-memory import
89    /// rules.
90    pub unsafe fn import_opaque_win32_handle(
91        handle: *mut c_void,
92        size: usize,
93        flags: ExternalMemoryFlags,
94    ) -> Result<Self> {
95        if handle.is_null() {
96            return Err(Error::NullHandle);
97        }
98
99        let mut desc = handle_desc(ExternalMemoryHandleType::OpaqueWin32, size, flags)?;
100        desc.handle.win32.handle = handle;
101        desc.handle.win32.name = ptr::null();
102        unsafe { Self::import(&desc, size) }
103    }
104
105    /// Imports an opaque named Win32 shared object as CUDA external memory.
106    ///
107    /// # Safety
108    ///
109    /// `name` must point to a valid null-terminated Win32 object name for a
110    /// memory object of `size` bytes and remain valid for the import call.
111    pub unsafe fn import_opaque_win32_name(
112        name: *const c_void,
113        size: usize,
114        flags: ExternalMemoryFlags,
115    ) -> Result<Self> {
116        if name.is_null() {
117            return Err(Error::NullHandle);
118        }
119
120        let mut desc = handle_desc(ExternalMemoryHandleType::OpaqueWin32, size, flags)?;
121        desc.handle.win32.handle = ptr::null_mut();
122        desc.handle.win32.name = name;
123        unsafe { Self::import(&desc, size) }
124    }
125
126    unsafe fn import(desc: &driver::CUDA_EXTERNAL_MEMORY_HANDLE_DESC, size: usize) -> Result<Self> {
127        let mut handle = ptr::null_mut();
128        unsafe {
129            try_ffi!(driver::cuImportExternalMemory(
130                &raw mut handle,
131                desc as *const _,
132            ))?;
133        }
134        if handle.is_null() {
135            return Err(Error::NullHandle);
136        }
137
138        Ok(Self { handle, size })
139    }
140
141    pub fn map_buffer<T: DeviceRepr>(
142        &self,
143        offset_bytes: usize,
144        length: usize,
145    ) -> Result<MappedBuffer<'_, T>> {
146        let bytes = checked_bytes::<T>(length)?;
147        if bytes == 0 {
148            return Err(Error::InvalidMemoryAllocationRequest);
149        }
150        let end = offset_bytes
151            .checked_add(bytes)
152            .ok_or(Error::InvalidMemoryAllocationRequest)?;
153        if end > self.size {
154            return Err(Error::InvalidMemoryAccess);
155        }
156
157        // CUDA returns a device pointer for this mapped range. That pointer is
158        // a distinct CUDA allocation handle and must be freed by MappedBuffer.
159        let desc = driver::CUDA_EXTERNAL_MEMORY_BUFFER_DESC {
160            offset: offset_bytes as _,
161            size: bytes as _,
162            flags: 0,
163            reserved: [0; 16],
164        };
165        let mut ptr = 0;
166        unsafe {
167            try_ffi!(driver::cuExternalMemoryGetMappedBuffer(
168                &raw mut ptr,
169                self.handle,
170                &raw const desc,
171            ))?;
172        }
173        if ptr == 0 {
174            return Err(Error::NullHandle);
175        }
176
177        Ok(MappedBuffer {
178            ptr: ptr as *mut T,
179            length,
180            _memory: PhantomData,
181        })
182    }
183
184    pub const fn byte_len(&self) -> usize {
185        self.size
186    }
187
188    pub const fn as_raw(&self) -> driver::CUexternalMemory {
189        self.handle
190    }
191
192    /// Takes ownership of a raw CUDA external-memory handle.
193    ///
194    /// # Safety
195    ///
196    /// `handle` must be a valid `CUexternalMemory` that is not owned by any
197    /// other wrapper, and `size` must match the imported memory object's byte
198    /// size so mapped ranges can be bounds-checked correctly.
199    pub unsafe fn from_raw(handle: driver::CUexternalMemory, size: usize) -> Result<Self> {
200        if handle.is_null() {
201            return Err(Error::NullHandle);
202        }
203
204        Ok(Self { handle, size })
205    }
206
207    /// Transfers ownership of the raw CUDA external-memory handle to the
208    /// caller without destroying it.
209    ///
210    /// The caller becomes responsible for eventually destroying the returned
211    /// handle with `cuDestroyExternalMemory`.
212    pub fn into_raw(self) -> driver::CUexternalMemory {
213        let handle = self.handle;
214        mem::forget(self);
215        handle
216    }
217}
218
219impl Drop for ExternalMemory {
220    fn drop(&mut self) {
221        if self.handle.is_null() {
222            return;
223        }
224
225        unsafe {
226            if let Err(error) = try_ffi!(driver::cuDestroyExternalMemory(self.handle)) {
227                #[cfg(debug_assertions)]
228                eprintln!("failed to destroy cuda external memory: {error}");
229            }
230        }
231        self.handle = ptr::null_mut();
232    }
233}
234
235// External memory handles identify imported allocations. Mapping operations
236// require &self, and lifetime-sensitive mapped buffers borrow the owner.
237unsafe impl Send for ExternalMemory {}
238unsafe impl Sync for ExternalMemory {}
239
240impl<T: DeviceRepr> MappedBuffer<'_, T> {
241    pub const fn as_ptr(&self) -> *const T {
242        self.ptr
243    }
244
245    pub const fn as_mut_ptr(&mut self) -> *mut T {
246        self.ptr
247    }
248
249    pub const fn len(&self) -> usize {
250        self.length
251    }
252
253    pub const fn is_empty(&self) -> bool {
254        self.length == 0
255    }
256}
257
258impl<T: DeviceRepr> Drop for MappedBuffer<'_, T> {
259    fn drop(&mut self) {
260        if self.ptr.is_null() {
261            return;
262        }
263
264        unsafe {
265            if let Err(error) = try_ffi!(driver::cuMemFree_v2(self.ptr as driver::CUdeviceptr)) {
266                #[cfg(debug_assertions)]
267                eprintln!("failed to free mapped external memory buffer: {error}");
268            }
269        }
270        self.ptr = ptr::null_mut();
271        self.length = 0;
272    }
273}
274
275unsafe impl<T: DeviceRepr + Send> Send for MappedBuffer<'_, T> {}
276unsafe impl<T: DeviceRepr + Sync> Sync for MappedBuffer<'_, T> {}
277
278impl<T: DeviceRepr> DeviceSlice<T> for MappedBuffer<'_, T> {
279    fn as_device_ptr(&self) -> *const T {
280        self.ptr
281    }
282
283    fn len(&self) -> usize {
284        self.length
285    }
286}
287
288impl<T: DeviceRepr> DeviceSliceMut<T> for MappedBuffer<'_, T> {
289    fn as_device_mut_ptr(&mut self) -> *mut T {
290        self.ptr
291    }
292}
293
294impl<T: DeviceRepr> PushKernelArg for &MappedBuffer<'_, T> {
295    fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
296        params.device_slice(self);
297    }
298}
299
300impl<T: DeviceRepr> PushKernelArg for &mut MappedBuffer<'_, T> {
301    fn push_to<'a>(self, params: &mut KernelParameters<'a>) {
302        params.device_slice_mut(self);
303    }
304}
305
306fn handle_desc(
307    handle_type: ExternalMemoryHandleType,
308    size: usize,
309    flags: ExternalMemoryFlags,
310) -> Result<driver::CUDA_EXTERNAL_MEMORY_HANDLE_DESC> {
311    if size == 0 {
312        return Err(Error::InvalidMemoryAllocationRequest);
313    }
314
315    Ok(driver::CUDA_EXTERNAL_MEMORY_HANDLE_DESC {
316        type_: handle_type.into(),
317        handle: Default::default(),
318        size: size as _,
319        flags: flags.bits(),
320        reserved: [0; 16],
321    })
322}
323
324fn checked_bytes<T>(length: usize) -> Result<usize> {
325    length
326        .checked_mul(size_of::<T>())
327        .ok_or(Error::InvalidMemoryAllocationRequest)
328}