Skip to main content

singe_cuda/
view.rs

1//! Borrowed typed views over CUDA-accessible memory.
2
3use std::{
4    marker::PhantomData,
5    mem::size_of,
6    ops::{Bound, RangeBounds},
7    ptr::NonNull,
8};
9
10use crate::{
11    error::{Error, Result},
12    memory::DeviceMemory,
13    types::{Complex32, Complex64, bf16, f4e2m1, f6e2m3, f6e3m2, f8e4m3, f8e5m2, f8ue8m0, f16},
14};
15
16/// A Rust type that can be represented as plain CUDA device memory.
17///
18/// # Safety
19///
20/// Implementors must have a stable bit representation for device memory. Values
21/// may be copied byte-for-byte between host and device memory without running
22/// Rust destructors or relying on host-only pointer validity.
23pub unsafe trait DeviceRepr: Copy + 'static {}
24
25/// A [`DeviceRepr`] whose all-zero byte pattern is a valid value.
26///
27/// # Safety
28///
29/// Implementors must accept an all-zero byte pattern as a valid instance.
30pub unsafe trait ZeroableDeviceRepr: DeviceRepr {}
31
32macro_rules! impl_device_repr {
33    ($($ty:ty),* $(,)?) => {
34        $(
35            unsafe impl DeviceRepr for $ty {}
36            unsafe impl ZeroableDeviceRepr for $ty {}
37        )*
38    };
39}
40
41impl_device_repr!(
42    bool, u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64, f16, bf16,
43    Complex32, Complex64, f8e4m3, f8e5m2, f8ue8m0, f6e2m3, f6e3m2, f4e2m1,
44);
45
46/// A typed contiguous range of CUDA-accessible device memory.
47pub trait DeviceSlice<T: DeviceRepr> {
48    fn as_device_ptr(&self) -> *const T;
49
50    fn len(&self) -> usize;
51
52    fn is_empty(&self) -> bool {
53        self.len() == 0
54    }
55
56    fn byte_len(&self) -> Result<usize> {
57        self.len()
58            .checked_mul(size_of::<T>())
59            .ok_or(Error::InvalidMemoryAllocationRequest)
60    }
61}
62
63/// A mutable typed contiguous range of CUDA-accessible device memory.
64pub trait DeviceSliceMut<T: DeviceRepr>: DeviceSlice<T> {
65    fn as_device_mut_ptr(&mut self) -> *mut T;
66}
67
68/// Shared abstraction for read-only typed device buffers.
69pub trait DeviceBuffer<T: DeviceRepr>: DeviceSlice<T> {}
70
71impl<T, B> DeviceBuffer<T> for B
72where
73    T: DeviceRepr,
74    B: DeviceSlice<T> + ?Sized,
75{
76}
77
78/// Shared abstraction for mutable typed device buffers.
79pub trait DeviceBufferMut<T: DeviceRepr>: DeviceBuffer<T> + DeviceSliceMut<T> {}
80
81impl<T, B> DeviceBufferMut<T> for B
82where
83    T: DeviceRepr,
84    B: DeviceBuffer<T> + DeviceSliceMut<T> + ?Sized,
85{
86}
87
88/// A typed contiguous host-memory range that can be copied to CUDA memory.
89pub trait HostSlice<T: DeviceRepr> {
90    fn as_host_ptr(&self) -> *const T;
91
92    fn len(&self) -> usize;
93
94    fn is_empty(&self) -> bool {
95        self.len() == 0
96    }
97}
98
99/// A mutable typed contiguous host-memory range that can be copied from CUDA memory.
100pub trait HostSliceMut<T: DeviceRepr>: HostSlice<T> {
101    fn as_host_mut_ptr(&mut self) -> *mut T;
102}
103
104/// Shared abstraction for read-only host buffers.
105pub trait HostBuffer<T: DeviceRepr>: HostSlice<T> {}
106
107impl<T, B> HostBuffer<T> for B
108where
109    T: DeviceRepr,
110    B: HostSlice<T> + ?Sized,
111{
112}
113
114/// Shared abstraction for mutable host buffers.
115pub trait HostBufferMut<T: DeviceRepr>: HostBuffer<T> + HostSliceMut<T> {}
116
117impl<T, B> HostBufferMut<T> for B
118where
119    T: DeviceRepr,
120    B: HostBuffer<T> + HostSliceMut<T> + ?Sized,
121{
122}
123
124/// Shared abstraction for read-only byte buffers.
125pub trait ByteBuffer {
126    fn as_byte_ptr(&self) -> *const u8;
127
128    fn byte_len(&self) -> usize;
129
130    fn is_empty(&self) -> bool {
131        self.byte_len() == 0
132    }
133}
134
135/// Shared abstraction for mutable byte buffers.
136pub trait ByteBufferMut: ByteBuffer {
137    fn as_byte_mut_ptr(&mut self) -> *mut u8;
138}
139
140impl<B> ByteBuffer for B
141where
142    B: DeviceSlice<u8> + ?Sized,
143{
144    fn as_byte_ptr(&self) -> *const u8 {
145        self.as_device_ptr()
146    }
147
148    fn byte_len(&self) -> usize {
149        self.len()
150    }
151}
152
153impl<B> ByteBufferMut for B
154where
155    B: DeviceSliceMut<u8> + ?Sized,
156{
157    fn as_byte_mut_ptr(&mut self) -> *mut u8 {
158        self.as_device_mut_ptr()
159    }
160}
161
162#[derive(Debug, Clone, Copy)]
163/// Non-owning immutable view over CUDA-accessible device memory.
164///
165/// This type is `Copy` because it models a shared immutable borrow: duplicating
166/// the view duplicates only the pointer/length pair and does not create or free
167/// device memory. The lifetime ties the view to the allocation or owner that
168/// created it, but CUDA kernels may still observe mutations performed through
169/// other aliases according to CUDA stream ordering.
170pub struct DeviceView<'a, T: DeviceRepr> {
171    ptr: *const T,
172    length: usize,
173    _t: PhantomData<&'a T>,
174}
175
176#[derive(Debug)]
177/// Non-owning mutable view over CUDA-accessible device memory.
178///
179/// This type is intentionally not `Clone` or `Copy` because it models a unique
180/// mutable borrow of a device-memory range for the lifetime `'a`.
181pub struct DeviceViewMut<'a, T: DeviceRepr> {
182    ptr: *mut T,
183    length: usize,
184    _t: PhantomData<&'a mut T>,
185}
186
187impl<'a, T: DeviceRepr> DeviceView<'a, T> {
188    /// Creates a borrowed immutable device view from a raw pointer and length.
189    ///
190    /// # Safety
191    ///
192    /// `ptr` must be valid for `length` contiguous elements of `T` for the
193    /// returned lifetime. If `length` is zero, `ptr` may be null; the stored
194    /// view pointer is normalized to `NonNull::dangling()` because safe
195    /// borrowed views should not expose null unless a vendor API explicitly
196    /// requires it. The memory must remain alive and CUDA-accessible while the
197    /// view is used.
198    pub const unsafe fn from_raw_parts(ptr: *const T, length: usize) -> Self {
199        let ptr = if length == 0 {
200            NonNull::<T>::dangling().as_ptr() as *const T
201        } else {
202            ptr
203        };
204        Self {
205            ptr,
206            length,
207            _t: PhantomData,
208        }
209    }
210
211    pub fn from_memory(memory: &'a DeviceMemory<T>) -> Self {
212        Self {
213            ptr: memory.as_ptr(),
214            length: memory.len(),
215            _t: PhantomData,
216        }
217    }
218
219    pub const fn as_ptr(&self) -> *const T {
220        self.ptr
221    }
222
223    pub const fn len(&self) -> usize {
224        self.length
225    }
226
227    pub const fn is_empty(&self) -> bool {
228        self.length == 0
229    }
230
231    pub fn slice<R: RangeBounds<usize>>(self, range: R) -> Result<Self> {
232        let (start, end) = bounds_to_range(range, self.length)?;
233        // Empty device allocations are represented by null pointers. Use
234        // wrapping_add so slicing an empty view with a zero offset stays defined.
235        let ptr = self.ptr.wrapping_add(start);
236        Ok(Self {
237            ptr,
238            length: end - start,
239            _t: PhantomData,
240        })
241    }
242}
243
244impl<'a, T: DeviceRepr> DeviceViewMut<'a, T> {
245    /// Creates a borrowed mutable device view from a raw pointer and length.
246    ///
247    /// # Safety
248    ///
249    /// `ptr` must be valid for `length` contiguous mutable elements of `T` for
250    /// the returned lifetime. If `length` is zero, `ptr` may be null; the
251    /// stored view pointer is normalized to `NonNull::dangling()` because safe
252    /// borrowed views should not expose null unless a vendor API explicitly
253    /// requires it. The caller must guarantee unique access to the memory
254    /// represented by the view.
255    pub const unsafe fn from_raw_parts(ptr: *mut T, length: usize) -> Self {
256        let ptr = if length == 0 {
257            NonNull::<T>::dangling().as_ptr()
258        } else {
259            ptr
260        };
261        Self {
262            ptr,
263            length,
264            _t: PhantomData,
265        }
266    }
267
268    pub fn from_memory(memory: &'a mut DeviceMemory<T>) -> Self {
269        Self {
270            ptr: memory.as_mut_ptr(),
271            length: memory.len(),
272            _t: PhantomData,
273        }
274    }
275
276    pub const fn as_ptr(&self) -> *const T {
277        self.ptr
278    }
279
280    pub const fn as_mut_ptr(&mut self) -> *mut T {
281        self.ptr
282    }
283
284    pub const fn len(&self) -> usize {
285        self.length
286    }
287
288    pub const fn is_empty(&self) -> bool {
289        self.length == 0
290    }
291
292    pub fn as_view(&self) -> DeviceView<'_, T> {
293        DeviceView {
294            ptr: self.ptr,
295            length: self.length,
296            _t: PhantomData,
297        }
298    }
299
300    pub fn slice<R: RangeBounds<usize>>(&self, range: R) -> Result<DeviceView<'_, T>> {
301        self.as_view().slice(range)
302    }
303
304    pub fn slice_mut<R: RangeBounds<usize>>(&mut self, range: R) -> Result<DeviceViewMut<'_, T>> {
305        let (start, end) = bounds_to_range(range, self.length)?;
306        // Empty device allocations are represented by null pointers. Use
307        // wrapping_add so slicing an empty view with a zero offset stays defined.
308        let ptr = self.ptr.wrapping_add(start);
309        Ok(DeviceViewMut {
310            ptr,
311            length: end - start,
312            _t: PhantomData,
313        })
314    }
315
316    pub fn split_at_mut(
317        &mut self,
318        mid: usize,
319    ) -> Result<(DeviceViewMut<'_, T>, DeviceViewMut<'_, T>)> {
320        if mid > self.length {
321            return Err(Error::InvalidMemoryAccess);
322        }
323
324        // See slice_mut: split_at_mut(0) must be valid for empty null views.
325        let right = self.ptr.wrapping_add(mid);
326        Ok((
327            DeviceViewMut {
328                ptr: self.ptr,
329                length: mid,
330                _t: PhantomData,
331            },
332            DeviceViewMut {
333                ptr: right,
334                length: self.length - mid,
335                _t: PhantomData,
336            },
337        ))
338    }
339}
340
341impl<T: DeviceRepr> DeviceMemory<T> {
342    pub fn view(&self) -> DeviceView<'_, T> {
343        DeviceView::from_memory(self)
344    }
345
346    pub fn view_mut(&mut self) -> DeviceViewMut<'_, T> {
347        DeviceViewMut::from_memory(self)
348    }
349}
350
351impl<T: DeviceRepr> DeviceSlice<T> for DeviceMemory<T> {
352    fn as_device_ptr(&self) -> *const T {
353        self.as_ptr()
354    }
355
356    fn len(&self) -> usize {
357        self.len()
358    }
359}
360
361impl<T: DeviceRepr> DeviceSliceMut<T> for DeviceMemory<T> {
362    fn as_device_mut_ptr(&mut self) -> *mut T {
363        self.as_mut_ptr()
364    }
365}
366
367impl<T: DeviceRepr> DeviceSlice<T> for DeviceView<'_, T> {
368    fn as_device_ptr(&self) -> *const T {
369        self.ptr
370    }
371
372    fn len(&self) -> usize {
373        self.length
374    }
375}
376
377impl<T: DeviceRepr> DeviceSlice<T> for DeviceViewMut<'_, T> {
378    fn as_device_ptr(&self) -> *const T {
379        self.ptr
380    }
381
382    fn len(&self) -> usize {
383        self.length
384    }
385}
386
387impl<T: DeviceRepr> DeviceSliceMut<T> for DeviceViewMut<'_, T> {
388    fn as_device_mut_ptr(&mut self) -> *mut T {
389        self.ptr
390    }
391}
392
393impl<T: DeviceRepr> HostSlice<T> for [T] {
394    fn as_host_ptr(&self) -> *const T {
395        self.as_ptr()
396    }
397
398    fn len(&self) -> usize {
399        self.len()
400    }
401}
402
403impl<T: DeviceRepr> HostSliceMut<T> for [T] {
404    fn as_host_mut_ptr(&mut self) -> *mut T {
405        self.as_mut_ptr()
406    }
407}
408
409impl<T: DeviceRepr, const N: usize> HostSlice<T> for [T; N] {
410    fn as_host_ptr(&self) -> *const T {
411        self.as_ptr()
412    }
413
414    fn len(&self) -> usize {
415        N
416    }
417}
418
419impl<T: DeviceRepr, const N: usize> HostSliceMut<T> for [T; N] {
420    fn as_host_mut_ptr(&mut self) -> *mut T {
421        self.as_mut_ptr()
422    }
423}
424
425impl<T: DeviceRepr> HostSlice<T> for Vec<T> {
426    fn as_host_ptr(&self) -> *const T {
427        self.as_ptr()
428    }
429
430    fn len(&self) -> usize {
431        self.len()
432    }
433}
434
435impl<T: DeviceRepr> HostSliceMut<T> for Vec<T> {
436    fn as_host_mut_ptr(&mut self) -> *mut T {
437        self.as_mut_ptr()
438    }
439}
440
441fn bounds_to_range<R: RangeBounds<usize>>(range: R, length: usize) -> Result<(usize, usize)> {
442    let start = match range.start_bound() {
443        Bound::Included(&value) => value,
444        Bound::Excluded(&value) => value.checked_add(1).ok_or(Error::InvalidMemoryAccess)?,
445        Bound::Unbounded => 0,
446    };
447    let end = match range.end_bound() {
448        Bound::Included(&value) => value.checked_add(1).ok_or(Error::InvalidMemoryAccess)?,
449        Bound::Excluded(&value) => value,
450        Bound::Unbounded => length,
451    };
452
453    if start > end || end > length {
454        return Err(Error::InvalidMemoryAccess);
455    }
456
457    Ok((start, end))
458}